diff --git a/meilisearch-auth/src/lib.rs b/meilisearch-auth/src/lib.rs index 8d4a7f2b7..c81f9f20b 100644 --- a/meilisearch-auth/src/lib.rs +++ b/meilisearch-auth/src/lib.rs @@ -8,6 +8,7 @@ use std::path::Path; use std::sync::Arc; use error::{AuthControllerError, Result}; +use meilisearch_types::index_uid_pattern::IndexUidPattern; use meilisearch_types::keys::{Action, CreateApiKey, Key, PatchApiKey}; use meilisearch_types::star_or::StarOr; use serde::{Deserialize, Serialize}; @@ -141,9 +142,7 @@ impl AuthController { .get_expiration_date(uid, action, None)? .or(match index { // else check if the key has access to the requested index. - Some(index) => { - self.store.get_expiration_date(uid, action, Some(index.as_bytes()))? - } + Some(index) => self.store.get_expiration_date(uid, action, Some(index))?, // or to any index if no index has been requested. None => self.store.prefix_first_expiration_date(uid, action)?, }) { @@ -196,8 +195,20 @@ impl Default for SearchRules { impl SearchRules { pub fn is_index_authorized(&self, index: &str) -> bool { match self { - Self::Set(set) => set.contains("*") || set.contains(index), - Self::Map(map) => map.contains_key("*") || map.contains_key(index), + Self::Set(set) => { + set.contains("*") + || set.contains(index) + || set + .iter() // We must store the IndexUidPattern in the Set + .any(|pattern| IndexUidPattern::new_unchecked(pattern).matches_str(index)) + } + Self::Map(map) => { + map.contains_key("*") + || map.contains_key(index) + || map + .keys() // We must store the IndexUidPattern in the Map + .any(|pattern| IndexUidPattern::new_unchecked(pattern).matches_str(index)) + } } } diff --git a/meilisearch-auth/src/store.rs b/meilisearch-auth/src/store.rs index b3f9ed672..d1c2562c1 100644 --- a/meilisearch-auth/src/store.rs +++ b/meilisearch-auth/src/store.rs @@ -9,6 +9,7 @@ use std::str; use std::sync::Arc; use hmac::{Hmac, Mac}; +use meilisearch_types::index_uid_pattern::IndexUidPattern; use meilisearch_types::keys::KeyId; use meilisearch_types::milli; use meilisearch_types::milli::heed::types::{ByteSlice, DecodeIgnore, SerdeJson}; @@ -210,11 +211,28 @@ impl HeedAuthStore { &self, uid: Uuid, action: Action, - index: Option<&[u8]>, + index: Option<&str>, ) -> Result>> { let rtxn = self.env.read_txn()?; - let tuple = (&uid, &action, index); - Ok(self.action_keyid_index_expiration.get(&rtxn, &tuple)?) + let tuple = (&uid, &action, index.map(|s| s.as_bytes())); + match self.action_keyid_index_expiration.get(&rtxn, &tuple)? { + Some(expiration) => Ok(Some(expiration)), + None => { + let tuple = (&uid, &action, None); + for result in self.action_keyid_index_expiration.prefix_iter(&rtxn, &tuple)? { + let ((_, _, index_uid_pattern), expiration) = result?; + if let Some((pattern, index)) = index_uid_pattern.zip(index) { + let index_uid_pattern = str::from_utf8(pattern)?.to_string(); + // TODO I shouldn't unwrap here but rather return an internal error + let pattern = IndexUidPattern::try_from(index_uid_pattern).unwrap(); + if pattern.matches_str(index) { + return Ok(Some(expiration)); + } + } + } + Ok(None) + } + } } pub fn prefix_first_expiration_date(