sequential extractor

This commit is contained in:
ad hoc 2022-03-04 20:12:44 +01:00
parent af8a5f2c21
commit b57c59baa4
No known key found for this signature in database
GPG Key ID: 4F00A782990CC643
14 changed files with 198 additions and 38 deletions

2
Cargo.lock generated
View File

@ -1780,7 +1780,7 @@ dependencies = [
"once_cell", "once_cell",
"parking_lot 0.11.2", "parking_lot 0.11.2",
"paste", "paste",
"pin-project", "pin-project-lite",
"platform-dirs", "platform-dirs",
"rand", "rand",
"rayon", "rayon",

View File

@ -54,7 +54,6 @@ num_cpus = "1.13.0"
obkv = "0.2.0" obkv = "0.2.0"
once_cell = "1.8.0" once_cell = "1.8.0"
parking_lot = "0.11.2" parking_lot = "0.11.2"
pin-project = "1.0.8"
platform-dirs = "0.3.0" platform-dirs = "0.3.0"
rand = "0.8.4" rand = "0.8.4"
rayon = "1.5.1" rayon = "1.5.1"
@ -78,6 +77,7 @@ tokio = { version = "1.11.0", features = ["full"] }
tokio-stream = "0.1.7" tokio-stream = "0.1.7"
uuid = { version = "0.8.2", features = ["serde"] } uuid = { version = "0.8.2", features = ["serde"] }
walkdir = "2.3.2" walkdir = "2.3.2"
pin-project-lite = "0.2.8"
[dev-dependencies] [dev-dependencies]
actix-rt = "2.2.0" actix-rt = "2.2.0"

View File

@ -41,10 +41,7 @@ impl<P, D> GuardedData<P, D> {
}), }),
None => Err(AuthenticationError::IrretrievableState.into()), None => Err(AuthenticationError::IrretrievableState.into()),
}, },
(token, None) => { (token, None) => Err(AuthenticationError::InvalidToken(token).into()),
let token = token.to_string();
Err(AuthenticationError::InvalidToken(token).into())
}
} }
} }

View File

@ -1,3 +1,4 @@
pub mod payload; pub mod payload;
#[macro_use] #[macro_use]
pub mod authentication; pub mod authentication;
pub mod sequential_extractor;

View File

@ -0,0 +1,148 @@
#![allow(non_snake_case)]
use std::{future::Future, pin::Pin, task::Poll};
use actix_web::{dev::Payload, FromRequest, Handler, HttpRequest};
use pin_project_lite::pin_project;
/// `SeqHandler` is an actix `Handler` that enforces that extractors errors are returned in the
/// same order as they are defined in the wrapped handler. This is needed because, by default, actix
/// to resolves the extractors concurrently, whereas we always need the authentication extractor to
/// throw first.
#[derive(Clone)]
pub struct SeqHandler<H>(pub H);
pub struct SeqFromRequest<T>(T);
/// This macro implements `FromRequest` for arbitrary arity handler, except for one, which is
/// useless anyway.
macro_rules! gen_seq {
($ty:ident; $($T:ident)+) => {
pin_project! {
pub struct $ty<$($T: FromRequest), +> {
$(
#[pin]
$T: ExtractFuture<$T::Future, $T, $T::Error>,
)+
}
}
impl<$($T: FromRequest), +> Future for $ty<$($T),+> {
type Output = Result<SeqFromRequest<($($T),+)>, actix_web::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
let mut count_fut = 0;
let mut count_finished = 0;
$(
count_fut += 1;
match this.$T.as_mut().project() {
ExtractProj::Future { fut } => match fut.poll(cx) {
Poll::Ready(Ok(output)) => {
count_finished += 1;
let _ = this
.$T
.as_mut()
.project_replace(ExtractFuture::Done { output });
}
Poll::Ready(Err(error)) => {
count_finished += 1;
let _ = this
.$T
.as_mut()
.project_replace(ExtractFuture::Error { error });
}
Poll::Pending => (),
},
ExtractProj::Done { .. } => count_finished += 1,
ExtractProj::Error { .. } => {
// short circuit if all previous are finished and we had an error.
if count_finished == count_fut {
match this.$T.project_replace(ExtractFuture::Empty) {
ExtractReplaceProj::Error { error } => {
return Poll::Ready(Err(error.into()))
}
_ => unreachable!("Invalid future state"),
}
} else {
count_finished += 1;
}
}
ExtractProj::Empty => unreachable!("From request polled after being finished. {}", stringify!($T)),
}
)+
if count_fut == count_finished {
let result = (
$(
match this.$T.project_replace(ExtractFuture::Empty) {
ExtractReplaceProj::Done { output } => output,
ExtractReplaceProj::Error { error } => return Poll::Ready(Err(error.into())),
_ => unreachable!("Invalid future state"),
},
)+
);
Poll::Ready(Ok(SeqFromRequest(result)))
} else {
Poll::Pending
}
}
}
impl<$($T: FromRequest,)+> FromRequest for SeqFromRequest<($($T,)+)> {
type Error = actix_web::Error;
type Future = $ty<$($T),+>;
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
$ty {
$(
$T: ExtractFuture::Future {
fut: $T::from_request(req, payload),
},
)+
}
}
}
impl<Han, $($T: FromRequest),+> Handler<SeqFromRequest<($($T),+)>> for SeqHandler<Han>
where
Han: Handler<($($T),+)>,
{
type Output = Han::Output;
type Future = Han::Future;
fn call(&self, args: SeqFromRequest<($($T),+)>) -> Self::Future {
self.0.call(args.0)
}
}
};
}
// Not working for a single argument, but then, it is not really necessary.
// gen_seq! { SeqFromRequestFut1; A }
gen_seq! { SeqFromRequestFut2; A B }
gen_seq! { SeqFromRequestFut3; A B C }
gen_seq! { SeqFromRequestFut4; A B C D }
gen_seq! { SeqFromRequestFut5; A B C D E }
gen_seq! { SeqFromRequestFut6; A B C D E F }
pin_project! {
#[project = ExtractProj]
#[project_replace = ExtractReplaceProj]
enum ExtractFuture<Fut, Res, Err> {
Future {
#[pin]
fut: Fut,
},
Done {
output: Res,
},
Error {
error: Err,
},
Empty,
}
}

