1use super::facts::{FactValue, TypedFacts};
10use super::multifield::MultifieldOp;
11use super::working_memory::{FactHandle, WorkingMemory};
12use std::collections::HashMap;
13
14pub type Variable = String;
16
17#[derive(Debug, Clone)]
19pub enum PatternConstraint {
20 Simple {
22 field: String,
23 operator: String,
24 value: FactValue,
25 },
26 Binding { field: String, variable: Variable },
28 Variable {
30 field: String,
31 operator: String,
32 variable: Variable,
33 },
34 MultiField {
41 field: String,
42 variable: Option<Variable>, operator: MultifieldOp,
44 value: Option<FactValue>, },
46}
47
48impl PatternConstraint {
49 pub fn simple(field: String, operator: String, value: FactValue) -> Self {
51 Self::Simple {
52 field,
53 operator,
54 value,
55 }
56 }
57
58 pub fn binding(field: String, variable: Variable) -> Self {
60 Self::Binding { field, variable }
61 }
62
63 pub fn variable(field: String, operator: String, variable: Variable) -> Self {
65 Self::Variable {
66 field,
67 operator,
68 variable,
69 }
70 }
71
72 pub fn multifield(
74 field: String,
75 operator: MultifieldOp,
76 variable: Option<Variable>,
77 value: Option<FactValue>,
78 ) -> Self {
79 Self::MultiField {
80 field,
81 operator,
82 variable,
83 value,
84 }
85 }
86
87 pub fn evaluate(
89 &self,
90 facts: &TypedFacts,
91 bindings: &HashMap<Variable, FactValue>,
92 ) -> Option<HashMap<Variable, FactValue>> {
93 match self {
94 PatternConstraint::Simple {
95 field,
96 operator,
97 value,
98 } => {
99 if facts.evaluate_condition(field, operator, value) {
100 Some(HashMap::new())
101 } else {
102 None
103 }
104 }
105 PatternConstraint::Binding { field, variable } => {
106 if let Some(fact_value) = facts.get(field) {
107 let mut new_bindings = HashMap::new();
108 new_bindings.insert(variable.clone(), fact_value.clone());
109 Some(new_bindings)
110 } else {
111 None
112 }
113 }
114 PatternConstraint::Variable {
115 field,
116 operator,
117 variable,
118 } => {
119 if let Some(bound_value) = bindings.get(variable) {
120 if facts.evaluate_condition(field, operator, bound_value) {
121 Some(HashMap::new())
122 } else {
123 None
124 }
125 } else {
126 None }
128 }
129 PatternConstraint::MultiField {
130 field,
131 operator,
132 variable,
133 value,
134 } => {
135 super::multifield::evaluate_multifield_pattern(
137 facts,
138 field,
139 operator,
140 variable.as_deref(),
141 value.as_ref(),
142 bindings,
143 )
144 }
145 }
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct Pattern {
152 pub fact_type: String,
154 pub constraints: Vec<PatternConstraint>,
156 pub name: Option<String>,
158}
159
160impl Pattern {
161 pub fn new(fact_type: String) -> Self {
163 Self {
164 fact_type,
165 constraints: Vec::new(),
166 name: None,
167 }
168 }
169
170 pub fn with_constraint(mut self, constraint: PatternConstraint) -> Self {
172 self.constraints.push(constraint);
173 self
174 }
175
176 pub fn with_name(mut self, name: String) -> Self {
178 self.name = Some(name);
179 self
180 }
181
182 pub fn matches(
184 &self,
185 facts: &TypedFacts,
186 bindings: &HashMap<Variable, FactValue>,
187 ) -> Option<HashMap<Variable, FactValue>> {
188 let mut new_bindings = bindings.clone();
189
190 for constraint in &self.constraints {
191 match constraint.evaluate(facts, &new_bindings) {
192 Some(additional_bindings) => {
193 new_bindings.extend(additional_bindings);
194 }
195 None => return None,
196 }
197 }
198
199 Some(new_bindings)
200 }
201
202 pub fn match_in_working_memory(
204 &self,
205 wm: &WorkingMemory,
206 bindings: &HashMap<Variable, FactValue>,
207 ) -> Vec<(FactHandle, HashMap<Variable, FactValue>)> {
208 let mut results = Vec::new();
209
210 for fact in wm.get_by_type(&self.fact_type) {
211 if let Some(new_bindings) = self.matches(&fact.data, bindings) {
212 results.push((fact.handle, new_bindings));
213 }
214 }
215
216 results
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct MultiPattern {
223 pub patterns: Vec<Pattern>,
225 pub name: String,
227}
228
229impl MultiPattern {
230 pub fn new(name: String) -> Self {
232 Self {
233 patterns: Vec::new(),
234 name,
235 }
236 }
237
238 pub fn with_pattern(mut self, pattern: Pattern) -> Self {
240 self.patterns.push(pattern);
241 self
242 }
243
244 pub fn match_all(
246 &self,
247 wm: &WorkingMemory,
248 ) -> Vec<(Vec<FactHandle>, HashMap<Variable, FactValue>)> {
249 if self.patterns.is_empty() {
250 return Vec::new();
251 }
252
253 let mut results = Vec::new();
255 let first_pattern = &self.patterns[0];
256 let empty_bindings = HashMap::new();
257
258 for (handle, bindings) in first_pattern.match_in_working_memory(wm, &empty_bindings) {
259 results.push((vec![handle], bindings));
260 }
261
262 for pattern in &self.patterns[1..] {
264 let mut new_results = Vec::new();
265
266 for (handles, bindings) in results {
267 for (handle, new_bindings) in pattern.match_in_working_memory(wm, &bindings) {
268 let mut combined_handles = handles.clone();
269 combined_handles.push(handle);
270 new_results.push((combined_handles, new_bindings));
271 }
272 }
273
274 results = new_results;
275
276 if results.is_empty() {
277 break; }
279 }
280
281 results
282 }
283}
284
285pub struct PatternBuilder {
287 pattern: Pattern,
288}
289
290impl PatternBuilder {
291 pub fn for_type(fact_type: impl Into<String>) -> Self {
293 Self {
294 pattern: Pattern::new(fact_type.into()),
295 }
296 }
297
298 pub fn where_field(
300 mut self,
301 field: impl Into<String>,
302 operator: impl Into<String>,
303 value: FactValue,
304 ) -> Self {
305 self.pattern.constraints.push(PatternConstraint::Simple {
306 field: field.into(),
307 operator: operator.into(),
308 value,
309 });
310 self
311 }
312
313 pub fn bind(mut self, field: impl Into<String>, variable: impl Into<String>) -> Self {
315 self.pattern.constraints.push(PatternConstraint::Binding {
316 field: field.into(),
317 variable: variable.into(),
318 });
319 self
320 }
321
322 pub fn where_var(
324 mut self,
325 field: impl Into<String>,
326 operator: impl Into<String>,
327 variable: impl Into<String>,
328 ) -> Self {
329 self.pattern.constraints.push(PatternConstraint::Variable {
330 field: field.into(),
331 operator: operator.into(),
332 variable: variable.into(),
333 });
334 self
335 }
336
337 pub fn named(mut self, name: impl Into<String>) -> Self {
339 self.pattern.name = Some(name.into());
340 self
341 }
342
343 pub fn build(self) -> Pattern {
345 self.pattern
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_simple_pattern() {
355 let pattern = PatternBuilder::for_type("Person")
356 .where_field("age", ">", FactValue::Integer(18))
357 .where_field("status", "==", FactValue::String("active".to_string()))
358 .build();
359
360 let mut facts = TypedFacts::new();
361 facts.set("age", 25i64);
362 facts.set("status", "active");
363
364 let bindings = HashMap::new();
365 let result = pattern.matches(&facts, &bindings);
366 assert!(result.is_some());
367 }
368
369 #[test]
370 fn test_variable_binding() {
371 let pattern = PatternBuilder::for_type("Person")
372 .bind("name", "$personName")
373 .bind("age", "$personAge")
374 .build();
375
376 let mut facts = TypedFacts::new();
377 facts.set("name", "John");
378 facts.set("age", 25i64);
379
380 let bindings = HashMap::new();
381 let result = pattern.matches(&facts, &bindings).unwrap();
382
383 assert_eq!(result.get("$personName").unwrap().as_string(), "John");
384 assert_eq!(result.get("$personAge").unwrap().as_integer(), Some(25));
385 }
386
387 #[test]
388 fn test_variable_constraint() {
389 let mut bindings = HashMap::new();
391 bindings.insert("$minAge".to_string(), FactValue::Integer(18));
392
393 let pattern = PatternBuilder::for_type("Person")
395 .where_var("age", ">=", "$minAge")
396 .build();
397
398 let mut facts = TypedFacts::new();
399 facts.set("age", 25i64);
400
401 let result = pattern.matches(&facts, &bindings);
402 assert!(result.is_some());
403 }
404
405 #[test]
406 fn test_multi_pattern_join() {
407 let mut wm = WorkingMemory::new();
408
409 let mut person = TypedFacts::new();
411 person.set("name", "John");
412 person.set("age", 25i64);
413 wm.insert("Person".to_string(), person);
414
415 let mut order = TypedFacts::new();
417 order.set("customer", "John");
418 order.set("amount", 1000.0);
419 wm.insert("Order".to_string(), order);
420
421 let person_pattern = PatternBuilder::for_type("Person")
423 .bind("name", "$name")
424 .build();
425
426 let order_pattern = PatternBuilder::for_type("Order")
427 .where_var("customer", "==", "$name")
428 .build();
429
430 let multi = MultiPattern::new("PersonWithOrder".to_string())
431 .with_pattern(person_pattern)
432 .with_pattern(order_pattern);
433
434 let matches = multi.match_all(&wm);
435 assert_eq!(matches.len(), 1);
436
437 let (handles, bindings) = &matches[0];
438 assert_eq!(handles.len(), 2);
439 assert_eq!(bindings.get("$name").unwrap().as_string(), "John");
440 }
441}