rust_rule_engine/backward/
disjunction.rs

1//! Disjunction (OR) support for backward chaining queries
2//!
3//! This module implements OR patterns in queries, allowing multiple alternative
4//! conditions to be specified. The query succeeds if ANY of the alternatives succeed.
5//!
6//! # Examples
7//!
8//! ```rust,ignore
9//! // Find people who are either managers OR seniors
10//! let results = engine.query(
11//!     "eligible(?person) WHERE (manager(?person) OR senior(?person))",
12//!     &mut facts
13//! )?;
14//!
15//! // Complex OR with multiple conditions
16//! let results = engine.query(
17//!     "discount(?customer) WHERE (vip(?customer) OR total_spent(?customer, ?amt) > 10000)",
18//!     &mut facts
19//! )?;
20//! ```
21
22use super::goal::Goal;
23use super::unification::Bindings;
24use std::collections::HashSet;
25
26/// Represents a disjunction (OR) of goals
27#[derive(Debug, Clone)]
28pub struct Disjunction {
29    /// Alternative goals - at least one must succeed
30    pub branches: Vec<Goal>,
31
32    /// Original pattern string
33    pub pattern: String,
34}
35
36impl Disjunction {
37    /// Create a new disjunction from a list of goals
38    pub fn new(branches: Vec<Goal>, pattern: String) -> Self {
39        assert!(!branches.is_empty(), "Disjunction must have at least one branch");
40        Self { branches, pattern }
41    }
42
43    /// Create a disjunction from two goals
44    pub fn from_pair(left: Goal, right: Goal) -> Self {
45        let pattern = format!("({} OR {})", left.pattern, right.pattern);
46        Self {
47            branches: vec![left, right],
48            pattern,
49        }
50    }
51
52    /// Add another branch to this disjunction
53    pub fn add_branch(&mut self, goal: Goal) {
54        self.branches.push(goal);
55    }
56
57    /// Get the number of branches
58    pub fn branch_count(&self) -> usize {
59        self.branches.len()
60    }
61}
62
63/// Result of evaluating a disjunction
64#[derive(Debug, Clone)]
65pub struct DisjunctionResult {
66    /// All solutions from all branches
67    pub solutions: Vec<Bindings>,
68
69    /// Which branches succeeded (by index)
70    pub successful_branches: Vec<usize>,
71
72    /// Whether the disjunction as a whole succeeded
73    pub success: bool,
74}
75
76impl DisjunctionResult {
77    /// Create a new result
78    pub fn new() -> Self {
79        Self {
80            solutions: Vec::new(),
81            successful_branches: Vec::new(),
82            success: false,
83        }
84    }
85
86    /// Create a successful result
87    pub fn success(solutions: Vec<Bindings>, successful_branches: Vec<usize>) -> Self {
88        Self {
89            solutions,
90            successful_branches,
91            success: true,
92        }
93    }
94
95    /// Create a failed result
96    pub fn failure() -> Self {
97        Self {
98            solutions: Vec::new(),
99            successful_branches: Vec::new(),
100            success: false,
101        }
102    }
103
104    /// Add solutions from a branch
105    pub fn add_branch_solutions(&mut self, branch_index: usize, solutions: Vec<Bindings>) {
106        if !solutions.is_empty() {
107            self.successful_branches.push(branch_index);
108            self.solutions.extend(solutions);
109            self.success = true;
110        }
111    }
112
113    /// Deduplicate solutions based on variable bindings
114    pub fn deduplicate(&mut self) {
115        // Use a set to track unique binding combinations
116        let mut seen = HashSet::new();
117        let mut unique_solutions = Vec::new();
118
119        for solution in &self.solutions {
120            let binding_map = solution.to_map();
121            let key = format!("{:?}", binding_map);
122
123            if seen.insert(key) {
124                unique_solutions.push(solution.clone());
125            }
126        }
127
128        self.solutions = unique_solutions;
129    }
130
131    /// Get the total number of solutions
132    pub fn solution_count(&self) -> usize {
133        self.solutions.len()
134    }
135}
136
137impl Default for DisjunctionResult {
138    fn default() -> Self {
139        Self::new()
140    }
141}
142
143/// Parser for OR patterns in queries
144pub struct DisjunctionParser;
145
146impl DisjunctionParser {
147    /// Parse a pattern that might contain OR
148    ///
149    /// Examples:
150    /// - "(A OR B)" -> Disjunction with 2 branches
151    /// - "(A OR B OR C)" -> Disjunction with 3 branches
152    /// - "A" -> None (no OR, single goal)
153    pub fn parse(pattern: &str) -> Option<Disjunction> {
154        let pattern = pattern.trim();
155
156        // Check if pattern starts with '(' and ends with ')'
157        if !pattern.starts_with('(') || !pattern.ends_with(')') {
158            return None;
159        }
160
161        // Remove outer parentheses
162        let inner = &pattern[1..pattern.len()-1];
163
164        // Split by OR (naive implementation - TODO: handle nested parens)
165        if !inner.contains(" OR ") {
166            return None;
167        }
168
169        let branches: Vec<Goal> = inner
170            .split(" OR ")
171            .map(|s| Goal::new(s.trim().to_string()))
172            .collect();
173
174        if branches.len() < 2 {
175            return None;
176        }
177
178        Some(Disjunction::new(branches, pattern.to_string()))
179    }
180
181    /// Check if a pattern contains OR
182    pub fn contains_or(pattern: &str) -> bool {
183        pattern.contains(" OR ")
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_disjunction_creation() {
193        let goal1 = Goal::new("manager(?person)".to_string());
194        let goal2 = Goal::new("senior(?person)".to_string());
195
196        let disj = Disjunction::from_pair(goal1, goal2);
197
198        assert_eq!(disj.branch_count(), 2);
199        assert!(disj.pattern.contains("OR"));
200    }
201
202    #[test]
203    fn test_disjunction_add_branch() {
204        let goal1 = Goal::new("manager(?person)".to_string());
205        let goal2 = Goal::new("senior(?person)".to_string());
206        let goal3 = Goal::new("director(?person)".to_string());
207
208        let mut disj = Disjunction::from_pair(goal1, goal2);
209        disj.add_branch(goal3);
210
211        assert_eq!(disj.branch_count(), 3);
212    }
213
214    #[test]
215    fn test_disjunction_result_success() {
216        let mut result = DisjunctionResult::new();
217
218        let bindings1 = Bindings::new();
219        let bindings2 = Bindings::new();
220
221        result.add_branch_solutions(0, vec![bindings1]);
222        result.add_branch_solutions(1, vec![bindings2]);
223
224        assert!(result.success);
225        assert_eq!(result.solution_count(), 2);
226        assert_eq!(result.successful_branches.len(), 2);
227    }
228
229    #[test]
230    fn test_disjunction_result_empty() {
231        let mut result = DisjunctionResult::new();
232
233        result.add_branch_solutions(0, vec![]);
234        result.add_branch_solutions(1, vec![]);
235
236        assert!(!result.success);
237        assert_eq!(result.solution_count(), 0);
238    }
239
240    #[test]
241    fn test_parser_simple_or() {
242        let pattern = "(manager(?person) OR senior(?person))";
243        let disj = DisjunctionParser::parse(pattern);
244
245        assert!(disj.is_some());
246        let disj = disj.unwrap();
247        assert_eq!(disj.branch_count(), 2);
248    }
249
250    #[test]
251    fn test_parser_triple_or() {
252        let pattern = "(A OR B OR C)";
253        let disj = DisjunctionParser::parse(pattern);
254
255        assert!(disj.is_some());
256        let disj = disj.unwrap();
257        assert_eq!(disj.branch_count(), 3);
258    }
259
260    #[test]
261    fn test_parser_no_or() {
262        let pattern = "manager(?person)";
263        let disj = DisjunctionParser::parse(pattern);
264
265        assert!(disj.is_none());
266    }
267
268    #[test]
269    fn test_parser_contains_or() {
270        assert!(DisjunctionParser::contains_or("(A OR B)"));
271        assert!(!DisjunctionParser::contains_or("A AND B"));
272    }
273
274    #[test]
275    fn test_deduplication() {
276        let mut result = DisjunctionResult::new();
277
278        // Add duplicate solutions
279        let bindings = Bindings::new();
280        result.add_branch_solutions(0, vec![bindings.clone(), bindings.clone()]);
281
282        assert_eq!(result.solution_count(), 2);
283
284        result.deduplicate();
285
286        assert_eq!(result.solution_count(), 1);
287    }
288}