From 9cef8ec087107c44394877ca806e6868c9797660 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 10 Apr 2024 09:43:33 +0200 Subject: [PATCH] add prompt and context --- meilisearch-types/src/error.rs | 4 ++ meilisearch/src/error.rs | 6 ++ meilisearch/src/search.rs | 58 ++++++++++++----- milli/src/prompt/document.rs | 16 ++++- milli/src/prompt/mod.rs | 5 +- milli/src/prompt/recommend.rs | 112 +++++++++++++++++++++++++++++++++ milli/src/search/recommend.rs | 102 ++++++++++++++++++++++++++++-- 7 files changed, 279 insertions(+), 24 deletions(-) create mode 100644 milli/src/prompt/recommend.rs diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 6777a7ebe..107a21ec5 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -245,7 +245,9 @@ InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; +InvalidRecommendContext , InvalidRequest , BAD_REQUEST ; InvalidRecommendId , InvalidRequest , BAD_REQUEST ; +InvalidRecommendPrompt , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPreTag , InvalidRequest , BAD_REQUEST ; @@ -309,6 +311,8 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; MissingIndexUid , InvalidRequest , BAD_REQUEST ; MissingMasterKey , Auth , UNAUTHORIZED ; MissingPayload , InvalidRequest , BAD_REQUEST ; +MissingPrompt , InvalidRequest , BAD_REQUEST ; +MissingPromptOrId , InvalidRequest , BAD_REQUEST ; MissingSearchHybrid , InvalidRequest , BAD_REQUEST ; MissingSwapIndexes , InvalidRequest , BAD_REQUEST ; MissingTaskFilters , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch/src/error.rs b/meilisearch/src/error.rs index 13e460c24..2d23df4f1 100644 --- a/meilisearch/src/error.rs +++ b/meilisearch/src/error.rs @@ -61,6 +61,10 @@ pub enum MeilisearchHttpError { Join(#[from] JoinError), #[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")] MissingSearchHybrid, + #[error("Invalid request: `prompt` parameter is required when `context` is present.")] + RecommendMissingPrompt, + #[error("Invalid request: one of the `prompt` or `id` parameters is required.")] + RecommendMissingPromptOrId, } impl ErrorCode for MeilisearchHttpError { @@ -89,6 +93,8 @@ impl ErrorCode for MeilisearchHttpError { MeilisearchHttpError::DocumentFormat(e) => e.error_code(), MeilisearchHttpError::Join(_) => Code::Internal, MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid, + MeilisearchHttpError::RecommendMissingPrompt => Code::MissingPrompt, + MeilisearchHttpError::RecommendMissingPromptOrId => Code::MissingPromptOrId, } } } diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 3fbc20757..94fe2c4f7 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -316,7 +316,7 @@ impl SearchQueryWithIndex { #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct RecommendQuery { #[deserr(default, error = DeserrJsonError)] - pub id: String, + pub id: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -331,6 +331,11 @@ pub struct RecommendQuery { pub show_ranking_score: bool, #[deserr(default, error = DeserrJsonError, default)] pub show_ranking_score_details: bool, + + #[deserr(default, error = DeserrJsonError)] + pub prompt: Option, + #[deserr(default, error = DeserrJsonError)] + pub context: Option, } #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr)] @@ -418,7 +423,8 @@ pub struct SearchResult { #[serde(rename_all = "camelCase")] pub struct RecommendResult { pub hits: Vec, - pub id: String, + pub id: Option, + pub prompt: Option, pub processing_time_ms: u128, #[serde(flatten)] pub hits_info: HitsInfo, @@ -836,20 +842,41 @@ pub fn perform_recommend( let before_search = Instant::now(); let rtxn = index.read_txn()?; - let internal_id = index - .external_documents_ids() - .get(&rtxn, &query.id)? - .ok_or_else(|| MeilisearchHttpError::DocumentNotFound(query.id.clone()))?; + let internal_id = query + .id + .as_deref() + .map(|id| -> Result<_, MeilisearchHttpError> { + Ok(index + .external_documents_ids() + .get(&rtxn, id)? + .ok_or_else(|| MeilisearchHttpError::DocumentNotFound(id.to_owned()))?) + }) + .transpose()?; - let mut recommend = milli::Recommend::new( - internal_id, - query.offset, - query.limit, - index, - &rtxn, - embedder_name, - embedder, - ); + let mut recommend = match (query.prompt.as_deref(), internal_id, query.context) { + (None, Some(internal_id), None) => milli::Recommend::with_docid( + internal_id, + query.offset, + query.limit, + index, + &rtxn, + embedder_name, + embedder, + ), + (Some(prompt), internal_id, context) => milli::Recommend::with_prompt( + prompt, + internal_id, + context, + query.offset, + query.limit, + index, + &rtxn, + embedder_name, + embedder, + ), + (None, _, Some(_)) => return Err(MeilisearchHttpError::RecommendMissingPrompt.into()), + (None, None, None) => return Err(MeilisearchHttpError::RecommendMissingPromptOrId.into()), + }; if let Some(ref filter) = query.filter { if let Some(facets) = parse_filter(filter)? { @@ -947,6 +974,7 @@ pub fn perform_recommend( hits: documents, hits_info, id: query.id, + prompt: query.prompt, processing_time_ms: before_search.elapsed().as_millis(), }; Ok(result) diff --git a/milli/src/prompt/document.rs b/milli/src/prompt/document.rs index b5d43b5be..ae5866a33 100644 --- a/milli/src/prompt/document.rs +++ b/milli/src/prompt/document.rs @@ -29,7 +29,7 @@ impl ParsedValue { } impl<'a> Document<'a> { - pub fn new( + pub fn from_deladd_obkv( data: obkv::KvReaderU16<'a>, side: DelAdd, inverted_field_map: &'a FieldsIdsMap, @@ -48,6 +48,20 @@ impl<'a> Document<'a> { Self(out_data) } + pub fn from_doc_obkv( + data: obkv::KvReaderU16<'a>, + inverted_field_map: &'a FieldsIdsMap, + ) -> Self { + let mut out_data = BTreeMap::new(); + for (fid, raw) in data { + let Some(name) = inverted_field_map.name(fid) else { + continue; + }; + out_data.insert(name, (raw, ParsedValue::empty())); + } + Self(out_data) + } + fn is_empty(&self) -> bool { self.0.is_empty() } diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs index 97ccbfb61..1b0a4ab74 100644 --- a/milli/src/prompt/mod.rs +++ b/milli/src/prompt/mod.rs @@ -2,6 +2,7 @@ mod context; mod document; pub(crate) mod error; mod fields; +pub mod recommend; mod template_checker; use std::convert::TryFrom; @@ -9,7 +10,7 @@ use std::convert::TryFrom; use error::{NewPromptError, RenderPromptError}; use self::context::Context; -use self::document::Document; +pub use self::document::Document; use crate::update::del_add::DelAdd; use crate::FieldsIdsMap; @@ -95,7 +96,7 @@ impl Prompt { side: DelAdd, field_id_map: &FieldsIdsMap, ) -> Result { - let document = Document::new(document, side, field_id_map); + let document = Document::from_deladd_obkv(document, side, field_id_map); let context = Context::new(&document, field_id_map); self.template.render(&context).map_err(RenderPromptError::missing_context) diff --git a/milli/src/prompt/recommend.rs b/milli/src/prompt/recommend.rs new file mode 100644 index 000000000..c2ddcf294 --- /dev/null +++ b/milli/src/prompt/recommend.rs @@ -0,0 +1,112 @@ +use liquid::model::{ + DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; + +#[derive(Clone, Debug)] +pub struct Context<'a> { + document: Option<&'a Document<'a>>, + context: Option, +} + +impl<'a> Context<'a> { + pub fn new(document: Option<&'a Document<'a>>, context: Option) -> Self { + /// FIXME: unwrap + let context = context.map(|context| liquid::to_object(&context).unwrap()); + Self { document, context } + } +} + +impl<'a> ObjectView for Context<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + match (self.context.as_ref(), self.document.as_ref()) { + (None, None) => 0, + (None, Some(_)) => 1, + (Some(_), None) => 1, + (Some(_), Some(_)) => 2, + } + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + let keys = match (self.context.as_ref(), self.document.as_ref()) { + (None, None) => [].as_slice(), + (None, Some(_)) => ["doc"].as_slice(), + (Some(_), None) => ["context"].as_slice(), + (Some(_), Some(_)) => ["context", "doc"].as_slice(), + }; + + Box::new(keys.iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + self.context + .as_ref() + .map(|context| context.as_value()) + .into_iter() + .chain(self.document.map(|document| document.as_value()).into_iter()), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "context" || index == "doc" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "context" => self.context.as_ref().map(|context| context.as_value()), + "doc" => self.document.as_ref().map(|doc| doc.as_value()), + _ => None, + } + } +} + +impl<'a> ValueView for Context<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/search/recommend.rs b/milli/src/search/recommend.rs index 269d65c49..fb9ce56f7 100644 --- a/milli/src/search/recommend.rs +++ b/milli/src/search/recommend.rs @@ -1,13 +1,20 @@ use std::sync::Arc; use ordered_float::OrderedFloat; +use roaring::RoaringBitmap; +use serde_json::Value; use crate::score_details::{self, ScoreDetails}; use crate::vector::Embedder; use crate::{filtered_universe, DocumentId, Filter, Index, Result, SearchResult}; +enum RecommendKind<'a> { + Id(DocumentId), + Prompt { prompt: &'a str, context: Option, id: Option }, +} + pub struct Recommend<'a> { - id: DocumentId, + kind: RecommendKind<'a>, // this should be linked to the String in the query filter: Option>, offset: usize, @@ -19,7 +26,7 @@ pub struct Recommend<'a> { } impl<'a> Recommend<'a> { - pub fn new( + pub fn with_docid( id: DocumentId, offset: usize, limit: usize, @@ -28,7 +35,39 @@ impl<'a> Recommend<'a> { embedder_name: String, embedder: Arc, ) -> Self { - Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } + Self { + kind: RecommendKind::Id(id), + filter: None, + offset, + limit, + rtxn, + index, + embedder_name, + embedder, + } + } + + pub fn with_prompt( + prompt: &'a str, + id: Option, + context: Option, + offset: usize, + limit: usize, + index: &'a Index, + rtxn: &'a heed::RoTxn<'a>, + embedder_name: String, + embedder: Arc, + ) -> Self { + Self { + kind: RecommendKind::Prompt { prompt, context, id }, + filter: None, + offset, + limit, + rtxn, + index, + embedder_name, + embedder, + } } pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { @@ -62,16 +101,67 @@ impl<'a> Recommend<'a> { let mut results = Vec::new(); + /// FIXME: make id optional... + let id = match &self.kind { + RecommendKind::Id(id) => *id, + RecommendKind::Prompt { prompt, context, id } => id.unwrap(), + }; + + let personalization_vector = if let RecommendKind::Prompt { prompt, context, id } = + &self.kind + { + let fields_ids_map = self.index.fields_ids_map(self.rtxn)?; + + let document = if let Some(id) = id { + Some(self.index.iter_documents(self.rtxn, std::iter::once(*id))?.next().unwrap()?.1) + } else { + None + }; + let document = document + .map(|document| crate::prompt::Document::from_doc_obkv(document, &fields_ids_map)); + + let context = + crate::prompt::recommend::Context::new(document.as_ref(), context.clone()); + + /// FIXME: handle error bad template + let template = + liquid::ParserBuilder::new().stdlib().build().unwrap().parse(prompt).unwrap(); + + /// FIXME: handle error bad context + let rendered = template.render(&context).unwrap(); + + /// FIXME: handle embedding error + Some(self.embedder.embed_one(rendered).unwrap()) + } else { + None + }; + for reader in readers.iter() { let nns_by_item = reader.nns_by_item( self.rtxn, - self.id, + id, self.limit + self.offset + 1, None, Some(&universe), )?; - if let Some(mut nns_by_item) = nns_by_item { - results.append(&mut nns_by_item); + + if let Some(nns_by_item) = nns_by_item { + let mut nns = match &personalization_vector { + Some(vector) => { + let candidates: RoaringBitmap = + nns_by_item.iter().map(|(docid, _)| docid).collect(); + reader.nns_by_vector( + self.rtxn, + vector, + self.limit + self.offset + 1, + None, + Some(&candidates), + )? + } + None => nns_by_item, + }; + + results.append(&mut nns); } }