Skip to main content

rustauth_scim/
filters.rs

1//! SCIM filter parsing.
2//!
3//! User list routes use two evaluation paths:
4//!
5//! - [`list_user_filter_uses_database_pushdown`] — only `userName eq "..."` is pushed
6//!   to SQL as an `email` equality check. This matches Better Auth upstream list
7//!   behavior.
8//! - [`parse_filter`] + [`resource_matches_filter`] — every other expression is
9//!   evaluated in memory against the serialized SCIM User resource (including
10//!   extension profile attributes). Use this for Groups, `.search`, and advanced
11//!   User filters.
12
13use http::StatusCode;
14use serde_json::Value;
15
16use crate::errors::ScimError;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ScimFilterOperator {
20    Eq,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct ScimDbFilter {
25    pub field: String,
26    pub value: String,
27    pub operator: ScimFilterOperator,
28}
29
30#[derive(Debug, Clone, PartialEq)]
31pub enum ScimFilterExpression {
32    Compare {
33        path: ScimAttributePath,
34        operator: ScimCompareOperator,
35        value: Value,
36    },
37    Present(ScimAttributePath),
38    And(Box<ScimFilterExpression>, Box<ScimFilterExpression>),
39    Or(Box<ScimFilterExpression>, Box<ScimFilterExpression>),
40    Not(Box<ScimFilterExpression>),
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum ScimCompareOperator {
45    Eq,
46    Ne,
47    Co,
48    Sw,
49    Ew,
50    Gt,
51    Ge,
52    Lt,
53    Le,
54}
55
56#[derive(Debug, Clone, PartialEq)]
57pub struct ScimAttributePath {
58    pub attribute: String,
59    pub value_filter: Option<Box<ScimFilterExpression>>,
60    pub sub_attribute: Option<String>,
61}
62
63impl ScimAttributePath {
64    pub fn value_path(
65        attribute: impl Into<String>,
66        value_filter: ScimFilterExpression,
67        sub_attribute: Option<&str>,
68    ) -> Self {
69        Self {
70            attribute: attribute.into(),
71            value_filter: Some(Box::new(value_filter)),
72            sub_attribute: sub_attribute.map(str::to_owned),
73        }
74    }
75}
76
77impl From<&str> for ScimAttributePath {
78    fn from(value: &str) -> Self {
79        Self {
80            attribute: value.to_owned(),
81            value_filter: None,
82            sub_attribute: None,
83        }
84    }
85}
86
87pub fn parse_filter(filter: &str) -> Result<ScimFilterExpression, ScimError> {
88    let tokens = tokenize(filter)?;
89    let mut parser = FilterParser { tokens, cursor: 0 };
90    let expression = parser.parse_or()?;
91    if parser.peek().is_some() {
92        return Err(invalid_filter("Invalid filter expression"));
93    }
94    Ok(expression)
95}
96
97/// Returns true when `GET /Users?filter=...` can apply the filter in the database.
98///
99/// Today this is only the upstream-compatible form `userName eq "<email>"`.
100pub fn list_user_filter_uses_database_pushdown(filter: &str) -> bool {
101    parse_user_filter(filter).is_ok()
102}
103
104pub fn parse_user_filter(filter: &str) -> Result<Vec<ScimDbFilter>, ScimError> {
105    let expression = parse_filter(filter)?;
106    let ScimFilterExpression::Compare {
107        path,
108        operator,
109        value,
110    } = expression
111    else {
112        return Err(invalid_filter("Invalid filter expression"));
113    };
114    if path.attribute != "userName" || path.value_filter.is_some() || path.sub_attribute.is_some() {
115        return Err(invalid_filter(format!(
116            r#"The attribute "{}" is not supported"#,
117            path.attribute
118        )));
119    }
120    if operator != ScimCompareOperator::Eq {
121        return Err(invalid_filter(format!(
122            r#"The operator "{}" is not supported"#,
123            operator.as_str()
124        )));
125    }
126    let Some(value) = value.as_str().map(str::to_ascii_lowercase) else {
127        return Err(invalid_filter("Invalid filter expression"));
128    };
129    if value.is_empty() {
130        return Err(invalid_filter("Invalid filter expression"));
131    }
132
133    Ok(vec![ScimDbFilter {
134        field: "email".to_owned(),
135        value,
136        operator: ScimFilterOperator::Eq,
137    }])
138}
139
140pub fn resource_matches_filter(resource: &Value, filter: &str) -> Result<bool, ScimError> {
141    let expression = parse_filter(filter)?;
142    evaluate_filter(resource, &expression)
143}
144
145fn evaluate_filter(resource: &Value, expression: &ScimFilterExpression) -> Result<bool, ScimError> {
146    match expression {
147        ScimFilterExpression::Compare {
148            path,
149            operator,
150            value,
151        } => Ok(extract_path_values(resource, path)?
152            .iter()
153            .any(|candidate| compare_value(candidate, *operator, value, path))),
154        ScimFilterExpression::Present(path) => Ok(!extract_path_values(resource, path)?.is_empty()),
155        ScimFilterExpression::And(left, right) => {
156            Ok(evaluate_filter(resource, left)? && evaluate_filter(resource, right)?)
157        }
158        ScimFilterExpression::Or(left, right) => {
159            Ok(evaluate_filter(resource, left)? || evaluate_filter(resource, right)?)
160        }
161        ScimFilterExpression::Not(expression) => Ok(!evaluate_filter(resource, expression)?),
162    }
163}
164
165fn extract_path_values<'a>(
166    resource: &'a Value,
167    path: &ScimAttributePath,
168) -> Result<Vec<&'a Value>, ScimError> {
169    let (root_attribute, derived_sub_attribute) =
170        resolve_extension_attribute(resource, &path.attribute);
171    let Some(root) = resource.get(root_attribute) else {
172        return Ok(Vec::new());
173    };
174
175    let values = if let Some(value_filter) = path.value_filter.as_deref() {
176        let Some(items) = root.as_array() else {
177            return Ok(Vec::new());
178        };
179        items
180            .iter()
181            .filter_map(|item| match evaluate_filter(item, value_filter) {
182                Ok(true) => Some(Ok(item)),
183                Ok(false) => None,
184                Err(error) => Some(Err(error)),
185            })
186            .collect::<Result<Vec<_>, _>>()?
187    } else if let Some(items) = root.as_array() {
188        items.iter().collect()
189    } else {
190        vec![root]
191    };
192
193    let sub_attribute = path
194        .sub_attribute
195        .as_deref()
196        .or(derived_sub_attribute.as_deref());
197    if let Some(sub_attribute) = sub_attribute {
198        Ok(values
199            .into_iter()
200            .filter_map(|value| value.get(sub_attribute))
201            .collect())
202    } else {
203        Ok(values)
204    }
205}
206
207fn resolve_extension_attribute<'a>(
208    resource: &Value,
209    attribute: &'a str,
210) -> (&'a str, Option<String>) {
211    if resource.get(attribute).is_some() {
212        return (attribute, None);
213    }
214    let Some((schema, sub_attribute)) = attribute.rsplit_once(':') else {
215        return (attribute, None);
216    };
217    if schema.starts_with("urn:ietf:params:scim:schemas:") && resource.get(schema).is_some() {
218        (schema, Some(sub_attribute.to_owned()))
219    } else {
220        (attribute, None)
221    }
222}
223
224fn compare_value(
225    candidate: &Value,
226    operator: ScimCompareOperator,
227    expected: &Value,
228    path: &ScimAttributePath,
229) -> bool {
230    match (candidate, expected) {
231        (Value::String(left), Value::String(right)) => {
232            compare_strings(left, operator, right, is_case_exact_path(path))
233        }
234        (Value::Bool(left), Value::Bool(right)) => {
235            matches!(operator, ScimCompareOperator::Eq) && left == right
236                || matches!(operator, ScimCompareOperator::Ne) && left != right
237        }
238        (Value::Number(left), Value::Number(right)) => left
239            .as_f64()
240            .zip(right.as_f64())
241            .is_some_and(|(left, right)| compare_f64(left, operator, right)),
242        (Value::Null, Value::Null) => matches!(operator, ScimCompareOperator::Eq),
243        _ => false,
244    }
245}
246
247fn compare_strings(
248    left: &str,
249    operator: ScimCompareOperator,
250    right: &str,
251    case_exact: bool,
252) -> bool {
253    if !case_exact {
254        let left = left.to_ascii_lowercase();
255        let right = right.to_ascii_lowercase();
256        return compare_strings(&left, operator, &right, true);
257    }
258    match operator {
259        ScimCompareOperator::Eq => left == right,
260        ScimCompareOperator::Ne => left != right,
261        ScimCompareOperator::Co => left.contains(right),
262        ScimCompareOperator::Sw => left.starts_with(right),
263        ScimCompareOperator::Ew => left.ends_with(right),
264        ScimCompareOperator::Gt => left > right,
265        ScimCompareOperator::Ge => left >= right,
266        ScimCompareOperator::Lt => left < right,
267        ScimCompareOperator::Le => left <= right,
268    }
269}
270
271fn is_case_exact_path(path: &ScimAttributePath) -> bool {
272    match (path.attribute.as_str(), path.sub_attribute.as_deref()) {
273        ("id", None) | ("externalId", None) | ("meta", _) => true,
274        ("displayName", None) => true,
275        ("groups", Some("value")) => true,
276        (attribute, Some("value")) if attribute.ends_with(":manager") => true,
277        _ => false,
278    }
279}
280
281fn compare_f64(left: f64, operator: ScimCompareOperator, right: f64) -> bool {
282    match operator {
283        ScimCompareOperator::Eq => left == right,
284        ScimCompareOperator::Ne => left != right,
285        ScimCompareOperator::Gt => left > right,
286        ScimCompareOperator::Ge => left >= right,
287        ScimCompareOperator::Lt => left < right,
288        ScimCompareOperator::Le => left <= right,
289        ScimCompareOperator::Co | ScimCompareOperator::Sw | ScimCompareOperator::Ew => false,
290    }
291}
292
293#[derive(Debug, Clone, PartialEq)]
294enum Token {
295    Word(String),
296    String(String),
297    Boolean(bool),
298    Number(String),
299    Null,
300    LeftParen,
301    RightParen,
302    LeftBracket,
303    RightBracket,
304    Dot,
305}
306
307struct FilterParser {
308    tokens: Vec<Token>,
309    cursor: usize,
310}
311
312impl FilterParser {
313    fn parse_or(&mut self) -> Result<ScimFilterExpression, ScimError> {
314        let mut expression = self.parse_and()?;
315        while self.consume_word("or") {
316            expression =
317                ScimFilterExpression::Or(Box::new(expression), Box::new(self.parse_and()?));
318        }
319        Ok(expression)
320    }
321
322    fn parse_and(&mut self) -> Result<ScimFilterExpression, ScimError> {
323        let mut expression = self.parse_not()?;
324        while self.consume_word("and") {
325            expression =
326                ScimFilterExpression::And(Box::new(expression), Box::new(self.parse_not()?));
327        }
328        Ok(expression)
329    }
330
331    fn parse_not(&mut self) -> Result<ScimFilterExpression, ScimError> {
332        if self.consume_word("not") {
333            return Ok(ScimFilterExpression::Not(Box::new(self.parse_not()?)));
334        }
335        self.parse_primary()
336    }
337
338    fn parse_primary(&mut self) -> Result<ScimFilterExpression, ScimError> {
339        if self.consume_symbol(&Token::LeftParen) {
340            let expression = self.parse_or()?;
341            self.expect_symbol(&Token::RightParen)?;
342            return Ok(expression);
343        }
344        let path = self.parse_path()?;
345        if self.consume_word("pr") {
346            return Ok(ScimFilterExpression::Present(path));
347        }
348        let Some(operator) = self.consume_compare_operator() else {
349            return Err(invalid_filter("Invalid filter expression"));
350        };
351        let value = self.parse_value()?;
352        Ok(ScimFilterExpression::Compare {
353            path,
354            operator,
355            value,
356        })
357    }
358
359    fn parse_path(&mut self) -> Result<ScimAttributePath, ScimError> {
360        let Some(Token::Word(mut attribute)) = self.next().cloned() else {
361            return Err(invalid_filter("Invalid filter expression"));
362        };
363        let mut sub_attribute =
364            split_embedded_sub_attribute(&attribute).map(|(root_attribute, sub_attribute)| {
365                attribute = root_attribute;
366                sub_attribute
367            });
368        let value_filter = if self.consume_symbol(&Token::LeftBracket) {
369            let filter = self.parse_or()?;
370            self.expect_symbol(&Token::RightBracket)?;
371            Some(Box::new(filter))
372        } else {
373            None
374        };
375        if self.consume_symbol(&Token::Dot) {
376            let Some(Token::Word(parsed_sub_attribute)) = self.next().cloned() else {
377                return Err(invalid_filter("Invalid filter expression"));
378            };
379            sub_attribute = Some(parsed_sub_attribute);
380        }
381        Ok(ScimAttributePath {
382            attribute,
383            value_filter,
384            sub_attribute,
385        })
386    }
387
388    fn parse_value(&mut self) -> Result<Value, ScimError> {
389        match self.next().cloned() {
390            Some(Token::String(value)) => Ok(Value::String(value)),
391            Some(Token::Boolean(value)) => Ok(Value::Bool(value)),
392            Some(Token::Null) => Ok(Value::Null),
393            Some(Token::Number(value)) => value
394                .parse::<serde_json::Number>()
395                .map(Value::Number)
396                .map_err(|_| invalid_filter("Invalid filter expression")),
397            Some(Token::Word(value)) => Ok(Value::String(value)),
398            _ => Err(invalid_filter("Invalid filter expression")),
399        }
400    }
401
402    fn consume_compare_operator(&mut self) -> Option<ScimCompareOperator> {
403        let Some(Token::Word(value)) = self.peek() else {
404            return None;
405        };
406        let operator = match value.to_ascii_lowercase().as_str() {
407            "eq" => ScimCompareOperator::Eq,
408            "ne" => ScimCompareOperator::Ne,
409            "co" => ScimCompareOperator::Co,
410            "sw" => ScimCompareOperator::Sw,
411            "ew" => ScimCompareOperator::Ew,
412            "gt" => ScimCompareOperator::Gt,
413            "ge" => ScimCompareOperator::Ge,
414            "lt" => ScimCompareOperator::Lt,
415            "le" => ScimCompareOperator::Le,
416            _ => return None,
417        };
418        self.cursor += 1;
419        Some(operator)
420    }
421
422    fn consume_word(&mut self, expected: &str) -> bool {
423        match self.peek() {
424            Some(Token::Word(value)) if value.eq_ignore_ascii_case(expected) => {
425                self.cursor += 1;
426                true
427            }
428            _ => false,
429        }
430    }
431
432    fn consume_symbol(&mut self, expected: &Token) -> bool {
433        if self.peek() == Some(expected) {
434            self.cursor += 1;
435            return true;
436        }
437        false
438    }
439
440    fn expect_symbol(&mut self, expected: &Token) -> Result<(), ScimError> {
441        if self.consume_symbol(expected) {
442            Ok(())
443        } else {
444            Err(invalid_filter("Invalid filter expression"))
445        }
446    }
447
448    fn peek(&self) -> Option<&Token> {
449        self.tokens.get(self.cursor)
450    }
451
452    fn next(&mut self) -> Option<&Token> {
453        let token = self.tokens.get(self.cursor);
454        if token.is_some() {
455            self.cursor += 1;
456        }
457        token
458    }
459}
460
461impl ScimCompareOperator {
462    fn as_str(self) -> &'static str {
463        match self {
464            Self::Eq => "eq",
465            Self::Ne => "ne",
466            Self::Co => "co",
467            Self::Sw => "sw",
468            Self::Ew => "ew",
469            Self::Gt => "gt",
470            Self::Ge => "ge",
471            Self::Lt => "lt",
472            Self::Le => "le",
473        }
474    }
475}
476
477fn tokenize(input: &str) -> Result<Vec<Token>, ScimError> {
478    let mut tokens = Vec::new();
479    let mut chars = input.char_indices().peekable();
480    while let Some((index, ch)) = chars.next() {
481        match ch {
482            ch if ch.is_whitespace() => {}
483            '(' => tokens.push(Token::LeftParen),
484            ')' => tokens.push(Token::RightParen),
485            '[' => tokens.push(Token::LeftBracket),
486            ']' => tokens.push(Token::RightBracket),
487            '.' => tokens.push(Token::Dot),
488            '"' => tokens.push(Token::String(read_string(input, &mut chars)?)),
489            '-' | '0'..='9' => tokens.push(Token::Number(read_number(input, index, &mut chars))),
490            _ => {
491                if is_word_char(ch) {
492                    let word = read_word(input, index, &mut chars);
493                    tokens.push(match word.to_ascii_lowercase().as_str() {
494                        "true" => Token::Boolean(true),
495                        "false" => Token::Boolean(false),
496                        "null" => Token::Null,
497                        _ => Token::Word(word),
498                    });
499                } else {
500                    return Err(invalid_filter("Invalid filter expression"));
501                }
502            }
503        }
504    }
505    if tokens.is_empty() {
506        return Err(invalid_filter("Invalid filter expression"));
507    }
508    Ok(tokens)
509}
510
511fn read_string(
512    input: &str,
513    chars: &mut std::iter::Peekable<std::str::CharIndices<'_>>,
514) -> Result<String, ScimError> {
515    let mut value = String::new();
516    while let Some((_, ch)) = chars.next() {
517        match ch {
518            '"' => return Ok(value),
519            '\\' => {
520                let Some((_, escaped)) = chars.next() else {
521                    return Err(invalid_filter("Invalid filter expression"));
522                };
523                value.push(match escaped {
524                    '"' | '\\' | '/' => escaped,
525                    'n' => '\n',
526                    'r' => '\r',
527                    't' => '\t',
528                    _ => escaped,
529                });
530            }
531            _ => value.push(ch),
532        }
533    }
534    let _ = input;
535    Err(invalid_filter("Invalid filter expression"))
536}
537
538fn read_number(
539    input: &str,
540    start: usize,
541    chars: &mut std::iter::Peekable<std::str::CharIndices<'_>>,
542) -> String {
543    while let Some((_, ch)) = chars.peek() {
544        if ch.is_ascii_digit() || matches!(ch, '.' | 'e' | 'E' | '+' | '-') {
545            chars.next();
546        } else {
547            break;
548        }
549    }
550    let end = chars.peek().map(|(index, _)| *index).unwrap_or(input.len());
551    input[start..end].to_owned()
552}
553
554fn read_word(
555    input: &str,
556    start: usize,
557    chars: &mut std::iter::Peekable<std::str::CharIndices<'_>>,
558) -> String {
559    while let Some((_, ch)) = chars.peek() {
560        if is_word_char(*ch) {
561            chars.next();
562        } else {
563            break;
564        }
565    }
566    let end = chars.peek().map(|(index, _)| *index).unwrap_or(input.len());
567    input[start..end].to_owned()
568}
569
570fn is_word_char(ch: char) -> bool {
571    ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':' | '$' | '.')
572}
573
574fn split_embedded_sub_attribute(path: &str) -> Option<(String, String)> {
575    let split_at = path
576        .char_indices()
577        .filter(|(index, ch)| {
578            if *ch != '.' {
579                return false;
580            }
581            let previous = path[..*index].chars().next_back();
582            let next = path[index + ch.len_utf8()..].chars().next();
583            !matches!((previous, next), (Some(left), Some(right)) if left.is_ascii_digit() && right.is_ascii_digit())
584        })
585        .map(|(index, _)| index)
586        .next()?;
587    let root = path[..split_at].to_owned();
588    let child = path[split_at + 1..].to_owned();
589    (!root.is_empty() && !child.is_empty()).then_some((root, child))
590}
591
592fn invalid_filter(detail: impl Into<String>) -> ScimError {
593    ScimError::new(StatusCode::BAD_REQUEST, detail).with_scim_type("invalidFilter")
594}