1use std::rc::Rc;
2
3use spacetimedb_lib::identity::AuthCtx;
4use spacetimedb_primitives::TableId;
5use spacetimedb_sql_parser::ast::BinOp;
6
7use crate::{
8 check::{parse_and_type_sub, SchemaView},
9 expr::{AggType, Expr, FieldProject, LeftDeepJoin, ProjectList, ProjectName, RelExpr, Relvar},
10};
11
12pub fn resolve_views_for_sub(
15 tx: &impl SchemaView,
16 expr: ProjectName,
17 auth: &AuthCtx,
18 has_param: &mut bool,
19) -> anyhow::Result<Vec<ProjectName>> {
20 if auth.is_owner() {
22 return Ok(vec![expr]);
23 }
24
25 let Some(return_name) = expr.return_name().map(|name| name.to_owned().into_boxed_str()) else {
26 anyhow::bail!("Could not determine return type during RLS resolution")
27 };
28
29 let expr = match expr {
31 ProjectName::None(expr) | ProjectName::Some(expr, _) => expr,
32 };
33
34 resolve_views_for_expr(
35 tx,
36 expr,
37 None,
39 Rc::new(ResolveList::None),
41 has_param,
42 &mut 0,
43 auth,
44 )
45 .map(|fragments| {
46 fragments
47 .into_iter()
48 .map(|expr| ProjectName::Some(expr, return_name.clone()))
51 .collect()
52 })
53}
54
55pub fn resolve_views_for_sql(tx: &impl SchemaView, expr: ProjectList, auth: &AuthCtx) -> anyhow::Result<ProjectList> {
58 if auth.is_owner() {
60 return Ok(expr);
61 }
62 let resolve_for_sub = |expr| resolve_views_for_sub(tx, expr, auth, &mut false);
66 let resolve_for_sql = |expr| {
67 resolve_views_for_expr(
68 tx,
70 expr,
71 None,
72 Rc::new(ResolveList::None),
73 &mut false,
74 &mut 0,
75 auth,
76 )
77 };
78 match expr {
79 ProjectList::Limit(expr, n) => Ok(ProjectList::Limit(Box::new(resolve_views_for_sql(tx, *expr, auth)?), n)),
80 ProjectList::Name(exprs) => Ok(ProjectList::Name(
81 exprs
82 .into_iter()
83 .map(resolve_for_sub)
84 .collect::<Result<Vec<_>, _>>()?
85 .into_iter()
86 .flatten()
87 .collect(),
88 )),
89 ProjectList::List(exprs, fields) => Ok(ProjectList::List(
90 exprs
91 .into_iter()
92 .map(resolve_for_sql)
93 .collect::<Result<Vec<_>, _>>()?
94 .into_iter()
95 .flatten()
96 .collect(),
97 fields,
98 )),
99 ProjectList::Agg(exprs, AggType::Count, name, ty) => Ok(ProjectList::Agg(
100 exprs
101 .into_iter()
102 .map(resolve_for_sql)
103 .collect::<Result<Vec<_>, _>>()?
104 .into_iter()
105 .flatten()
106 .collect(),
107 AggType::Count,
108 name,
109 ty,
110 )),
111 }
112}
113
114enum ResolveList {
116 None,
117 Some(TableId, Rc<ResolveList>),
118}
119
120impl ResolveList {
121 fn new(table_id: TableId, list: Rc<Self>) -> Rc<Self> {
122 Rc::new(Self::Some(table_id, list))
123 }
124
125 fn contains(&self, table_id: &TableId) -> bool {
126 match self {
127 Self::None => false,
128 Self::Some(id, suffix) if id != table_id => suffix.contains(table_id),
129 Self::Some(..) => true,
130 }
131 }
132}
133
134fn resolve_views_for_expr(
206 tx: &impl SchemaView,
207 view: RelExpr,
208 return_table_id: Option<TableId>,
209 resolving: Rc<ResolveList>,
210 has_param: &mut bool,
211 suffix: &mut usize,
212 auth: &AuthCtx,
213) -> anyhow::Result<Vec<RelExpr>> {
214 let is_return_table = |relvar: &Relvar| return_table_id.is_some_and(|id| relvar.schema.table_id == id);
215
216 let mut names = vec![];
219 view.visit(&mut |expr| match expr {
220 RelExpr::RelVar(rhs)
221 | RelExpr::LeftDeepJoin(LeftDeepJoin { rhs, .. })
222 | RelExpr::EqJoin(LeftDeepJoin { rhs, .. }, ..)
223 if !is_return_table(rhs) =>
224 {
225 names.push((rhs.schema.table_id, rhs.alias.clone()));
226 }
227 _ => {}
228 });
229
230 if let Some(table_id) = names
232 .iter()
233 .map(|(table_id, _)| table_id)
234 .find(|table_id| resolving.contains(table_id))
235 {
236 anyhow::bail!("Discovered cyclic dependency when resolving RLS rules for table id `{table_id}`");
237 }
238
239 let return_name = |expr: &ProjectName| {
240 expr.return_name()
241 .map(|name| name.to_owned())
242 .ok_or_else(|| anyhow::anyhow!("Could not resolve table reference in RLS filter"))
243 };
244
245 let mut view_def_fragments = vec![];
246
247 for (table_id, alias) in names {
248 let mut view_fragments = vec![];
249
250 for sql in tx.rls_rules_for_table(table_id)? {
251 let (expr, is_parameterized) = parse_and_type_sub(&sql, tx, auth)?;
253
254 *has_param = *has_param || is_parameterized;
256
257 let return_name = return_name(&expr)?;
259
260 let fragments = resolve_views_for_expr(
262 tx,
263 expr.unwrap(),
264 Some(table_id),
265 ResolveList::new(table_id, resolving.clone()),
266 has_param,
267 suffix,
268 auth,
269 )?;
270
271 alpha_rename_fragments(
273 &return_name,
275 &alias,
277 fragments,
278 &mut view_fragments,
279 suffix,
280 );
281 }
282
283 if !view_fragments.is_empty() {
284 view_def_fragments.push((table_id, alias, view_fragments));
285 }
286 }
287
288 fn expand_views(expr: RelExpr, view_def_fragments: &[(TableId, Box<str>, Vec<RelExpr>)], out: &mut Vec<RelExpr>) {
291 match view_def_fragments {
292 [] => out.push(expr),
293 [(table_id, alias, fragments), view_def_fragments @ ..] => {
294 for fragment in fragments {
295 let expanded = expand_leaf(expr.clone(), *table_id, alias, fragment);
296 expand_views(expanded, view_def_fragments, out);
297 }
298 }
299 }
300 }
301
302 let mut resolved = vec![];
303 expand_views(view, &view_def_fragments, &mut resolved);
304 Ok(resolved)
305}
306
307fn alpha_rename_fragments(
335 return_name: &str,
336 outer_alias: &str,
337 inputs: Vec<RelExpr>,
338 output: &mut Vec<RelExpr>,
339 suffix: &mut usize,
340) {
341 for mut fragment in inputs {
342 *suffix += 1;
343 alpha_rename(&mut fragment, &mut |name: &str| {
344 if name == return_name {
345 return outer_alias.to_owned().into_boxed_str();
346 }
347 (name.to_owned() + "_" + &suffix.to_string()).into_boxed_str()
348 });
349 output.push(fragment);
350 }
351}
352
353fn alpha_rename(expr: &mut RelExpr, f: &mut impl FnMut(&str) -> Box<str>) {
356 fn rename(relvar: &mut Relvar, f: &mut impl FnMut(&str) -> Box<str>) {
358 relvar.alias = f(&relvar.alias);
359 }
360 fn rename_field(field: &mut FieldProject, f: &mut impl FnMut(&str) -> Box<str>) {
362 field.table = f(&field.table);
363 }
364 expr.visit_mut(&mut |expr| match expr {
365 RelExpr::RelVar(rhs) | RelExpr::LeftDeepJoin(LeftDeepJoin { rhs, .. }) => {
366 rename(rhs, f);
367 }
368 RelExpr::EqJoin(LeftDeepJoin { rhs, .. }, a, b) => {
369 rename(rhs, f);
370 rename_field(a, f);
371 rename_field(b, f);
372 }
373 RelExpr::Select(_, expr) => {
374 expr.visit_mut(&mut |expr| {
375 if let Expr::Field(field) = expr {
376 rename_field(field, f);
377 }
378 });
379 }
380 });
381}
382
383fn extend_lhs(expr: RelExpr, with: RelExpr) -> RelExpr {
420 match expr {
421 RelExpr::RelVar(rhs) => RelExpr::LeftDeepJoin(LeftDeepJoin {
422 lhs: Box::new(with),
423 rhs,
424 }),
425 RelExpr::Select(input, expr) => RelExpr::Select(Box::new(extend_lhs(*input, with)), expr),
426 RelExpr::LeftDeepJoin(join) => RelExpr::LeftDeepJoin(LeftDeepJoin {
427 lhs: Box::new(extend_lhs(*join.lhs, with)),
428 ..join
429 }),
430 RelExpr::EqJoin(join, a, b) => RelExpr::EqJoin(
431 LeftDeepJoin {
432 lhs: Box::new(extend_lhs(*join.lhs, with)),
433 ..join
434 },
435 a,
436 b,
437 ),
438 }
439}
440
441fn expand_leaf(expr: RelExpr, table_id: TableId, alias: &str, with: &RelExpr) -> RelExpr {
444 let ok = |relvar: &Relvar| relvar.schema.table_id == table_id && relvar.alias.as_ref() == alias;
445 match expr {
446 RelExpr::RelVar(relvar, ..) if ok(&relvar) => with.clone(),
447 RelExpr::RelVar(..) => expr,
448 RelExpr::Select(input, expr) => RelExpr::Select(Box::new(expand_leaf(*input, table_id, alias, with)), expr),
449 RelExpr::LeftDeepJoin(join) if ok(&join.rhs) => extend_lhs(with.clone(), *join.lhs),
450 RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs }) => RelExpr::LeftDeepJoin(LeftDeepJoin {
451 lhs: Box::new(expand_leaf(*lhs, table_id, alias, with)),
452 rhs,
453 }),
454 RelExpr::EqJoin(join, a, b) if ok(&join.rhs) => RelExpr::Select(
455 Box::new(extend_lhs(with.clone(), *join.lhs)),
456 Expr::BinOp(BinOp::Eq, Box::new(Expr::Field(a)), Box::new(Expr::Field(b))),
457 ),
458 RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b) => RelExpr::EqJoin(
459 LeftDeepJoin {
460 lhs: Box::new(expand_leaf(*lhs, table_id, alias, with)),
461 rhs,
462 },
463 a,
464 b,
465 ),
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use std::sync::Arc;
472
473 use pretty_assertions as pretty;
474
475 use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, AlgebraicValue, Identity, ProductType};
476 use spacetimedb_primitives::TableId;
477 use spacetimedb_schema::{
478 def::ModuleDef,
479 schema::{Schema, TableSchema},
480 };
481 use spacetimedb_sql_parser::ast::BinOp;
482
483 use crate::{
484 check::{parse_and_type_sub, test_utils::build_module_def, SchemaView},
485 expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar},
486 };
487
488 use super::resolve_views_for_sub;
489
490 pub struct SchemaViewer(pub ModuleDef);
491
492 impl SchemaView for SchemaViewer {
493 fn table_id(&self, name: &str) -> Option<TableId> {
494 match name {
495 "users" => Some(TableId(0)),
496 "admins" => Some(TableId(1)),
497 "player" => Some(TableId(2)),
498 _ => None,
499 }
500 }
501
502 fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
503 match table_id.idx() {
504 0 => Some((TableId(0), "users")),
505 1 => Some((TableId(1), "admins")),
506 2 => Some((TableId(2), "player")),
507 _ => None,
508 }
509 .and_then(|(table_id, name)| {
510 self.0
511 .table(name)
512 .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
513 })
514 }
515
516 fn rls_rules_for_table(&self, table_id: TableId) -> anyhow::Result<Vec<Box<str>>> {
517 match table_id {
518 TableId(0) => Ok(vec!["select * from users where identity = :sender".into()]),
519 TableId(1) => Ok(vec!["select * from admins where identity = :sender".into()]),
520 TableId(2) => Ok(vec![
521 "select player.* from player join users u on player.id = u.id".into(),
522 "select player.* from player join admins".into(),
523 ]),
524 _ => Ok(vec![]),
525 }
526 }
527 }
528
529 fn module_def() -> ModuleDef {
530 build_module_def(vec![
531 (
532 "users",
533 ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]),
534 ),
535 (
536 "admins",
537 ProductType::from([("identity", AlgebraicType::identity()), ("id", AlgebraicType::U64)]),
538 ),
539 (
540 "player",
541 ProductType::from([("id", AlgebraicType::U64), ("level_num", AlgebraicType::U64)]),
542 ),
543 ])
544 }
545
546 fn resolve(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> anyhow::Result<Vec<ProjectName>> {
548 let (expr, _) = parse_and_type_sub(sql, tx, auth)?;
549 resolve_views_for_sub(tx, expr, auth, &mut false)
550 }
551
552 #[test]
553 fn test_rls_for_owner() -> anyhow::Result<()> {
554 let tx = SchemaViewer(module_def());
555 let auth = AuthCtx::new(Identity::ONE, Identity::ONE);
556 let sql = "select * from users";
557 let resolved = resolve(sql, &tx, &auth)?;
558
559 let users_schema = tx.schema("users").unwrap();
560
561 pretty::assert_eq!(
562 resolved,
563 vec![ProjectName::None(RelExpr::RelVar(Relvar {
564 schema: users_schema,
565 alias: "users".into(),
566 delta: None,
567 }))]
568 );
569
570 Ok(())
571 }
572
573 #[test]
574 fn test_rls_for_non_owner() -> anyhow::Result<()> {
575 let tx = SchemaViewer(module_def());
576 let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
577 let sql = "select * from users";
578 let resolved = resolve(sql, &tx, &auth)?;
579
580 let users_schema = tx.schema("users").unwrap();
581
582 pretty::assert_eq!(
583 resolved,
584 vec![ProjectName::Some(
585 RelExpr::Select(
586 Box::new(RelExpr::RelVar(Relvar {
587 schema: users_schema,
588 alias: "users".into(),
589 delta: None,
590 })),
591 Expr::BinOp(
592 BinOp::Eq,
593 Box::new(Expr::Field(FieldProject {
594 table: "users".into(),
595 field: 0,
596 ty: AlgebraicType::identity(),
597 })),
598 Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity()))
599 )
600 ),
601 "users".into()
602 )]
603 );
604
605 Ok(())
606 }
607
608 #[test]
609 fn test_multiple_rls_rules_for_table() -> anyhow::Result<()> {
610 let tx = SchemaViewer(module_def());
611 let auth = AuthCtx::new(Identity::ZERO, Identity::ONE);
612 let sql = "select * from player where level_num = 5";
613 let resolved = resolve(sql, &tx, &auth)?;
614
615 let users_schema = tx.schema("users").unwrap();
616 let admins_schema = tx.schema("admins").unwrap();
617 let player_schema = tx.schema("player").unwrap();
618
619 pretty::assert_eq!(
620 resolved,
621 vec![
622 ProjectName::Some(
623 RelExpr::Select(
624 Box::new(RelExpr::Select(
625 Box::new(RelExpr::Select(
626 Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
627 lhs: Box::new(RelExpr::RelVar(Relvar {
628 schema: player_schema.clone(),
629 alias: "player".into(),
630 delta: None,
631 })),
632 rhs: Relvar {
633 schema: users_schema.clone(),
634 alias: "u_2".into(),
635 delta: None,
636 },
637 })),
638 Expr::BinOp(
639 BinOp::Eq,
640 Box::new(Expr::Field(FieldProject {
641 table: "u_2".into(),
642 field: 0,
643 ty: AlgebraicType::identity(),
644 })),
645 Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
646 ),
647 )),
648 Expr::BinOp(
649 BinOp::Eq,
650 Box::new(Expr::Field(FieldProject {
651 table: "player".into(),
652 field: 0,
653 ty: AlgebraicType::U64,
654 })),
655 Box::new(Expr::Field(FieldProject {
656 table: "u_2".into(),
657 field: 1,
658 ty: AlgebraicType::U64,
659 })),
660 ),
661 )),
662 Expr::BinOp(
663 BinOp::Eq,
664 Box::new(Expr::Field(FieldProject {
665 table: "player".into(),
666 field: 1,
667 ty: AlgebraicType::U64,
668 })),
669 Box::new(Expr::Value(AlgebraicValue::U64(5), AlgebraicType::U64)),
670 ),
671 ),
672 "player".into(),
673 ),
674 ProjectName::Some(
675 RelExpr::Select(
676 Box::new(RelExpr::Select(
677 Box::new(RelExpr::LeftDeepJoin(LeftDeepJoin {
678 lhs: Box::new(RelExpr::RelVar(Relvar {
679 schema: player_schema.clone(),
680 alias: "player".into(),
681 delta: None,
682 })),
683 rhs: Relvar {
684 schema: admins_schema.clone(),
685 alias: "admins_4".into(),
686 delta: None,
687 },
688 })),
689 Expr::BinOp(
690 BinOp::Eq,
691 Box::new(Expr::Field(FieldProject {
692 table: "admins_4".into(),
693 field: 0,
694 ty: AlgebraicType::identity(),
695 })),
696 Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())),
697 ),
698 )),
699 Expr::BinOp(
700 BinOp::Eq,
701 Box::new(Expr::Field(FieldProject {
702 table: "player".into(),
703 field: 1,
704 ty: AlgebraicType::U64,
705 })),
706 Box::new(Expr::Value(AlgebraicValue::U64(5), AlgebraicType::U64)),
707 ),
708 ),
709 "player".into(),
710 ),
711 ]
712 );
713
714 Ok(())
715 }
716}