1use xlog_core::ScalarType;
4
5#[derive(Debug, Clone, PartialEq)]
7pub enum Term {
8 Variable(String),
10 Anonymous,
12 Integer(i64),
14 Float(f64),
16 String(String),
18 Symbol(u32),
20 List(Vec<Term>),
22 Cons {
24 head: Box<Term>,
26 tail: Box<Term>,
28 },
29 Compound {
31 functor: String,
33 args: Vec<Term>,
35 },
36 PredRef(String),
38 Aggregate(AggExpr),
40}
41
42impl Term {
43 pub fn is_variable(&self) -> bool {
45 matches!(self, Term::Variable(_))
46 }
47
48 pub fn is_anonymous(&self) -> bool {
50 matches!(self, Term::Anonymous)
51 }
52
53 pub fn is_any_variable(&self) -> bool {
55 matches!(self, Term::Variable(_) | Term::Anonymous)
56 }
57
58 pub fn is_constant(&self) -> bool {
60 !self.is_any_variable()
61 && !matches!(
62 self,
63 Term::Aggregate(_)
64 | Term::List(_)
65 | Term::Cons { .. }
66 | Term::Compound { .. }
67 | Term::PredRef(_)
68 )
69 }
70
71 pub fn variable_name(&self) -> Option<&str> {
73 match self {
74 Term::Variable(name) => Some(name),
75 _ => None,
76 }
77 }
78
79 pub fn variables(&self) -> Vec<&str> {
81 match self {
82 Term::Variable(name) => vec![name.as_str()],
83 Term::List(items) => items.iter().flat_map(Term::variables).collect(),
84 Term::Cons { head, tail } => {
85 let mut vars = head.variables();
86 vars.extend(tail.variables());
87 vars
88 }
89 Term::Compound { args, .. } => args.iter().flat_map(Term::variables).collect(),
90 Term::Anonymous
91 | Term::Integer(_)
92 | Term::Float(_)
93 | Term::String(_)
94 | Term::Symbol(_)
95 | Term::PredRef(_)
96 | Term::Aggregate(_) => vec![],
97 }
98 }
99}
100
101#[derive(Debug, Clone, PartialEq)]
103pub struct AggExpr {
104 pub op: AggOp,
106 pub variable: String,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum AggOp {
113 Count,
115 Sum,
117 Min,
119 Max,
121 LogSumExp,
123}
124
125#[derive(Debug, Clone, PartialEq)]
127pub enum ArithExpr {
128 Variable(String),
130 Integer(i64),
132 Float(f64),
134
135 Add(Box<ArithExpr>, Box<ArithExpr>),
137 Sub(Box<ArithExpr>, Box<ArithExpr>),
139 Mul(Box<ArithExpr>, Box<ArithExpr>),
141 Div(Box<ArithExpr>, Box<ArithExpr>),
143 Mod(Box<ArithExpr>, Box<ArithExpr>),
145
146 Abs(Box<ArithExpr>),
148 Min(Box<ArithExpr>, Box<ArithExpr>),
150 Max(Box<ArithExpr>, Box<ArithExpr>),
152 Pow(Box<ArithExpr>, Box<ArithExpr>),
154
155 Cast(Box<ArithExpr>, ScalarType),
157
158 FuncCall {
160 name: String,
162 args: Vec<ArithExpr>,
164 },
165
166 Conditional {
168 cond_left: Box<ArithExpr>,
170 cond_op: CompOp,
172 cond_right: Box<ArithExpr>,
174 then_expr: Box<ArithExpr>,
176 else_expr: Box<ArithExpr>,
178 },
179}
180
181impl ArithExpr {
182 pub fn variables(&self) -> Vec<&str> {
184 match self {
185 ArithExpr::Variable(name) => vec![name.as_str()],
186 ArithExpr::Integer(_) | ArithExpr::Float(_) => vec![],
187 ArithExpr::Add(l, r)
188 | ArithExpr::Sub(l, r)
189 | ArithExpr::Mul(l, r)
190 | ArithExpr::Div(l, r)
191 | ArithExpr::Mod(l, r)
192 | ArithExpr::Min(l, r)
193 | ArithExpr::Max(l, r)
194 | ArithExpr::Pow(l, r) => {
195 let mut vars = l.variables();
196 vars.extend(r.variables());
197 vars
198 }
199 ArithExpr::Abs(e) | ArithExpr::Cast(e, _) => e.variables(),
200 ArithExpr::FuncCall { args, .. } => args.iter().flat_map(|a| a.variables()).collect(),
201 ArithExpr::Conditional {
202 cond_left,
203 cond_right,
204 then_expr,
205 else_expr,
206 ..
207 } => {
208 let mut vars = cond_left.variables();
209 vars.extend(cond_right.variables());
210 vars.extend(then_expr.variables());
211 vars.extend(else_expr.variables());
212 vars
213 }
214 }
215 }
216}
217
218#[derive(Debug, Clone, PartialEq)]
220pub struct IsExpr {
221 pub target: String,
223 pub expr: ArithExpr,
225}
226
227#[derive(Debug, Clone, PartialEq)]
229pub struct Atom {
230 pub predicate: String,
232 pub terms: Vec<Term>,
234}
235
236impl Atom {
237 pub fn arity(&self) -> usize {
239 self.terms.len()
240 }
241
242 pub fn variables(&self) -> Vec<&str> {
244 self.terms.iter().flat_map(Term::variables).collect()
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
250pub enum EpistemicOp {
251 Know,
253 Possible,
255}
256
257#[derive(Debug, Clone, PartialEq)]
259pub struct EpistemicLiteral {
260 pub op: EpistemicOp,
262 pub negated: bool,
264 pub atom: Atom,
266}
267
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
270pub enum CompOp {
271 Eq,
273 Ne,
275 Lt,
277 Le,
279 Gt,
281 Ge,
283}
284
285#[derive(Debug, Clone, PartialEq)]
287pub struct Comparison {
288 pub left: Term,
290 pub op: CompOp,
292 pub right: Term,
294}
295
296#[derive(Debug, Clone, PartialEq)]
298pub struct Univ {
299 pub term: Term,
301 pub parts: Term,
303}
304
305#[derive(Debug, Clone, PartialEq)]
307pub enum BodyLiteral {
308 Positive(Atom),
310 Negated(Atom),
312 Epistemic(EpistemicLiteral),
314 Comparison(Comparison),
316 IsExpr(IsExpr),
318 Univ(Univ),
320}
321
322impl BodyLiteral {
323 pub fn is_positive(&self) -> bool {
325 matches!(self, BodyLiteral::Positive(_))
326 }
327
328 pub fn is_negated(&self) -> bool {
330 matches!(self, BodyLiteral::Negated(_))
331 }
332
333 pub fn atom(&self) -> Option<&Atom> {
335 match self {
336 BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => Some(a),
337 BodyLiteral::Epistemic(lit) => Some(&lit.atom),
338 BodyLiteral::Comparison(_) | BodyLiteral::IsExpr(_) | BodyLiteral::Univ(_) => None,
339 }
340 }
341
342 pub fn variables(&self) -> Vec<&str> {
344 match self {
345 BodyLiteral::Positive(a) | BodyLiteral::Negated(a) => a.variables(),
346 BodyLiteral::Epistemic(lit) => lit.atom.variables(),
347 BodyLiteral::Comparison(c) => {
348 let mut vars = vec![];
349 vars.extend(c.left.variables());
350 vars.extend(c.right.variables());
351 vars
352 }
353 BodyLiteral::IsExpr(is_expr) => {
354 let mut vars = is_expr.expr.variables();
355 vars.push(is_expr.target.as_str());
356 vars
357 }
358 BodyLiteral::Univ(univ) => {
359 let mut vars = univ.term.variables();
360 vars.extend(univ.parts.variables());
361 vars
362 }
363 }
364 }
365}
366
367#[derive(Debug, Clone, PartialEq)]
369pub struct Rule {
370 pub head: Atom,
372 pub body: Vec<BodyLiteral>,
374}
375
376impl Rule {
377 pub fn is_fact(&self) -> bool {
379 self.body.is_empty()
380 }
381
382 pub fn has_negation(&self) -> bool {
384 self.body.iter().any(|l| l.is_negated())
385 }
386
387 pub fn has_aggregation(&self) -> bool {
389 self.head
390 .terms
391 .iter()
392 .any(|t| matches!(t, Term::Aggregate(_)))
393 }
394
395 pub fn body_predicates(&self) -> Vec<&str> {
397 self.body
398 .iter()
399 .filter_map(|l| l.atom().map(|a| a.predicate.as_str()))
400 .collect()
401 }
402
403 pub fn head_variables(&self) -> Vec<&str> {
405 self.head.variables()
406 }
407
408 pub fn body_variables(&self) -> Vec<&str> {
410 self.body.iter().flat_map(|l| l.variables()).collect()
411 }
412}
413
414#[derive(Debug, Clone, PartialEq)]
416pub struct Constraint {
417 pub body: Vec<BodyLiteral>,
419}
420
421#[derive(Debug, Clone, PartialEq)]
423pub struct Query {
424 pub atom: Atom,
426}
427
428#[derive(Debug, Clone, Copy, PartialEq, Eq)]
430pub enum ProbEngine {
431 ExactDdnnf,
433 Mc,
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq)]
439pub enum ProbCache {
440 On,
442 Off,
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq)]
448pub enum EpistemicMode {
449 G91,
451 Faeel,
453}
454
455#[derive(Debug, Clone, Copy, PartialEq, Eq)]
457pub enum ProbMethod {
458 Rejection,
460 EvidenceClamping,
462}
463
464#[derive(Debug, Clone, Copy, PartialEq, Eq)]
466pub enum MagicSetsMode {
467 Auto,
469 On,
471 Off,
473}
474
475#[derive(Debug, Clone, Default, PartialEq)]
477pub struct Directives {
478 pub prob_engine: Option<ProbEngine>,
480 pub prob_cache: Option<ProbCache>,
482 pub prob_samples: Option<usize>,
484 pub prob_seed: Option<u64>,
486 pub prob_confidence: Option<f64>,
488 pub prob_method: Option<ProbMethod>,
490 pub prob_max_nonmonotone_iterations: Option<usize>,
492 pub max_recursion_depth: Option<u32>,
494 pub epistemic_mode: Option<EpistemicMode>,
496 pub magic_sets: Option<MagicSetsMode>,
498}
499
500impl Directives {
501 pub fn prob_engine_or_default(&self) -> ProbEngine {
503 self.prob_engine.unwrap_or(ProbEngine::ExactDdnnf)
504 }
505
506 pub fn max_recursion_depth_or_default(&self) -> u32 {
508 self.max_recursion_depth.unwrap_or(1000)
509 }
510
511 pub fn epistemic_mode_or_default(&self) -> EpistemicMode {
513 self.epistemic_mode.unwrap_or(EpistemicMode::Faeel)
514 }
515
516 pub fn prob_samples_or_default(&self) -> usize {
518 self.prob_samples.unwrap_or(10000)
519 }
520
521 pub fn prob_seed_or_default(&self) -> u64 {
523 self.prob_seed.unwrap_or(0)
524 }
525
526 pub fn prob_confidence_or_default(&self) -> f64 {
528 self.prob_confidence.unwrap_or(0.95)
529 }
530
531 pub fn prob_max_nonmonotone_iterations_or_default(&self) -> usize {
533 self.prob_max_nonmonotone_iterations.unwrap_or(1024)
534 }
535}
536
537#[derive(Debug, Clone, PartialEq)]
539pub struct ProbFact {
540 pub prob: f64,
542 pub atom: Atom,
544}
545
546#[derive(Debug, Clone, PartialEq)]
560pub struct NeuralPredDecl {
561 pub network: String,
563 pub inputs: Vec<String>,
565 pub output: String,
567 pub labels: Option<Vec<NeuralLabel>>,
570 pub predicate: Atom,
572}
573
574#[derive(Debug, Clone, PartialEq)]
578pub enum NeuralLabel {
579 Integer(i64),
581 Symbol(String),
583}
584
585#[derive(Debug, Clone)]
589pub struct LearnableRule {
590 pub mask_name: String,
592 pub head: Atom,
594 pub body: Vec<BodyLiteral>,
596}
597
598#[derive(Debug, Clone, PartialEq)]
600pub struct AnnotatedDisjunction {
601 pub choices: Vec<ProbFact>,
603}
604
605#[derive(Debug, Clone, PartialEq)]
607pub struct Evidence {
608 pub atom: Atom,
610 pub value: bool,
612}
613
614#[derive(Debug, Clone, PartialEq)]
616pub struct ProbQuery {
617 pub atom: Atom,
619}
620
621#[derive(Debug, Clone, PartialEq)]
623pub struct UseDecl {
624 pub module_path: Vec<String>,
626 pub imports: Option<Vec<String>>,
628}
629
630#[derive(Debug, Clone, PartialEq)]
632pub struct DomainDecl {
633 pub name: String,
635 pub typ: ScalarType,
637}
638
639#[derive(Debug, Clone, PartialEq, Eq)]
641pub enum TypeRef {
642 Scalar(ScalarType),
644 Domain(String),
646 List(Box<TypeRef>),
648 Term,
650 Compound,
652 PredRef,
654}
655
656#[derive(Debug, Clone, PartialEq, Eq)]
658pub struct PredColumn {
659 pub name: Option<String>,
661 pub typ: TypeRef,
663}
664
665#[derive(Debug, Clone, PartialEq)]
667pub struct PredDecl {
668 pub name: String,
670 pub types: Vec<TypeRef>,
672 pub columns: Vec<PredColumn>,
674 pub is_private: bool,
676}
677
678#[derive(Debug, Clone, PartialEq)]
680pub struct FuncParam {
681 pub name: String,
683 pub typ: Option<ScalarType>,
685}
686
687#[derive(Debug, Clone, PartialEq)]
689pub struct CondExpr {
690 pub cond_left: ArithExpr,
692 pub cond_op: CompOp,
694 pub cond_right: ArithExpr,
696 pub then_branch: Box<FuncBody>,
698 pub else_branch: Box<FuncBody>,
700}
701
702#[derive(Debug, Clone, PartialEq)]
704pub enum FuncBody {
705 Arithmetic(ArithExpr),
707 Conditional(CondExpr),
709 Predicate {
711 result: String,
713 body: Vec<BodyLiteral>,
715 },
716}
717
718#[derive(Debug, Clone, PartialEq)]
720pub struct FuncDef {
721 pub name: String,
723 pub params: Vec<FuncParam>,
725 pub return_type: Option<ScalarType>,
727 pub body: FuncBody,
729 pub is_private: bool,
731}
732
733#[derive(Debug, Clone, Default)]
735pub struct Program {
736 pub imports: Vec<UseDecl>,
738 pub functions: Vec<FuncDef>,
740 pub domains: Vec<DomainDecl>,
742 pub predicates: Vec<PredDecl>,
744 pub rules: Vec<Rule>,
746 pub constraints: Vec<Constraint>,
748 pub queries: Vec<Query>,
750 pub prob_facts: Vec<ProbFact>,
752 pub annotated_disjunctions: Vec<AnnotatedDisjunction>,
754 pub evidence: Vec<Evidence>,
756 pub prob_queries: Vec<ProbQuery>,
758 pub neural_predicates: Vec<NeuralPredDecl>,
760 pub learnable_rules: Vec<LearnableRule>,
762 pub directives: Directives,
764}
765
766impl Program {
767 pub fn new() -> Self {
769 Self::default()
770 }
771
772 pub fn facts(&self) -> impl Iterator<Item = &Rule> {
774 self.rules.iter().filter(|r| r.is_fact())
775 }
776
777 pub fn proper_rules(&self) -> impl Iterator<Item = &Rule> {
779 self.rules.iter().filter(|r| !r.is_fact())
780 }
781
782 pub fn defined_predicates(&self) -> Vec<&str> {
784 self.rules
785 .iter()
786 .map(|r| r.head.predicate.as_str())
787 .collect::<std::collections::HashSet<_>>()
788 .into_iter()
789 .collect()
790 }
791
792 pub fn is_probabilistic_profile(&self) -> bool {
794 !self.prob_facts.is_empty()
795 || !self.annotated_disjunctions.is_empty()
796 || !self.evidence.is_empty()
797 || !self.prob_queries.is_empty()
798 || self.directives.prob_engine.is_some()
799 || self.directives.prob_cache.is_some()
800 || self.directives.prob_samples.is_some()
801 || self.directives.prob_seed.is_some()
802 || self.directives.prob_confidence.is_some()
803 || self.directives.prob_method.is_some()
804 || self.directives.prob_max_nonmonotone_iterations.is_some()
805 }
806
807 pub fn prob_engine(&self) -> ProbEngine {
809 self.directives.prob_engine_or_default()
810 }
811
812 pub fn merge_from(
820 &mut self,
821 other: &Program,
822 imported_items: Option<&std::collections::HashSet<String>>,
823 ) {
824 use std::collections::HashSet;
825
826 let private_preds: HashSet<&str> = other
828 .predicates
829 .iter()
830 .filter(|p| p.is_private)
831 .map(|p| p.name.as_str())
832 .collect();
833
834 let _private_funcs: HashSet<&str> = other
835 .functions
836 .iter()
837 .filter(|f| f.is_private)
838 .map(|f| f.name.as_str())
839 .collect();
840
841 for pred in &other.predicates {
843 if pred.is_private {
844 continue;
845 }
846 if let Some(items) = imported_items {
848 if !items.contains(&pred.name) {
849 continue;
850 }
851 }
852 if !self.predicates.iter().any(|p| p.name == pred.name) {
854 self.predicates.push(pred.clone());
855 }
856 }
857
858 for func in &other.functions {
860 if func.is_private {
861 continue;
862 }
863 if let Some(items) = imported_items {
864 if !items.contains(&func.name) {
865 continue;
866 }
867 }
868 if !self.functions.iter().any(|f| f.name == func.name) {
870 self.functions.push(func.clone());
871 }
872 }
873
874 for rule in &other.rules {
876 if private_preds.contains(rule.head.predicate.as_str()) {
878 continue;
879 }
880 if let Some(items) = imported_items {
882 if !items.contains(&rule.head.predicate) {
883 continue;
884 }
885 }
886 self.rules.push(rule.clone());
887 }
888
889 for domain in &other.domains {
891 if !self.domains.iter().any(|d| d.name == domain.name) {
892 self.domains.push(domain.clone());
893 }
894 }
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::*;
901
902 #[test]
903 fn test_term_variable() {
904 let term = Term::Variable("X".to_string());
905 assert!(term.is_variable());
906 assert!(!term.is_constant());
907 }
908
909 #[test]
910 fn test_term_constant() {
911 let term = Term::Integer(42);
912 assert!(!term.is_variable());
913 assert!(term.is_constant());
914 }
915
916 #[test]
917 fn test_atom_arity() {
918 let atom = Atom {
919 predicate: "edge".to_string(),
920 terms: vec![Term::Integer(1), Term::Integer(2)],
921 };
922 assert_eq!(atom.arity(), 2);
923 }
924
925 #[test]
926 fn test_atom_variables() {
927 let atom = Atom {
928 predicate: "edge".to_string(),
929 terms: vec![Term::Variable("X".to_string()), Term::Integer(2)],
930 };
931 let vars = atom.variables();
932 assert_eq!(vars, vec!["X"]);
933 }
934
935 #[test]
936 fn test_rule_is_fact() {
937 let fact = Rule {
938 head: Atom {
939 predicate: "edge".to_string(),
940 terms: vec![Term::Integer(1), Term::Integer(2)],
941 },
942 body: vec![],
943 };
944 assert!(fact.is_fact());
945 }
946
947 #[test]
948 fn test_rule_has_negation() {
949 let rule = Rule {
950 head: Atom {
951 predicate: "isolated".to_string(),
952 terms: vec![Term::Variable("X".to_string())],
953 },
954 body: vec![
955 BodyLiteral::Positive(Atom {
956 predicate: "node".to_string(),
957 terms: vec![Term::Variable("X".to_string())],
958 }),
959 BodyLiteral::Negated(Atom {
960 predicate: "edge".to_string(),
961 terms: vec![
962 Term::Variable("X".to_string()),
963 Term::Variable("Y".to_string()),
964 ],
965 }),
966 ],
967 };
968 assert!(rule.has_negation());
969 }
970
971 #[test]
972 fn test_program_facts() {
973 let mut program = Program::new();
974 program.rules.push(Rule {
975 head: Atom {
976 predicate: "edge".to_string(),
977 terms: vec![Term::Integer(1), Term::Integer(2)],
978 },
979 body: vec![],
980 });
981 program.rules.push(Rule {
982 head: Atom {
983 predicate: "reach".to_string(),
984 terms: vec![
985 Term::Variable("X".to_string()),
986 Term::Variable("Y".to_string()),
987 ],
988 },
989 body: vec![BodyLiteral::Positive(Atom {
990 predicate: "edge".to_string(),
991 terms: vec![
992 Term::Variable("X".to_string()),
993 Term::Variable("Y".to_string()),
994 ],
995 })],
996 });
997 assert_eq!(program.facts().count(), 1);
998 assert_eq!(program.proper_rules().count(), 1);
999 }
1000
1001 #[test]
1002 fn test_arith_expr_structure() {
1003 let expr = ArithExpr::Add(
1004 Box::new(ArithExpr::Variable("X".to_string())),
1005 Box::new(ArithExpr::Integer(1)),
1006 );
1007 assert!(matches!(expr, ArithExpr::Add(_, _)));
1008 }
1009
1010 #[test]
1011 fn test_is_expr_structure() {
1012 let is_expr = IsExpr {
1013 target: "Z".to_string(),
1014 expr: ArithExpr::Variable("Y".to_string()),
1015 };
1016 assert_eq!(is_expr.target, "Z");
1017 }
1018}