1use http::StatusCode;
14use serde_json::Value;
15
16use crate::errors::ScimError;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ScimFilterOperator {
20 Eq,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct ScimDbFilter {
25 pub field: String,
26 pub value: String,
27 pub operator: ScimFilterOperator,
28}
29
30#[derive(Debug, Clone, PartialEq)]
31pub enum ScimFilterExpression {
32 Compare {
33 path: ScimAttributePath,
34 operator: ScimCompareOperator,
35 value: Value,
36 },
37 Present(ScimAttributePath),
38 And(Box<ScimFilterExpression>, Box<ScimFilterExpression>),
39 Or(Box<ScimFilterExpression>, Box<ScimFilterExpression>),
40 Not(Box<ScimFilterExpression>),
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ScimCompareOperator {
45 Eq,
46 Ne,
47 Co,
48 Sw,
49 Ew,
50 Gt,
51 Ge,
52 Lt,
53 Le,
54}
55
56#[derive(Debug, Clone, PartialEq)]
57pub struct ScimAttributePath {
58 pub attribute: String,
59 pub value_filter: Option<Box<ScimFilterExpression>>,
60 pub sub_attribute: Option<String>,
61}
62
63impl ScimAttributePath {
64 pub fn value_path(
65 attribute: impl Into<String>,
66 value_filter: ScimFilterExpression,
67 sub_attribute: Option<&str>,
68 ) -> Self {
69 Self {
70 attribute: attribute.into(),
71 value_filter: Some(Box::new(value_filter)),
72 sub_attribute: sub_attribute.map(str::to_owned),
73 }
74 }
75}
76
77impl From<&str> for ScimAttributePath {
78 fn from(value: &str) -> Self {
79 Self {
80 attribute: value.to_owned(),
81 value_filter: None,
82 sub_attribute: None,
83 }
84 }
85}
86
87pub fn parse_filter(filter: &str) -> Result<ScimFilterExpression, ScimError> {
88 let tokens = tokenize(filter)?;
89 let mut parser = FilterParser { tokens, cursor: 0 };
90 let expression = parser.parse_or()?;
91 if parser.peek().is_some() {
92 return Err(invalid_filter("Invalid filter expression"));
93 }
94 Ok(expression)
95}
96
97pub fn list_user_filter_uses_database_pushdown(filter: &str) -> bool {
101 parse_user_filter(filter).is_ok()
102}
103
104pub fn parse_user_filter(filter: &str) -> Result<Vec<ScimDbFilter>, ScimError> {
105 let expression = parse_filter(filter)?;
106 let ScimFilterExpression::Compare {
107 path,
108 operator,
109 value,
110 } = expression
111 else {
112 return Err(invalid_filter("Invalid filter expression"));
113 };
114 if path.attribute != "userName" || path.value_filter.is_some() || path.sub_attribute.is_some() {
115 return Err(invalid_filter(format!(
116 r#"The attribute "{}" is not supported"#,
117 path.attribute
118 )));
119 }
120 if operator != ScimCompareOperator::Eq {
121 return Err(invalid_filter(format!(
122 r#"The operator "{}" is not supported"#,
123 operator.as_str()
124 )));
125 }
126 let Some(value) = value.as_str().map(str::to_ascii_lowercase) else {
127 return Err(invalid_filter("Invalid filter expression"));
128 };
129 if value.is_empty() {
130 return Err(invalid_filter("Invalid filter expression"));
131 }
132
133 Ok(vec![ScimDbFilter {
134 field: "email".to_owned(),
135 value,
136 operator: ScimFilterOperator::Eq,
137 }])
138}
139
140pub fn resource_matches_filter(resource: &Value, filter: &str) -> Result<bool, ScimError> {
141 let expression = parse_filter(filter)?;
142 evaluate_filter(resource, &expression)
143}
144
145fn evaluate_filter(resource: &Value, expression: &ScimFilterExpression) -> Result<bool, ScimError> {
146 match expression {
147 ScimFilterExpression::Compare {
148 path,
149 operator,
150 value,
151 } => Ok(extract_path_values(resource, path)?
152 .iter()
153 .any(|candidate| compare_value(candidate, *operator, value, path))),
154 ScimFilterExpression::Present(path) => Ok(!extract_path_values(resource, path)?.is_empty()),
155 ScimFilterExpression::And(left, right) => {
156 Ok(evaluate_filter(resource, left)? && evaluate_filter(resource, right)?)
157 }
158 ScimFilterExpression::Or(left, right) => {
159 Ok(evaluate_filter(resource, left)? || evaluate_filter(resource, right)?)
160 }
161 ScimFilterExpression::Not(expression) => Ok(!evaluate_filter(resource, expression)?),
162 }
163}
164
165fn extract_path_values<'a>(
166 resource: &'a Value,
167 path: &ScimAttributePath,
168) -> Result<Vec<&'a Value>, ScimError> {
169 let (root_attribute, derived_sub_attribute) =
170 resolve_extension_attribute(resource, &path.attribute);
171 let Some(root) = resource.get(root_attribute) else {
172 return Ok(Vec::new());
173 };
174
175 let values = if let Some(value_filter) = path.value_filter.as_deref() {
176 let Some(items) = root.as_array() else {
177 return Ok(Vec::new());
178 };
179 items
180 .iter()
181 .filter_map(|item| match evaluate_filter(item, value_filter) {
182 Ok(true) => Some(Ok(item)),
183 Ok(false) => None,
184 Err(error) => Some(Err(error)),
185 })
186 .collect::<Result<Vec<_>, _>>()?
187 } else if let Some(items) = root.as_array() {
188 items.iter().collect()
189 } else {
190 vec![root]
191 };
192
193 let sub_attribute = path
194 .sub_attribute
195 .as_deref()
196 .or(derived_sub_attribute.as_deref());
197 if let Some(sub_attribute) = sub_attribute {
198 Ok(values
199 .into_iter()
200 .filter_map(|value| value.get(sub_attribute))
201 .collect())
202 } else {
203 Ok(values)
204 }
205}
206
207fn resolve_extension_attribute<'a>(
208 resource: &Value,
209 attribute: &'a str,
210) -> (&'a str, Option<String>) {
211 if resource.get(attribute).is_some() {
212 return (attribute, None);
213 }
214 let Some((schema, sub_attribute)) = attribute.rsplit_once(':') else {
215 return (attribute, None);
216 };
217 if schema.starts_with("urn:ietf:params:scim:schemas:") && resource.get(schema).is_some() {
218 (schema, Some(sub_attribute.to_owned()))
219 } else {
220 (attribute, None)
221 }
222}
223
224fn compare_value(
225 candidate: &Value,
226 operator: ScimCompareOperator,
227 expected: &Value,
228 path: &ScimAttributePath,
229) -> bool {
230 match (candidate, expected) {
231 (Value::String(left), Value::String(right)) => {
232 compare_strings(left, operator, right, is_case_exact_path(path))
233 }
234 (Value::Bool(left), Value::Bool(right)) => {
235 matches!(operator, ScimCompareOperator::Eq) && left == right
236 || matches!(operator, ScimCompareOperator::Ne) && left != right
237 }
238 (Value::Number(left), Value::Number(right)) => left
239 .as_f64()
240 .zip(right.as_f64())
241 .is_some_and(|(left, right)| compare_f64(left, operator, right)),
242 (Value::Null, Value::Null) => matches!(operator, ScimCompareOperator::Eq),
243 _ => false,
244 }
245}
246
247fn compare_strings(
248 left: &str,
249 operator: ScimCompareOperator,
250 right: &str,
251 case_exact: bool,
252) -> bool {
253 if !case_exact {
254 let left = left.to_ascii_lowercase();
255 let right = right.to_ascii_lowercase();
256 return compare_strings(&left, operator, &right, true);
257 }
258 match operator {
259 ScimCompareOperator::Eq => left == right,
260 ScimCompareOperator::Ne => left != right,
261 ScimCompareOperator::Co => left.contains(right),
262 ScimCompareOperator::Sw => left.starts_with(right),
263 ScimCompareOperator::Ew => left.ends_with(right),
264 ScimCompareOperator::Gt => left > right,
265 ScimCompareOperator::Ge => left >= right,
266 ScimCompareOperator::Lt => left < right,
267 ScimCompareOperator::Le => left <= right,
268 }
269}
270
271fn is_case_exact_path(path: &ScimAttributePath) -> bool {
272 match (path.attribute.as_str(), path.sub_attribute.as_deref()) {
273 ("id", None) | ("externalId", None) | ("meta", _) => true,
274 ("displayName", None) => true,
275 ("groups", Some("value")) => true,
276 (attribute, Some("value")) if attribute.ends_with(":manager") => true,
277 _ => false,
278 }
279}
280
281fn compare_f64(left: f64, operator: ScimCompareOperator, right: f64) -> bool {
282 match operator {
283 ScimCompareOperator::Eq => left == right,
284 ScimCompareOperator::Ne => left != right,
285 ScimCompareOperator::Gt => left > right,
286 ScimCompareOperator::Ge => left >= right,
287 ScimCompareOperator::Lt => left < right,
288 ScimCompareOperator::Le => left <= right,
289 ScimCompareOperator::Co | ScimCompareOperator::Sw | ScimCompareOperator::Ew => false,
290 }
291}
292
293#[derive(Debug, Clone, PartialEq)]
294enum Token {
295 Word(String),
296 String(String),
297 Boolean(bool),
298 Number(String),
299 Null,
300 LeftParen,
301 RightParen,
302 LeftBracket,
303 RightBracket,
304 Dot,
305}
306
307struct FilterParser {
308 tokens: Vec<Token>,
309 cursor: usize,
310}
311
312impl FilterParser {
313 fn parse_or(&mut self) -> Result<ScimFilterExpression, ScimError> {
314 let mut expression = self.parse_and()?;
315 while self.consume_word("or") {
316 expression =
317 ScimFilterExpression::Or(Box::new(expression), Box::new(self.parse_and()?));
318 }
319 Ok(expression)
320 }
321
322 fn parse_and(&mut self) -> Result<ScimFilterExpression, ScimError> {
323 let mut expression = self.parse_not()?;
324 while self.consume_word("and") {
325 expression =
326 ScimFilterExpression::And(Box::new(expression), Box::new(self.parse_not()?));
327 }
328 Ok(expression)
329 }
330
331 fn parse_not(&mut self) -> Result<ScimFilterExpression, ScimError> {
332 if self.consume_word("not") {
333 return Ok(ScimFilterExpression::Not(Box::new(self.parse_not()?)));
334 }
335 self.parse_primary()
336 }
337
338 fn parse_primary(&mut self) -> Result<ScimFilterExpression, ScimError> {
339 if self.consume_symbol(&Token::LeftParen) {
340 let expression = self.parse_or()?;
341 self.expect_symbol(&Token::RightParen)?;
342 return Ok(expression);
343 }
344 let path = self.parse_path()?;
345 if self.consume_word("pr") {
346 return Ok(ScimFilterExpression::Present(path));
347 }
348 let Some(operator) = self.consume_compare_operator() else {
349 return Err(invalid_filter("Invalid filter expression"));
350 };
351 let value = self.parse_value()?;
352 Ok(ScimFilterExpression::Compare {
353 path,
354 operator,
355 value,
356 })
357 }
358
359 fn parse_path(&mut self) -> Result<ScimAttributePath, ScimError> {
360 let Some(Token::Word(mut attribute)) = self.next().cloned() else {
361 return Err(invalid_filter("Invalid filter expression"));
362 };
363 let mut sub_attribute =
364 split_embedded_sub_attribute(&attribute).map(|(root_attribute, sub_attribute)| {
365 attribute = root_attribute;
366 sub_attribute
367 });
368 let value_filter = if self.consume_symbol(&Token::LeftBracket) {
369 let filter = self.parse_or()?;
370 self.expect_symbol(&Token::RightBracket)?;
371 Some(Box::new(filter))
372 } else {
373 None
374 };
375 if self.consume_symbol(&Token::Dot) {
376 let Some(Token::Word(parsed_sub_attribute)) = self.next().cloned() else {
377 return Err(invalid_filter("Invalid filter expression"));
378 };
379 sub_attribute = Some(parsed_sub_attribute);
380 }
381 Ok(ScimAttributePath {
382 attribute,
383 value_filter,
384 sub_attribute,
385 })
386 }
387
388 fn parse_value(&mut self) -> Result<Value, ScimError> {
389 match self.next().cloned() {
390 Some(Token::String(value)) => Ok(Value::String(value)),
391 Some(Token::Boolean(value)) => Ok(Value::Bool(value)),
392 Some(Token::Null) => Ok(Value::Null),
393 Some(Token::Number(value)) => value
394 .parse::<serde_json::Number>()
395 .map(Value::Number)
396 .map_err(|_| invalid_filter("Invalid filter expression")),
397 Some(Token::Word(value)) => Ok(Value::String(value)),
398 _ => Err(invalid_filter("Invalid filter expression")),
399 }
400 }
401
402 fn consume_compare_operator(&mut self) -> Option<ScimCompareOperator> {
403 let Some(Token::Word(value)) = self.peek() else {
404 return None;
405 };
406 let operator = match value.to_ascii_lowercase().as_str() {
407 "eq" => ScimCompareOperator::Eq,
408 "ne" => ScimCompareOperator::Ne,
409 "co" => ScimCompareOperator::Co,
410 "sw" => ScimCompareOperator::Sw,
411 "ew" => ScimCompareOperator::Ew,
412 "gt" => ScimCompareOperator::Gt,
413 "ge" => ScimCompareOperator::Ge,
414 "lt" => ScimCompareOperator::Lt,
415 "le" => ScimCompareOperator::Le,
416 _ => return None,
417 };
418 self.cursor += 1;
419 Some(operator)
420 }
421
422 fn consume_word(&mut self, expected: &str) -> bool {
423 match self.peek() {
424 Some(Token::Word(value)) if value.eq_ignore_ascii_case(expected) => {
425 self.cursor += 1;
426 true
427 }
428 _ => false,
429 }
430 }
431
432 fn consume_symbol(&mut self, expected: &Token) -> bool {
433 if self.peek() == Some(expected) {
434 self.cursor += 1;
435 return true;
436 }
437 false
438 }
439
440 fn expect_symbol(&mut self, expected: &Token) -> Result<(), ScimError> {
441 if self.consume_symbol(expected) {
442 Ok(())
443 } else {
444 Err(invalid_filter("Invalid filter expression"))
445 }
446 }
447
448 fn peek(&self) -> Option<&Token> {
449 self.tokens.get(self.cursor)
450 }
451
452 fn next(&mut self) -> Option<&Token> {
453 let token = self.tokens.get(self.cursor);
454 if token.is_some() {
455 self.cursor += 1;
456 }
457 token
458 }
459}
460
461impl ScimCompareOperator {
462 fn as_str(self) -> &'static str {
463 match self {
464 Self::Eq => "eq",
465 Self::Ne => "ne",
466 Self::Co => "co",
467 Self::Sw => "sw",
468 Self::Ew => "ew",
469 Self::Gt => "gt",
470 Self::Ge => "ge",
471 Self::Lt => "lt",
472 Self::Le => "le",
473 }
474 }
475}
476
477fn tokenize(input: &str) -> Result<Vec<Token>, ScimError> {
478 let mut tokens = Vec::new();
479 let mut chars = input.char_indices().peekable();
480 while let Some((index, ch)) = chars.next() {
481 match ch {
482 ch if ch.is_whitespace() => {}
483 '(' => tokens.push(Token::LeftParen),
484 ')' => tokens.push(Token::RightParen),
485 '[' => tokens.push(Token::LeftBracket),
486 ']' => tokens.push(Token::RightBracket),
487 '.' => tokens.push(Token::Dot),
488 '"' => tokens.push(Token::String(read_string(input, &mut chars)?)),
489 '-' | '0'..='9' => tokens.push(Token::Number(read_number(input, index, &mut chars))),
490 _ => {
491 if is_word_char(ch) {
492 let word = read_word(input, index, &mut chars);
493 tokens.push(match word.to_ascii_lowercase().as_str() {
494 "true" => Token::Boolean(true),
495 "false" => Token::Boolean(false),
496 "null" => Token::Null,
497 _ => Token::Word(word),
498 });
499 } else {
500 return Err(invalid_filter("Invalid filter expression"));
501 }
502 }
503 }
504 }
505 if tokens.is_empty() {
506 return Err(invalid_filter("Invalid filter expression"));
507 }
508 Ok(tokens)
509}
510
511fn read_string(
512 input: &str,
513 chars: &mut std::iter::Peekable<std::str::CharIndices<'_>>,
514) -> Result<String, ScimError> {
515 let mut value = String::new();
516 while let Some((_, ch)) = chars.next() {
517 match ch {
518 '"' => return Ok(value),
519 '\\' => {
520 let Some((_, escaped)) = chars.next() else {
521 return Err(invalid_filter("Invalid filter expression"));
522 };
523 value.push(match escaped {
524 '"' | '\\' | '/' => escaped,
525 'n' => '\n',
526 'r' => '\r',
527 't' => '\t',
528 _ => escaped,
529 });
530 }
531 _ => value.push(ch),
532 }
533 }
534 let _ = input;
535 Err(invalid_filter("Invalid filter expression"))
536}
537
538fn read_number(
539 input: &str,
540 start: usize,
541 chars: &mut std::iter::Peekable<std::str::CharIndices<'_>>,
542) -> String {
543 while let Some((_, ch)) = chars.peek() {
544 if ch.is_ascii_digit() || matches!(ch, '.' | 'e' | 'E' | '+' | '-') {
545 chars.next();
546 } else {
547 break;
548 }
549 }
550 let end = chars.peek().map(|(index, _)| *index).unwrap_or(input.len());
551 input[start..end].to_owned()
552}
553
554fn read_word(
555 input: &str,
556 start: usize,
557 chars: &mut std::iter::Peekable<std::str::CharIndices<'_>>,
558) -> String {
559 while let Some((_, ch)) = chars.peek() {
560 if is_word_char(*ch) {
561 chars.next();
562 } else {
563 break;
564 }
565 }
566 let end = chars.peek().map(|(index, _)| *index).unwrap_or(input.len());
567 input[start..end].to_owned()
568}
569
570fn is_word_char(ch: char) -> bool {
571 ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':' | '$' | '.')
572}
573
574fn split_embedded_sub_attribute(path: &str) -> Option<(String, String)> {
575 let split_at = path
576 .char_indices()
577 .filter(|(index, ch)| {
578 if *ch != '.' {
579 return false;
580 }
581 let previous = path[..*index].chars().next_back();
582 let next = path[index + ch.len_utf8()..].chars().next();
583 !matches!((previous, next), (Some(left), Some(right)) if left.is_ascii_digit() && right.is_ascii_digit())
584 })
585 .map(|(index, _)| index)
586 .next()?;
587 let root = path[..split_at].to_owned();
588 let child = path[split_at + 1..].to_owned();
589 (!root.is_empty() && !child.is_empty()).then_some((root, child))
590}
591
592fn invalid_filter(detail: impl Into<String>) -> ScimError {
593 ScimError::new(StatusCode::BAD_REQUEST, detail).with_scim_type("invalidFilter")
594}