1use std::collections::HashMap;
10use super::facts::{FactValue, TypedFacts};
11use super::working_memory::{WorkingMemory, FactHandle};
12
13pub type Variable = String;
15
16#[derive(Debug, Clone)]
18pub enum PatternConstraint {
19 Simple {
21 field: String,
22 operator: String,
23 value: FactValue,
24 },
25 Binding {
27 field: String,
28 variable: Variable,
29 },
30 Variable {
32 field: String,
33 operator: String,
34 variable: Variable,
35 },
36}
37
38impl PatternConstraint {
39 pub fn simple(field: String, operator: String, value: FactValue) -> Self {
41 Self::Simple { field, operator, value }
42 }
43
44 pub fn binding(field: String, variable: Variable) -> Self {
46 Self::Binding { field, variable }
47 }
48
49 pub fn variable(field: String, operator: String, variable: Variable) -> Self {
51 Self::Variable { field, operator, variable }
52 }
53
54 pub fn evaluate(
56 &self,
57 facts: &TypedFacts,
58 bindings: &HashMap<Variable, FactValue>,
59 ) -> Option<HashMap<Variable, FactValue>> {
60 match self {
61 PatternConstraint::Simple { field, operator, value } => {
62 if facts.evaluate_condition(field, operator, value) {
63 Some(HashMap::new())
64 } else {
65 None
66 }
67 }
68 PatternConstraint::Binding { field, variable } => {
69 if let Some(fact_value) = facts.get(field) {
70 let mut new_bindings = HashMap::new();
71 new_bindings.insert(variable.clone(), fact_value.clone());
72 Some(new_bindings)
73 } else {
74 None
75 }
76 }
77 PatternConstraint::Variable { field, operator, variable } => {
78 if let Some(bound_value) = bindings.get(variable) {
79 if facts.evaluate_condition(field, operator, bound_value) {
80 Some(HashMap::new())
81 } else {
82 None
83 }
84 } else {
85 None }
87 }
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct Pattern {
95 pub fact_type: String,
97 pub constraints: Vec<PatternConstraint>,
99 pub name: Option<String>,
101}
102
103impl Pattern {
104 pub fn new(fact_type: String) -> Self {
106 Self {
107 fact_type,
108 constraints: Vec::new(),
109 name: None,
110 }
111 }
112
113 pub fn with_constraint(mut self, constraint: PatternConstraint) -> Self {
115 self.constraints.push(constraint);
116 self
117 }
118
119 pub fn with_name(mut self, name: String) -> Self {
121 self.name = Some(name);
122 self
123 }
124
125 pub fn matches(
127 &self,
128 facts: &TypedFacts,
129 bindings: &HashMap<Variable, FactValue>,
130 ) -> Option<HashMap<Variable, FactValue>> {
131 let mut new_bindings = bindings.clone();
132
133 for constraint in &self.constraints {
134 match constraint.evaluate(facts, &new_bindings) {
135 Some(additional_bindings) => {
136 new_bindings.extend(additional_bindings);
137 }
138 None => return None,
139 }
140 }
141
142 Some(new_bindings)
143 }
144
145 pub fn match_in_working_memory(
147 &self,
148 wm: &WorkingMemory,
149 bindings: &HashMap<Variable, FactValue>,
150 ) -> Vec<(FactHandle, HashMap<Variable, FactValue>)> {
151 let mut results = Vec::new();
152
153 for fact in wm.get_by_type(&self.fact_type) {
154 if let Some(new_bindings) = self.matches(&fact.data, bindings) {
155 results.push((fact.handle, new_bindings));
156 }
157 }
158
159 results
160 }
161}
162
163#[derive(Debug, Clone)]
165pub struct MultiPattern {
166 pub patterns: Vec<Pattern>,
168 pub name: String,
170}
171
172impl MultiPattern {
173 pub fn new(name: String) -> Self {
175 Self {
176 patterns: Vec::new(),
177 name,
178 }
179 }
180
181 pub fn with_pattern(mut self, pattern: Pattern) -> Self {
183 self.patterns.push(pattern);
184 self
185 }
186
187 pub fn match_all(
189 &self,
190 wm: &WorkingMemory,
191 ) -> Vec<(Vec<FactHandle>, HashMap<Variable, FactValue>)> {
192 if self.patterns.is_empty() {
193 return Vec::new();
194 }
195
196 let mut results = Vec::new();
198 let first_pattern = &self.patterns[0];
199 let empty_bindings = HashMap::new();
200
201 for (handle, bindings) in first_pattern.match_in_working_memory(wm, &empty_bindings) {
202 results.push((vec![handle], bindings));
203 }
204
205 for pattern in &self.patterns[1..] {
207 let mut new_results = Vec::new();
208
209 for (handles, bindings) in results {
210 for (handle, new_bindings) in pattern.match_in_working_memory(wm, &bindings) {
211 let mut combined_handles = handles.clone();
212 combined_handles.push(handle);
213 new_results.push((combined_handles, new_bindings));
214 }
215 }
216
217 results = new_results;
218
219 if results.is_empty() {
220 break; }
222 }
223
224 results
225 }
226}
227
228pub struct PatternBuilder {
230 pattern: Pattern,
231}
232
233impl PatternBuilder {
234 pub fn for_type(fact_type: impl Into<String>) -> Self {
236 Self {
237 pattern: Pattern::new(fact_type.into()),
238 }
239 }
240
241 pub fn where_field(mut self, field: impl Into<String>, operator: impl Into<String>, value: FactValue) -> Self {
243 self.pattern.constraints.push(PatternConstraint::Simple {
244 field: field.into(),
245 operator: operator.into(),
246 value,
247 });
248 self
249 }
250
251 pub fn bind(mut self, field: impl Into<String>, variable: impl Into<String>) -> Self {
253 self.pattern.constraints.push(PatternConstraint::Binding {
254 field: field.into(),
255 variable: variable.into(),
256 });
257 self
258 }
259
260 pub fn where_var(mut self, field: impl Into<String>, operator: impl Into<String>, variable: impl Into<String>) -> Self {
262 self.pattern.constraints.push(PatternConstraint::Variable {
263 field: field.into(),
264 operator: operator.into(),
265 variable: variable.into(),
266 });
267 self
268 }
269
270 pub fn named(mut self, name: impl Into<String>) -> Self {
272 self.pattern.name = Some(name.into());
273 self
274 }
275
276 pub fn build(self) -> Pattern {
278 self.pattern
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_simple_pattern() {
288 let pattern = PatternBuilder::for_type("Person")
289 .where_field("age", ">", FactValue::Integer(18))
290 .where_field("status", "==", FactValue::String("active".to_string()))
291 .build();
292
293 let mut facts = TypedFacts::new();
294 facts.set("age", 25i64);
295 facts.set("status", "active");
296
297 let bindings = HashMap::new();
298 let result = pattern.matches(&facts, &bindings);
299 assert!(result.is_some());
300 }
301
302 #[test]
303 fn test_variable_binding() {
304 let pattern = PatternBuilder::for_type("Person")
305 .bind("name", "$personName")
306 .bind("age", "$personAge")
307 .build();
308
309 let mut facts = TypedFacts::new();
310 facts.set("name", "John");
311 facts.set("age", 25i64);
312
313 let bindings = HashMap::new();
314 let result = pattern.matches(&facts, &bindings).unwrap();
315
316 assert_eq!(result.get("$personName").unwrap().as_string(), "John");
317 assert_eq!(result.get("$personAge").unwrap().as_integer(), Some(25));
318 }
319
320 #[test]
321 fn test_variable_constraint() {
322 let mut bindings = HashMap::new();
324 bindings.insert("$minAge".to_string(), FactValue::Integer(18));
325
326 let pattern = PatternBuilder::for_type("Person")
328 .where_var("age", ">=", "$minAge")
329 .build();
330
331 let mut facts = TypedFacts::new();
332 facts.set("age", 25i64);
333
334 let result = pattern.matches(&facts, &bindings);
335 assert!(result.is_some());
336 }
337
338 #[test]
339 fn test_multi_pattern_join() {
340 let mut wm = WorkingMemory::new();
341
342 let mut person = TypedFacts::new();
344 person.set("name", "John");
345 person.set("age", 25i64);
346 wm.insert("Person".to_string(), person);
347
348 let mut order = TypedFacts::new();
350 order.set("customer", "John");
351 order.set("amount", 1000.0);
352 wm.insert("Order".to_string(), order);
353
354 let person_pattern = PatternBuilder::for_type("Person")
356 .bind("name", "$name")
357 .build();
358
359 let order_pattern = PatternBuilder::for_type("Order")
360 .where_var("customer", "==", "$name")
361 .build();
362
363 let multi = MultiPattern::new("PersonWithOrder".to_string())
364 .with_pattern(person_pattern)
365 .with_pattern(order_pattern);
366
367 let matches = multi.match_all(&wm);
368 assert_eq!(matches.len(), 1);
369
370 let (handles, bindings) = &matches[0];
371 assert_eq!(handles.len(), 2);
372 assert_eq!(bindings.get("$name").unwrap().as_string(), "John");
373 }
374}