Compare commits

...

9 Commits

Author SHA1 Message Date
meili-bors[bot]
4e1ac9b0b4
Merge #5051
Some checks failed
Test suite / Tests on ${{ matrix.os }} (macos-12) (push) Waiting to run
Test suite / Tests almost all features (push) Has been skipped
Test suite / Test disabled tokenization (push) Has been skipped
Test suite / Tests on ubuntu-20.04 (push) Failing after 28s
Test suite / Run tests in debug (push) Failing after 31s
Test suite / Run Rustfmt (push) Successful in 2m6s
Test suite / Tests on ${{ matrix.os }} (windows-2022) (push) Failing after 8m36s
Test suite / Run Clippy (push) Failing after 21m21s
5051: Add timeout on read and write operations. r=irevoire a=dureuill

# Pull Request

## Related issue
Addresses #5054 

## What does this PR do?
- Add a timeout for read and write operations in the REST embedder. This might address some issues about tasks that get "stuck" while embedding documents.


Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2024-11-14 08:34:43 +00:00
meili-bors[bot]
8a18e37a3d
Merge #5055
5055: Update version for the next release (v1.11.2) in Cargo.toml r=dureuill a=meili-bot

⚠️ This PR is automatically generated. Check the new version is the expected one and Cargo.lock has been updated before merging.

Co-authored-by: dureuill <dureuill@users.noreply.github.com>
2024-11-14 07:53:51 +00:00
dureuill
36375ea326 Update version for the next release (v1.11.2) in Cargo.toml 2024-11-13 16:24:23 +00:00
Louis Dureuil
bca2974266
Add timeout on read and write operations. 2024-11-13 17:01:23 +01:00
meili-bors[bot]
13025594a8
Merge #5041
5041: Update version for the next release (v1.11.1) in Cargo.toml r=dureuill a=meili-bot

⚠️ This PR is automatically generated. Check the new version is the expected one and Cargo.lock has been updated before merging.

Co-authored-by: dureuill <dureuill@users.noreply.github.com>
2024-11-06 11:35:26 +00:00
meili-bors[bot]
2c1c33166d
Merge #5039
5039: Add 3s timeout to embedding requests made during search r=irevoire a=dureuill

# Pull Request

## Related issue
Fixes #5032 

## What does this PR do?
- Add a 3-second timeout to embedding requests against a remote embedder made in the context of search. The timeout triggers when there are failing requests due to rate-limiting.
- Add a test of that timeout.

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
2024-11-06 10:56:50 +00:00
dureuill
cdb6e3f45a Update version for the next release (v1.11.1) in Cargo.toml 2024-11-06 08:35:51 +00:00
Louis Dureuil
1d574bd443
Add test 2024-11-06 09:25:41 +01:00
Louis Dureuil
37a4fd7f99
Add deadline of 3 seconds to embedding requests made in the context of hybrid search 2024-11-06 09:25:24 +01:00
10 changed files with 230 additions and 58 deletions

34
Cargo.lock generated
View File

