rust_rule_engine/backward/
disjunction.rs1use super::goal::Goal;
23use super::unification::Bindings;
24use std::collections::HashSet;
25
26#[derive(Debug, Clone)]
28pub struct Disjunction {
29 pub branches: Vec<Goal>,
31
32 pub pattern: String,
34}
35
36impl Disjunction {
37 pub fn new(branches: Vec<Goal>, pattern: String) -> Self {
39 assert!(!branches.is_empty(), "Disjunction must have at least one branch");
40 Self { branches, pattern }
41 }
42
43 pub fn from_pair(left: Goal, right: Goal) -> Self {
45 let pattern = format!("({} OR {})", left.pattern, right.pattern);
46 Self {
47 branches: vec![left, right],
48 pattern,
49 }
50 }
51
52 pub fn add_branch(&mut self, goal: Goal) {
54 self.branches.push(goal);
55 }
56
57 pub fn branch_count(&self) -> usize {
59 self.branches.len()
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct DisjunctionResult {
66 pub solutions: Vec<Bindings>,
68
69 pub successful_branches: Vec<usize>,
71
72 pub success: bool,
74}
75
76impl DisjunctionResult {
77 pub fn new() -> Self {
79 Self {
80 solutions: Vec::new(),
81 successful_branches: Vec::new(),
82 success: false,
83 }
84 }
85
86 pub fn success(solutions: Vec<Bindings>, successful_branches: Vec<usize>) -> Self {
88 Self {
89 solutions,
90 successful_branches,
91 success: true,
92 }
93 }
94
95 pub fn failure() -> Self {
97 Self {
98 solutions: Vec::new(),
99 successful_branches: Vec::new(),
100 success: false,
101 }
102 }
103
104 pub fn add_branch_solutions(&mut self, branch_index: usize, solutions: Vec<Bindings>) {
106 if !solutions.is_empty() {
107 self.successful_branches.push(branch_index);
108 self.solutions.extend(solutions);
109 self.success = true;
110 }
111 }
112
113 pub fn deduplicate(&mut self) {
115 let mut seen = HashSet::new();
117 let mut unique_solutions = Vec::new();
118
119 for solution in &self.solutions {
120 let binding_map = solution.to_map();
121 let key = format!("{:?}", binding_map);
122
123 if seen.insert(key) {
124 unique_solutions.push(solution.clone());
125 }
126 }
127
128 self.solutions = unique_solutions;
129 }
130
131 pub fn solution_count(&self) -> usize {
133 self.solutions.len()
134 }
135}
136
137impl Default for DisjunctionResult {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143pub struct DisjunctionParser;
145
146impl DisjunctionParser {
147 pub fn parse(pattern: &str) -> Option<Disjunction> {
154 let pattern = pattern.trim();
155
156 if !pattern.starts_with('(') || !pattern.ends_with(')') {
158 return None;
159 }
160
161 let inner = &pattern[1..pattern.len()-1];
163
164 if !inner.contains(" OR ") {
166 return None;
167 }
168
169 let branches: Vec<Goal> = inner
170 .split(" OR ")
171 .map(|s| Goal::new(s.trim().to_string()))
172 .collect();
173
174 if branches.len() < 2 {
175 return None;
176 }
177
178 Some(Disjunction::new(branches, pattern.to_string()))
179 }
180
181 pub fn contains_or(pattern: &str) -> bool {
183 pattern.contains(" OR ")
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_disjunction_creation() {
193 let goal1 = Goal::new("manager(?person)".to_string());
194 let goal2 = Goal::new("senior(?person)".to_string());
195
196 let disj = Disjunction::from_pair(goal1, goal2);
197
198 assert_eq!(disj.branch_count(), 2);
199 assert!(disj.pattern.contains("OR"));
200 }
201
202 #[test]
203 fn test_disjunction_add_branch() {
204 let goal1 = Goal::new("manager(?person)".to_string());
205 let goal2 = Goal::new("senior(?person)".to_string());
206 let goal3 = Goal::new("director(?person)".to_string());
207
208 let mut disj = Disjunction::from_pair(goal1, goal2);
209 disj.add_branch(goal3);
210
211 assert_eq!(disj.branch_count(), 3);
212 }
213
214 #[test]
215 fn test_disjunction_result_success() {
216 let mut result = DisjunctionResult::new();
217
218 let bindings1 = Bindings::new();
219 let bindings2 = Bindings::new();
220
221 result.add_branch_solutions(0, vec![bindings1]);
222 result.add_branch_solutions(1, vec![bindings2]);
223
224 assert!(result.success);
225 assert_eq!(result.solution_count(), 2);
226 assert_eq!(result.successful_branches.len(), 2);
227 }
228
229 #[test]
230 fn test_disjunction_result_empty() {
231 let mut result = DisjunctionResult::new();
232
233 result.add_branch_solutions(0, vec![]);
234 result.add_branch_solutions(1, vec![]);
235
236 assert!(!result.success);
237 assert_eq!(result.solution_count(), 0);
238 }
239
240 #[test]
241 fn test_parser_simple_or() {
242 let pattern = "(manager(?person) OR senior(?person))";
243 let disj = DisjunctionParser::parse(pattern);
244
245 assert!(disj.is_some());
246 let disj = disj.unwrap();
247 assert_eq!(disj.branch_count(), 2);
248 }
249
250 #[test]
251 fn test_parser_triple_or() {
252 let pattern = "(A OR B OR C)";
253 let disj = DisjunctionParser::parse(pattern);
254
255 assert!(disj.is_some());
256 let disj = disj.unwrap();
257 assert_eq!(disj.branch_count(), 3);
258 }
259
260 #[test]
261 fn test_parser_no_or() {
262 let pattern = "manager(?person)";
263 let disj = DisjunctionParser::parse(pattern);
264
265 assert!(disj.is_none());
266 }
267
268 #[test]
269 fn test_parser_contains_or() {
270 assert!(DisjunctionParser::contains_or("(A OR B)"));
271 assert!(!DisjunctionParser::contains_or("A AND B"));
272 }
273
274 #[test]
275 fn test_deduplication() {
276 let mut result = DisjunctionResult::new();
277
278 let bindings = Bindings::new();
280 result.add_branch_solutions(0, vec![bindings.clone(), bindings.clone()]);
281
282 assert_eq!(result.solution_count(), 2);
283
284 result.deduplicate();
285
286 assert_eq!(result.solution_count(), 1);
287 }
288}