mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-02-07 03:16:13 +08:00
308 lines
9.7 KiB
Rust
308 lines
9.7 KiB
Rust
|
// Copied from "openai.rs" with the sections I actually understand changed for Ollama.
|
||
|
// The common components of the Ollama and OpenAI interfaces might need to be extracted.
|
||
|
|
||
|
use std::fmt::Display;
|
||
|
|
||
|
use reqwest::StatusCode;
|
||
|
|
||
|
use super::error::{EmbedError, NewEmbedderError};
|
||
|
use super::openai::Retry;
|
||
|
use super::{DistributionShift, Embedding, Embeddings};
|
||
|
|
||
|
#[derive(Debug)]
|
||
|
pub struct Embedder {
|
||
|
headers: reqwest::header::HeaderMap,
|
||
|
options: EmbedderOptions,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||
|
pub struct EmbedderOptions {
|
||
|
pub embedding_model: EmbeddingModel,
|
||
|
}
|
||
|
|
||
|
#[derive(
|
||
|
Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr,
|
||
|
)]
|
||
|
#[deserr(deny_unknown_fields)]
|
||
|
pub struct EmbeddingModel {
|
||
|
name: String,
|
||
|
dimensions: usize,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, serde::Serialize)]
|
||
|
struct OllamaRequest<'a> {
|
||
|
model: &'a str,
|
||
|
prompt: &'a str,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, serde::Deserialize)]
|
||
|
struct OllamaResponse {
|
||
|
embedding: Embedding,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, serde::Deserialize)]
|
||
|
pub struct OllamaError {
|
||
|
error: String,
|
||
|
}
|
||
|
|
||
|
impl EmbeddingModel {
|
||
|
pub fn max_token(&self) -> usize {
|
||
|
// this might not be the same for all models
|
||
|
8192
|
||
|
}
|
||
|
|
||
|
pub fn default_dimensions(&self) -> usize {
|
||
|
// Dimensions for nomic-embed-text
|
||
|
768
|
||
|
}
|
||
|
|
||
|
pub fn name(&self) -> String {
|
||
|
self.name.clone()
|
||
|
}
|
||
|
|
||
|
pub fn from_name(name: &str) -> Self {
|
||
|
Self { name: name.to_string(), dimensions: 0 }
|
||
|
}
|
||
|
|
||
|
pub fn supports_overriding_dimensions(&self) -> bool {
|
||
|
false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Default for EmbeddingModel {
|
||
|
fn default() -> Self {
|
||
|
Self { name: "nomic-embed-text".to_string(), dimensions: 0 }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl EmbedderOptions {
|
||
|
pub fn with_default_model() -> Self {
|
||
|
Self { embedding_model: Default::default() }
|
||
|
}
|
||
|
|
||
|
pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self {
|
||
|
Self { embedding_model }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Embedder {
|
||
|
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||
|
reqwest::ClientBuilder::new()
|
||
|
.default_headers(self.headers.clone())
|
||
|
.build()
|
||
|
.map_err(EmbedError::openai_initialize_web_client)
|
||
|
}
|
||
|
|
||
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||
|
headers.insert(
|
||
|
reqwest::header::CONTENT_TYPE,
|
||
|
reqwest::header::HeaderValue::from_static("application/json"),
|
||
|
);
|
||
|
|
||
|
let mut embedder = Self { options, headers };
|
||
|
|
||
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||
|
.enable_io()
|
||
|
.enable_time()
|
||
|
.build()
|
||
|
.map_err(EmbedError::openai_runtime_init)
|
||
|
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||
|
|
||
|
// Get dimensions from Ollama
|
||
|
let request =
|
||
|
OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" };
|
||
|
// TODO: Refactor into shared error type
|
||
|
let client = embedder
|
||
|
.new_client()
|
||
|
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||
|
|
||
|
rt.block_on(async move {
|
||
|
let response = client
|
||
|
.post(get_ollama_path())
|
||
|
.json(&request)
|
||
|
.send()
|
||
|
.await
|
||
|
.map_err(EmbedError::ollama_unexpected)
|
||
|
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||
|
|
||
|
// Process error in case model not found
|
||
|
let response = Self::check_response(response).await.map_err(|_err| {
|
||
|
let e = EmbedError::ollama_model_not_found(OllamaError {
|
||
|
error: format!("model: {}", embedder.options.embedding_model.name()),
|
||
|
});
|
||
|
NewEmbedderError::ollama_could_not_determine_dimension(e)
|
||
|
})?;
|
||
|
|
||
|
let response: OllamaResponse = response
|
||
|
.json()
|
||
|
.await
|
||
|
.map_err(EmbedError::ollama_unexpected)
|
||
|
.map_err(NewEmbedderError::ollama_could_not_determine_dimension)?;
|
||
|
|
||
|
let embedding = Embeddings::from_single_embedding(response.embedding);
|
||
|
|
||
|
embedder.options.embedding_model.dimensions = embedding.dimension();
|
||
|
|
||
|
tracing::info!(
|
||
|
"ollama model {} with dimensionality {} added",
|
||
|
embedder.options.embedding_model.name(),
|
||
|
embedding.dimension()
|
||
|
);
|
||
|
|
||
|
Ok(embedder)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
||
|
if !response.status().is_success() {
|
||
|
// Not the same number of possible error cases covered as with OpenAI.
|
||
|
match response.status() {
|
||
|
StatusCode::TOO_MANY_REQUESTS => {
|
||
|
let error_response: OllamaError = response
|
||
|
.json()
|
||
|
.await
|
||
|
.map_err(EmbedError::ollama_unexpected)
|
||
|
.map_err(Retry::retry_later)?;
|
||
|
|
||
|
return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests(
|
||
|
OllamaError { error: error_response.error },
|
||
|
)));
|
||
|
}
|
||
|
StatusCode::SERVICE_UNAVAILABLE => {
|
||
|
let error_response: OllamaError = response
|
||
|
.json()
|
||
|
.await
|
||
|
.map_err(EmbedError::ollama_unexpected)
|
||
|
.map_err(Retry::retry_later)?;
|
||
|
return Err(Retry::retry_later(EmbedError::ollama_internal_server_error(
|
||
|
OllamaError { error: error_response.error },
|
||
|
)));
|
||
|
}
|
||
|
StatusCode::NOT_FOUND => {
|
||
|
let error_response: OllamaError = response
|
||
|
.json()
|
||
|
.await
|
||
|
.map_err(EmbedError::ollama_unexpected)
|
||
|
.map_err(Retry::give_up)?;
|
||
|
|
||
|
return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError {
|
||
|
error: error_response.error,
|
||
|
})));
|
||
|
}
|
||
|
code => {
|
||
|
return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code(
|
||
|
code.as_u16(),
|
||
|
)));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
Ok(response)
|
||
|
}
|
||
|
|
||
|
pub async fn embed(
|
||
|
&self,
|
||
|
texts: Vec<String>,
|
||
|
client: &reqwest::Client,
|
||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||
|
// Ollama only embedds one document at a time.
|
||
|
let mut results = Vec::with_capacity(texts.len());
|
||
|
|
||
|
// The retry loop is inside the texts loop, might have to switch that around
|
||
|
for text in texts {
|
||
|
// Retries copied from openai.rs
|
||
|
for attempt in 0..7 {
|
||
|
let retry_duration = match self.try_embed(&text, client).await {
|
||
|
Ok(result) => {
|
||
|
results.push(result);
|
||
|
break;
|
||
|
}
|
||
|
Err(retry) => {
|
||
|
tracing::warn!("Failed: {}", retry.error);
|
||
|
retry.into_duration(attempt)
|
||
|
}
|
||
|
}?;
|
||
|
tracing::warn!(
|
||
|
"Attempt #{}, retrying after {}ms.",
|
||
|
attempt,
|
||
|
retry_duration.as_millis()
|
||
|
);
|
||
|
tokio::time::sleep(retry_duration).await;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Ok(results)
|
||
|
}
|
||
|
|
||
|
async fn try_embed(
|
||
|
&self,
|
||
|
text: &str,
|
||
|
client: &reqwest::Client,
|
||
|
) -> Result<Embeddings<f32>, Retry> {
|
||
|
let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text };
|
||
|
let response = client
|
||
|
.post(get_ollama_path())
|
||
|
.json(&request)
|
||
|
.send()
|
||
|
.await
|
||
|
.map_err(EmbedError::openai_network)
|
||
|
.map_err(Retry::retry_later)?;
|
||
|
|
||
|
let response = Self::check_response(response).await?;
|
||
|
|
||
|
let response: OllamaResponse = response
|
||
|
.json()
|
||
|
.await
|
||
|
.map_err(EmbedError::openai_unexpected)
|
||
|
.map_err(Retry::retry_later)?;
|
||
|
|
||
|
tracing::trace!("response: {:?}", response.embedding);
|
||
|
|
||
|
let embedding = Embeddings::from_single_embedding(response.embedding);
|
||
|
Ok(embedding)
|
||
|
}
|
||
|
|
||
|
pub fn embed_chunks(
|
||
|
&self,
|
||
|
text_chunks: Vec<Vec<String>>,
|
||
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||
|
.enable_io()
|
||
|
.enable_time()
|
||
|
.build()
|
||
|
.map_err(EmbedError::openai_runtime_init)?;
|
||
|
let client = self.new_client()?;
|
||
|
rt.block_on(futures::future::try_join_all(
|
||
|
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||
|
))
|
||
|
}
|
||
|
|
||
|
// Defaults copied from openai.rs
|
||
|
pub fn chunk_count_hint(&self) -> usize {
|
||
|
10
|
||
|
}
|
||
|
|
||
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||
|
10
|
||
|
}
|
||
|
|
||
|
pub fn dimensions(&self) -> usize {
|
||
|
self.options.embedding_model.dimensions
|
||
|
}
|
||
|
|
||
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||
|
None
|
||
|
}
|
||
|
}
|
||
|
|
||
|
impl Display for OllamaError {
|
||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
|
write!(f, "{}", self.error)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
fn get_ollama_path() -> String {
|
||
|
// Important: Hostname not enough, has to be entire path to embeddings endpoint
|
||
|
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
|
||
|
}
|