diff --git a/Cargo.lock b/Cargo.lock index 8e26cc940..9cee246dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1026,6 +1026,7 @@ dependencies = [ "chrono 0.4.9 (registry+https://github.com/rust-lang/crates.io-index)", "crossbeam-channel 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", + "futures 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "heed 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)", "http 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)", "http-service 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/meilisearch-http/Cargo.toml b/meilisearch-http/Cargo.toml index bb6a7a00b..47472c8e6 100644 --- a/meilisearch-http/Cargo.toml +++ b/meilisearch-http/Cargo.toml @@ -39,11 +39,11 @@ tide = "0.5.1" ureq = { version = "0.11.2", features = ["tls"], default-features = false } walkdir = "2.2.9" whoami = "0.6" - +http-service = "0.4.0" +futures = "0.3.1" [dev-dependencies] http-service-mock = "0.4.0" -http-service = "0.4.0" tempdir = "0.3.7" [dev-dependencies.assert-json-diff] diff --git a/meilisearch-http/src/cors.rs b/meilisearch-http/src/cors.rs new file mode 100644 index 000000000..e30ff4d69 --- /dev/null +++ b/meilisearch-http/src/cors.rs @@ -0,0 +1,424 @@ +//! Cors middleware + +use futures::future::BoxFuture; +use http::header::HeaderValue; +use http::{header, Method, StatusCode}; +use http_service::Body; + +use tide::middleware::{Middleware, Next}; +use tide::{Request, Response}; + +/// Middleware for CORS +/// +/// # Example +/// +/// ```no_run +/// use http::header::HeaderValue; +/// use tide::middleware::{Cors, Origin}; +/// +/// Cors::new() +/// .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")) +/// .allow_origin(Origin::from("*")) +/// .allow_credentials(false); +/// ``` +#[derive(Clone, Debug, Hash)] +pub struct Cors { + allow_credentials: Option, + allow_headers: HeaderValue, + allow_methods: HeaderValue, + allow_origin: Origin, + expose_headers: Option, + max_age: HeaderValue, +} + +pub const DEFAULT_MAX_AGE: &str = "86400"; +pub const DEFAULT_METHODS: &str = "GET, POST, OPTIONS"; +pub const WILDCARD: &str = "*"; + +impl Cors { + /// Creates a new Cors middleware. + pub fn new() -> Self { + Self { + allow_credentials: None, + allow_headers: HeaderValue::from_static(WILDCARD), + allow_methods: HeaderValue::from_static(DEFAULT_METHODS), + allow_origin: Origin::Any, + expose_headers: None, + max_age: HeaderValue::from_static(DEFAULT_MAX_AGE), + } + } + + /// Set allow_credentials and return new Cors + pub fn allow_credentials(mut self, allow_credentials: bool) -> Self { + self.allow_credentials = match HeaderValue::from_str(&allow_credentials.to_string()) { + Ok(header) => Some(header), + Err(_) => None, + }; + self + } + + /// Set allow_headers and return new Cors + pub fn allow_headers>(mut self, headers: T) -> Self { + self.allow_headers = headers.into(); + self + } + + /// Set max_age and return new Cors + pub fn max_age>(mut self, max_age: T) -> Self { + self.max_age = max_age.into(); + self + } + + /// Set allow_methods and return new Cors + pub fn allow_methods>(mut self, methods: T) -> Self { + self.allow_methods = methods.into(); + self + } + + /// Set allow_origin and return new Cors + pub fn allow_origin>(mut self, origin: T) -> Self { + self.allow_origin = origin.into(); + self + } + + /// Set expose_headers and return new Cors + pub fn expose_headers>(mut self, headers: T) -> Self { + self.expose_headers = Some(headers.into()); + self + } + + fn build_preflight_response(&self, origin: &HeaderValue) -> http::response::Response { + let mut response = http::Response::builder() + .status(StatusCode::OK) + .header::<_, HeaderValue>(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone()) + .header( + header::ACCESS_CONTROL_ALLOW_METHODS, + self.allow_methods.clone(), + ) + .header( + header::ACCESS_CONTROL_ALLOW_HEADERS, + self.allow_headers.clone(), + ) + .header(header::ACCESS_CONTROL_MAX_AGE, self.max_age.clone()) + .body(Body::empty()) + .unwrap(); + + if let Some(allow_credentials) = self.allow_credentials.clone() { + response + .headers_mut() + .append(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, allow_credentials); + } + + if let Some(expose_headers) = self.expose_headers.clone() { + response + .headers_mut() + .append(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers); + } + + response + } + + /// Look at origin of request and determine allow_origin + fn response_origin>(&self, origin: T) -> Option { + let origin = origin.into(); + if !self.is_valid_origin(origin.clone()) { + return None; + } + + match self.allow_origin { + Origin::Any => Some(HeaderValue::from_static(WILDCARD)), + _ => Some(origin), + } + } + + /// Determine if origin is appropriate + fn is_valid_origin>(&self, origin: T) -> bool { + let origin = match origin.into().to_str() { + Ok(s) => s.to_string(), + Err(_) => return false, + }; + + match &self.allow_origin { + Origin::Any => true, + Origin::Exact(s) => s == &origin, + Origin::List(list) => list.contains(&origin), + } + } +} + +impl Middleware for Cors { + fn handle<'a>(&'a self, req: Request, next: Next<'a, State>) -> BoxFuture<'a, Response> { + Box::pin(async move { + let origin = req + .headers() + .get(header::ORIGIN) + .cloned() + .unwrap_or_else(|| HeaderValue::from_static("")); + + if !self.is_valid_origin(&origin) { + return http::Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap() + .into(); + } + + // Return results immediately upon preflight request + if req.method() == Method::OPTIONS { + return self.build_preflight_response(&origin).into(); + } + + let mut response: http_service::Response = next.run(req).await.into(); + let headers = response.headers_mut(); + + headers.append( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + self.response_origin(origin).unwrap(), + ); + + if let Some(allow_credentials) = self.allow_credentials.clone() { + headers.append(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, allow_credentials); + } + + if let Some(expose_headers) = self.expose_headers.clone() { + headers.append(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers); + } + response.into() + }) + } +} + +impl Default for Cors { + fn default() -> Self { + Self::new() + } +} + +/// allow_origin enum +#[derive(Clone, Debug, Hash, PartialEq)] +pub enum Origin { + /// Wildcard. Accept all origin requests + Any, + /// Set a single allow_origin target + Exact(String), + /// Set multiple allow_origin targets + List(Vec), +} + +impl From for Origin { + fn from(s: String) -> Self { + if s == "*" { + return Origin::Any; + } + Origin::Exact(s) + } +} + +impl From<&str> for Origin { + fn from(s: &str) -> Self { + Origin::from(s.to_string()) + } +} + +impl From> for Origin { + fn from(list: Vec) -> Self { + if list.len() == 1 { + return Self::from(list[0].clone()); + } + + Origin::List(list) + } +} + +impl From> for Origin { + fn from(list: Vec<&str>) -> Self { + Origin::from(list.iter().map(|s| s.to_string()).collect::>()) + } +} + +#[cfg(test)] +mod test { + use super::*; + use http::header::HeaderValue; + use http_service::Body; + use http_service_mock::make_server; + + const ALLOW_ORIGIN: &str = "example.com"; + const ALLOW_METHODS: &str = "GET, POST, OPTIONS, DELETE"; + const EXPOSE_HEADER: &str = "X-My-Custom-Header"; + + const ENDPOINT: &str = "/cors"; + + fn app() -> crate::Server<()> { + let mut app = crate::Server::new(); + app.at(ENDPOINT).get(|_| async move { "Hello World" }); + + app + } + + fn request() -> http::Request { + http::Request::get(ENDPOINT) + .header(http::header::ORIGIN, ALLOW_ORIGIN) + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap() + } + + #[test] + fn preflight_request() { + let mut app = app(); + app.middleware( + Cors::new() + .allow_origin(Origin::from(ALLOW_ORIGIN)) + .allow_methods(HeaderValue::from_static(ALLOW_METHODS)) + .expose_headers(HeaderValue::from_static(EXPOSE_HEADER)) + .allow_credentials(true), + ); + + let mut server = make_server(app.into_http_service()).unwrap(); + + let req = http::Request::get(ENDPOINT) + .header(http::header::ORIGIN, ALLOW_ORIGIN) + .method(http::method::Method::OPTIONS) + .body(Body::empty()) + .unwrap(); + + let res = server.simulate(req).unwrap(); + + assert_eq!(res.status(), 200); + + assert_eq!( + res.headers().get("access-control-allow-origin").unwrap(), + ALLOW_ORIGIN + ); + assert_eq!( + res.headers().get("access-control-allow-methods").unwrap(), + ALLOW_METHODS + ); + assert_eq!( + res.headers().get("access-control-allow-headers").unwrap(), + WILDCARD + ); + assert_eq!( + res.headers().get("access-control-max-age").unwrap(), + DEFAULT_MAX_AGE + ); + + assert_eq!( + res.headers() + .get("access-control-allow-credentials") + .unwrap(), + "true" + ); + } + #[test] + fn default_cors_middleware() { + let mut app = app(); + app.middleware(Cors::new()); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request()).unwrap(); + + assert_eq!(res.status(), 200); + + assert_eq!( + res.headers().get("access-control-allow-origin").unwrap(), + "*" + ); + } + + #[test] + fn custom_cors_middleware() { + let mut app = app(); + app.middleware( + Cors::new() + .allow_origin(Origin::from(ALLOW_ORIGIN)) + .allow_credentials(false) + .allow_methods(HeaderValue::from_static(ALLOW_METHODS)) + .expose_headers(HeaderValue::from_static(EXPOSE_HEADER)), + ); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request()).unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!( + res.headers().get("access-control-allow-origin").unwrap(), + ALLOW_ORIGIN + ); + } + + #[test] + fn credentials_true() { + let mut app = app(); + app.middleware(Cors::new().allow_credentials(true)); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request()).unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!( + res.headers() + .get("access-control-allow-credentials") + .unwrap(), + "true" + ); + } + + #[test] + fn set_allow_origin_list() { + let mut app = app(); + let origins = vec![ALLOW_ORIGIN, "foo.com", "bar.com"]; + app.middleware(Cors::new().allow_origin(origins.clone())); + let mut server = make_server(app.into_http_service()).unwrap(); + + for origin in origins { + let request = http::Request::get(ENDPOINT) + .header(http::header::ORIGIN, origin) + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap(); + + let res = server.simulate(request).unwrap(); + + assert_eq!(res.status(), 200); + assert_eq!( + res.headers().get("access-control-allow-origin").unwrap(), + origin + ); + } + } + + #[test] + fn not_set_origin_header() { + let mut app = app(); + app.middleware(Cors::new()); + + let request = http::Request::get(ENDPOINT) + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap(); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request).unwrap(); + + assert_eq!(res.status(), 200); + } + + #[test] + fn unauthorized_origin() { + let mut app = app(); + app.middleware(Cors::new().allow_origin(ALLOW_ORIGIN)); + + let request = http::Request::get(ENDPOINT) + .header(http::header::ORIGIN, "unauthorize-origin.net") + .method(http::method::Method::GET) + .body(Body::empty()) + .unwrap(); + + let mut server = make_server(app.into_http_service()).unwrap(); + let res = server.simulate(request).unwrap(); + + assert_eq!(res.status(), 401); + } +} diff --git a/meilisearch-http/src/main.rs b/meilisearch-http/src/main.rs index 5083ca4ec..0517c3875 100644 --- a/meilisearch-http/src/main.rs +++ b/meilisearch-http/src/main.rs @@ -5,7 +5,6 @@ use async_std::task; use log::info; use main_error::MainError; use structopt::StructOpt; -// use tide::middleware::{CorsMiddleware, CorsOrigin}; use tide::middleware::RequestLogger; use meilisearch_http::data::Data; @@ -13,7 +12,10 @@ use meilisearch_http::option::Opt; use meilisearch_http::routes; use meilisearch_http::routes::index::index_update_callback; +use cors::Cors; + mod analytics; +mod cors; #[cfg(target_os = "linux")] #[global_allocator] @@ -36,11 +38,7 @@ pub fn main() -> Result<(), MainError> { let mut app = tide::with_state(data); - // app.middleware( - // CorsMiddleware::new() - // .allow_origin(CorsOrigin::from("*")) - // .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")), - // ); + app.middleware(Cors::new()); app.middleware(RequestLogger::new()); // app.middleware(tide_compression::Compression::new()); // app.middleware(tide_compression::Decompression::new()); diff --git a/meilisearch-http/src/routes/mod.rs b/meilisearch-http/src/routes/mod.rs index 40305e369..1433f1b2b 100644 --- a/meilisearch-http/src/routes/mod.rs +++ b/meilisearch-http/src/routes/mod.rs @@ -136,6 +136,7 @@ pub fn load_routes(app: &mut tide::Server) { .post(|ctx| into_response(setting::update_displayed(ctx))) .delete(|ctx| into_response(setting::delete_displayed(ctx))); }); + router.at("/index-new-fields") .get(|ctx| into_response(setting::get_index_new_fields(ctx))) .post(|ctx| into_response(setting::update_index_new_fields(ctx)));