1use std::collections::HashMap;
10use super::facts::{FactValue, TypedFacts};
11use super::working_memory::{WorkingMemory, FactHandle};
12use super::multifield::MultifieldOp;
13
14pub type Variable = String;
16
17#[derive(Debug, Clone)]
19pub enum PatternConstraint {
20 Simple {
22 field: String,
23 operator: String,
24 value: FactValue,
25 },
26 Binding {
28 field: String,
29 variable: Variable,
30 },
31 Variable {
33 field: String,
34 operator: String,
35 variable: Variable,
36 },
37 MultiField {
44 field: String,
45 variable: Option<Variable>, operator: MultifieldOp,
47 value: Option<FactValue>, },
49}
50
51impl PatternConstraint {
52 pub fn simple(field: String, operator: String, value: FactValue) -> Self {
54 Self::Simple { field, operator, value }
55 }
56
57 pub fn binding(field: String, variable: Variable) -> Self {
59 Self::Binding { field, variable }
60 }
61
62 pub fn variable(field: String, operator: String, variable: Variable) -> Self {
64 Self::Variable { field, operator, variable }
65 }
66
67 pub fn multifield(
69 field: String,
70 operator: MultifieldOp,
71 variable: Option<Variable>,
72 value: Option<FactValue>,
73 ) -> Self {
74 Self::MultiField { field, operator, variable, value }
75 }
76
77 pub fn evaluate(
79 &self,
80 facts: &TypedFacts,
81 bindings: &HashMap<Variable, FactValue>,
82 ) -> Option<HashMap<Variable, FactValue>> {
83 match self {
84 PatternConstraint::Simple { field, operator, value } => {
85 if facts.evaluate_condition(field, operator, value) {
86 Some(HashMap::new())
87 } else {
88 None
89 }
90 }
91 PatternConstraint::Binding { field, variable } => {
92 if let Some(fact_value) = facts.get(field) {
93 let mut new_bindings = HashMap::new();
94 new_bindings.insert(variable.clone(), fact_value.clone());
95 Some(new_bindings)
96 } else {
97 None
98 }
99 }
100 PatternConstraint::Variable { field, operator, variable } => {
101 if let Some(bound_value) = bindings.get(variable) {
102 if facts.evaluate_condition(field, operator, bound_value) {
103 Some(HashMap::new())
104 } else {
105 None
106 }
107 } else {
108 None }
110 }
111 PatternConstraint::MultiField { field, operator, variable, value } => {
112 super::multifield::evaluate_multifield_pattern(
114 facts,
115 field,
116 operator,
117 variable.as_deref(),
118 value.as_ref(),
119 bindings,
120 )
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct Pattern {
129 pub fact_type: String,
131 pub constraints: Vec<PatternConstraint>,
133 pub name: Option<String>,
135}
136
137impl Pattern {
138 pub fn new(fact_type: String) -> Self {
140 Self {
141 fact_type,
142 constraints: Vec::new(),
143 name: None,
144 }
145 }
146
147 pub fn with_constraint(mut self, constraint: PatternConstraint) -> Self {
149 self.constraints.push(constraint);
150 self
151 }
152
153 pub fn with_name(mut self, name: String) -> Self {
155 self.name = Some(name);
156 self
157 }
158
159 pub fn matches(
161 &self,
162 facts: &TypedFacts,
163 bindings: &HashMap<Variable, FactValue>,
164 ) -> Option<HashMap<Variable, FactValue>> {
165 let mut new_bindings = bindings.clone();
166
167 for constraint in &self.constraints {
168 match constraint.evaluate(facts, &new_bindings) {
169 Some(additional_bindings) => {
170 new_bindings.extend(additional_bindings);
171 }
172 None => return None,
173 }
174 }
175
176 Some(new_bindings)
177 }
178
179 pub fn match_in_working_memory(
181 &self,
182 wm: &WorkingMemory,
183 bindings: &HashMap<Variable, FactValue>,
184 ) -> Vec<(FactHandle, HashMap<Variable, FactValue>)> {
185 let mut results = Vec::new();
186
187 for fact in wm.get_by_type(&self.fact_type) {
188 if let Some(new_bindings) = self.matches(&fact.data, bindings) {
189 results.push((fact.handle, new_bindings));
190 }
191 }
192
193 results
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct MultiPattern {
200 pub patterns: Vec<Pattern>,
202 pub name: String,
204}
205
206impl MultiPattern {
207 pub fn new(name: String) -> Self {
209 Self {
210 patterns: Vec::new(),
211 name,
212 }
213 }
214
215 pub fn with_pattern(mut self, pattern: Pattern) -> Self {
217 self.patterns.push(pattern);
218 self
219 }
220
221 pub fn match_all(
223 &self,
224 wm: &WorkingMemory,
225 ) -> Vec<(Vec<FactHandle>, HashMap<Variable, FactValue>)> {
226 if self.patterns.is_empty() {
227 return Vec::new();
228 }
229
230 let mut results = Vec::new();
232 let first_pattern = &self.patterns[0];
233 let empty_bindings = HashMap::new();
234
235 for (handle, bindings) in first_pattern.match_in_working_memory(wm, &empty_bindings) {
236 results.push((vec![handle], bindings));
237 }
238
239 for pattern in &self.patterns[1..] {
241 let mut new_results = Vec::new();
242
243 for (handles, bindings) in results {
244 for (handle, new_bindings) in pattern.match_in_working_memory(wm, &bindings) {
245 let mut combined_handles = handles.clone();
246 combined_handles.push(handle);
247 new_results.push((combined_handles, new_bindings));
248 }
249 }
250
251 results = new_results;
252
253 if results.is_empty() {
254 break; }
256 }
257
258 results
259 }
260}
261
262pub struct PatternBuilder {
264 pattern: Pattern,
265}
266
267impl PatternBuilder {
268 pub fn for_type(fact_type: impl Into<String>) -> Self {
270 Self {
271 pattern: Pattern::new(fact_type.into()),
272 }
273 }
274
275 pub fn where_field(mut self, field: impl Into<String>, operator: impl Into<String>, value: FactValue) -> Self {
277 self.pattern.constraints.push(PatternConstraint::Simple {
278 field: field.into(),
279 operator: operator.into(),
280 value,
281 });
282 self
283 }
284
285 pub fn bind(mut self, field: impl Into<String>, variable: impl Into<String>) -> Self {
287 self.pattern.constraints.push(PatternConstraint::Binding {
288 field: field.into(),
289 variable: variable.into(),
290 });
291 self
292 }
293
294 pub fn where_var(mut self, field: impl Into<String>, operator: impl Into<String>, variable: impl Into<String>) -> Self {
296 self.pattern.constraints.push(PatternConstraint::Variable {
297 field: field.into(),
298 operator: operator.into(),
299 variable: variable.into(),
300 });
301 self
302 }
303
304 pub fn named(mut self, name: impl Into<String>) -> Self {
306 self.pattern.name = Some(name.into());
307 self
308 }
309
310 pub fn build(self) -> Pattern {
312 self.pattern
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_simple_pattern() {
322 let pattern = PatternBuilder::for_type("Person")
323 .where_field("age", ">", FactValue::Integer(18))
324 .where_field("status", "==", FactValue::String("active".to_string()))
325 .build();
326
327 let mut facts = TypedFacts::new();
328 facts.set("age", 25i64);
329 facts.set("status", "active");
330
331 let bindings = HashMap::new();
332 let result = pattern.matches(&facts, &bindings);
333 assert!(result.is_some());
334 }
335
336 #[test]
337 fn test_variable_binding() {
338 let pattern = PatternBuilder::for_type("Person")
339 .bind("name", "$personName")
340 .bind("age", "$personAge")
341 .build();
342
343 let mut facts = TypedFacts::new();
344 facts.set("name", "John");
345 facts.set("age", 25i64);
346
347 let bindings = HashMap::new();
348 let result = pattern.matches(&facts, &bindings).unwrap();
349
350 assert_eq!(result.get("$personName").unwrap().as_string(), "John");
351 assert_eq!(result.get("$personAge").unwrap().as_integer(), Some(25));
352 }
353
354 #[test]
355 fn test_variable_constraint() {
356 let mut bindings = HashMap::new();
358 bindings.insert("$minAge".to_string(), FactValue::Integer(18));
359
360 let pattern = PatternBuilder::for_type("Person")
362 .where_var("age", ">=", "$minAge")
363 .build();
364
365 let mut facts = TypedFacts::new();
366 facts.set("age", 25i64);
367
368 let result = pattern.matches(&facts, &bindings);
369 assert!(result.is_some());
370 }
371
372 #[test]
373 fn test_multi_pattern_join() {
374 let mut wm = WorkingMemory::new();
375
376 let mut person = TypedFacts::new();
378 person.set("name", "John");
379 person.set("age", 25i64);
380 wm.insert("Person".to_string(), person);
381
382 let mut order = TypedFacts::new();
384 order.set("customer", "John");
385 order.set("amount", 1000.0);
386 wm.insert("Order".to_string(), order);
387
388 let person_pattern = PatternBuilder::for_type("Person")
390 .bind("name", "$name")
391 .build();
392
393 let order_pattern = PatternBuilder::for_type("Order")
394 .where_var("customer", "==", "$name")
395 .build();
396
397 let multi = MultiPattern::new("PersonWithOrder".to_string())
398 .with_pattern(person_pattern)
399 .with_pattern(order_pattern);
400
401 let matches = multi.match_all(&wm);
402 assert_eq!(matches.len(), 1);
403
404 let (handles, bindings) = &matches[0];
405 assert_eq!(handles.len(), 2);
406 assert_eq!(bindings.get("$name").unwrap().as_string(), "John");
407 }
408}