@ -472,7 +472,7 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]] [[package]]
name = "benchmarks" name = "benchmarks"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",
@ -653,7 +653,7 @@ dependencies = [
[[package]] [[package]]
name = "build-info" name = "build-info"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"time", "time",
@ -1623,7 +1623,7 @@ dependencies = [
[[package]] [[package]]
name = "dump" name = "dump"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"big_s", "big_s",
@ -1835,7 +1835,7 @@ checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
[[package]] [[package]]
name = "file-store" name = "file-store"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"tempfile", "tempfile",
"thiserror", "thiserror",
@ -1857,7 +1857,7 @@ dependencies = [
[[package]] [[package]]
name = "filter-parser" name = "filter-parser"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"insta", "insta",
"nom", "nom",
@ -1877,7 +1877,7 @@ dependencies = [
[[package]] [[package]]
name = "flatten-serde-json" name = "flatten-serde-json"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"criterion", "criterion",
"serde_json", "serde_json",
@ -2001,7 +2001,7 @@ dependencies = [
[[package]] [[package]]
name = "fuzzers" name = "fuzzers"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"arbitrary", "arbitrary",
"clap", "clap",
@ -2553,7 +2553,7 @@ checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d"
[[package]] [[package]]
name = "index-scheduler" name = "index-scheduler"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"arroy", "arroy",
@ -2747,7 +2747,7 @@ dependencies = [
[[package]] [[package]]
name = "json-depth-checker" name = "json-depth-checker"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"criterion", "criterion",
"serde_json", "serde_json",
@ -3366,7 +3366,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]] [[package]]
name = "meili-snap" name = "meili-snap"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"insta", "insta",
"md5", "md5",
@ -3375,7 +3375,7 @@ dependencies = [
[[package]] [[package]]
name = "meilisearch" name = "meilisearch"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"actix-cors", "actix-cors",
"actix-http", "actix-http",
@ -3465,7 +3465,7 @@ dependencies = [
[[package]] [[package]]
name = "meilisearch-auth" name = "meilisearch-auth"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"enum-iterator", "enum-iterator",
@ -3484,7 +3484,7 @@ dependencies = [
[[package]] [[package]]
name = "meilisearch-types" name = "meilisearch-types"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"actix-web", "actix-web",
"anyhow", "anyhow",
@ -3514,7 +3514,7 @@ dependencies = [
[[package]] [[package]]
name = "meilitool" name = "meilitool"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"clap", "clap",
@ -3545,7 +3545,7 @@ dependencies = [
[[package]] [[package]]
name = "milli" name = "milli"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"arroy", "arroy",
"big_s", "big_s",
@ -3991,7 +3991,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]] [[package]]
name = "permissive-json-pointer" name = "permissive-json-pointer"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"big_s", "big_s",
"serde_json", "serde_json",
@ -6380,7 +6380,7 @@ dependencies = [
[[package]] [[package]]
name = "xtask" name = "xtask"
version = "1.11.0" version = "1.11.2"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"build-info", "build-info",

View File

@ -22,7 +22,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "1.11.0" version = "1.11.2"
authors = [ authors = [
"Quentin de Quelen <quentin@dequelen.me>", "Quentin de Quelen <quentin@dequelen.me>",
"Clément Renault <clement@meilisearch.com>", "Clément Renault <clement@meilisearch.com>",

View File

@ -5201,9 +5201,10 @@ mod tests {
let configs = index_scheduler.embedders(configs).unwrap(); let configs = index_scheduler.embedders(configs).unwrap();
let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap(); let (hf_embedder, _, _) = configs.get(&simple_hf_name).unwrap();
let beagle_embed = hf_embedder.embed_one(S("Intel the beagle best doggo")).unwrap(); let beagle_embed =
let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo")).unwrap(); hf_embedder.embed_one(S("Intel the beagle best doggo"), None).unwrap();
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo")).unwrap(); let lab_embed = hf_embedder.embed_one(S("Max the lab best doggo"), None).unwrap();
let patou_embed = hf_embedder.embed_one(S("kefir the patou best doggo"), None).unwrap();
(fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed) (fakerest_name, simple_hf_name, beagle_embed, lab_embed, patou_embed)
}; };

View File

@ -796,8 +796,10 @@ fn prepare_search<'t>(
let span = tracing::trace_span!(target: "search::vector", "embed_one"); let span = tracing::trace_span!(target: "search::vector", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10);
embedder embedder
.embed_one(query.q.clone().unwrap()) .embed_one(query.q.clone().unwrap(), Some(deadline))
.map_err(milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
} }

View File

@ -137,13 +137,14 @@ fn long_text() -> &'static str {
} }
async fn create_mock_tokenized() -> (MockServer, Value) { async fn create_mock_tokenized() -> (MockServer, Value) {
create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false).await create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false, false).await
} }
async fn create_mock_with_template( async fn create_mock_with_template(
document_template: &str, document_template: &str,
model_dimensions: ModelDimensions, model_dimensions: ModelDimensions,
fallible: bool, fallible: bool,
slow: bool,
) -> (MockServer, Value) { ) -> (MockServer, Value) {
let mock_server = MockServer::start().await; let mock_server = MockServer::start().await;
const API_KEY: &str = "my-api-key"; const API_KEY: &str = "my-api-key";
@ -154,7 +155,11 @@ async fn create_mock_with_template(
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/")) .and(path("/"))
.respond_with(move |req: &Request| { .respond_with(move |req: &Request| {
// 0. maybe return 500 // 0. wait for a long time
if slow {
std::thread::sleep(std::time::Duration::from_secs(1));
}
// 1. maybe return 500
if fallible { if fallible {
let attempt = attempt.fetch_add(1, Ordering::Relaxed); let attempt = attempt.fetch_add(1, Ordering::Relaxed);
let failed = matches!(attempt % 4, 0 | 1 | 3); let failed = matches!(attempt % 4, 0 | 1 | 3);
@ -167,7 +172,7 @@ async fn create_mock_with_template(
})) }))
} }
} }
// 1. check API key // 3. check API key
match req.headers.get("Authorization") { match req.headers.get("Authorization") {
Some(api_key) if api_key == API_KEY_BEARER => { Some(api_key) if api_key == API_KEY_BEARER => {
{} {}
@ -202,7 +207,7 @@ async fn create_mock_with_template(
) )
} }
} }
// 2. parse text inputs // 3. parse text inputs
let query: serde_json::Value = match req.body_json() { let query: serde_json::Value = match req.body_json() {
Ok(query) => query, Ok(query) => query,
Err(_error) => return ResponseTemplate::new(400).set_body_json( Err(_error) => return ResponseTemplate::new(400).set_body_json(
@ -223,7 +228,7 @@ async fn create_mock_with_template(
panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}") panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}")
} }
// 3. for each text, find embedding in responses // 4. for each text, find embedding in responses
let serde_json::Value::Array(inputs) = &query["input"] else { let serde_json::Value::Array(inputs) = &query["input"] else {
panic!("Unexpected `input` value") panic!("Unexpected `input` value")
}; };
@ -283,7 +288,7 @@ async fn create_mock_with_template(
"embedding": embedding, "embedding": embedding,
})).collect(); })).collect();
// 4. produce output from embeddings // 5. produce output from embeddings
ResponseTemplate::new(200).set_body_json(json!({ ResponseTemplate::new(200).set_body_json(json!({
"object": "list", "object": "list",
"data": data, "data": data,
@ -317,23 +322,27 @@ const DOGGO_TEMPLATE: &str = r#"{%- if doc.gender == "F" -%}Une chienne nommée
{%- endif %}, de race {{doc.breed}}."#; {%- endif %}, de race {{doc.breed}}."#;
async fn create_mock() -> (MockServer, Value) { async fn create_mock() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, false, false).await
} }
async fn create_mock_dimensions() -> (MockServer, Value) { async fn create_mock_dimensions() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large512, false, false).await
} }
async fn create_mock_small_embedding_model() -> (MockServer, Value) { async fn create_mock_small_embedding_model() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Small, false, false).await
} }
async fn create_mock_legacy_embedding_model() -> (MockServer, Value) { async fn create_mock_legacy_embedding_model() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Ada, false, false).await
} }
async fn create_fallible_mock() -> (MockServer, Value) { async fn create_fallible_mock() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true).await create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, false).await
}
async fn create_slow_mock() -> (MockServer, Value) {
create_mock_with_template(DOGGO_TEMPLATE, ModelDimensions::Large, true, true).await
} }
// basic test "it works" // basic test "it works"
@ -1873,4 +1882,114 @@ async fn it_still_works() {
] ]
"###); "###);
} }
// test with a server that responds 500 on 3 out of 4 calls
#[actix_rt::test]
async fn timeout() {
let (_mock, setting) = create_slow_mock().await;
let server = get_server_vector().await;
let index = server.index("doggo");
let (response, code) = index
.update_settings(json!({
"embedders": {
"default": setting,
},
}))
.await;
snapshot!(code, @"202 Accepted");
let task = server.wait_task(response.uid()).await;
snapshot!(task["status"], @r###""succeeded""###);
let documents = json!([
{"id": 0, "name": "kefir", "gender": "M", "birthyear": 2023, "breed": "Patou"},
]);
let (value, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
let task = index.wait_task(value.uid()).await;
snapshot!(task, @r###"
{
"uid": "[uid]",
"indexUid": "doggo",
"status": "succeeded",
"type": "documentAdditionOrUpdate",
"canceledBy": null,
"details": {
"receivedDocuments": 1,
"indexedDocuments": 1
},
"error": null,
"duration": "[duration]",
"enqueuedAt": "[date]",
"startedAt": "[date]",
"finishedAt": "[date]"
}
"###);
let (documents, _code) = index
.get_all_documents(GetAllDocumentsOptions { retrieve_vectors: true, ..Default::default() })
.await;
snapshot!(json_string!(documents, {".results.*._vectors.default.embeddings" => "[vector]"}), @r###"
{
"results": [
{
"id": 0,
"name": "kefir",
"gender": "M",
"birthyear": 2023,
"breed": "Patou",
"_vectors": {
"default": {
"embeddings": "[vector]",
"regenerate": true
}
}
}
],
"offset": 0,
"limit": 20,
"total": 1
}
"###);
let (response, code) = index
.search_post(json!({
"q": "grand chien de berger des montagnes",
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["semanticHitCount"]), @"0");
snapshot!(json_string!(response["hits"]), @"[]");
let (response, code) = index
.search_post(json!({
"q": "grand chien de berger des montagnes",
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["semanticHitCount"]), @"1");
snapshot!(json_string!(response["hits"]), @r###"
[
{
"id": 0,
"name": "kefir",
"gender": "M",
"birthyear": 2023,
"breed": "Patou"
}
]
"###);
let (response, code) = index
.search_post(json!({
"q": "grand chien de berger des montagnes",
"hybrid": {"semanticRatio": 0.99, "embedder": "default"}
}))
.await;
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["semanticHitCount"]), @"0");
snapshot!(json_string!(response["hits"]), @"[]");
}
// test with a server that wrongly responds 400 // test with a server that wrongly responds 400

View File

@ -201,7 +201,9 @@ impl<'a> Search<'a> {
let span = tracing::trace_span!(target: "search::hybrid", "embed_one"); let span = tracing::trace_span!(target: "search::hybrid", "embed_one");
let _entered = span.enter(); let _entered = span.enter();
match embedder.embed_one(query) { let deadline = std::time::Instant::now() + std::time::Duration::from_secs(3);
match embedder.embed_one(query, Some(deadline)) {
Ok(embedding) => embedding, Ok(embedding) => embedding,
Err(error) => { Err(error) => {
tracing::error!(error=%error, "Embedding failed"); tracing::error!(error=%error, "Embedding failed");

View File

@ -1,5 +1,6 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
use arroy::distances::{BinaryQuantizedCosine, Cosine}; use arroy::distances::{BinaryQuantizedCosine, Cosine};
use arroy::ItemId; use arroy::ItemId;
@ -594,18 +595,23 @@ impl Embedder {
pub fn embed( pub fn embed(
&self, &self,
texts: Vec<String>, texts: Vec<String>,
deadline: Option<Instant>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts), Embedder::OpenAi(embedder) => embedder.embed(texts, deadline),
Embedder::Ollama(embedder) => embedder.embed(texts), Embedder::Ollama(embedder) => embedder.embed(texts, deadline),
Embedder::UserProvided(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts),
Embedder::Rest(embedder) => embedder.embed(texts), Embedder::Rest(embedder) => embedder.embed(texts, deadline),
} }
} }
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> { pub fn embed_one(
let mut embeddings = self.embed(vec![text])?; &self,
text: String,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
let mut embeddings = self.embed(vec![text], deadline)?;
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?; let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
Ok(if embeddings.iter().nth(1).is_some() { Ok(if embeddings.iter().nth(1).is_some() {
tracing::warn!("Ignoring embeddings past the first one in long search query"); tracing::warn!("Ignoring embeddings past the first one in long search query");

View File

@ -1,3 +1,5 @@
use std::time::Instant;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
@ -75,8 +77,12 @@ impl Embedder {
Ok(Self { rest_embedder }) Ok(Self { rest_embedder })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(
match self.rest_embedder.embed(texts) { &self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed(texts, deadline) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
Err(EmbedError::ollama_model_not_found(error)) Err(EmbedError::ollama_model_not_found(error))
@ -92,7 +98,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),

View File

@ -1,3 +1,5 @@
use std::time::Instant;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
@ -206,32 +208,40 @@ impl Embedder {
Ok(Self { options, rest_embedder, tokenizer }) Ok(Self { options, rest_embedder, tokenizer })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(
match self.rest_embedder.embed_ref(&texts) { &self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
match self.rest_embedder.embed_ref(&texts, deadline) {
Ok(embeddings) => Ok(embeddings), Ok(embeddings) => Ok(embeddings),
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => { Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error, _), fault: _ }) => {
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
self.try_embed_tokenized(&texts) self.try_embed_tokenized(&texts, deadline)
} }
Err(error) => Err(error), Err(error) => Err(error),
} }
} }
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> { fn try_embed_tokenized(
&self,
text: &[String],
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let mut all_embeddings = Vec::with_capacity(text.len()); let mut all_embeddings = Vec::with_capacity(text.len());
for text in text { for text in text {
let max_token_count = self.options.embedding_model.max_token(); let max_token_count = self.options.embedding_model.max_token();
let encoded = self.tokenizer.encode_ordinary(text.as_str()); let encoded = self.tokenizer.encode_ordinary(text.as_str());
let len = encoded.len(); let len = encoded.len();
if len < max_token_count { if len < max_token_count {
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text], deadline)?);
continue; continue;
} }
let tokens = &encoded.as_slice()[0..max_token_count]; let tokens = &encoded.as_slice()[0..max_token_count];
let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
let embedding = self.rest_embedder.embed_tokens(tokens)?; let embedding = self.rest_embedder.embed_tokens(tokens, deadline)?;
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
EmbedError::rest_unexpected_dimension(self.dimensions(), got.len()) EmbedError::rest_unexpected_dimension(self.dimensions(), got.len())
})?; })?;
@ -248,7 +258,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),

View File

@ -1,4 +1,5 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::time::Instant;
use deserr::Deserr; use deserr::Deserr;
use rand::Rng; use rand::Rng;
@ -130,6 +131,7 @@ impl Embedder {
let client = ureq::AgentBuilder::new() let client = ureq::AgentBuilder::new()
.max_idle_connections(REQUEST_PARALLELISM * 2) .max_idle_connections(REQUEST_PARALLELISM * 2)
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2) .max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.timeout(std::time::Duration::from_secs(30))
.build(); .build();
let request = Request::new(options.request)?; let request = Request::new(options.request)?;
@ -154,19 +156,31 @@ impl Embedder {
Ok(Self { data, dimensions, distribution: options.distribution }) Ok(Self { data, dimensions, distribution: options.distribution })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) &self,
texts: Vec<String>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
} }
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError> pub fn embed_ref<S>(
&self,
texts: &[S],
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError>
where where
S: AsRef<str> + Serialize, S: AsRef<str> + Serialize,
{ {
embed(&self.data, texts, texts.len(), Some(self.dimensions)) embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
} }
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> { pub fn embed_tokens(
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; &self,
tokens: &[usize],
deadline: Option<Instant>,
) -> Result<Embeddings<f32>, EmbedError> {
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap()) Ok(embeddings.pop().unwrap())
} }
@ -178,7 +192,7 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads threads
.install(move || { .install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
}) })
.map_err(|error| EmbedError { .map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error), kind: EmbedErrorKind::PanicInThreadPool(error),
@ -207,7 +221,7 @@ impl Embedder {
} }
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> { fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
let v = embed(data, ["test"].as_slice(), 1, None) let v = embed(data, ["test"].as_slice(), 1, None, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().dimension()) Ok(v.first().unwrap().dimension())
@ -218,6 +232,7 @@ fn embed<S>(
inputs: &[S], inputs: &[S],
expected_count: usize, expected_count: usize,
expected_dimension: Option<usize>, expected_dimension: Option<usize>,
deadline: Option<Instant>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> ) -> Result<Vec<Embeddings<f32>>, EmbedError>
where where
S: Serialize, S: Serialize,
@ -245,7 +260,18 @@ where
} }
Err(retry) => { Err(retry) => {
tracing::warn!("Failed: {}", retry.error); tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt) if let Some(deadline) = deadline {
let now = std::time::Instant::now();
if now > deadline {
tracing::warn!("Could not embed due to deadline");
return Err(retry.into_error());
}
let duration_to_deadline = deadline - now;
retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
} else {
retry.into_duration(attempt)
}
} }
}?; }?;