rust_rule_engine/backward/
disjunction.rs1use super::goal::Goal;
23use super::unification::Bindings;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28pub struct Disjunction {
29 pub branches: Vec<Goal>,
31
32 pub pattern: String,
34}
35
36impl Disjunction {
37 pub fn new(branches: Vec<Goal>, pattern: String) -> Self {
39 assert!(
40 !branches.is_empty(),
41 "Disjunction must have at least one branch"
42 );
43 Self { branches, pattern }
44 }
45
46 pub fn from_pair(left: Goal, right: Goal) -> Self {
48 let pattern = format!("({} OR {})", left.pattern, right.pattern);
49 Self {
50 branches: vec![left, right],
51 pattern,
52 }
53 }
54
55 pub fn add_branch(&mut self, goal: Goal) {
57 self.branches.push(goal);
58 }
59
60 pub fn branch_count(&self) -> usize {
62 self.branches.len()
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct DisjunctionResult {
69 pub solutions: Vec<Bindings>,
71
72 pub successful_branches: Vec<usize>,
74
75 pub success: bool,
77}
78
79impl DisjunctionResult {
80 pub fn new() -> Self {
82 Self {
83 solutions: Vec::new(),
84 successful_branches: Vec::new(),
85 success: false,
86 }
87 }
88
89 pub fn success(solutions: Vec<Bindings>, successful_branches: Vec<usize>) -> Self {
91 Self {
92 solutions,
93 successful_branches,
94 success: true,
95 }
96 }
97
98 pub fn failure() -> Self {
100 Self {
101 solutions: Vec::new(),
102 successful_branches: Vec::new(),
103 success: false,
104 }
105 }
106
107 pub fn add_branch_solutions(&mut self, branch_index: usize, solutions: Vec<Bindings>) {
109 if !solutions.is_empty() {
110 self.successful_branches.push(branch_index);
111 self.solutions.extend(solutions);
112 self.success = true;
113 }
114 }
115
116 pub fn deduplicate(&mut self) {
118 let mut seen = HashSet::new();
120 let mut unique_solutions = Vec::new();
121
122 for solution in &self.solutions {
123 let binding_map = solution.to_map();
124 let key = format!("{:?}", binding_map);
125
126 if seen.insert(key) {
127 unique_solutions.push(solution.clone());
128 }
129 }
130
131 self.solutions = unique_solutions;
132 }
133
134 pub fn solution_count(&self) -> usize {
136 self.solutions.len()
137 }
138}
139
140impl Default for DisjunctionResult {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146pub struct DisjunctionParser;
148
149impl DisjunctionParser {
150 pub fn parse(pattern: &str) -> Option<Disjunction> {
157 let pattern = pattern.trim();
158
159 if !pattern.starts_with('(') || !pattern.ends_with(')') {
161 return None;
162 }
163
164 let inner = &pattern[1..pattern.len() - 1];
166
167 if !inner.contains(" OR ") {
169 return None;
170 }
171
172 let branches: Vec<Goal> = inner
173 .split(" OR ")
174 .map(|s| Goal::new(s.trim().to_string()))
175 .collect();
176
177 if branches.len() < 2 {
178 return None;
179 }
180
181 Some(Disjunction::new(branches, pattern.to_string()))
182 }
183
184 pub fn contains_or(pattern: &str) -> bool {
186 pattern.contains(" OR ")
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_disjunction_creation() {
196 let goal1 = Goal::new("manager(?person)".to_string());
197 let goal2 = Goal::new("senior(?person)".to_string());
198
199 let disj = Disjunction::from_pair(goal1, goal2);
200
201 assert_eq!(disj.branch_count(), 2);
202 assert!(disj.pattern.contains("OR"));
203 }
204
205 #[test]
206 fn test_disjunction_add_branch() {
207 let goal1 = Goal::new("manager(?person)".to_string());
208 let goal2 = Goal::new("senior(?person)".to_string());
209 let goal3 = Goal::new("director(?person)".to_string());
210
211 let mut disj = Disjunction::from_pair(goal1, goal2);
212 disj.add_branch(goal3);
213
214 assert_eq!(disj.branch_count(), 3);
215 }
216
217 #[test]
218 fn test_disjunction_result_success() {
219 let mut result = DisjunctionResult::new();
220
221 let bindings1 = Bindings::new();
222 let bindings2 = Bindings::new();
223
224 result.add_branch_solutions(0, vec![bindings1]);
225 result.add_branch_solutions(1, vec![bindings2]);
226
227 assert!(result.success);
228 assert_eq!(result.solution_count(), 2);
229 assert_eq!(result.successful_branches.len(), 2);
230 }
231
232 #[test]
233 fn test_disjunction_result_empty() {
234 let mut result = DisjunctionResult::new();
235
236 result.add_branch_solutions(0, vec![]);
237 result.add_branch_solutions(1, vec![]);
238
239 assert!(!result.success);
240 assert_eq!(result.solution_count(), 0);
241 }
242
243 #[test]
244 fn test_parser_simple_or() {
245 let pattern = "(manager(?person) OR senior(?person))";
246 let disj = DisjunctionParser::parse(pattern);
247
248 assert!(disj.is_some());
249 let disj = disj.unwrap();
250 assert_eq!(disj.branch_count(), 2);
251 }
252
253 #[test]
254 fn test_parser_triple_or() {
255 let pattern = "(A OR B OR C)";
256 let disj = DisjunctionParser::parse(pattern);
257
258 assert!(disj.is_some());
259 let disj = disj.unwrap();
260 assert_eq!(disj.branch_count(), 3);
261 }
262
263 #[test]
264 fn test_parser_no_or() {
265 let pattern = "manager(?person)";
266 let disj = DisjunctionParser::parse(pattern);
267
268 assert!(disj.is_none());
269 }
270
271 #[test]
272 fn test_parser_contains_or() {
273 assert!(DisjunctionParser::contains_or("(A OR B)"));
274 assert!(!DisjunctionParser::contains_or("A AND B"));
275 }
276
277 #[test]
278 fn test_deduplication() {
279 let mut result = DisjunctionResult::new();
280
281 let bindings = Bindings::new();
283 result.add_branch_solutions(0, vec![bindings.clone(), bindings.clone()]);
284
285 assert_eq!(result.solution_count(), 2);
286
287 result.deduplicate();
288
289 assert_eq!(result.solution_count(), 1);
290 }
291}