rust_rule_engine/backward/
disjunction.rs

1//! Disjunction (OR) support for backward chaining queries
2//!
3//! This module implements OR patterns in queries, allowing multiple alternative
4//! conditions to be specified. The query succeeds if ANY of the alternatives succeed.
5//!
6//! # Examples
7//!
8//! ```rust,ignore
9//! // Find people who are either managers OR seniors
10//! let results = engine.query(
11//!     "eligible(?person) WHERE (manager(?person) OR senior(?person))",
12//!     &mut facts
13//! )?;
14//!
15//! // Complex OR with multiple conditions
16//! let results = engine.query(
17//!     "discount(?customer) WHERE (vip(?customer) OR total_spent(?customer, ?amt) > 10000)",
18//!     &mut facts
19//! )?;
20//! ```
21
22use super::goal::Goal;
23use super::unification::Bindings;
24use std::collections::HashSet;
25
26/// Represents a disjunction (OR) of goals
27#[derive(Debug, Clone)]
28pub struct Disjunction {
29    /// Alternative goals - at least one must succeed
30    pub branches: Vec<Goal>,
31
32    /// Original pattern string
33    pub pattern: String,
34}
35
36impl Disjunction {
37    /// Create a new disjunction from a list of goals
38    pub fn new(branches: Vec<Goal>, pattern: String) -> Self {
39        assert!(
40            !branches.is_empty(),
41            "Disjunction must have at least one branch"
42        );
43        Self { branches, pattern }
44    }
45
46    /// Create a disjunction from two goals
47    pub fn from_pair(left: Goal, right: Goal) -> Self {
48        let pattern = format!("({} OR {})", left.pattern, right.pattern);
49        Self {
50            branches: vec![left, right],
51            pattern,
52        }
53    }
54
55    /// Add another branch to this disjunction
56    pub fn add_branch(&mut self, goal: Goal) {
57        self.branches.push(goal);
58    }
59
60    /// Get the number of branches
61    pub fn branch_count(&self) -> usize {
62        self.branches.len()
63    }
64}
65
66/// Result of evaluating a disjunction
67#[derive(Debug, Clone)]
68pub struct DisjunctionResult {
69    /// All solutions from all branches
70    pub solutions: Vec<Bindings>,
71
72    /// Which branches succeeded (by index)
73    pub successful_branches: Vec<usize>,
74
75    /// Whether the disjunction as a whole succeeded
76    pub success: bool,
77}
78
79impl DisjunctionResult {
80    /// Create a new result
81    pub fn new() -> Self {
82        Self {
83            solutions: Vec::new(),
84            successful_branches: Vec::new(),
85            success: false,
86        }
87    }
88
89    /// Create a successful result
90    pub fn success(solutions: Vec<Bindings>, successful_branches: Vec<usize>) -> Self {
91        Self {
92            solutions,
93            successful_branches,
94            success: true,
95        }
96    }
97
98    /// Create a failed result
99    pub fn failure() -> Self {
100        Self {
101            solutions: Vec::new(),
102            successful_branches: Vec::new(),
103            success: false,
104        }
105    }
106
107    /// Add solutions from a branch
108    pub fn add_branch_solutions(&mut self, branch_index: usize, solutions: Vec<Bindings>) {
109        if !solutions.is_empty() {
110            self.successful_branches.push(branch_index);
111            self.solutions.extend(solutions);
112            self.success = true;
113        }
114    }
115
116    /// Deduplicate solutions based on variable bindings
117    pub fn deduplicate(&mut self) {
118        // Use a set to track unique binding combinations
119        let mut seen = HashSet::new();
120        let mut unique_solutions = Vec::new();
121
122        for solution in &self.solutions {
123            let binding_map = solution.to_map();
124            let key = format!("{:?}", binding_map);
125
126            if seen.insert(key) {
127                unique_solutions.push(solution.clone());
128            }
129        }
130
131        self.solutions = unique_solutions;
132    }
133
134    /// Get the total number of solutions
135    pub fn solution_count(&self) -> usize {
136        self.solutions.len()
137    }
138}
139
140impl Default for DisjunctionResult {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// Parser for OR patterns in queries
147pub struct DisjunctionParser;
148
149impl DisjunctionParser {
150    /// Parse a pattern that might contain OR
151    ///
152    /// Examples:
153    /// - "(A OR B)" -> Disjunction with 2 branches
154    /// - "(A OR B OR C)" -> Disjunction with 3 branches
155    /// - "A" -> None (no OR, single goal)
156    pub fn parse(pattern: &str) -> Option<Disjunction> {
157        let pattern = pattern.trim();
158
159        // Check if pattern starts with '(' and ends with ')'
160        if !pattern.starts_with('(') || !pattern.ends_with(')') {
161            return None;
162        }
163
164        // Remove outer parentheses
165        let inner = &pattern[1..pattern.len() - 1];
166
167        // Split by OR (naive implementation - TODO: handle nested parens)
168        if !inner.contains(" OR ") {
169            return None;
170        }
171
172        let branches: Vec<Goal> = inner
173            .split(" OR ")
174            .map(|s| Goal::new(s.trim().to_string()))
175            .collect();
176
177        if branches.len() < 2 {
178            return None;
179        }
180
181        Some(Disjunction::new(branches, pattern.to_string()))
182    }
183
184    /// Check if a pattern contains OR
185    pub fn contains_or(pattern: &str) -> bool {
186        pattern.contains(" OR ")
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_disjunction_creation() {
196        let goal1 = Goal::new("manager(?person)".to_string());
197        let goal2 = Goal::new("senior(?person)".to_string());
198
199        let disj = Disjunction::from_pair(goal1, goal2);
200
201        assert_eq!(disj.branch_count(), 2);
202        assert!(disj.pattern.contains("OR"));
203    }
204
205    #[test]
206    fn test_disjunction_add_branch() {
207        let goal1 = Goal::new("manager(?person)".to_string());
208        let goal2 = Goal::new("senior(?person)".to_string());
209        let goal3 = Goal::new("director(?person)".to_string());
210
211        let mut disj = Disjunction::from_pair(goal1, goal2);
212        disj.add_branch(goal3);
213
214        assert_eq!(disj.branch_count(), 3);
215    }
216
217    #[test]
218    fn test_disjunction_result_success() {
219        let mut result = DisjunctionResult::new();
220
221        let bindings1 = Bindings::new();
222        let bindings2 = Bindings::new();
223
224        result.add_branch_solutions(0, vec![bindings1]);
225        result.add_branch_solutions(1, vec![bindings2]);
226
227        assert!(result.success);
228        assert_eq!(result.solution_count(), 2);
229        assert_eq!(result.successful_branches.len(), 2);
230    }
231
232    #[test]
233    fn test_disjunction_result_empty() {
234        let mut result = DisjunctionResult::new();
235
236        result.add_branch_solutions(0, vec![]);
237        result.add_branch_solutions(1, vec![]);
238
239        assert!(!result.success);
240        assert_eq!(result.solution_count(), 0);
241    }
242
243    #[test]
244    fn test_parser_simple_or() {
245        let pattern = "(manager(?person) OR senior(?person))";
246        let disj = DisjunctionParser::parse(pattern);
247
248        assert!(disj.is_some());
249        let disj = disj.unwrap();
250        assert_eq!(disj.branch_count(), 2);
251    }
252
253    #[test]
254    fn test_parser_triple_or() {
255        let pattern = "(A OR B OR C)";
256        let disj = DisjunctionParser::parse(pattern);
257
258        assert!(disj.is_some());
259        let disj = disj.unwrap();
260        assert_eq!(disj.branch_count(), 3);
261    }
262
263    #[test]
264    fn test_parser_no_or() {
265        let pattern = "manager(?person)";
266        let disj = DisjunctionParser::parse(pattern);
267
268        assert!(disj.is_none());
269    }
270
271    #[test]
272    fn test_parser_contains_or() {
273        assert!(DisjunctionParser::contains_or("(A OR B)"));
274        assert!(!DisjunctionParser::contains_or("A AND B"));
275    }
276
277    #[test]
278    fn test_deduplication() {
279        let mut result = DisjunctionResult::new();
280
281        // Add duplicate solutions
282        let bindings = Bindings::new();
283        result.add_branch_solutions(0, vec![bindings.clone(), bindings.clone()]);
284
285        assert_eq!(result.solution_count(), 2);
286
287        result.deduplicate();
288
289        assert_eq!(result.solution_count(), 1);
290    }
291}