rust_rule_engine/backward/
disjunction.rs1use super::goal::Goal;
23use super::unification::Bindings;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28pub struct Disjunction {
29 pub branches: Vec<Goal>,
31
32 pub pattern: String,
34}
35
36impl Disjunction {
37 pub fn new(branches: Vec<Goal>, pattern: String) -> Self {
39 assert!(
40 !branches.is_empty(),
41 "Disjunction must have at least one branch"
42 );
43 Self { branches, pattern }
44 }
45
46 pub fn from_pair(left: Goal, right: Goal) -> Self {
48 let pattern = format!("({} OR {})", left.pattern, right.pattern);
49 Self {
50 branches: vec![left, right],
51 pattern,
52 }
53 }
54
55 pub fn add_branch(&mut self, goal: Goal) {
57 self.branches.push(goal);
58 }
59
60 pub fn branch_count(&self) -> usize {
62 self.branches.len()
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct DisjunctionResult {
69 pub solutions: Vec<Bindings>,
71
72 pub successful_branches: Vec<usize>,
74
75 pub success: bool,
77}
78
79impl DisjunctionResult {
80 pub fn new() -> Self {
82 Self {
83 solutions: Vec::new(),
84 successful_branches: Vec::new(),
85 success: false,
86 }
87 }
88
89 pub fn success(solutions: Vec<Bindings>, successful_branches: Vec<usize>) -> Self {
91 Self {
92 solutions,
93 successful_branches,
94 success: true,
95 }
96 }
97
98 pub fn failure() -> Self {
100 Self {
101 solutions: Vec::new(),
102 successful_branches: Vec::new(),
103 success: false,
104 }
105 }
106
107 pub fn add_branch_solutions(&mut self, branch_index: usize, solutions: Vec<Bindings>) {
109 if !solutions.is_empty() {
110 self.successful_branches.push(branch_index);
111 self.solutions.extend(solutions);
112 self.success = true;
113 }
114 }
115
116 pub fn deduplicate(&mut self) {
118 let mut seen = HashSet::new();
120 let mut unique_solutions = Vec::new();
121
122 for solution in &self.solutions {
123 let binding_map = solution.to_map();
124 let key = format!("{:?}", binding_map);
125
126 if seen.insert(key) {
127 unique_solutions.push(solution.clone());
128 }
129 }
130
131 self.solutions = unique_solutions;
132 }
133
134 pub fn solution_count(&self) -> usize {
136 self.solutions.len()
137 }
138}
139
140impl Default for DisjunctionResult {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146fn split_top_level_or(input: &str) -> Vec<String> {
155 let mut parts = Vec::new();
156 let mut current = String::new();
157 let mut paren_depth: i32 = 0;
158 let mut in_string = false;
159 let chars: Vec<char> = input.chars().collect();
160 let len = chars.len();
161 let mut i = 0;
162
163 while i < len {
164 let ch = chars[i];
165
166 match ch {
167 '"' if !in_string => {
168 in_string = true;
169 current.push(ch);
170 }
171 '"' if in_string => {
172 in_string = false;
173 current.push(ch);
174 }
175 '(' if !in_string => {
176 paren_depth += 1;
177 current.push(ch);
178 }
179 ')' if !in_string => {
180 paren_depth -= 1;
181 current.push(ch);
182 }
183 ' ' if !in_string && paren_depth == 0 => {
184 if i + 4 <= len && &input[i..i + 4] == " OR " {
186 let trimmed = current.trim().to_string();
187 if !trimmed.is_empty() {
188 parts.push(trimmed);
189 }
190 current.clear();
191 i += 4; continue;
193 }
194 current.push(ch);
195 }
196 _ => {
197 current.push(ch);
198 }
199 }
200 i += 1;
201 }
202
203 let trimmed = current.trim().to_string();
204 if !trimmed.is_empty() {
205 parts.push(trimmed);
206 }
207
208 parts
209}
210
211pub struct DisjunctionParser;
213
214impl DisjunctionParser {
215 pub fn parse(pattern: &str) -> Option<Disjunction> {
222 let pattern = pattern.trim();
223
224 if !pattern.starts_with('(') || !pattern.ends_with(')') {
226 return None;
227 }
228
229 let inner = &pattern[1..pattern.len() - 1];
231
232 if !inner.contains(" OR ") {
233 return None;
234 }
235
236 let parts = split_top_level_or(inner);
239
240 let branches: Vec<Goal> = parts
241 .into_iter()
242 .map(|s| Goal::new(s.trim().to_string()))
243 .collect();
244
245 if branches.len() < 2 {
246 return None;
247 }
248
249 Some(Disjunction::new(branches, pattern.to_string()))
250 }
251
252 pub fn contains_or(pattern: &str) -> bool {
254 let parts = split_top_level_or(pattern);
255 parts.len() > 1
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_disjunction_creation() {
265 let goal1 = Goal::new("manager(?person)".to_string());
266 let goal2 = Goal::new("senior(?person)".to_string());
267
268 let disj = Disjunction::from_pair(goal1, goal2);
269
270 assert_eq!(disj.branch_count(), 2);
271 assert!(disj.pattern.contains("OR"));
272 }
273
274 #[test]
275 fn test_disjunction_add_branch() {
276 let goal1 = Goal::new("manager(?person)".to_string());
277 let goal2 = Goal::new("senior(?person)".to_string());
278 let goal3 = Goal::new("director(?person)".to_string());
279
280 let mut disj = Disjunction::from_pair(goal1, goal2);
281 disj.add_branch(goal3);
282
283 assert_eq!(disj.branch_count(), 3);
284 }
285
286 #[test]
287 fn test_disjunction_result_success() {
288 let mut result = DisjunctionResult::new();
289
290 let bindings1 = Bindings::new();
291 let bindings2 = Bindings::new();
292
293 result.add_branch_solutions(0, vec![bindings1]);
294 result.add_branch_solutions(1, vec![bindings2]);
295
296 assert!(result.success);
297 assert_eq!(result.solution_count(), 2);
298 assert_eq!(result.successful_branches.len(), 2);
299 }
300
301 #[test]
302 fn test_disjunction_result_empty() {
303 let mut result = DisjunctionResult::new();
304
305 result.add_branch_solutions(0, vec![]);
306 result.add_branch_solutions(1, vec![]);
307
308 assert!(!result.success);
309 assert_eq!(result.solution_count(), 0);
310 }
311
312 #[test]
313 fn test_parser_simple_or() {
314 let pattern = "(manager(?person) OR senior(?person))";
315 let disj = DisjunctionParser::parse(pattern);
316
317 assert!(disj.is_some());
318 let disj = disj.unwrap();
319 assert_eq!(disj.branch_count(), 2);
320 }
321
322 #[test]
323 fn test_parser_triple_or() {
324 let pattern = "(A OR B OR C)";
325 let disj = DisjunctionParser::parse(pattern);
326
327 assert!(disj.is_some());
328 let disj = disj.unwrap();
329 assert_eq!(disj.branch_count(), 3);
330 }
331
332 #[test]
333 fn test_parser_no_or() {
334 let pattern = "manager(?person)";
335 let disj = DisjunctionParser::parse(pattern);
336
337 assert!(disj.is_none());
338 }
339
340 #[test]
341 fn test_parser_contains_or() {
342 assert!(DisjunctionParser::contains_or("A OR B"));
343 assert!(!DisjunctionParser::contains_or("A AND B"));
344 assert!(!DisjunctionParser::contains_or("(A OR B)"));
346 }
347
348 #[test]
349 fn test_parser_nested_parens() {
350 let pattern = "(A OR (B AND C))";
352 let disj = DisjunctionParser::parse(pattern).unwrap();
353 assert_eq!(disj.branch_count(), 2);
354 assert_eq!(disj.branches[0].pattern, "A");
355 assert_eq!(disj.branches[1].pattern, "(B AND C)");
356 }
357
358 #[test]
359 fn test_parser_nested_or_groups() {
360 let pattern = "((A OR B) OR (C OR D))";
362 let disj = DisjunctionParser::parse(pattern).unwrap();
363 assert_eq!(disj.branch_count(), 2);
364 assert_eq!(disj.branches[0].pattern, "(A OR B)");
365 assert_eq!(disj.branches[1].pattern, "(C OR D)");
366 }
367
368 #[test]
369 fn test_parser_function_args_with_or_keyword() {
370 let pattern = "(func(a, OR, b) OR c)";
372 let disj = DisjunctionParser::parse(pattern).unwrap();
373 assert_eq!(disj.branch_count(), 2);
374 assert_eq!(disj.branches[0].pattern, "func(a, OR, b)");
375 assert_eq!(disj.branches[1].pattern, "c");
376 }
377
378 #[test]
379 fn test_parser_deeply_nested() {
380 let pattern = "(A OR (B OR (C AND D)))";
381 let disj = DisjunctionParser::parse(pattern).unwrap();
382 assert_eq!(disj.branch_count(), 2);
383 assert_eq!(disj.branches[0].pattern, "A");
384 assert_eq!(disj.branches[1].pattern, "(B OR (C AND D))");
385 }
386
387 #[test]
388 fn test_contains_or_nested() {
389 assert!(!DisjunctionParser::contains_or("(A OR B)"));
391 assert!(DisjunctionParser::contains_or("A OR B"));
393 assert!(!DisjunctionParser::contains_or("func(A OR B)"));
395 }
396
397 #[test]
398 fn test_split_top_level_or_basic() {
399 let parts = split_top_level_or("A OR B OR C");
400 assert_eq!(parts, vec!["A", "B", "C"]);
401 }
402
403 #[test]
404 fn test_split_top_level_or_with_quotes() {
405 let parts = split_top_level_or(r#""hello OR world" OR B"#);
406 assert_eq!(parts.len(), 2);
407 assert_eq!(parts[0], r#""hello OR world""#);
408 assert_eq!(parts[1], "B");
409 }
410
411 #[test]
412 fn test_deduplication() {
413 let mut result = DisjunctionResult::new();
414
415 let bindings = Bindings::new();
417 result.add_branch_solutions(0, vec![bindings.clone(), bindings.clone()]);
418
419 assert_eq!(result.solution_count(), 2);
420
421 result.deduplicate();
422
423 assert_eq!(result.solution_count(), 1);
424 }
425}