Skip to main content

rust_rule_engine/backward/
disjunction.rs

1//! Disjunction (OR) support for backward chaining queries
2//!
3//! This module implements OR patterns in queries, allowing multiple alternative
4//! conditions to be specified. The query succeeds if ANY of the alternatives succeed.
5//!
6//! # Examples
7//!
8//! ```rust,ignore
9//! // Find people who are either managers OR seniors
10//! let results = engine.query(
11//!     "eligible(?person) WHERE (manager(?person) OR senior(?person))",
12//!     &mut facts
13//! )?;
14//!
15//! // Complex OR with multiple conditions
16//! let results = engine.query(
17//!     "discount(?customer) WHERE (vip(?customer) OR total_spent(?customer, ?amt) > 10000)",
18//!     &mut facts
19//! )?;
20//! ```
21
22use super::goal::Goal;
23use super::unification::Bindings;
24use std::collections::HashSet;
25
26/// Represents a disjunction (OR) of goals
27#[derive(Debug, Clone)]
28pub struct Disjunction {
29    /// Alternative goals - at least one must succeed
30    pub branches: Vec<Goal>,
31
32    /// Original pattern string
33    pub pattern: String,
34}
35
36impl Disjunction {
37    /// Create a new disjunction from a list of goals
38    pub fn new(branches: Vec<Goal>, pattern: String) -> Self {
39        assert!(
40            !branches.is_empty(),
41            "Disjunction must have at least one branch"
42        );
43        Self { branches, pattern }
44    }
45
46    /// Create a disjunction from two goals
47    pub fn from_pair(left: Goal, right: Goal) -> Self {
48        let pattern = format!("({} OR {})", left.pattern, right.pattern);
49        Self {
50            branches: vec![left, right],
51            pattern,
52        }
53    }
54
55    /// Add another branch to this disjunction
56    pub fn add_branch(&mut self, goal: Goal) {
57        self.branches.push(goal);
58    }
59
60    /// Get the number of branches
61    pub fn branch_count(&self) -> usize {
62        self.branches.len()
63    }
64}
65
66/// Result of evaluating a disjunction
67#[derive(Debug, Clone)]
68pub struct DisjunctionResult {
69    /// All solutions from all branches
70    pub solutions: Vec<Bindings>,
71
72    /// Which branches succeeded (by index)
73    pub successful_branches: Vec<usize>,
74
75    /// Whether the disjunction as a whole succeeded
76    pub success: bool,
77}
78
79impl DisjunctionResult {
80    /// Create a new result
81    pub fn new() -> Self {
82        Self {
83            solutions: Vec::new(),
84            successful_branches: Vec::new(),
85            success: false,
86        }
87    }
88
89    /// Create a successful result
90    pub fn success(solutions: Vec<Bindings>, successful_branches: Vec<usize>) -> Self {
91        Self {
92            solutions,
93            successful_branches,
94            success: true,
95        }
96    }
97
98    /// Create a failed result
99    pub fn failure() -> Self {
100        Self {
101            solutions: Vec::new(),
102            successful_branches: Vec::new(),
103            success: false,
104        }
105    }
106
107    /// Add solutions from a branch
108    pub fn add_branch_solutions(&mut self, branch_index: usize, solutions: Vec<Bindings>) {
109        if !solutions.is_empty() {
110            self.successful_branches.push(branch_index);
111            self.solutions.extend(solutions);
112            self.success = true;
113        }
114    }
115
116    /// Deduplicate solutions based on variable bindings
117    pub fn deduplicate(&mut self) {
118        // Use a set to track unique binding combinations
119        let mut seen = HashSet::new();
120        let mut unique_solutions = Vec::new();
121
122        for solution in &self.solutions {
123            let binding_map = solution.to_map();
124            let key = format!("{:?}", binding_map);
125
126            if seen.insert(key) {
127                unique_solutions.push(solution.clone());
128            }
129        }
130
131        self.solutions = unique_solutions;
132    }
133
134    /// Get the total number of solutions
135    pub fn solution_count(&self) -> usize {
136        self.solutions.len()
137    }
138}
139
140impl Default for DisjunctionResult {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// Split a string by " OR " only at the top level (paren_depth == 0),
147/// respecting nested parentheses and quoted strings.
148///
149/// For example:
150/// - `"A OR B"` → `["A", "B"]`
151/// - `"A OR (B AND C)"` → `["A", "(B AND C)"]`
152/// - `"(A OR B) OR (C OR D)"` → `["(A OR B)", "(C OR D)"]`
153/// - `"func(a, OR b) OR c"` → `["func(a, OR b)", "c"]`
154fn split_top_level_or(input: &str) -> Vec<String> {
155    let mut parts = Vec::new();
156    let mut current = String::new();
157    let mut paren_depth: i32 = 0;
158    let mut in_string = false;
159    let chars: Vec<char> = input.chars().collect();
160    let len = chars.len();
161    let mut i = 0;
162
163    while i < len {
164        let ch = chars[i];
165
166        match ch {
167            '"' if !in_string => {
168                in_string = true;
169                current.push(ch);
170            }
171            '"' if in_string => {
172                in_string = false;
173                current.push(ch);
174            }
175            '(' if !in_string => {
176                paren_depth += 1;
177                current.push(ch);
178            }
179            ')' if !in_string => {
180                paren_depth -= 1;
181                current.push(ch);
182            }
183            ' ' if !in_string && paren_depth == 0 => {
184                // Check if we're at " OR " boundary
185                if i + 4 <= len && &input[i..i + 4] == " OR " {
186                    let trimmed = current.trim().to_string();
187                    if !trimmed.is_empty() {
188                        parts.push(trimmed);
189                    }
190                    current.clear();
191                    i += 4; // skip " OR "
192                    continue;
193                }
194                current.push(ch);
195            }
196            _ => {
197                current.push(ch);
198            }
199        }
200        i += 1;
201    }
202
203    let trimmed = current.trim().to_string();
204    if !trimmed.is_empty() {
205        parts.push(trimmed);
206    }
207
208    parts
209}
210
211/// Parser for OR patterns in queries
212pub struct DisjunctionParser;
213
214impl DisjunctionParser {
215    /// Parse a pattern that might contain OR
216    ///
217    /// Examples:
218    /// - "(A OR B)" -> Disjunction with 2 branches
219    /// - "(A OR B OR C)" -> Disjunction with 3 branches
220    /// - "A" -> None (no OR, single goal)
221    pub fn parse(pattern: &str) -> Option<Disjunction> {
222        let pattern = pattern.trim();
223
224        // Check if pattern starts with '(' and ends with ')'
225        if !pattern.starts_with('(') || !pattern.ends_with(')') {
226            return None;
227        }
228
229        // Remove outer parentheses
230        let inner = &pattern[1..pattern.len() - 1];
231
232        if !inner.contains(" OR ") {
233            return None;
234        }
235
236        // Split by " OR " at the top level only, respecting nested parentheses
237        // and quoted strings so that patterns like "(A OR (B AND C))" split correctly.
238        let parts = split_top_level_or(inner);
239
240        let branches: Vec<Goal> = parts
241            .into_iter()
242            .map(|s| Goal::new(s.trim().to_string()))
243            .collect();
244
245        if branches.len() < 2 {
246            return None;
247        }
248
249        Some(Disjunction::new(branches, pattern.to_string()))
250    }
251
252    /// Check if a pattern contains a top-level OR (not inside nested parentheses)
253    pub fn contains_or(pattern: &str) -> bool {
254        let parts = split_top_level_or(pattern);
255        parts.len() > 1
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_disjunction_creation() {
265        let goal1 = Goal::new("manager(?person)".to_string());
266        let goal2 = Goal::new("senior(?person)".to_string());
267
268        let disj = Disjunction::from_pair(goal1, goal2);
269
270        assert_eq!(disj.branch_count(), 2);
271        assert!(disj.pattern.contains("OR"));
272    }
273
274    #[test]
275    fn test_disjunction_add_branch() {
276        let goal1 = Goal::new("manager(?person)".to_string());
277        let goal2 = Goal::new("senior(?person)".to_string());
278        let goal3 = Goal::new("director(?person)".to_string());
279
280        let mut disj = Disjunction::from_pair(goal1, goal2);
281        disj.add_branch(goal3);
282
283        assert_eq!(disj.branch_count(), 3);
284    }
285
286    #[test]
287    fn test_disjunction_result_success() {
288        let mut result = DisjunctionResult::new();
289
290        let bindings1 = Bindings::new();
291        let bindings2 = Bindings::new();
292
293        result.add_branch_solutions(0, vec![bindings1]);
294        result.add_branch_solutions(1, vec![bindings2]);
295
296        assert!(result.success);
297        assert_eq!(result.solution_count(), 2);
298        assert_eq!(result.successful_branches.len(), 2);
299    }
300
301    #[test]
302    fn test_disjunction_result_empty() {
303        let mut result = DisjunctionResult::new();
304
305        result.add_branch_solutions(0, vec![]);
306        result.add_branch_solutions(1, vec![]);
307
308        assert!(!result.success);
309        assert_eq!(result.solution_count(), 0);
310    }
311
312    #[test]
313    fn test_parser_simple_or() {
314        let pattern = "(manager(?person) OR senior(?person))";
315        let disj = DisjunctionParser::parse(pattern);
316
317        assert!(disj.is_some());
318        let disj = disj.unwrap();
319        assert_eq!(disj.branch_count(), 2);
320    }
321
322    #[test]
323    fn test_parser_triple_or() {
324        let pattern = "(A OR B OR C)";
325        let disj = DisjunctionParser::parse(pattern);
326
327        assert!(disj.is_some());
328        let disj = disj.unwrap();
329        assert_eq!(disj.branch_count(), 3);
330    }
331
332    #[test]
333    fn test_parser_no_or() {
334        let pattern = "manager(?person)";
335        let disj = DisjunctionParser::parse(pattern);
336
337        assert!(disj.is_none());
338    }
339
340    #[test]
341    fn test_parser_contains_or() {
342        assert!(DisjunctionParser::contains_or("A OR B"));
343        assert!(!DisjunctionParser::contains_or("A AND B"));
344        // OR inside parentheses is not top-level
345        assert!(!DisjunctionParser::contains_or("(A OR B)"));
346    }
347
348    #[test]
349    fn test_parser_nested_parens() {
350        // "(A OR (B AND C))" should split into ["A", "(B AND C)"]
351        let pattern = "(A OR (B AND C))";
352        let disj = DisjunctionParser::parse(pattern).unwrap();
353        assert_eq!(disj.branch_count(), 2);
354        assert_eq!(disj.branches[0].pattern, "A");
355        assert_eq!(disj.branches[1].pattern, "(B AND C)");
356    }
357
358    #[test]
359    fn test_parser_nested_or_groups() {
360        // "((A OR B) OR (C OR D))" should split at top-level only
361        let pattern = "((A OR B) OR (C OR D))";
362        let disj = DisjunctionParser::parse(pattern).unwrap();
363        assert_eq!(disj.branch_count(), 2);
364        assert_eq!(disj.branches[0].pattern, "(A OR B)");
365        assert_eq!(disj.branches[1].pattern, "(C OR D)");
366    }
367
368    #[test]
369    fn test_parser_function_args_with_or_keyword() {
370        // OR inside function arguments should not be treated as a split point
371        let pattern = "(func(a, OR, b) OR c)";
372        let disj = DisjunctionParser::parse(pattern).unwrap();
373        assert_eq!(disj.branch_count(), 2);
374        assert_eq!(disj.branches[0].pattern, "func(a, OR, b)");
375        assert_eq!(disj.branches[1].pattern, "c");
376    }
377
378    #[test]
379    fn test_parser_deeply_nested() {
380        let pattern = "(A OR (B OR (C AND D)))";
381        let disj = DisjunctionParser::parse(pattern).unwrap();
382        assert_eq!(disj.branch_count(), 2);
383        assert_eq!(disj.branches[0].pattern, "A");
384        assert_eq!(disj.branches[1].pattern, "(B OR (C AND D))");
385    }
386
387    #[test]
388    fn test_contains_or_nested() {
389        // OR inside parens should not count as top-level
390        assert!(!DisjunctionParser::contains_or("(A OR B)"));
391        // OR at top level should count
392        assert!(DisjunctionParser::contains_or("A OR B"));
393        // OR only inside nested parens
394        assert!(!DisjunctionParser::contains_or("func(A OR B)"));
395    }
396
397    #[test]
398    fn test_split_top_level_or_basic() {
399        let parts = split_top_level_or("A OR B OR C");
400        assert_eq!(parts, vec!["A", "B", "C"]);
401    }
402
403    #[test]
404    fn test_split_top_level_or_with_quotes() {
405        let parts = split_top_level_or(r#""hello OR world" OR B"#);
406        assert_eq!(parts.len(), 2);
407        assert_eq!(parts[0], r#""hello OR world""#);
408        assert_eq!(parts[1], "B");
409    }
410
411    #[test]
412    fn test_deduplication() {
413        let mut result = DisjunctionResult::new();
414
415        // Add duplicate solutions
416        let bindings = Bindings::new();
417        result.add_branch_solutions(0, vec![bindings.clone(), bindings.clone()]);
418
419        assert_eq!(result.solution_count(), 2);
420
421        result.deduplicate();
422
423        assert_eq!(result.solution_count(), 1);
424    }
425}