mirror of
https://github.com/meilisearch/meilisearch.git
synced 2024-11-22 18:17:39 +08:00
Add tokenized test
This commit is contained in:
parent
9d6efd92d2
commit
ab1ec9ca21
@ -1,6 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::io::Write;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use meili_snap::{json_string, snapshot};
|
||||
use wiremock::matchers::{method, path};
|
||||
@ -21,6 +22,12 @@ struct OpenAiResponse {
|
||||
large_512: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct OpenAiTokenizedResponses {
|
||||
tokens: Vec<u64>,
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
impl OpenAiResponses {
|
||||
fn get(&self, text: &str, model_dimensions: ModelDimensions) -> Option<&[f32]> {
|
||||
let entry = self.0.get(text)?;
|
||||
@ -81,7 +88,7 @@ impl ModelDimensions {
|
||||
}
|
||||
|
||||
fn openai_responses() -> &'static OpenAiResponses {
|
||||
static OPENAI_RESPONSES: std::sync::OnceLock<OpenAiResponses> = std::sync::OnceLock::new();
|
||||
static OPENAI_RESPONSES: OnceLock<OpenAiResponses> = OnceLock::new();
|
||||
OPENAI_RESPONSES.get_or_init(|| {
|
||||
// json file that was compressed with gzip
|
||||
// decompress with `gzip --keep -d openai_responses.json.gz`
|
||||
@ -96,6 +103,43 @@ fn openai_responses() -> &'static OpenAiResponses {
|
||||
})
|
||||
}
|
||||
|
||||
fn openai_tokenized_responses() -> &'static OpenAiTokenizedResponses {
|
||||
static OPENAI_TOKENIZED_RESPONSES: OnceLock<OpenAiTokenizedResponses> = OnceLock::new();
|
||||
OPENAI_TOKENIZED_RESPONSES.get_or_init(|| {
|
||||
// json file that was compressed with gzip
|
||||
// decompress with `gzip --keep -d openai_tokenized_responses.json.gz`
|
||||
// recompress with `gzip --keep -c openai_tokenized_responses.json > openai_tokenized_responses.json.gz`
|
||||
let compressed_responses = include_bytes!("openai_tokenized_responses.json.gz");
|
||||
let mut responses = Vec::new();
|
||||
let mut decoder = flate2::write::GzDecoder::new(&mut responses);
|
||||
|
||||
decoder.write_all(compressed_responses).unwrap();
|
||||
drop(decoder);
|
||||
serde_json::from_slice(&responses).unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn long_text() -> &'static str {
|
||||
static LONG_TEXT: OnceLock<String> = OnceLock::new();
|
||||
LONG_TEXT.get_or_init(|| {
|
||||
// decompress with `gzip --keep -d intel_gen.txt.gz`
|
||||
// recompress with `gzip --keep -c intel_gen.txt > intel_gen.txt.gz`
|
||||
let compressed_long_text = include_bytes!("intel_gen.txt.gz");
|
||||
let mut long_text = Vec::new();
|
||||
let mut decoder = flate2::write::GzDecoder::new(&mut long_text);
|
||||
|
||||
decoder.write_all(compressed_long_text).unwrap();
|
||||
drop(decoder);
|
||||
let long_text = std::str::from_utf8(&long_text).unwrap();
|
||||
|
||||
long_text.repeat(3)
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_mock_tokenized() -> (MockServer, Value) {
|
||||
create_mock_with_template("{{doc.text}}", ModelDimensions::Large, false).await
|
||||
}
|
||||
|
||||
async fn create_mock_with_template(
|
||||
document_template: &str,
|
||||
model_dimensions: ModelDimensions,
|
||||
@ -176,28 +220,19 @@ async fn create_mock_with_template(
|
||||
};
|
||||
let query_model_dimensions = ModelDimensions::from_request(&query);
|
||||
if query_model_dimensions != model_dimensions {
|
||||
return ResponseTemplate::new(400).set_body_json(json!({
|
||||
"error": {
|
||||
"message": format!("Expected {model_dimensions:?}, got {query_model_dimensions:?}"),
|
||||
"type": "invalid_model_dimensions",
|
||||
"query": query,
|
||||
}
|
||||
}))
|
||||
panic!("Expected {model_dimensions:?}, got {query_model_dimensions:?}")
|
||||
}
|
||||
|
||||
// 3. for each text, find embedding in responses
|
||||
let serde_json::Value::Array(inputs) = &query["input"] else {
|
||||
return ResponseTemplate::new(400).set_body_json(json!({
|
||||
"error": {
|
||||
"message": "Unexpected `input` value",
|
||||
"type": "test_response",
|
||||
"query": query
|
||||
}
|
||||
}))
|
||||
panic!("Unexpected `input` value")
|
||||
};
|
||||
|
||||
let openai_tokenized_responses = openai_tokenized_responses();
|
||||
let embeddings = if inputs == openai_tokenized_responses.tokens.as_slice() {
|
||||
vec![openai_tokenized_responses.embedding.clone()]
|
||||
} else {
|
||||
let mut embeddings = Vec::new();
|
||||
|
||||
for input in inputs {
|
||||
let serde_json::Value::String(input) = input else {
|
||||
return ResponseTemplate::new(400).set_body_json(json!({
|
||||
@ -209,8 +244,21 @@ async fn create_mock_with_template(
|
||||
}))
|
||||
};
|
||||
|
||||
let Some(embedding) = openai_responses().get(input, model_dimensions) else {
|
||||
if input == long_text() {
|
||||
return ResponseTemplate::new(400).set_body_json(json!(
|
||||
{
|
||||
"error": {
|
||||
"message": "This model's maximum context length is 8192 tokens, however you requested 10554 tokens (10554 in your prompt; 0 for the completion). Please reduce your prompt; or completion length.",
|
||||
"type": "invalid_request_error",
|
||||
"param": null,
|
||||
"code": null,
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
let Some(embedding) = openai_responses().get(input, model_dimensions) else {
|
||||
return ResponseTemplate::new(404).set_body_json(json!(
|
||||
{
|
||||
"error": {
|
||||
"message": "Could not find embedding for text",
|
||||
@ -225,6 +273,9 @@ async fn create_mock_with_template(
|
||||
|
||||
embeddings.push(embedding.to_vec());
|
||||
}
|
||||
embeddings
|
||||
};
|
||||
|
||||
|
||||
let data : Vec<_> = embeddings.into_iter().enumerate().map(|(index, embedding)| json!({
|
||||
"object": "embedding",
|
||||
@ -517,6 +568,67 @@ async fn it_works() {
|
||||
|
||||
// tokenize long text
|
||||
|
||||
// basic test "it works"
|
||||
#[actix_rt::test]
|
||||
async fn tokenize_long_text() {
|
||||
let (_mock, setting) = create_mock_tokenized().await;
|
||||
let server = get_server_vector().await;
|
||||
let index = server.index("doggo");
|
||||
|
||||
let (response, code) = index
|
||||
.update_settings(json!({
|
||||
"embedders": {
|
||||
"default": setting,
|
||||
},
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let task = server.wait_task(response.uid()).await;
|
||||
snapshot!(task["status"], @r###""succeeded""###);
|
||||
let documents = json!([
|
||||
{"id": 0, "text": long_text()}
|
||||
]);
|
||||
let (value, code) = index.add_documents(documents, None).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
let task = index.wait_task(value.uid()).await;
|
||||
snapshot!(task, @r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "doggo",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
let (response, code) = index
|
||||
.search_post(json!({
|
||||
"q": "grand chien de berger des montagnes",
|
||||
"showRankingScore": true,
|
||||
"attributesToRetrieve": ["id"],
|
||||
"hybrid": {"semanticRatio": 1.0}
|
||||
}))
|
||||
.await;
|
||||
snapshot!(code, @"200 OK");
|
||||
snapshot!(json_string!(response["hits"]), @r###"
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"_rankingScore": 0.07944583892822266
|
||||
}
|
||||
]
|
||||
"###);
|
||||
}
|
||||
|
||||
// "wrong parameters"
|
||||
|
||||
#[actix_rt::test]
|
||||
|
Loading…
Reference in New Issue
Block a user