View File

@ -7,20 +7,23 @@ use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use time::OffsetDateTime; use time::OffsetDateTime;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::{
authentication::{policies::*, GuardedData},
sequential_extractor::SeqHandler,
};
use meilisearch_error::{Code, ResponseError}; use meilisearch_error::{Code, ResponseError};
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service( cfg.service(
web::resource("") web::resource("")
.route(web::post().to(create_api_key)) .route(web::post().to(SeqHandler(create_api_key)))
.route(web::get().to(list_api_keys)), .route(web::get().to(SeqHandler(list_api_keys))),
) )
.service( .service(
web::resource("/{api_key}") web::resource("/{api_key}")
.route(web::get().to(get_api_key)) .route(web::get().to(SeqHandler(get_api_key)))
.route(web::patch().to(patch_api_key)) .route(web::patch().to(SeqHandler(patch_api_key)))
.route(web::delete().to(delete_api_key)), .route(web::delete().to(SeqHandler(delete_api_key))),
); );
} }

View File

@ -7,10 +7,13 @@ use serde_json::json;
use crate::analytics::Analytics; use crate::analytics::Analytics;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("").route(web::post().to(create_dump))) cfg.service(web::resource("").route(web::post().to(SeqHandler(create_dump))))
.service(web::resource("/{dump_uid}/status").route(web::get().to(get_dump_status))); .service(
web::resource("/{dump_uid}/status").route(web::get().to(SeqHandler(get_dump_status))),
);
} }
pub async fn create_dump( pub async fn create_dump(

View File

@ -20,6 +20,7 @@ use crate::analytics::Analytics;
use crate::error::MeilisearchHttpError; use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::payload::Payload; use crate::extractors::payload::Payload;
use crate::extractors::sequential_extractor::SeqHandler;
use crate::task::SummarizedTaskView; use crate::task::SummarizedTaskView;
const DEFAULT_RETRIEVE_DOCUMENTS_OFFSET: usize = 0; const DEFAULT_RETRIEVE_DOCUMENTS_OFFSET: usize = 0;
@ -71,17 +72,17 @@ pub struct DocumentParam {
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service( cfg.service(
web::resource("") web::resource("")
.route(web::get().to(get_all_documents)) .route(web::get().to(SeqHandler(get_all_documents)))
.route(web::post().to(add_documents)) .route(web::post().to(SeqHandler(add_documents)))
.route(web::put().to(update_documents)) .route(web::put().to(SeqHandler(update_documents)))
.route(web::delete().to(clear_all_documents)), .route(web::delete().to(SeqHandler(clear_all_documents))),
) )
// this route needs to be before the /documents/{document_id} to match properly // this route needs to be before the /documents/{document_id} to match properly
.service(web::resource("/delete-batch").route(web::post().to(delete_documents))) .service(web::resource("/delete-batch").route(web::post().to(SeqHandler(delete_documents))))
.service( .service(
web::resource("/{document_id}") web::resource("/{document_id}")
.route(web::get().to(get_document)) .route(web::get().to(SeqHandler(get_document)))
.route(web::delete().to(delete_document)), .route(web::delete().to(SeqHandler(delete_document))),
); );
} }

View File

@ -9,6 +9,7 @@ use time::OffsetDateTime;
use crate::analytics::Analytics; use crate::analytics::Analytics;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::task::SummarizedTaskView; use crate::task::SummarizedTaskView;
pub mod documents; pub mod documents;
@ -20,17 +21,17 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service( cfg.service(
web::resource("") web::resource("")
.route(web::get().to(list_indexes)) .route(web::get().to(list_indexes))
.route(web::post().to(create_index)), .route(web::post().to(SeqHandler(create_index))),
) )
.service( .service(
web::scope("/{index_uid}") web::scope("/{index_uid}")
.service( .service(
web::resource("") web::resource("")
.route(web::get().to(get_index)) .route(web::get().to(SeqHandler(get_index)))
.route(web::put().to(update_index)) .route(web::put().to(SeqHandler(update_index)))
.route(web::delete().to(delete_index)), .route(web::delete().to(SeqHandler(delete_index))),
) )
.service(web::resource("/stats").route(web::get().to(get_index_stats))) .service(web::resource("/stats").route(web::get().to(SeqHandler(get_index_stats))))
.service(web::scope("/documents").configure(documents::configure)) .service(web::scope("/documents").configure(documents::configure))
.service(web::scope("/search").configure(search::configure)) .service(web::scope("/search").configure(search::configure))
.service(web::scope("/tasks").configure(tasks::configure)) .service(web::scope("/tasks").configure(tasks::configure))

View File

@ -9,12 +9,13 @@ use serde_json::Value;
use crate::analytics::{Analytics, SearchAggregator}; use crate::analytics::{Analytics, SearchAggregator};
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service( cfg.service(
web::resource("") web::resource("")
.route(web::get().to(search_with_url_query)) .route(web::get().to(SeqHandler(search_with_url_query)))
.route(web::post().to(search_with_post)), .route(web::post().to(SeqHandler(search_with_post))),
); );
} }

View File

@ -23,6 +23,7 @@ macro_rules! make_setting_route {
use crate::analytics::Analytics; use crate::analytics::Analytics;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::task::SummarizedTaskView; use crate::task::SummarizedTaskView;
use meilisearch_error::ResponseError; use meilisearch_error::ResponseError;
@ -98,9 +99,9 @@ macro_rules! make_setting_route {
pub fn resources() -> Resource { pub fn resources() -> Resource {
Resource::new($route) Resource::new($route)
.route(web::get().to(get)) .route(web::get().to(SeqHandler(get)))
.route(web::post().to(update)) .route(web::post().to(SeqHandler(update)))
.route(web::delete().to(delete)) .route(web::delete().to(SeqHandler(delete)))
} }
} }
}; };
@ -226,11 +227,12 @@ make_setting_route!(
macro_rules! generate_configure { macro_rules! generate_configure {
($($mod:ident),*) => { ($($mod:ident),*) => {
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
use crate::extractors::sequential_extractor::SeqHandler;
cfg.service( cfg.service(
web::resource("") web::resource("")
.route(web::post().to(update_all)) .route(web::post().to(SeqHandler(update_all)))
.route(web::get().to(get_all)) .route(web::get().to(SeqHandler(get_all)))
.route(web::delete().to(delete_all))) .route(web::delete().to(SeqHandler(delete_all))))
$(.service($mod::resources()))*; $(.service($mod::resources()))*;
} }
}; };

View File

@ -8,11 +8,12 @@ use time::OffsetDateTime;
use crate::analytics::Analytics; use crate::analytics::Analytics;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::task::{TaskListView, TaskView}; use crate::task::{TaskListView, TaskView};
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("").route(web::get().to(get_all_tasks_status))) cfg.service(web::resource("").route(web::get().to(SeqHandler(get_all_tasks_status))))
.service(web::resource("{task_id}").route(web::get().to(get_task_status))); .service(web::resource("{task_id}").route(web::get().to(SeqHandler(get_task_status))));
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]

View File

@ -7,11 +7,12 @@ use serde_json::json;
use crate::analytics::Analytics; use crate::analytics::Analytics;
use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::authentication::{policies::*, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::task::{TaskListView, TaskView}; use crate::task::{TaskListView, TaskView};
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("").route(web::get().to(get_tasks))) cfg.service(web::resource("").route(web::get().to(SeqHandler(get_tasks))))
.service(web::resource("/{task_id}").route(web::get().to(get_task))); .service(web::resource("/{task_id}").route(web::get().to(SeqHandler(get_task))));
} }
async fn get_tasks( async fn get_tasks(

View File

@ -91,6 +91,7 @@ async fn error_access_expired_key() {
thread::sleep(time::Duration::new(1, 0)); thread::sleep(time::Duration::new(1, 0));
for (method, route) in AUTHORIZATIONS.keys() { for (method, route) in AUTHORIZATIONS.keys() {
dbg!(route);
let (response, code) = server.dummy_request(method, route).await; let (response, code) = server.dummy_request(method, route).await;
assert_eq!(response, INVALID_RESPONSE.clone()); assert_eq!(response, INVALID_RESPONSE.clone());