diff --git a/Cargo.lock b/Cargo.lock index afaacb43e..3f9171edc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -494,7 +494,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" [[package]] name = "benchmarks" -version = "1.6.0" +version = "1.6.1" dependencies = [ "anyhow", "bytes", @@ -1476,7 +1476,7 @@ dependencies = [ [[package]] name = "dump" -version = "1.6.0" +version = "1.6.1" dependencies = [ "anyhow", "big_s", @@ -1720,7 +1720,7 @@ dependencies = [ [[package]] name = "file-store" -version = "1.6.0" +version = "1.6.1" dependencies = [ "faux", "tempfile", @@ -1742,7 +1742,7 @@ dependencies = [ [[package]] name = "filter-parser" -version = "1.6.0" +version = "1.6.1" dependencies = [ "insta", "nom", @@ -1773,7 +1773,7 @@ dependencies = [ [[package]] name = "flatten-serde-json" -version = "1.6.0" +version = "1.6.1" dependencies = [ "criterion", "serde_json", @@ -1891,7 +1891,7 @@ dependencies = [ [[package]] name = "fuzzers" -version = "1.6.0" +version = "1.6.1" dependencies = [ "arbitrary", "clap", @@ -2856,7 +2856,7 @@ checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d" [[package]] name = "index-scheduler" -version = "1.6.0" +version = "1.6.1" dependencies = [ "anyhow", "big_s", @@ -3043,7 +3043,7 @@ dependencies = [ [[package]] name = "json-depth-checker" -version = "1.6.0" +version = "1.6.1" dependencies = [ "criterion", "serde_json", @@ -3555,7 +3555,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" [[package]] name = "meili-snap" -version = "1.6.0" +version = "1.6.1" dependencies = [ "insta", "md5", @@ -3564,7 +3564,7 @@ dependencies = [ [[package]] name = "meilisearch" -version = "1.6.0" +version = "1.6.1" dependencies = [ "actix-cors", "actix-http", @@ -3655,7 +3655,7 @@ dependencies = [ [[package]] name = "meilisearch-auth" -version = "1.6.0" +version = "1.6.1" dependencies = [ "base64 0.21.7", "enum-iterator", @@ -3674,7 +3674,7 @@ dependencies = [ [[package]] name = "meilisearch-types" -version = "1.6.0" +version = "1.6.1" dependencies = [ "actix-web", "anyhow", @@ -3704,7 +3704,7 @@ dependencies = [ [[package]] name = "meilitool" -version = "1.6.0" +version = "1.6.1" dependencies = [ "anyhow", "clap", @@ -3743,7 +3743,7 @@ dependencies = [ [[package]] name = "milli" -version = "1.6.0" +version = "1.6.1" dependencies = [ "arroy", "big_s", @@ -4141,7 +4141,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "permissive-json-pointer" -version = "1.6.0" +version = "1.6.1" dependencies = [ "big_s", "serde_json", @@ -6232,7 +6232,7 @@ dependencies = [ [[package]] name = "xtask" -version = "1.6.0" +version = "1.6.1" dependencies = [ "cargo_metadata", "clap", diff --git a/Cargo.toml b/Cargo.toml index bb8d7d787..a0c6c3ac9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ members = [ ] [workspace.package] -version = "1.6.0" +version = "1.6.1" authors = ["Quentin de Quelen ", "Clément Renault "] description = "Meilisearch HTTP server" homepage = "https://meilisearch.com" diff --git a/meilisearch/Cargo.toml b/meilisearch/Cargo.toml index 1f85783f6..1d7f53229 100644 --- a/meilisearch/Cargo.toml +++ b/meilisearch/Cargo.toml @@ -154,5 +154,5 @@ greek = ["meilisearch-types/greek"] khmer = ["meilisearch-types/khmer"] [package.metadata.mini-dashboard] -assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.12/build.zip" -sha1 = "acfe9a018c93eb0604ea87ee87bff7df5474e18e" +assets-url = "https://github.com/meilisearch/mini-dashboard/releases/download/v0.2.13/build.zip" +sha1 = "e20cc9b390003c6c844f4b8bcc5c5013191a77ff" diff --git a/meilisearch/tests/common/mod.rs b/meilisearch/tests/common/mod.rs index d7888b7db..2b9e5e1d7 100644 --- a/meilisearch/tests/common/mod.rs +++ b/meilisearch/tests/common/mod.rs @@ -64,7 +64,7 @@ impl Display for Value { write!( f, "{}", - json_string!(self, { ".enqueuedAt" => "[date]", ".processedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" }) + json_string!(self, { ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]", ".duration" => "[duration]" }) ) } } diff --git a/meilisearch/tests/documents/add_documents.rs b/meilisearch/tests/documents/add_documents.rs index b2904691f..9733f7741 100644 --- a/meilisearch/tests/documents/add_documents.rs +++ b/meilisearch/tests/documents/add_documents.rs @@ -1760,6 +1760,181 @@ async fn add_documents_invalid_geo_field() { "finishedAt": "[date]" } "###); + + // The three next tests are related to #4333 + + // _geo has a lat and lng but set to `null` + let documents = json!([ + { + "id": "12", + "_geo": { "lng": null, "lat": 67} + } + ]); + + let (response, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + let response = index.wait_task(response.uid()).await; + snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }), + @r###" + { + "uid": 14, + "indexUid": "test", + "status": "failed", + "type": "documentAdditionOrUpdate", + "canceledBy": null, + "details": { + "receivedDocuments": 1, + "indexedDocuments": 0 + }, + "error": { + "message": "Could not parse longitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.", + "code": "invalid_document_geo_field", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" + }, + "duration": "[duration]", + "enqueuedAt": "[date]", + "startedAt": "[date]", + "finishedAt": "[date]" + } + "###); + + // _geo has a lat and lng but set to `null` + let documents = json!([ + { + "id": "12", + "_geo": { "lng": 35, "lat": null } + } + ]); + + let (response, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + let response = index.wait_task(response.uid()).await; + snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }), + @r###" + { + "uid": 15, + "indexUid": "test", + "status": "failed", + "type": "documentAdditionOrUpdate", + "canceledBy": null, + "details": { + "receivedDocuments": 1, + "indexedDocuments": 0 + }, + "error": { + "message": "Could not parse latitude in the document with the id: `12`. Was expecting a finite number but instead got `null`.", + "code": "invalid_document_geo_field", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" + }, + "duration": "[duration]", + "enqueuedAt": "[date]", + "startedAt": "[date]", + "finishedAt": "[date]" + } + "###); + + // _geo has a lat and lng but set to `null` + let documents = json!([ + { + "id": "13", + "_geo": { "lng": null, "lat": null } + } + ]); + + let (response, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + let response = index.wait_task(response.uid()).await; + snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }), + @r###" + { + "uid": 16, + "indexUid": "test", + "status": "failed", + "type": "documentAdditionOrUpdate", + "canceledBy": null, + "details": { + "receivedDocuments": 1, + "indexedDocuments": 0 + }, + "error": { + "message": "Could not parse latitude nor longitude in the document with the id: `13`. Was expecting finite numbers but instead got `null` and `null`.", + "code": "invalid_document_geo_field", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" + }, + "duration": "[duration]", + "enqueuedAt": "[date]", + "startedAt": "[date]", + "finishedAt": "[date]" + } + "###); +} + +// Related to #4333 +#[actix_rt::test] +async fn add_invalid_geo_and_then_settings() { + let server = Server::new().await; + let index = server.index("test"); + index.create(Some("id")).await; + + // _geo is not an object + let documents = json!([ + { + "id": "11", + "_geo": { "lat": null, "lng": null }, + } + ]); + let (ret, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + let ret = index.wait_task(ret.uid()).await; + snapshot!(ret, @r###" + { + "uid": 1, + "indexUid": "test", + "status": "succeeded", + "type": "documentAdditionOrUpdate", + "canceledBy": null, + "details": { + "receivedDocuments": 1, + "indexedDocuments": 1 + }, + "error": null, + "duration": "[duration]", + "enqueuedAt": "[date]", + "startedAt": "[date]", + "finishedAt": "[date]" + } + "###); + + let (ret, code) = index.update_settings(json!({"sortableAttributes": ["_geo"]})).await; + snapshot!(code, @"202 Accepted"); + let ret = index.wait_task(ret.uid()).await; + snapshot!(ret, @r###" + { + "uid": 2, + "indexUid": "test", + "status": "failed", + "type": "settingsUpdate", + "canceledBy": null, + "details": { + "sortableAttributes": [ + "_geo" + ] + }, + "error": { + "message": "Could not parse latitude in the document with the id: `\"11\"`. Was expecting a finite number but instead got `null`.", + "code": "invalid_document_geo_field", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_document_geo_field" + }, + "duration": "[duration]", + "enqueuedAt": "[date]", + "startedAt": "[date]", + "finishedAt": "[date]" + } + "###); } #[actix_rt::test] diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index 79819cab2..6ea9920f6 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -87,6 +87,52 @@ async fn simple_search() { snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); } +#[actix_rt::test] +async fn highlighter() { + let server = Server::new().await; + let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; + + let (response, code) = index + .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], + "hybrid": {"semanticRatio": 0.2}, + "attributesToHighlight": [ + "desc" + ], + "highlightPreTag": "**BEGIN**", + "highlightPostTag": "**END**" + })) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}}}]"###); + + let (response, code) = index + .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], + "hybrid": {"semanticRatio": 0.8}, + "attributesToHighlight": [ + "desc" + ], + "highlightPreTag": "**BEGIN**", + "highlightPostTag": "**END**" + })) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the **BEGIN**Marvel**END** Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a **BEGIN**Captain**END** **BEGIN**Marvel**END** ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}},"_semanticScore":0.9472136}]"###); + + // no highlighting on full semantic + let (response, code) = index + .search_post(json!({"q": "Captain Marvel", "vector": [1.0, 1.0], + "hybrid": {"semanticRatio": 1.0}, + "attributesToHighlight": [ + "desc" + ], + "highlightPreTag": "**BEGIN**", + "highlightPostTag": "**END**" + })) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_formatted":{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":["2.0","3.0"]}},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_formatted":{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":["1.0","2.0"]}},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_formatted":{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":["1.0","3.0"]}}}]"###); +} + #[actix_rt::test] async fn invalid_semantic_ratio() { let server = Server::new().await; diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index 67365cf52..b4c79f7f5 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -102,7 +102,7 @@ impl ScoreWithRatioResult { } SearchResult { - matching_words: left.matching_words, + matching_words: right.matching_words, candidates: left.candidates | right.candidates, documents_ids, document_scores, diff --git a/milli/src/update/index_documents/extract/extract_geo_points.rs b/milli/src/update/index_documents/extract/extract_geo_points.rs index 5ee7967d2..b3600e3bc 100644 --- a/milli/src/update/index_documents/extract/extract_geo_points.rs +++ b/milli/src/update/index_documents/extract/extract_geo_points.rs @@ -34,7 +34,9 @@ pub fn extract_geo_points( // since we only need the primary key when we throw an error // we create this getter to lazily get it when needed let document_id = || -> Value { - let document_id = obkv.get(primary_key_id).unwrap(); + let reader = KvReaderDelAdd::new(obkv.get(primary_key_id).unwrap()); + let document_id = + reader.get(DelAdd::Deletion).or(reader.get(DelAdd::Addition)).unwrap(); serde_json::from_slice(document_id).unwrap() }; diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index cdf0b37f0..87181edc2 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -339,9 +339,7 @@ pub fn extract_embeddings( indexer: GrenadParameters, embedder: Arc, ) -> Result>> { - let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; - - let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism + let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk // docid, state with embedding @@ -375,11 +373,8 @@ pub fn extract_embeddings( current_chunk_ids.push(docid); if chunks.len() == chunks.capacity() { - let chunked_embeds = rt - .block_on( - embedder - .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), - ) + let chunked_embeds = embedder + .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; @@ -396,8 +391,8 @@ pub fn extract_embeddings( // send last chunk if !chunks.is_empty() { - let chunked_embeds = rt - .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) + let chunked_embeds = embedder + .embed_chunks(std::mem::take(&mut chunks)) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -410,13 +405,15 @@ pub fn extract_embeddings( } if !current_chunk.is_empty() { - let embeds = rt - .block_on(embedder.embed(std::mem::take(&mut current_chunk))) + let embeds = embedder + .embed_chunks(vec![std::mem::take(&mut current_chunk)]) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; - for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { - state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + if let Some(embeds) = embeds.first() { + for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + } } } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index c5cce622d..3673c85e3 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -67,6 +67,10 @@ pub enum EmbedErrorKind { OpenAiUnhandledStatusCode(u16), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] ManualEmbed(String), + #[error("could not initialize asynchronous runtime: {0}")] + OpenAiRuntimeInit(std::io::Error), + #[error("initializing web client for sending embedding requests failed: {0}")] + InitWebClient(reqwest::Error), } impl EmbedError { @@ -117,6 +121,14 @@ impl EmbedError { pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } } + + pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } + } + + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { + Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } + } } #[derive(Debug, thiserror::Error)] @@ -183,10 +195,6 @@ impl NewEmbedderError { } } - pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { - Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } - } - pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } } @@ -237,8 +245,6 @@ pub enum NewEmbedderErrorKind { #[error("loading model failed: {0}")] LoadModel(candle_core::Error), // openai - #[error("initializing web client for sending embedding requests failed: {0}")] - InitWebClient(reqwest::Error), #[error("The API key passed to Authorization error was in an invalid format: {0}")] InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), } diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 7acb09aa8..cdfdbfb75 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -151,7 +151,8 @@ impl Embedder { let token_ids = tokens .iter() .map(|tokens| { - let tokens = tokens.get_ids().to_vec(); + let mut tokens = tokens.get_ids().to_vec(); + tokens.truncate(512); Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) }) .collect::, EmbedError>>()?; diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 81c4cf4a1..99b7bff7e 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -163,18 +163,24 @@ impl Embedder { ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => embedder.embed(texts).await, + Embedder::OpenAi(embedder) => { + let client = embedder.new_client()?; + embedder.embed(texts, &client).await + } Embedder::UserProvided(embedder) => embedder.embed(texts), } } - pub async fn embed_chunks( + /// # Panics + /// + /// - if called from an asynchronous context + pub fn embed_chunks( &self, text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), - Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 53e8a041b..524f83b80 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -8,7 +8,7 @@ use super::{DistributionShift, Embedding, Embeddings}; #[derive(Debug)] pub struct Embedder { - client: reqwest::Client, + headers: reqwest::header::HeaderMap, tokenizer: tiktoken_rs::CoreBPE, options: EmbedderOptions, } @@ -95,6 +95,13 @@ impl EmbedderOptions { } impl Embedder { + pub fn new_client(&self) -> Result { + reqwest::ClientBuilder::new() + .default_headers(self.headers.clone()) + .build() + .map_err(EmbedError::openai_initialize_web_client) + } + pub fn new(options: EmbedderOptions) -> Result { let mut headers = reqwest::header::HeaderMap::new(); let mut inferred_api_key = Default::default(); @@ -111,25 +118,25 @@ impl Embedder { reqwest::header::CONTENT_TYPE, reqwest::header::HeaderValue::from_static("application/json"), ); - let client = reqwest::ClientBuilder::new() - .default_headers(headers) - .build() - .map_err(NewEmbedderError::openai_initialize_web_client)?; // looking at the code it is very unclear that this can actually fail. let tokenizer = tiktoken_rs::cl100k_base().unwrap(); - Ok(Self { options, client, tokenizer }) + Ok(Self { options, headers, tokenizer }) } - pub async fn embed(&self, texts: Vec) -> Result>, EmbedError> { + pub async fn embed( + &self, + texts: Vec, + client: &reqwest::Client, + ) -> Result>, EmbedError> { let mut tokenized = false; for attempt in 0..7 { let result = if tokenized { - self.try_embed_tokenized(&texts).await + self.try_embed_tokenized(&texts, client).await } else { - self.try_embed(&texts).await + self.try_embed(&texts, client).await }; let retry_duration = match result { @@ -145,9 +152,9 @@ impl Embedder { } let result = if tokenized { - self.try_embed_tokenized(&texts).await + self.try_embed_tokenized(&texts, client).await } else { - self.try_embed(&texts).await + self.try_embed(&texts, client).await }; result.map_err(Retry::into_error) @@ -225,13 +232,13 @@ impl Embedder { async fn try_embed + serde::Serialize>( &self, texts: &[S], + client: &reqwest::Client, ) -> Result>, Retry> { for text in texts { log::trace!("Received prompt: {}", text.as_ref()) } let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; - let response = self - .client + let response = client .post(OPENAI_EMBEDDINGS_URL) .json(&request) .send() @@ -256,7 +263,11 @@ impl Embedder { .collect()) } - async fn try_embed_tokenized(&self, text: &[String]) -> Result>, Retry> { + async fn try_embed_tokenized( + &self, + text: &[String], + client: &reqwest::Client, + ) -> Result>, Retry> { pub const OVERLAP_SIZE: usize = 200; let mut all_embeddings = Vec::with_capacity(text.len()); for text in text { @@ -264,7 +275,7 @@ impl Embedder { let encoded = self.tokenizer.encode_ordinary(text.as_str()); let len = encoded.len(); if len < max_token_count { - all_embeddings.append(&mut self.try_embed(&[text]).await?); + all_embeddings.append(&mut self.try_embed(&[text], client).await?); continue; } @@ -273,22 +284,26 @@ impl Embedder { Embeddings::new(self.options.embedding_model.dimensions()); while tokens.len() > max_token_count { let window = &tokens[..max_token_count]; - embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); + embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); tokens = &tokens[max_token_count - OVERLAP_SIZE..]; } // end of text - embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); + embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap(); all_embeddings.push(embeddings_for_prompt); } Ok(all_embeddings) } - async fn embed_tokens(&self, tokens: &[usize]) -> Result { + async fn embed_tokens( + &self, + tokens: &[usize], + client: &reqwest::Client, + ) -> Result { for attempt in 0..9 { - let duration = match self.try_embed_tokens(tokens).await { + let duration = match self.try_embed_tokens(tokens, client).await { Ok(embedding) => return Ok(embedding), Err(retry) => retry.into_duration(attempt), } @@ -297,14 +312,19 @@ impl Embedder { tokio::time::sleep(duration).await; } - self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) + self.try_embed_tokens(tokens, client) + .await + .map_err(|retry| Retry::give_up(retry.into_error())) } - async fn try_embed_tokens(&self, tokens: &[usize]) -> Result { + async fn try_embed_tokens( + &self, + tokens: &[usize], + client: &reqwest::Client, + ) -> Result { let request = OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; - let response = self - .client + let response = client .post(OPENAI_EMBEDDINGS_URL) .json(&request) .send() @@ -322,12 +342,19 @@ impl Embedder { Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) } - pub async fn embed_chunks( + pub fn embed_chunks( &self, text_chunks: Vec>, ) -> Result>>, EmbedError> { - futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) - .await + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .map_err(EmbedError::openai_runtime_init)?; + let client = self.new_client()?; + rt.block_on(futures::future::try_join_all( + text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), + )) } pub fn chunk_count_hint(&self) -> usize {