1use crate::error::{Result, WaypointError};
15
16const MAX_PARSE_DEPTH: usize = 50;
18
19#[derive(Debug, Clone, Default, PartialEq, Eq)]
25pub enum OnRequireFail {
26 #[default]
28 Error,
29 Warn,
31 Skip,
33}
34
35impl std::str::FromStr for OnRequireFail {
36 type Err = String;
37
38 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
39 match s.to_lowercase().as_str() {
40 "error" => Ok(Self::Error),
41 "warn" => Ok(Self::Warn),
42 "skip" => Ok(Self::Skip),
43 other => Err(format!("unknown on_require_fail value: '{other}'")),
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct GuardsConfig {
51 pub enabled: bool,
53 pub on_require_fail: OnRequireFail,
55}
56
57impl Default for GuardsConfig {
58 fn default() -> Self {
59 Self {
60 enabled: true,
61 on_require_fail: OnRequireFail::default(),
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq)]
72pub enum ComparisonOp {
73 Lt,
75 Gt,
77 Le,
79 Ge,
81}
82
83impl std::fmt::Display for ComparisonOp {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 ComparisonOp::Lt => write!(f, "<"),
87 ComparisonOp::Gt => write!(f, ">"),
88 ComparisonOp::Le => write!(f, "<="),
89 ComparisonOp::Ge => write!(f, ">="),
90 }
91 }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
96pub enum GuardExpr {
97 FunctionCall {
99 name: String,
101 args: Vec<GuardExpr>,
103 },
104 And(Box<GuardExpr>, Box<GuardExpr>),
106 Or(Box<GuardExpr>, Box<GuardExpr>),
108 Not(Box<GuardExpr>),
110 Comparison {
112 left: Box<GuardExpr>,
114 op: ComparisonOp,
116 right: Box<GuardExpr>,
118 },
119 StringLiteral(String),
121 NumberLiteral(i64),
123 BoolLiteral(bool),
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
129pub enum GuardValue {
130 Bool(bool),
132 Number(i64),
134 Str(String),
136}
137
138impl std::fmt::Display for GuardValue {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 match self {
141 GuardValue::Bool(b) => write!(f, "{b}"),
142 GuardValue::Number(n) => write!(f, "{n}"),
143 GuardValue::Str(s) => write!(f, "\"{s}\""),
144 }
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
154enum Token {
155 Ident(String),
156 StringLit(String),
157 NumberLit(i64),
158 And,
159 Or,
160 Not,
161 Lt,
162 Gt,
163 Le,
164 Ge,
165 LParen,
166 RParen,
167 Comma,
168}
169
170impl std::fmt::Display for Token {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 match self {
173 Token::Ident(s) => write!(f, "{s}"),
174 Token::StringLit(s) => write!(f, "\"{s}\""),
175 Token::NumberLit(n) => write!(f, "{n}"),
176 Token::And => write!(f, "AND"),
177 Token::Or => write!(f, "OR"),
178 Token::Not => write!(f, "NOT"),
179 Token::Lt => write!(f, "<"),
180 Token::Gt => write!(f, ">"),
181 Token::Le => write!(f, "<="),
182 Token::Ge => write!(f, ">="),
183 Token::LParen => write!(f, "("),
184 Token::RParen => write!(f, ")"),
185 Token::Comma => write!(f, ","),
186 }
187 }
188}
189
190fn tokenize(input: &str) -> Result<Vec<Token>> {
192 let mut tokens = Vec::new();
193 let chars: Vec<char> = input.chars().collect();
194 let len = chars.len();
195 let mut i = 0;
196
197 while i < len {
198 let ch = chars[i];
199
200 if ch.is_ascii_whitespace() {
202 i += 1;
203 continue;
204 }
205
206 if ch == '"' {
208 i += 1;
209 let start = i;
210 while i < len && chars[i] != '"' {
211 if chars[i] == '\\' && i + 1 < len {
212 i += 2; } else {
214 i += 1;
215 }
216 }
217 if i >= len {
218 return Err(WaypointError::ConfigError(
219 "Guard expression: unterminated string literal".to_string(),
220 ));
221 }
222 let s: String = chars[start..i].iter().collect();
223 tokens.push(Token::StringLit(s));
224 i += 1; continue;
226 }
227
228 if ch == '(' {
230 tokens.push(Token::LParen);
231 i += 1;
232 continue;
233 }
234 if ch == ')' {
235 tokens.push(Token::RParen);
236 i += 1;
237 continue;
238 }
239 if ch == ',' {
240 tokens.push(Token::Comma);
241 i += 1;
242 continue;
243 }
244
245 if ch == '<' {
247 if i + 1 < len && chars[i + 1] == '=' {
248 tokens.push(Token::Le);
249 i += 2;
250 } else {
251 tokens.push(Token::Lt);
252 i += 1;
253 }
254 continue;
255 }
256 if ch == '>' {
257 if i + 1 < len && chars[i + 1] == '=' {
258 tokens.push(Token::Ge);
259 i += 2;
260 } else {
261 tokens.push(Token::Gt);
262 i += 1;
263 }
264 continue;
265 }
266
267 if ch.is_ascii_digit() {
269 let start = i;
270 while i < len && chars[i].is_ascii_digit() {
271 i += 1;
272 }
273 let num_str: String = chars[start..i].iter().collect();
274 let n = num_str.parse::<i64>().map_err(|e| {
275 WaypointError::ConfigError(format!(
276 "Guard expression: invalid number '{num_str}': {e}"
277 ))
278 })?;
279 tokens.push(Token::NumberLit(n));
280 continue;
281 }
282
283 if ch.is_ascii_alphabetic() || ch == '_' {
285 let start = i;
286 while i < len && (chars[i].is_ascii_alphanumeric() || chars[i] == '_') {
287 i += 1;
288 }
289 let word: String = chars[start..i].iter().collect();
290 if word.eq_ignore_ascii_case("AND") {
291 tokens.push(Token::And);
292 } else if word.eq_ignore_ascii_case("OR") {
293 tokens.push(Token::Or);
294 } else if word.eq_ignore_ascii_case("NOT") {
295 tokens.push(Token::Not);
296 } else if word.eq_ignore_ascii_case("TRUE") {
297 tokens.push(Token::Ident("true".to_string()));
298 } else if word.eq_ignore_ascii_case("FALSE") {
299 tokens.push(Token::Ident("false".to_string()));
300 } else {
301 tokens.push(Token::Ident(word));
302 }
303 continue;
304 }
305
306 return Err(WaypointError::ConfigError(format!(
307 "Guard expression: unexpected character '{ch}'"
308 )));
309 }
310
311 Ok(tokens)
312}
313
314struct Parser {
320 tokens: Vec<Token>,
321 pos: usize,
322}
323
324impl Parser {
325 fn new(tokens: Vec<Token>) -> Self {
326 Self { tokens, pos: 0 }
327 }
328
329 fn peek(&self) -> Option<&Token> {
330 self.tokens.get(self.pos)
331 }
332
333 fn advance(&mut self) -> Option<Token> {
334 if self.pos < self.tokens.len() {
335 let tok = self.tokens[self.pos].clone();
336 self.pos += 1;
337 Some(tok)
338 } else {
339 None
340 }
341 }
342
343 fn expect(&mut self, expected: &Token) -> Result<()> {
344 match self.advance() {
345 Some(ref tok) if tok == expected => Ok(()),
346 Some(tok) => Err(WaypointError::ConfigError(format!(
347 "Guard expression: expected '{expected}', found '{tok}'"
348 ))),
349 None => Err(WaypointError::ConfigError(format!(
350 "Guard expression: expected '{expected}', found end of input"
351 ))),
352 }
353 }
354
355 fn parse_expr(&mut self, depth: usize) -> Result<GuardExpr> {
359 self.parse_or_expr(depth)
360 }
361
362 fn parse_or_expr(&mut self, depth: usize) -> Result<GuardExpr> {
364 if depth > MAX_PARSE_DEPTH {
365 return Err(WaypointError::ConfigError(
366 "Guard expression: maximum nesting depth exceeded".to_string(),
367 ));
368 }
369 let mut left = self.parse_and_expr(depth + 1)?;
370 while self.peek() == Some(&Token::Or) {
371 self.advance(); let right = self.parse_and_expr(depth + 1)?;
373 left = GuardExpr::Or(Box::new(left), Box::new(right));
374 }
375 Ok(left)
376 }
377
378 fn parse_and_expr(&mut self, depth: usize) -> Result<GuardExpr> {
380 if depth > MAX_PARSE_DEPTH {
381 return Err(WaypointError::ConfigError(
382 "Guard expression: maximum nesting depth exceeded".to_string(),
383 ));
384 }
385 let mut left = self.parse_not_expr(depth + 1)?;
386 while self.peek() == Some(&Token::And) {
387 self.advance(); let right = self.parse_not_expr(depth + 1)?;
389 left = GuardExpr::And(Box::new(left), Box::new(right));
390 }
391 Ok(left)
392 }
393
394 fn parse_not_expr(&mut self, depth: usize) -> Result<GuardExpr> {
396 if depth > MAX_PARSE_DEPTH {
397 return Err(WaypointError::ConfigError(
398 "Guard expression: maximum nesting depth exceeded".to_string(),
399 ));
400 }
401 if self.peek() == Some(&Token::Not) {
402 self.advance(); let inner = self.parse_not_expr(depth + 1)?;
404 Ok(GuardExpr::Not(Box::new(inner)))
405 } else {
406 self.parse_comparison(depth + 1)
407 }
408 }
409
410 fn parse_comparison(&mut self, depth: usize) -> Result<GuardExpr> {
412 if depth > MAX_PARSE_DEPTH {
413 return Err(WaypointError::ConfigError(
414 "Guard expression: maximum nesting depth exceeded".to_string(),
415 ));
416 }
417 let left = self.parse_primary(depth + 1)?;
418
419 let op = match self.peek() {
420 Some(Token::Lt) => Some(ComparisonOp::Lt),
421 Some(Token::Gt) => Some(ComparisonOp::Gt),
422 Some(Token::Le) => Some(ComparisonOp::Le),
423 Some(Token::Ge) => Some(ComparisonOp::Ge),
424 _ => None,
425 };
426
427 if let Some(op) = op {
428 self.advance(); let right = self.parse_primary(depth + 1)?;
430 Ok(GuardExpr::Comparison {
431 left: Box::new(left),
432 op,
433 right: Box::new(right),
434 })
435 } else {
436 Ok(left)
437 }
438 }
439
440 fn parse_primary(&mut self, depth: usize) -> Result<GuardExpr> {
442 if depth > MAX_PARSE_DEPTH {
443 return Err(WaypointError::ConfigError(
444 "Guard expression: maximum nesting depth exceeded".to_string(),
445 ));
446 }
447 match self.peek().cloned() {
448 Some(Token::Ident(name)) => {
449 if name == "true" {
451 self.advance();
452 return Ok(GuardExpr::BoolLiteral(true));
453 }
454 if name == "false" {
455 self.advance();
456 return Ok(GuardExpr::BoolLiteral(false));
457 }
458
459 if self.pos + 1 < self.tokens.len() && self.tokens[self.pos + 1] == Token::LParen {
461 self.advance(); self.advance(); let args = self.parse_args(depth + 1)?;
464 self.expect(&Token::RParen)?;
465 Ok(GuardExpr::FunctionCall { name, args })
466 } else {
467 Err(WaypointError::ConfigError(format!(
468 "Guard expression: unexpected identifier '{name}' (expected function call)"
469 )))
470 }
471 }
472 Some(Token::LParen) => {
473 self.advance(); let expr = self.parse_expr(depth + 1)?;
475 self.expect(&Token::RParen)?;
476 Ok(expr)
477 }
478 Some(Token::StringLit(s)) => {
479 self.advance();
480 Ok(GuardExpr::StringLiteral(s))
481 }
482 Some(Token::NumberLit(n)) => {
483 self.advance();
484 Ok(GuardExpr::NumberLiteral(n))
485 }
486 Some(tok) => Err(WaypointError::ConfigError(format!(
487 "Guard expression: unexpected token '{tok}'"
488 ))),
489 None => Err(WaypointError::ConfigError(
490 "Guard expression: unexpected end of input".to_string(),
491 )),
492 }
493 }
494
495 fn parse_args(&mut self, depth: usize) -> Result<Vec<GuardExpr>> {
497 let mut args = Vec::new();
498
499 if self.peek() == Some(&Token::RParen) {
501 return Ok(args);
502 }
503
504 args.push(self.parse_expr(depth)?);
505
506 while self.peek() == Some(&Token::Comma) {
507 self.advance(); args.push(self.parse_expr(depth)?);
509 }
510
511 Ok(args)
512 }
513}
514
515pub fn parse(input: &str) -> Result<GuardExpr> {
532 let tokens = tokenize(input)?;
533 if tokens.is_empty() {
534 return Err(WaypointError::ConfigError(
535 "Guard expression: empty expression".to_string(),
536 ));
537 }
538 let mut parser = Parser::new(tokens);
539 let expr = parser.parse_expr(0)?;
540
541 if parser.pos < parser.tokens.len() {
543 let remaining = &parser.tokens[parser.pos];
544 return Err(WaypointError::ConfigError(format!(
545 "Guard expression: unexpected token '{remaining}' after complete expression"
546 )));
547 }
548
549 Ok(expr)
550}
551
552fn builtin_sql(name: &str, args: &[String], schema: &str) -> Result<(String, Vec<String>, bool)> {
562 match name {
563 "table_exists" => {
564 require_args(name, args, 1)?;
565 let table = &args[0];
566 Ok((
567 "SELECT EXISTS(SELECT 1 FROM information_schema.tables \
568 WHERE table_schema = $1 AND table_name = $2)"
569 .to_string(),
570 vec![schema.to_string(), table.to_string()],
571 true,
572 ))
573 }
574 "column_exists" => {
575 require_args(name, args, 2)?;
576 let table = &args[0];
577 let column = &args[1];
578 Ok((
579 "SELECT EXISTS(SELECT 1 FROM information_schema.columns \
580 WHERE table_schema = $1 AND table_name = $2 \
581 AND column_name = $3)"
582 .to_string(),
583 vec![schema.to_string(), table.to_string(), column.to_string()],
584 true,
585 ))
586 }
587 "column_type" => {
588 require_args(name, args, 3)?;
589 let table = &args[0];
590 let column = &args[1];
591 let expected_type = &args[2];
592 Ok((
593 "SELECT EXISTS(SELECT 1 FROM information_schema.columns \
594 WHERE table_schema = $1 AND table_name = $2 \
595 AND column_name = $3 AND data_type = $4)"
596 .to_string(),
597 vec![
598 schema.to_string(),
599 table.to_string(),
600 column.to_string(),
601 expected_type.to_string(),
602 ],
603 true,
604 ))
605 }
606 "column_nullable" => {
607 require_args(name, args, 2)?;
608 let table = &args[0];
609 let column = &args[1];
610 Ok((
611 "SELECT EXISTS(SELECT 1 FROM information_schema.columns \
612 WHERE table_schema = $1 AND table_name = $2 \
613 AND column_name = $3 AND is_nullable = 'YES')"
614 .to_string(),
615 vec![schema.to_string(), table.to_string(), column.to_string()],
616 true,
617 ))
618 }
619 "index_exists" => {
620 require_args(name, args, 1)?;
621 let index = &args[0];
622 Ok((
623 "SELECT EXISTS(SELECT 1 FROM pg_indexes \
624 WHERE schemaname = $1 AND indexname = $2)"
625 .to_string(),
626 vec![schema.to_string(), index.to_string()],
627 true,
628 ))
629 }
630 "constraint_exists" => {
631 require_args(name, args, 2)?;
632 let table = &args[0];
633 let constraint = &args[1];
634 Ok((
635 "SELECT EXISTS(SELECT 1 FROM information_schema.table_constraints \
636 WHERE table_schema = $1 AND table_name = $2 \
637 AND constraint_name = $3)"
638 .to_string(),
639 vec![
640 schema.to_string(),
641 table.to_string(),
642 constraint.to_string(),
643 ],
644 true,
645 ))
646 }
647 "function_exists" => {
648 require_args(name, args, 1)?;
649 let func = &args[0];
650 Ok((
651 "SELECT EXISTS(SELECT 1 FROM pg_proc p \
652 JOIN pg_namespace n ON n.oid = p.pronamespace \
653 WHERE n.nspname = $1 AND p.proname = $2)"
654 .to_string(),
655 vec![schema.to_string(), func.to_string()],
656 true,
657 ))
658 }
659 "enum_exists" => {
660 require_args(name, args, 1)?;
661 let enum_name = &args[0];
662 Ok((
663 "SELECT EXISTS(SELECT 1 FROM pg_type t \
664 JOIN pg_namespace n ON n.oid = t.typnamespace \
665 WHERE n.nspname = $1 AND t.typname = $2 \
666 AND t.typtype = 'e')"
667 .to_string(),
668 vec![schema.to_string(), enum_name.to_string()],
669 true,
670 ))
671 }
672 "row_count" => {
673 require_args(name, args, 1)?;
674 let table = &args[0];
675 Ok((
676 "SELECT COALESCE(n_live_tup, 0)::bigint FROM pg_stat_user_tables \
677 WHERE schemaname = $1 AND relname = $2"
678 .to_string(),
679 vec![schema.to_string(), table.to_string()],
680 false,
681 ))
682 }
683 "sql" => {
684 require_args(name, args, 1)?;
685 let query = &args[0];
686 Ok((query.to_string(), vec![], true))
687 }
688 _ => Err(WaypointError::ConfigError(format!(
689 "Guard expression: unknown function '{name}'"
690 ))),
691 }
692}
693
694fn require_args(name: &str, args: &[String], expected: usize) -> Result<()> {
696 if args.len() != expected {
697 return Err(WaypointError::ConfigError(format!(
698 "Guard expression: {name}() expects {expected} argument(s), got {}",
699 args.len()
700 )));
701 }
702 Ok(())
703}
704
705fn extract_string_args(args: &[GuardExpr]) -> Result<Vec<String>> {
714 let mut result = Vec::with_capacity(args.len());
715 for arg in args {
716 match arg {
717 GuardExpr::StringLiteral(s) => result.push(s.clone()),
718 other => {
719 return Err(WaypointError::ConfigError(format!(
720 "Guard expression: expected string argument, found {other:?}"
721 )));
722 }
723 }
724 }
725 Ok(result)
726}
727
728pub async fn evaluate(
738 client: &tokio_postgres::Client,
739 schema: &str,
740 expr: &GuardExpr,
741) -> Result<bool> {
742 let value = eval_expr(client, schema, expr).await?;
743 match value {
744 GuardValue::Bool(b) => Ok(b),
745 other => Err(WaypointError::ConfigError(format!(
746 "Guard expression: expected boolean result, got {other}"
747 ))),
748 }
749}
750
751fn eval_expr<'a>(
753 client: &'a tokio_postgres::Client,
754 schema: &'a str,
755 expr: &'a GuardExpr,
756) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<GuardValue>> + Send + 'a>> {
757 Box::pin(async move {
758 match expr {
759 GuardExpr::BoolLiteral(b) => Ok(GuardValue::Bool(*b)),
760 GuardExpr::NumberLiteral(n) => Ok(GuardValue::Number(*n)),
761 GuardExpr::StringLiteral(s) => Ok(GuardValue::Str(s.clone())),
762
763 GuardExpr::Not(inner) => {
764 let val = eval_expr(client, schema, inner).await?;
765 match val {
766 GuardValue::Bool(b) => Ok(GuardValue::Bool(!b)),
767 other => Err(WaypointError::ConfigError(format!(
768 "Guard expression: NOT requires boolean, got {other}"
769 ))),
770 }
771 }
772
773 GuardExpr::And(left, right) => {
774 let lval = eval_expr(client, schema, left).await?;
775 match lval {
776 GuardValue::Bool(false) => Ok(GuardValue::Bool(false)),
777 GuardValue::Bool(true) => {
778 let rval = eval_expr(client, schema, right).await?;
779 match rval {
780 GuardValue::Bool(b) => Ok(GuardValue::Bool(b)),
781 other => Err(WaypointError::ConfigError(format!(
782 "Guard expression: AND requires boolean operands, got {other}"
783 ))),
784 }
785 }
786 other => Err(WaypointError::ConfigError(format!(
787 "Guard expression: AND requires boolean operands, got {other}"
788 ))),
789 }
790 }
791
792 GuardExpr::Or(left, right) => {
793 let lval = eval_expr(client, schema, left).await?;
794 match lval {
795 GuardValue::Bool(true) => Ok(GuardValue::Bool(true)),
796 GuardValue::Bool(false) => {
797 let rval = eval_expr(client, schema, right).await?;
798 match rval {
799 GuardValue::Bool(b) => Ok(GuardValue::Bool(b)),
800 other => Err(WaypointError::ConfigError(format!(
801 "Guard expression: OR requires boolean operands, got {other}"
802 ))),
803 }
804 }
805 other => Err(WaypointError::ConfigError(format!(
806 "Guard expression: OR requires boolean operands, got {other}"
807 ))),
808 }
809 }
810
811 GuardExpr::Comparison { left, op, right } => {
812 let lval = eval_expr(client, schema, left).await?;
813 let rval = eval_expr(client, schema, right).await?;
814 match (&lval, &rval) {
815 (GuardValue::Number(a), GuardValue::Number(b)) => {
816 let result = match op {
817 ComparisonOp::Lt => a < b,
818 ComparisonOp::Gt => a > b,
819 ComparisonOp::Le => a <= b,
820 ComparisonOp::Ge => a >= b,
821 };
822 Ok(GuardValue::Bool(result))
823 }
824 _ => Err(WaypointError::ConfigError(format!(
825 "Guard expression: comparison requires numeric operands, got {lval} {op} {rval}"
826 ))),
827 }
828 }
829
830 GuardExpr::FunctionCall { name, args } => {
831 let string_args = extract_string_args(args)?;
832 let (sql, param_values, is_boolean) = builtin_sql(name, &string_args, schema)?;
833 let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = param_values
834 .iter()
835 .map(|s| s as &(dyn tokio_postgres::types::ToSql + Sync))
836 .collect();
837
838 let row = client.query_one(&sql, ¶ms).await.map_err(|e| {
839 WaypointError::GuardFailed {
840 kind: "evaluation".to_string(),
841 script: String::new(),
842 expression: format!(
843 "{name}({}) failed: {e}",
844 string_args
845 .iter()
846 .map(|a| format!("\"{a}\""))
847 .collect::<Vec<_>>()
848 .join(", ")
849 ),
850 }
851 })?;
852
853 if is_boolean {
854 let val: bool = row.get(0);
855 Ok(GuardValue::Bool(val))
856 } else {
857 let val: i64 = row.get(0);
858 Ok(GuardValue::Number(val))
859 }
860 }
861 }
862 })
863}
864
865#[cfg(test)]
870mod tests {
871 use super::*;
872
873 #[test]
874 fn test_parse_simple_function_call() {
875 let expr = parse("table_exists(\"users\")").unwrap();
876 match expr {
877 GuardExpr::FunctionCall { name, args } => {
878 assert_eq!(name, "table_exists");
879 assert_eq!(args.len(), 1);
880 assert_eq!(args[0], GuardExpr::StringLiteral("users".to_string()));
881 }
882 other => panic!("Expected FunctionCall, got {other:?}"),
883 }
884 }
885
886 #[test]
887 fn test_parse_function_with_multiple_args() {
888 let expr = parse("column_exists(\"users\", \"email\")").unwrap();
889 match expr {
890 GuardExpr::FunctionCall { name, args } => {
891 assert_eq!(name, "column_exists");
892 assert_eq!(args.len(), 2);
893 assert_eq!(args[0], GuardExpr::StringLiteral("users".to_string()));
894 assert_eq!(args[1], GuardExpr::StringLiteral("email".to_string()));
895 }
896 other => panic!("Expected FunctionCall, got {other:?}"),
897 }
898 }
899
900 #[test]
901 fn test_parse_function_with_three_args() {
902 let expr = parse("column_type(\"users\", \"age\", \"integer\")").unwrap();
903 match expr {
904 GuardExpr::FunctionCall { name, args } => {
905 assert_eq!(name, "column_type");
906 assert_eq!(args.len(), 3);
907 assert_eq!(args[0], GuardExpr::StringLiteral("users".to_string()));
908 assert_eq!(args[1], GuardExpr::StringLiteral("age".to_string()));
909 assert_eq!(args[2], GuardExpr::StringLiteral("integer".to_string()));
910 }
911 other => panic!("Expected FunctionCall, got {other:?}"),
912 }
913 }
914
915 #[test]
916 fn test_parse_and_expression() {
917 let expr =
918 parse("table_exists(\"users\") AND column_exists(\"users\", \"email\")").unwrap();
919 match expr {
920 GuardExpr::And(left, right) => {
921 match *left {
922 GuardExpr::FunctionCall { ref name, .. } => assert_eq!(name, "table_exists"),
923 ref other => panic!("Expected FunctionCall on left, got {other:?}"),
924 }
925 match *right {
926 GuardExpr::FunctionCall { ref name, .. } => {
927 assert_eq!(name, "column_exists")
928 }
929 ref other => panic!("Expected FunctionCall on right, got {other:?}"),
930 }
931 }
932 other => panic!("Expected And, got {other:?}"),
933 }
934 }
935
936 #[test]
937 fn test_parse_or_expression() {
938 let expr = parse("table_exists(\"users\") OR table_exists(\"accounts\")").unwrap();
939 match expr {
940 GuardExpr::Or(left, right) => {
941 match *left {
942 GuardExpr::FunctionCall { ref name, ref args } => {
943 assert_eq!(name, "table_exists");
944 assert_eq!(args[0], GuardExpr::StringLiteral("users".to_string()));
945 }
946 ref other => panic!("Expected FunctionCall on left, got {other:?}"),
947 }
948 match *right {
949 GuardExpr::FunctionCall { ref name, ref args } => {
950 assert_eq!(name, "table_exists");
951 assert_eq!(args[0], GuardExpr::StringLiteral("accounts".to_string()));
952 }
953 ref other => panic!("Expected FunctionCall on right, got {other:?}"),
954 }
955 }
956 other => panic!("Expected Or, got {other:?}"),
957 }
958 }
959
960 #[test]
961 fn test_parse_not_expression() {
962 let expr = parse("NOT table_exists(\"legacy\")").unwrap();
963 match expr {
964 GuardExpr::Not(inner) => match *inner {
965 GuardExpr::FunctionCall { ref name, .. } => assert_eq!(name, "table_exists"),
966 ref other => panic!("Expected FunctionCall inside NOT, got {other:?}"),
967 },
968 other => panic!("Expected Not, got {other:?}"),
969 }
970 }
971
972 #[test]
973 fn test_parse_double_not() {
974 let expr = parse("NOT NOT table_exists(\"t\")").unwrap();
975 match expr {
976 GuardExpr::Not(inner) => match *inner {
977 GuardExpr::Not(inner2) => match *inner2 {
978 GuardExpr::FunctionCall { ref name, .. } => {
979 assert_eq!(name, "table_exists")
980 }
981 ref other => panic!("Expected FunctionCall, got {other:?}"),
982 },
983 ref other => panic!("Expected Not, got {other:?}"),
984 },
985 other => panic!("Expected Not, got {other:?}"),
986 }
987 }
988
989 #[test]
990 fn test_parse_nested_parentheses() {
991 let expr =
992 parse("(table_exists(\"a\") AND table_exists(\"b\")) OR table_exists(\"c\")").unwrap();
993 match expr {
994 GuardExpr::Or(left, right) => {
995 match *left {
996 GuardExpr::And(_, _) => {} ref other => panic!("Expected And on left, got {other:?}"),
998 }
999 match *right {
1000 GuardExpr::FunctionCall { ref name, .. } => {
1001 assert_eq!(name, "table_exists")
1002 }
1003 ref other => panic!("Expected FunctionCall on right, got {other:?}"),
1004 }
1005 }
1006 other => panic!("Expected Or, got {other:?}"),
1007 }
1008 }
1009
1010 #[test]
1011 fn test_parse_deeply_nested_parentheses() {
1012 let expr = parse("((table_exists(\"a\")))").unwrap();
1013 match expr {
1014 GuardExpr::FunctionCall { ref name, .. } => assert_eq!(name, "table_exists"),
1015 other => panic!("Expected FunctionCall, got {other:?}"),
1016 }
1017 }
1018
1019 #[test]
1020 fn test_parse_comparison_less_than() {
1021 let expr = parse("row_count(\"users\") < 1000").unwrap();
1022 match expr {
1023 GuardExpr::Comparison { left, op, right } => {
1024 match *left {
1025 GuardExpr::FunctionCall { ref name, .. } => assert_eq!(name, "row_count"),
1026 ref other => panic!("Expected FunctionCall on left, got {other:?}"),
1027 }
1028 assert_eq!(op, ComparisonOp::Lt);
1029 assert_eq!(*right, GuardExpr::NumberLiteral(1000));
1030 }
1031 other => panic!("Expected Comparison, got {other:?}"),
1032 }
1033 }
1034
1035 #[test]
1036 fn test_parse_comparison_greater_than() {
1037 let expr = parse("row_count(\"orders\") > 0").unwrap();
1038 match expr {
1039 GuardExpr::Comparison { op, .. } => assert_eq!(op, ComparisonOp::Gt),
1040 other => panic!("Expected Comparison, got {other:?}"),
1041 }
1042 }
1043
1044 #[test]
1045 fn test_parse_comparison_le_ge() {
1046 let expr = parse("row_count(\"t\") <= 500").unwrap();
1047 match expr {
1048 GuardExpr::Comparison { op, .. } => assert_eq!(op, ComparisonOp::Le),
1049 other => panic!("Expected Comparison, got {other:?}"),
1050 }
1051
1052 let expr = parse("row_count(\"t\") >= 10").unwrap();
1053 match expr {
1054 GuardExpr::Comparison { op, .. } => assert_eq!(op, ComparisonOp::Ge),
1055 other => panic!("Expected Comparison, got {other:?}"),
1056 }
1057 }
1058
1059 #[test]
1060 fn test_parse_error_empty() {
1061 let result = parse("");
1062 assert!(result.is_err());
1063 let err = result.unwrap_err().to_string();
1064 assert!(err.contains("empty expression"), "got: {err}");
1065 }
1066
1067 #[test]
1068 fn test_parse_error_unterminated_string() {
1069 let result = parse("table_exists(\"users)");
1070 assert!(result.is_err());
1071 let err = result.unwrap_err().to_string();
1072 assert!(err.contains("unterminated string"), "got: {err}");
1073 }
1074
1075 #[test]
1076 fn test_parse_error_unexpected_token() {
1077 let result = parse("AND");
1078 assert!(result.is_err());
1079 }
1080
1081 #[test]
1082 fn test_parse_error_missing_closing_paren() {
1083 let result = parse("table_exists(\"users\"");
1084 assert!(result.is_err());
1085 let err = result.unwrap_err().to_string();
1086 assert!(err.contains("expected ')'"), "got: {err}");
1087 }
1088
1089 #[test]
1090 fn test_parse_error_trailing_tokens() {
1091 let result = parse("table_exists(\"users\") table_exists(\"orders\")");
1092 assert!(result.is_err());
1093 let err = result.unwrap_err().to_string();
1094 assert!(err.contains("unexpected"), "got: {err}");
1095 }
1096
1097 #[test]
1098 fn test_parse_error_unexpected_character() {
1099 let result = parse("table_exists(\"users\") @ foo");
1100 assert!(result.is_err());
1101 let err = result.unwrap_err().to_string();
1102 assert!(err.contains("unexpected character"), "got: {err}");
1103 }
1104
1105 #[test]
1106 fn test_parse_complex_expression() {
1107 let input = "(table_exists(\"users\") AND NOT column_exists(\"users\", \"deleted_at\")) \
1110 OR (enum_exists(\"status\") AND row_count(\"users\") < 10000)";
1111 let expr = parse(input).unwrap();
1112 match expr {
1113 GuardExpr::Or(left, right) => {
1114 match *left {
1116 GuardExpr::And(ref a, ref b) => {
1117 match **a {
1118 GuardExpr::FunctionCall { ref name, .. } => {
1119 assert_eq!(name, "table_exists")
1120 }
1121 ref other => panic!("Expected FunctionCall, got {other:?}"),
1122 }
1123 match **b {
1124 GuardExpr::Not(ref inner) => match **inner {
1125 GuardExpr::FunctionCall { ref name, .. } => {
1126 assert_eq!(name, "column_exists")
1127 }
1128 ref other => panic!("Expected FunctionCall, got {other:?}"),
1129 },
1130 ref other => panic!("Expected Not, got {other:?}"),
1131 }
1132 }
1133 ref other => panic!("Expected And, got {other:?}"),
1134 }
1135 match *right {
1137 GuardExpr::And(ref a, ref b) => {
1138 match **a {
1139 GuardExpr::FunctionCall { ref name, .. } => {
1140 assert_eq!(name, "enum_exists")
1141 }
1142 ref other => panic!("Expected FunctionCall, got {other:?}"),
1143 }
1144 match **b {
1145 GuardExpr::Comparison {
1146 ref op, ref right, ..
1147 } => {
1148 assert_eq!(*op, ComparisonOp::Lt);
1149 assert_eq!(**right, GuardExpr::NumberLiteral(10000));
1150 }
1151 ref other => panic!("Expected Comparison, got {other:?}"),
1152 }
1153 }
1154 ref other => panic!("Expected And, got {other:?}"),
1155 }
1156 }
1157 other => panic!("Expected Or, got {other:?}"),
1158 }
1159 }
1160
1161 #[test]
1162 fn test_parse_and_or_precedence() {
1163 let expr =
1166 parse("table_exists(\"a\") OR table_exists(\"b\") AND table_exists(\"c\")").unwrap();
1167 match expr {
1168 GuardExpr::Or(left, right) => {
1169 match *left {
1170 GuardExpr::FunctionCall { ref name, .. } => assert_eq!(name, "table_exists"),
1171 ref other => panic!("Expected FunctionCall, got {other:?}"),
1172 }
1173 match *right {
1174 GuardExpr::And(_, _) => {} ref other => panic!("Expected And on right, got {other:?}"),
1176 }
1177 }
1178 other => panic!("Expected Or, got {other:?}"),
1179 }
1180 }
1181
1182 #[test]
1183 fn test_parse_chained_and() {
1184 let expr =
1185 parse("table_exists(\"a\") AND table_exists(\"b\") AND table_exists(\"c\")").unwrap();
1186 match expr {
1188 GuardExpr::And(left, right) => {
1189 match *left {
1190 GuardExpr::And(_, _) => {} ref other => panic!("Expected And on left (left-assoc), got {other:?}"),
1192 }
1193 match *right {
1194 GuardExpr::FunctionCall { ref name, .. } => assert_eq!(name, "table_exists"),
1195 ref other => panic!("Expected FunctionCall on right, got {other:?}"),
1196 }
1197 }
1198 other => panic!("Expected And, got {other:?}"),
1199 }
1200 }
1201
1202 #[test]
1203 fn test_parse_sql_function() {
1204 let expr = parse("sql(\"SELECT true\")").unwrap();
1205 match expr {
1206 GuardExpr::FunctionCall { name, args } => {
1207 assert_eq!(name, "sql");
1208 assert_eq!(args.len(), 1);
1209 assert_eq!(args[0], GuardExpr::StringLiteral("SELECT true".to_string()));
1210 }
1211 other => panic!("Expected FunctionCall, got {other:?}"),
1212 }
1213 }
1214
1215 #[test]
1216 fn test_parse_not_with_parentheses() {
1217 let expr = parse("NOT (table_exists(\"a\") OR table_exists(\"b\"))").unwrap();
1218 match expr {
1219 GuardExpr::Not(inner) => match *inner {
1220 GuardExpr::Or(_, _) => {} ref other => panic!("Expected Or inside Not, got {other:?}"),
1222 },
1223 other => panic!("Expected Not, got {other:?}"),
1224 }
1225 }
1226
1227 #[test]
1228 fn test_parse_boolean_literals() {
1229 let expr = parse("true").unwrap();
1230 assert_eq!(expr, GuardExpr::BoolLiteral(true));
1231
1232 let expr = parse("false").unwrap();
1233 assert_eq!(expr, GuardExpr::BoolLiteral(false));
1234 }
1235
1236 #[test]
1237 fn test_tokenize_all_operators() {
1238 let tokens = tokenize("< > <= >= AND OR NOT ( ) ,").unwrap();
1239 assert_eq!(
1240 tokens,
1241 vec![
1242 Token::Lt,
1243 Token::Gt,
1244 Token::Le,
1245 Token::Ge,
1246 Token::And,
1247 Token::Or,
1248 Token::Not,
1249 Token::LParen,
1250 Token::RParen,
1251 Token::Comma,
1252 ]
1253 );
1254 }
1255
1256 #[test]
1257 fn test_builtin_sql_table_exists() {
1258 let (sql, params, is_bool) =
1259 builtin_sql("table_exists", &["users".to_string()], "public").unwrap();
1260 assert!(is_bool);
1261 assert!(sql.contains("information_schema.tables"));
1262 assert!(sql.contains("$1"));
1263 assert!(sql.contains("$2"));
1264 assert_eq!(params, vec!["public", "users"]);
1265 }
1266
1267 #[test]
1268 fn test_builtin_sql_column_exists() {
1269 let (sql, params, is_bool) = builtin_sql(
1270 "column_exists",
1271 &["users".to_string(), "email".to_string()],
1272 "public",
1273 )
1274 .unwrap();
1275 assert!(is_bool);
1276 assert!(sql.contains("information_schema.columns"));
1277 assert!(sql.contains("$3"));
1278 assert_eq!(params, vec!["public", "users", "email"]);
1279 }
1280
1281 #[test]
1282 fn test_builtin_sql_row_count() {
1283 let (sql, params, is_bool) =
1284 builtin_sql("row_count", &["users".to_string()], "public").unwrap();
1285 assert!(!is_bool);
1286 assert!(sql.contains("pg_stat_user_tables"));
1287 assert!(sql.contains("n_live_tup"));
1288 assert_eq!(params, vec!["public", "users"]);
1289 }
1290
1291 #[test]
1292 fn test_builtin_sql_unknown_function() {
1293 let result = builtin_sql("unknown_fn", &[], "public");
1294 assert!(result.is_err());
1295 let err = result.unwrap_err().to_string();
1296 assert!(err.contains("unknown function"), "got: {err}");
1297 }
1298
1299 #[test]
1300 fn test_builtin_sql_wrong_arg_count() {
1301 let result = builtin_sql("table_exists", &[], "public");
1302 assert!(result.is_err());
1303 let err = result.unwrap_err().to_string();
1304 assert!(err.contains("expects 1 argument"), "got: {err}");
1305 }
1306
1307 #[test]
1308 fn test_parse_depth_limit() {
1309 let mut expr = String::new();
1311 for _ in 0..100 {
1312 expr.push_str("NOT ");
1313 }
1314 expr.push_str("true");
1315 let result = parse(&expr);
1316 assert!(result.is_err());
1317 let err = result.unwrap_err().to_string();
1318 assert!(err.contains("maximum nesting depth exceeded"), "got: {err}");
1319 }
1320
1321 #[test]
1322 fn test_builtin_sql_column_type() {
1323 let (sql, params, is_bool) = builtin_sql(
1324 "column_type",
1325 &[
1326 "users".to_string(),
1327 "age".to_string(),
1328 "integer".to_string(),
1329 ],
1330 "myschema",
1331 )
1332 .unwrap();
1333 assert!(is_bool);
1334 assert!(sql.contains("data_type = $4"));
1335 assert_eq!(params, vec!["myschema", "users", "age", "integer"]);
1336 }
1337
1338 #[test]
1339 fn test_builtin_sql_column_nullable() {
1340 let (sql, params, is_bool) = builtin_sql(
1341 "column_nullable",
1342 &["users".to_string(), "name".to_string()],
1343 "public",
1344 )
1345 .unwrap();
1346 assert!(is_bool);
1347 assert!(sql.contains("is_nullable = 'YES'"));
1348 assert_eq!(params, vec!["public", "users", "name"]);
1349 }
1350
1351 #[test]
1352 fn test_builtin_sql_index_exists() {
1353 let (sql, params, is_bool) =
1354 builtin_sql("index_exists", &["idx_users_email".to_string()], "public").unwrap();
1355 assert!(is_bool);
1356 assert!(sql.contains("pg_indexes"));
1357 assert!(sql.contains("$2"));
1358 assert_eq!(params, vec!["public", "idx_users_email"]);
1359 }
1360
1361 #[test]
1362 fn test_builtin_sql_constraint_exists() {
1363 let (sql, params, is_bool) = builtin_sql(
1364 "constraint_exists",
1365 &["users".to_string(), "users_pkey".to_string()],
1366 "public",
1367 )
1368 .unwrap();
1369 assert!(is_bool);
1370 assert!(sql.contains("table_constraints"));
1371 assert!(sql.contains("$3"));
1372 assert_eq!(params, vec!["public", "users", "users_pkey"]);
1373 }
1374
1375 #[test]
1376 fn test_builtin_sql_function_exists() {
1377 let (sql, params, is_bool) =
1378 builtin_sql("function_exists", &["my_func".to_string()], "public").unwrap();
1379 assert!(is_bool);
1380 assert!(sql.contains("pg_proc"));
1381 assert!(sql.contains("pg_namespace"));
1382 assert!(sql.contains("$2"));
1383 assert_eq!(params, vec!["public", "my_func"]);
1384 }
1385
1386 #[test]
1387 fn test_builtin_sql_enum_exists() {
1388 let (sql, params, is_bool) =
1389 builtin_sql("enum_exists", &["status_type".to_string()], "public").unwrap();
1390 assert!(is_bool);
1391 assert!(sql.contains("pg_type"));
1392 assert!(sql.contains("typtype = 'e'"));
1393 assert!(sql.contains("$2"));
1394 assert_eq!(params, vec!["public", "status_type"]);
1395 }
1396
1397 #[test]
1398 fn test_builtin_sql_custom_sql() {
1399 let (sql, params, is_bool) = builtin_sql(
1400 "sql",
1401 &["SELECT count(*) = 0 FROM old_table".to_string()],
1402 "public",
1403 )
1404 .unwrap();
1405 assert!(is_bool);
1406 assert_eq!(sql, "SELECT count(*) = 0 FROM old_table");
1407 assert!(params.is_empty());
1408 }
1409
1410 #[test]
1411 fn test_builtin_sql_params_order_table_exists() {
1412 let (sql, params, is_bool) =
1413 builtin_sql("table_exists", &["users".to_string()], "myschema").unwrap();
1414 assert!(is_bool);
1415 assert_eq!(params.len(), 2);
1416 assert_eq!(params[0], "myschema");
1417 assert_eq!(params[1], "users");
1418 assert!(sql.contains("$1"));
1419 assert!(sql.contains("$2"));
1420 }
1421
1422 #[test]
1423 fn test_builtin_sql_sql_function_empty_params() {
1424 let (sql, params, is_bool) =
1425 builtin_sql("sql", &["SELECT 1".to_string()], "public").unwrap();
1426 assert!(is_bool);
1427 assert!(params.is_empty());
1428 assert_eq!(sql, "SELECT 1");
1429 }
1430
1431 #[test]
1432 fn test_parse_empty_function_args() {
1433 let expr = parse("table_exists()").unwrap();
1435 match expr {
1436 GuardExpr::FunctionCall { name, args } => {
1437 assert_eq!(name, "table_exists");
1438 assert!(args.is_empty());
1439 }
1440 other => panic!("Expected FunctionCall, got {other:?}"),
1441 }
1442 }
1443}