spacetimedb_expr/
rls.rs

1use std::rc::Rc;
2
3use spacetimedb_lib::identity::AuthCtx;
4use spacetimedb_primitives::TableId;
5use spacetimedb_sql_parser::ast::BinOp;
6
7use crate::{
8    check::{parse_and_type_sub, SchemaView},
9    expr::{AggType, Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, RelExpr, Relvar},
10};
11
12/// The main driver of RLS resolution for subscription queries.
13/// Mainly a wrapper around [resolve_views_for_expr].
14pub fn resolve_views_for_sub(
15    tx: &impl SchemaView,
16    expr: ProjectName,
17    auth: &AuthCtx,
18    has_param: &mut bool,
19) -> anyhow::Result<Vec<ProjectName>> {
20    // RLS does not apply to the database owner
21    if auth.is_owner() {
22        return Ok(vec![expr]);
23    }
24
25    let Some(return_name) = expr.return_name().map(|name| name.to_owned().into_boxed_str()) else {
26        anyhow::bail!("Could not determine return type during RLS resolution")
27    };
28
29    // Unwrap the underlying `RelExpr`
30    let expr = match expr {
31        ProjectName::None(expr) | ProjectName::Some(expr, _) => expr,
32    };
33
34    resolve_views_for_expr(
35        tx,
36        expr,
37        // Do not ignore the return table when checking for RLS rules
38        None,
39        // Resolve list is empty as we are not yet resolving any RLS rules
40        Rc::new(ResolveList::None),
41        has_param,
42        &mut 0,
43        auth,
44    )
45    .map(|fragments| {
46        fragments
47            .into_iter()
48            // The expanded fragments could be join trees,
49            // so wrap each of them in an outer project.
50            .map(|expr| ProjectName::Some(expr, return_name.clone()))
51            .collect()
52    })
53}
54
55/// The main driver of RLS resolution for sql queries.
56/// Mainly a wrapper around [resolve_views_for_expr].
57pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result<ProjectList> {
58    // RLS does not apply to the database owner
59    if auth.is_owner() {
60        return Ok(expr);
61    }
62    // The subscription language is a subset of the sql language.
63    // Use the subscription helper if this is a compliant expression.
64    // Use the generic resolver otherwise.
65    let resolve_for_sub = |expr| resolve_views_for_sub(tx, expr, auth, &mut false);
66    let resolve_for_sql = |expr| {
67        resolve_views_for_expr(
68            // Use all default values
69            tx,
70            expr,
71            None,
72            Rc::new(ResolveList::None),
73            &mut false,
74            &mut 0,
75            auth,
76        )
77    };
78    match expr {
79        ProjectList::Limit(expr, n) => Ok(ProjectList::Limit(Box::new(resolve_views_for_sql(tx, *expr, auth)?), n)),
80        ProjectList::Name(exprs) => Ok(ProjectList::Name(
81            exprs
82                .into_iter()
83                .map(resolve_for_sub)
84                .collect::<Result<Vec<_>, _>>()?
85                .into_iter()
86                .flatten()
87                .collect(),
88        )),
89        ProjectList::List(exprs, fields) => Ok(ProjectList::List(
90            exprs
91                .into_iter()
92                .map(resolve_for_sql)
93                .collect::<Result<Vec<_>, _>>()?
94                .into_iter()
95                .flatten()
96                .collect(),
97            fields,
98        )),
99        ProjectList::Agg(exprs, AggType::Count, name, ty) => Ok(ProjectList::Agg(
100            exprs
101                .into_iter()
102                .map(resolve_for_sql)
103                .collect::<Result<Vec<_>, _>>()?
104                .into_iter()
105                .flatten()
106                .collect(),
107            AggType::Count,
108            name,
109            ty,
110        )),
111    }
112}
113
114/// A list for detecting cycles during RLS resolution.
115enum ResolveList {
116    None,
117    Some(TableId, Rc<ResolveList>),
118}
119
120impl ResolveList {
121    fn new(table_id: TableId, list: Rc<Self>) -> Rc<Self> {
122        Rc::new(Self::Some(table_id, list))
123    }
124
125    fn contains(&self, table_id: &TableId) -> bool {
126        match self {
127            Self::None => false,
128            Self::Some(id, suffix) if id != table_id => suffix.contains(table_id),
129            Self::Some(..) => true,
130        }
131    }
132}
133
134/// The main utility responsible for view resolution.
135///
136/// But what is view resolution and why do we need it?
137///
138/// A view is a named query that can be referenced as though it were just a regular table.
139/// In SpacetimeDB, Row Level Security (RLS) is implemented using views.
140/// We must resolve/expand these views in order to guarantee the correct access controls.
141///
142/// Before we discuss the implementation, a quick word on `return_table_id`.
143///
144/// Why do we care about it?
145/// What does it mean for it to be `None`?
146///
147/// If this IS NOT a user query, it must be a view definition.
148/// In SpacetimeDB this means we're expanding an RLS filter.
149/// RLS filters cannot be self-referential, meaning that within a filter,
150/// we cannot recursively expand references to its return table.
151///
152/// However, a `None` value implies that this expression is a user query,
153/// and so we should attempt to expand references to the return table.
154///
155/// Now back to the implementation.
156///
157/// Take the following join tree as an example:
158/// ```text
159///     x
160///    / \
161///   x   c
162///  / \
163/// a   b
164/// ```
165///
166/// Let's assume b is a view with the following structure:
167/// ```text
168///     x
169///    / \
170///   x   f
171///  / \
172/// d   e
173/// ```
174///
175/// Logically we just want to expand the tree like so:
176/// ```text
177///     x
178///    / \
179///   x   c
180///  / \
181/// a   x
182///    / \
183///   x   f
184///  / \
185/// d   e
186/// ```
187///
188/// However the join trees at this level are left deep.
189/// To maintain this invariant, the correct expansion would be:
190/// ```text
191///         x
192///        / \
193///       x   c
194///      / \
195///     x   f
196///    / \
197///   x   e
198///  / \
199/// a   d
200/// ```
201///
202/// That is, the subtree whose root is the left sibling of the node being expanded,
203/// i.e. the subtree rooted at `a` in the above example,
204/// must be pushed below the leftmost leaf node of the view expansion.
205fn resolve_views_for_expr(
206    tx: &impl SchemaView,
207    view: RelExpr,
208    return_table_id: Option<TableId>,
209    resolving: Rc<ResolveList>,
210    has_param: &mut bool,
211    suffix: &mut usize,
212    auth: &AuthCtx,
213) -> anyhow::Result<Vec<RelExpr>> {
214    let is_return_table = |relvar: &Relvar| return_table_id.is_some_and(|id| relvar.schema.table_id == id);
215
216    // Collect the table ids queried by this view.
217    // Ignore the id of the return table, since RLS views cannot be recursive.
218    let mut names = vec![];
219    view.visit(&mut |expr| match expr {
220        RelExpr::RelVar(rhs)
221        | RelExpr::LeftDeepJoin(LeftDeepJoin { rhs, .. })
222        | RelExpr::EqJoin(LeftDeepJoin { rhs, .. }, ..)
223            if !is_return_table(rhs) =>
224        {
225            names.push((rhs.schema.table_id, rhs.alias.clone()));
226        }
227        _ => {}
228    });
229
230    // Are we currently resolving any of them?
231    if let Some(table_id) = names
232        .iter()
233        .map(|(table_id, _)| table_id)
234        .find(|table_id| resolving.contains(table_id))
235    {
236        anyhow::bail!("Discovered cyclic dependency when resolving RLS rules for table id `{table_id}`");
237    }
238
239    let return_name = |expr: &ProjectName| {
240        expr.return_name()
241            .map(|name| name.to_owned())
242            .ok_or_else(|| anyhow::anyhow!("Could not resolve table reference in RLS filter"))
243    };
244
245    let mut view_def_fragments = vec![];
246
247    for (table_id, alias) in names {
248        let mut view_fragments = vec![];
249
250        for sql in tx.rls_rules_for_table(table_id)? {
251            // Parse and type check the RLS filter
252            let (expr, is_parameterized) = parse_and_type_sub(&sql, tx, auth)?;
253
254            // Are any of the RLS rules parameterized?
255            *has_param = *has_param || is_parameterized;
256
257            // We need to know which relvar is being returned for alpha-renaming
258            let return_name = return_name(&expr)?;
259
260            // Resolve views within the RLS filter itself
261            let fragments = resolve_views_for_expr(
262                tx,
263                expr.unwrap(),
264                Some(table_id),
265                ResolveList::new(table_id, resolving.clone()),
266                has_param,
267                suffix,
268                auth,
269            )?;
270
271            // Run alpha conversion on each view definition
272            alpha_rename_fragments(
273                // The revlar returned from the inner expression
274                &return_name,
275                // Its corresponding alias in the outer expression
276                &alias,
277                fragments,
278                &mut view_fragments,
279                suffix,
280            );
281        }
282
283        if !view_fragments.is_empty() {
284            view_def_fragments.push((table_id, alias, view_fragments));
285        }
286    }
287
288    /// After we collect all the necessary view definitions and run alpha conversion,
289    /// this function handles the actual replacement of the view with its definition.
290    fn expand_views(expr: RelExpr, view_def_fragments: &[(TableId, Box<str>, Vec<RelExpr>)], out: &mut Vec<RelExpr>) {
291        match view_def_fragments {
292            [] => out.push(expr),
293            [(table_id, alias, fragments), view_def_fragments @ ..] => {
294                for fragment in fragments {
295                    let expanded = expand_leaf(expr.clone(), *table_id, alias, fragment);
296                    expand_views(expanded, view_def_fragments, out);
297                }
298            }
299        }
300    }
301
302    let mut resolved = vec![];
303    expand_views(view, &view_def_fragments, &mut resolved);
304    Ok(resolved)
305}
306
307/// This is the main driver of alpha conversion.
308///
309/// For each expression that we alpha convert,
310/// we append a unique suffix to the names in that expression,
311/// with the one exception being the name of the return table.
312/// The return table is aliased in the outer expression,
313/// and so we use the same alias in the inner expression.
314///
315/// Ex.
316///
317/// Let `v` be a view defined as:
318/// ```sql
319/// SELECT r.* FROM r JOIN s ON r.id = s.id
320/// ```
321///
322/// Take the following user query:
323/// ```sql
324/// SELECT t.* FROM v JOIN t ON v.id = t.id WHERE v.x = 0
325/// ```
326///
327/// After alpha conversion, the expansion becomes:
328/// ```sql
329/// SELECT t.*
330/// FROM r AS v
331/// JOIN s AS s_1 ON v.id = s_1.id
332/// JOIN t AS t   ON t.id = v.id WHERE v.x = 0
333/// ```
334fn alpha_rename_fragments(
335    return_name: &str,
336    outer_alias: &str,
337    inputs: Vec<RelExpr>,
338    output: &mut Vec<RelExpr>,
339    suffix: &mut usize,
340) {
341    for mut fragment in inputs {
342        *suffix += 1;
343        alpha_rename(&mut fragment, &mut |name: &str| {
344            if name == return_name {
345                return outer_alias.to_owned().into_boxed_str();
346            }
347            (name.to_owned() + "_" + &suffix.to_string()).into_boxed_str()
348        });
349        output.push(fragment);
350    }
351}
352
353/// When expanding a view, we must do an alpha conversion on the view definition.
354/// This involves renaming the table aliases before replacing the view reference.
355fn alpha_rename(expr: &mut RelExpr, f: &mut impl FnMut(&str) -> Box<str>) {
356    /// Helper for renaming a relvar
357    fn rename(relvar: &mut Relvar, f: &mut impl FnMut(&str) -> Box<str>) {
358        relvar.alias = f(&relvar.alias);
359    }
360    /// Helper for renaming a field reference
361    fn rename_field(field: &mut FieldProject, f: &mut impl FnMut(&str) -> Box<str>) {
362        field.table = f(&field.table);
363    }
364    expr.visit_mut(&mut |expr| match expr {
365        RelExpr::RelVar(rhs) | RelExpr::LeftDeepJoin(LeftDeepJoin { rhs, .. }) => {
366            rename(rhs, f);
367        }
368        RelExpr::EqJoin(LeftDeepJoin { rhs, .. }, a, b) => {
369            rename(rhs, f);
370            rename_field(a, f);
371            rename_field(b, f);
372        }
373        RelExpr::Select(_, expr) => {
374            expr.visit_mut(&mut |expr| {
375                if let Expr::Field(field) = expr {
376                    rename_field(field, f);
377                }
378            });
379        }
380    });
381}
382
383/// Extends a left deep join tree with another.
384///
385/// Ex.
386///
387/// Assume `expr` is given by:
388/// ```text
389///     x
390///    / \
391///   x   f
392///  / \
393/// d   e
394/// ```
395///
396/// Assume `with` is given by:
397/// ```text
398///     x
399///    / \
400///   x   c
401///  / \
402/// a   b
403/// ```
404///
405/// This function extends `expr` by pushing `with` to the left-most leaf node:
406/// ```text
407///           x
408///          / \
409///         x   f
410///        / \
411///       x   e
412///      / \
413///     x   d
414///    / \
415///   x   c
416///  / \
417/// a   b
418/// ```
419fn extend_lhs(expr: RelExpr, with: RelExpr) -> RelExpr {
420    match expr {
421        RelExpr::RelVar(rhs) => RelExpr::LeftDeepJoin(LeftDeepJoin {
422            lhs: Box::new(with),
423            rhs,
424        }),
425        RelExpr::Select(input, expr) => RelExpr::Select(Box::new(extend_lhs(*input, with)), expr),
426        RelExpr::LeftDeepJoin(join) => RelExpr::LeftDeepJoin(LeftDeepJoin {
427            lhs: Box::new(extend_lhs(*join.lhs, with)),
428            ..join
429        }),
430        RelExpr::EqJoin(join, a, b) => RelExpr::EqJoin(
431            LeftDeepJoin {
432                lhs: Box::new(extend_lhs(*join.lhs, with)),
433                ..join
434            },
435            a,
436            b,
437        ),
438    }
439}
440
441/// Replaces the leaf node determined by `table_id` and `alias` with the subtree `with`.
442/// Ensures the expanded tree stays left deep.
443fn expand_leaf(expr: RelExpr, table_id: TableId, alias: &str, with: &RelExpr) -> RelExpr {
444    let ok = |relvar: &Relvar| relvar.schema.table_id == table_id && relvar.alias.as_ref() == alias;
445    match expr {
446        RelExpr::RelVar(relvar, ..) if ok(&relvar) => with.clone(),
447        RelExpr::RelVar(..) => expr,
448        RelExpr::Select(input, expr) => RelExpr::Select(Box::new(expand_leaf(*input, table_id, alias, with)), expr),
449        RelExpr::LeftDeepJoin(join) if ok(&join.rhs) => extend_lhs(with.clone(), *join.lhs),
450        RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs }) => RelExpr::LeftDeepJoin(LeftDeepJoin {
451            lhs: Box::new(expand_leaf(*lhs, table_id, alias, with)),
452            rhs,
453        }),
454        RelExpr::EqJoin(join, a, b) if ok(&join.rhs) => RelExpr::Select(
455            Box::new(extend_lhs(with.clone(), *join.lhs)),
456            Expr::BinOp(BinOp::Eq, Box::new(Expr::Field(a)), Box::new(Expr::Field(b))),
457        ),
458        RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b) => RelExpr::EqJoin(
459            LeftDeepJoin {
460                lhs: Box::new(expand_leaf(*lhs, table_id, alias, with)),
461                rhs,
462            },
463            a,
464            b,
465        ),
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use std::sync::Arc;
472
473    use pretty_assertions as pretty;
474
475    use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, AlgebraicValue, Identity, ProductType};
476    use spacetimedb_primitives::TableId;
477    use spacetimedb_schema::{
478        def::ModuleDef,
479        schema::{Schema, TableSchema},
480    };
481    use spacetimedb_sql_parser::ast::BinOp;
482
483    use crate::{
484        check::{parse_and_type_sub, test_utils::build_module_def, SchemaView},
485        expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar},
486    };
487
488    use super::resolve_views_for_sub;
489
490    pub struct SchemaViewer(pub ModuleDef);
491
492    impl SchemaView for SchemaViewer {
493        fn table_id(&self, name: &str) -> Option<TableId> {
494            match name {
495                "users" => Some(TableId(0)),
496                "admins" => Some(TableId(1)),
497                "player" => Some(TableId(2)),
498                _ => None,
499            }
500        }
501
502        fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
503            match table_id.idx() {
504                0 => Some((TableId(0), "users")),
505                1 => Some((TableId(1), "admins")),
506                2 => Some((TableId(2), "player")),
507                _ => None,
508            }
509            .and_then(|(table_id, name)| {
510                self.0
511                    .table(name)
512                    .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
513            })
514        }
515
516        fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result<Vec<Box<str>>> {
517            match table_id {
518                TableId(0) => Ok(vec!["select * from users where identity = :sender".into()]),
519                TableId(1) => Ok(vec!["select * from admins where identity = :sender".into()]),
520                TableId(2) => Ok(vec![
521                    "select player.* from player join users u on player.id = u.id".into(),
522                    "select player.* from player join admins".into(),
523                ]),
524                _ => Ok(vec![]),
525            }
526        }
527    }
528
529    fn module_def() -> ModuleDef {
530        build_module_def(vec![
531            (
532                "users",
533                ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]),
534            ),
535            (
536                "admins",
537                ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]),
538            ),
539            (
540                "player",
541                ProductType::from([("id", AlgebraicType::U64), ("level_num", AlgebraicType::U64)]),
542            ),
543        ])
544    }
545
546    /// Parse, type check, and resolve RLS rules
547    fn resolve(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> anyhow::Result<Vec<ProjectName>> {
548        let (expr, _) = parse_and_type_sub(sql, tx, auth)?;
549        resolve_views_for_sub(tx, expr, auth, &mut false)
550    }
551
552    #[test]
553    fn test_rls_for_owner() -> anyhow::Result<()> {
554        let tx = SchemaViewer(module_def());
555        let auth = AuthCtx::new(Identity::ONE, Identity::ONE);
556        let sql = "select * from users";
557        let resolved = resolve(sql, &tx, &auth)?;
558
559        let users_schema = tx.schema("users").unwrap();
560
561        pretty::assert_eq!(
562            resolved,
563            vec![ProjectName::None(RelExpr::RelVar(Relvar {
564                schema: users_schema,
565                alias: "users".into(),
566                delta: None,
567            }))]
568        );
569
570        Ok(())
571    }
572
573    #[test]
574    fn test_rls_for_non_owner() -> anyhow::Result<()> {
575        let tx = SchemaViewer(module_def());
576        let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
577        let sql = "select * from users";
578        let resolved = resolve(sql, &tx, &auth)?;
579
580        let users_schema = tx.schema("users").unwrap();
581
582        pretty::assert_eq!(
583            resolved,
584            vec![ProjectName::Some(
585                RelExpr::Select(
586                    Box::new(RelExpr::RelVar(Relvar {
587                        schema: users_schema,
588                        alias: "users".into(),
589                        delta: None,
590                    })),
591                    Expr::BinOp(
592                        BinOp::Eq,
593                        Box::new(Expr::Field(FieldProject {
594                            table: "users".into(),
595                            field: 0,
596                            ty: AlgebraicType::identity(),
597                        })),
598                        Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity()))
599                    )
600                ),
601                "users".into()
602            )]
603        );
604
605        Ok(())
606    }
607
608    #[test]
609    fn test_multiple_rls_rules_for_table() -> anyhow::Result<()> {
610        let tx = SchemaViewer(module_def());
611        let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
612        let sql = "select * from player where level_num = 5";
613        let resolved = resolve(sql, &tx, &auth)?;
614
615        let users_schema = tx.schema("users").unwrap();
616        let admins_schema = tx.schema("admins").unwrap();
617        let player_schema = tx.schema("player").unwrap();
618
619        pretty::assert_eq!(
620            resolved,
621            vec![
622                ProjectName::Some(
623                    RelExpr::Select(
624                        Box::new(RelExpr::Select(
625                            Box::new(RelExpr::Select(
626                                Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
627                                    lhs: Box::new(RelExpr::RelVar(Relvar {
628                                        schema: player_schema.clone(),
629                                        alias: "player".into(),
630                                        delta: None,
631                                    })),
632                                    rhs: Relvar {
633                                        schema: users_schema.clone(),
634                                        alias: "u_2".into(),
635                                        delta: None,
636                                    },
637                                })),
638                                Expr::BinOp(
639                                    BinOp::Eq,
640                                    Box::new(Expr::Field(FieldProject {
641                                        table: "u_2".into(),
642                                        field: 0,
643                                        ty: AlgebraicType::identity(),
644                                    })),
645                                    Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
646                                ),
647                            )),
648                            Expr::BinOp(
649                                BinOp::Eq,
650                                Box::new(Expr::Field(FieldProject {
651                                    table: "player".into(),
652                                    field: 0,
653                                    ty: AlgebraicType::U64,
654                                })),
655                                Box::new(Expr::Field(FieldProject {
656                                    table: "u_2".into(),
657                                    field: 1,
658                                    ty: AlgebraicType::U64,
659                                })),
660                            ),
661                        )),
662                        Expr::BinOp(
663                            BinOp::Eq,
664                            Box::new(Expr::Field(FieldProject {
665                                table: "player".into(),
666                                field: 1,
667                                ty: AlgebraicType::U64,
668                            })),
669                            Box::new(Expr::Value(AlgebraicValue::U64(5), AlgebraicType::U64)),
670                        ),
671                    ),
672                    "player".into(),
673                ),
674                ProjectName::Some(
675                    RelExpr::Select(
676                        Box::new(RelExpr::Select(
677                            Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
678                                lhs: Box::new(RelExpr::RelVar(Relvar {
679                                    schema: player_schema.clone(),
680                                    alias: "player".into(),
681                                    delta: None,
682                                })),
683                                rhs: Relvar {
684                                    schema: admins_schema.clone(),
685                                    alias: "admins_4".into(),
686                                    delta: None,
687                                },
688                            })),
689                            Expr::BinOp(
690                                BinOp::Eq,
691                                Box::new(Expr::Field(FieldProject {
692                                    table: "admins_4".into(),
693                                    field: 0,
694                                    ty: AlgebraicType::identity(),
695                                })),
696                                Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
697                            ),
698                        )),
699                        Expr::BinOp(
700                            BinOp::Eq,
701                            Box::new(Expr::Field(FieldProject {
702                                table: "player".into(),
703                                field: 1,
704                                ty: AlgebraicType::U64,
705                            })),
706                            Box::new(Expr::Value(AlgebraicValue::U64(5), AlgebraicType::U64)),
707                        ),
708                    ),
709                    "player".into(),
710                ),
711            ]
712        );
713
714        Ok(())
715    }
716}