1use crate::rete::facts::{FactValue, TypedFacts};
7use crate::errors::{Result, RuleEngineError};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct FieldDef {
13 pub name: String,
14 pub field_type: FieldType,
15 pub default_value: Option<FactValue>,
16 pub required: bool,
17}
18
19#[derive(Debug, Clone, PartialEq)]
21pub enum FieldType {
22 String,
23 Integer,
24 Float,
25 Boolean,
26 Array(Box<FieldType>),
27 Any,
28}
29
30impl FieldType {
31 pub fn matches(&self, value: &FactValue) -> bool {
33 match (self, value) {
34 (FieldType::String, FactValue::String(_)) => true,
35 (FieldType::Integer, FactValue::Integer(_)) => true,
36 (FieldType::Float, FactValue::Float(_)) => true,
37 (FieldType::Boolean, FactValue::Boolean(_)) => true,
38 (FieldType::Array(inner), FactValue::Array(arr)) => {
39 arr.iter().all(|v| inner.matches(v))
41 }
42 (FieldType::Any, _) => true,
43 _ => false,
44 }
45 }
46
47 pub fn default_value(&self) -> FactValue {
49 match self {
50 FieldType::String => FactValue::String(String::new()),
51 FieldType::Integer => FactValue::Integer(0),
52 FieldType::Float => FactValue::Float(0.0),
53 FieldType::Boolean => FactValue::Boolean(false),
54 FieldType::Array(_) => FactValue::Array(Vec::new()),
55 FieldType::Any => FactValue::Null,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct Template {
63 pub name: String,
64 pub fields: Vec<FieldDef>,
65 field_map: HashMap<String, usize>,
66}
67
68impl Template {
69 pub fn new(name: impl Into<String>) -> Self {
71 Self {
72 name: name.into(),
73 fields: Vec::new(),
74 field_map: HashMap::new(),
75 }
76 }
77
78 pub fn add_field(&mut self, field: FieldDef) -> &mut Self {
80 let idx = self.fields.len();
81 self.field_map.insert(field.name.clone(), idx);
82 self.fields.push(field);
83 self
84 }
85
86 pub fn validate(&self, facts: &TypedFacts) -> Result<()> {
88 for field in &self.fields {
90 let value = facts.get(&field.name);
91
92 if field.required && value.is_none() {
93 return Err(RuleEngineError::EvaluationError {
94 message: format!(
95 "Required field '{}' missing in template '{}'",
96 field.name, self.name
97 ),
98 });
99 }
100
101 if let Some(val) = value {
103 if !field.field_type.matches(val) {
104 return Err(RuleEngineError::EvaluationError {
105 message: format!(
106 "Field '{}' has wrong type. Expected {:?}, got {:?}",
107 field.name, field.field_type, val
108 ),
109 });
110 }
111 }
112 }
113
114 Ok(())
115 }
116
117 pub fn create_instance(&self) -> TypedFacts {
119 let mut facts = TypedFacts::new();
120
121 for field in &self.fields {
122 let value = field.default_value.clone()
123 .unwrap_or_else(|| field.field_type.default_value());
124 facts.set(&field.name, value);
125 }
126
127 facts
128 }
129
130 pub fn get_field(&self, name: &str) -> Option<&FieldDef> {
132 self.field_map.get(name).and_then(|idx| self.fields.get(*idx))
133 }
134}
135
136pub struct TemplateBuilder {
138 template: Template,
139}
140
141impl TemplateBuilder {
142 pub fn new(name: impl Into<String>) -> Self {
144 Self {
145 template: Template::new(name),
146 }
147 }
148
149 pub fn string_field(mut self, name: impl Into<String>) -> Self {
151 self.template.add_field(FieldDef {
152 name: name.into(),
153 field_type: FieldType::String,
154 default_value: None,
155 required: false,
156 });
157 self
158 }
159
160 pub fn required_string(mut self, name: impl Into<String>) -> Self {
162 self.template.add_field(FieldDef {
163 name: name.into(),
164 field_type: FieldType::String,
165 default_value: None,
166 required: true,
167 });
168 self
169 }
170
171 pub fn integer_field(mut self, name: impl Into<String>) -> Self {
173 self.template.add_field(FieldDef {
174 name: name.into(),
175 field_type: FieldType::Integer,
176 default_value: None,
177 required: false,
178 });
179 self
180 }
181
182 pub fn float_field(mut self, name: impl Into<String>) -> Self {
184 self.template.add_field(FieldDef {
185 name: name.into(),
186 field_type: FieldType::Float,
187 default_value: None,
188 required: false,
189 });
190 self
191 }
192
193 pub fn boolean_field(mut self, name: impl Into<String>) -> Self {
195 self.template.add_field(FieldDef {
196 name: name.into(),
197 field_type: FieldType::Boolean,
198 default_value: None,
199 required: false,
200 });
201 self
202 }
203
204 pub fn field_with_default(
206 mut self,
207 name: impl Into<String>,
208 field_type: FieldType,
209 default: FactValue,
210 ) -> Self {
211 self.template.add_field(FieldDef {
212 name: name.into(),
213 field_type,
214 default_value: Some(default),
215 required: false,
216 });
217 self
218 }
219
220 pub fn array_field(mut self, name: impl Into<String>, element_type: FieldType) -> Self {
222 self.template.add_field(FieldDef {
223 name: name.into(),
224 field_type: FieldType::Array(Box::new(element_type)),
225 default_value: None,
226 required: false,
227 });
228 self
229 }
230
231 pub fn build(self) -> Template {
233 self.template
234 }
235}
236
237pub struct TemplateRegistry {
239 templates: HashMap<String, Template>,
240}
241
242impl TemplateRegistry {
243 pub fn new() -> Self {
245 Self {
246 templates: HashMap::new(),
247 }
248 }
249
250 pub fn register(&mut self, template: Template) {
252 self.templates.insert(template.name.clone(), template);
253 }
254
255 pub fn get(&self, name: &str) -> Option<&Template> {
257 self.templates.get(name)
258 }
259
260 pub fn create_instance(&self, template_name: &str) -> Result<TypedFacts> {
262 let template = self.get(template_name).ok_or_else(|| {
263 RuleEngineError::EvaluationError {
264 message: format!("Template '{}' not found", template_name),
265 }
266 })?;
267
268 Ok(template.create_instance())
269 }
270
271 pub fn validate(&self, template_name: &str, facts: &TypedFacts) -> Result<()> {
273 let template = self.get(template_name).ok_or_else(|| {
274 RuleEngineError::EvaluationError {
275 message: format!("Template '{}' not found", template_name),
276 }
277 })?;
278
279 template.validate(facts)
280 }
281
282 pub fn list_templates(&self) -> Vec<&str> {
284 self.templates.keys().map(|s| s.as_str()).collect()
285 }
286}
287
288impl Default for TemplateRegistry {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_template_builder() {
300 let template = TemplateBuilder::new("Person")
301 .required_string("name")
302 .integer_field("age")
303 .boolean_field("is_adult")
304 .build();
305
306 assert_eq!(template.name, "Person");
307 assert_eq!(template.fields.len(), 3);
308 assert!(template.get_field("name").unwrap().required);
309 }
310
311 #[test]
312 fn test_create_instance() {
313 let template = TemplateBuilder::new("Person")
314 .string_field("name")
315 .integer_field("age")
316 .build();
317
318 let instance = template.create_instance();
319 assert_eq!(instance.get("name"), Some(&FactValue::String(String::new())));
320 assert_eq!(instance.get("age"), Some(&FactValue::Integer(0)));
321 }
322
323 #[test]
324 fn test_validation_success() {
325 let template = TemplateBuilder::new("Person")
326 .required_string("name")
327 .integer_field("age")
328 .build();
329
330 let mut facts = TypedFacts::new();
331 facts.set("name", FactValue::String("Alice".to_string()));
332 facts.set("age", FactValue::Integer(30));
333
334 assert!(template.validate(&facts).is_ok());
335 }
336
337 #[test]
338 fn test_validation_missing_required() {
339 let template = TemplateBuilder::new("Person")
340 .required_string("name")
341 .integer_field("age")
342 .build();
343
344 let mut facts = TypedFacts::new();
345 facts.set("age", FactValue::Integer(30));
346
347 assert!(template.validate(&facts).is_err());
348 }
349
350 #[test]
351 fn test_validation_wrong_type() {
352 let template = TemplateBuilder::new("Person")
353 .string_field("name")
354 .integer_field("age")
355 .build();
356
357 let mut facts = TypedFacts::new();
358 facts.set("name", FactValue::String("Alice".to_string()));
359 facts.set("age", FactValue::String("thirty".to_string())); assert!(template.validate(&facts).is_err());
362 }
363
364 #[test]
365 fn test_template_registry() {
366 let mut registry = TemplateRegistry::new();
367
368 let template = TemplateBuilder::new("Order")
369 .required_string("order_id")
370 .float_field("amount")
371 .build();
372
373 registry.register(template);
374
375 assert!(registry.get("Order").is_some());
376 assert!(registry.create_instance("Order").is_ok());
377 assert_eq!(registry.list_templates(), vec!["Order"]);
378 }
379
380 #[test]
381 fn test_array_field() {
382 let template = TemplateBuilder::new("ShoppingCart")
383 .array_field("items", FieldType::String)
384 .build();
385
386 let mut facts = TypedFacts::new();
387 facts.set("items", FactValue::Array(vec![
388 FactValue::String("item1".to_string()),
389 FactValue::String("item2".to_string()),
390 ]));
391
392 assert!(template.validate(&facts).is_ok());
393 }
394
395 #[test]
396 fn test_field_with_default() {
397 let template = TemplateBuilder::new("Config")
398 .field_with_default(
399 "timeout",
400 FieldType::Integer,
401 FactValue::Integer(30),
402 )
403 .build();
404
405 let instance = template.create_instance();
406 assert_eq!(instance.get("timeout"), Some(&FactValue::Integer(30)));
407 }
408}