1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tensorlogic_ir::TLExpr;
13
14use crate::error::AdapterError;
15
16#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct CompositePredicate {
22 pub name: String,
24 pub parameters: Vec<String>,
26 pub body: PredicateBody,
28 pub description: Option<String>,
30}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
34pub enum PredicateBody {
35 Expression(Box<TLExpr>),
37 Reference { name: String, args: Vec<String> },
39 And(Vec<PredicateBody>),
41 Or(Vec<PredicateBody>),
43 Not(Box<PredicateBody>),
45}
46
47#[derive(Clone, Debug, Default, Serialize, Deserialize)]
49pub struct CompositeRegistry {
50 predicates: HashMap<String, CompositePredicate>,
51}
52
53impl CompositePredicate {
54 pub fn new(name: impl Into<String>, parameters: Vec<String>, body: PredicateBody) -> Self {
56 CompositePredicate {
57 name: name.into(),
58 parameters,
59 body,
60 description: None,
61 }
62 }
63
64 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
66 self.description = Some(desc.into());
67 self
68 }
69
70 pub fn arity(&self) -> usize {
72 self.parameters.len()
73 }
74
75 pub fn validate(&self) -> Result<(), AdapterError> {
77 let mut seen = std::collections::HashSet::new();
79 for param in &self.parameters {
80 if !seen.insert(param) {
81 return Err(AdapterError::InvalidParametricType(format!(
82 "Duplicate parameter '{}' in predicate '{}'",
83 param, self.name
84 )));
85 }
86 }
87
88 self.body.validate(&self.parameters)?;
90
91 Ok(())
92 }
93
94 pub fn expand(&self, args: &[String]) -> Result<PredicateBody, AdapterError> {
98 if args.len() != self.parameters.len() {
99 return Err(AdapterError::ArityMismatch {
100 name: self.name.clone(),
101 expected: self.parameters.len(),
102 found: args.len(),
103 });
104 }
105
106 let mut substitutions = HashMap::new();
108 for (param, arg) in self.parameters.iter().zip(args.iter()) {
109 substitutions.insert(param.clone(), arg.clone());
110 }
111
112 self.body.substitute(&substitutions)
113 }
114}
115
116impl PredicateBody {
117 fn validate(&self, parameters: &[String]) -> Result<(), AdapterError> {
119 match self {
120 PredicateBody::Expression(_) => Ok(()), PredicateBody::Reference { args, .. } => {
122 for arg in args {
124 if !parameters.contains(arg) && !arg.starts_with('_') {
125 return Err(AdapterError::UnboundVariable(arg.clone()));
126 }
127 }
128 Ok(())
129 }
130 PredicateBody::And(bodies) | PredicateBody::Or(bodies) => {
131 for body in bodies {
132 body.validate(parameters)?;
133 }
134 Ok(())
135 }
136 PredicateBody::Not(body) => body.validate(parameters),
137 }
138 }
139
140 fn substitute(
142 &self,
143 substitutions: &HashMap<String, String>,
144 ) -> Result<PredicateBody, AdapterError> {
145 match self {
146 PredicateBody::Expression(expr) => {
147 Ok(PredicateBody::Expression(expr.clone()))
150 }
151 PredicateBody::Reference { name, args } => {
152 let new_args = args
153 .iter()
154 .map(|arg| {
155 substitutions
156 .get(arg)
157 .cloned()
158 .unwrap_or_else(|| arg.clone())
159 })
160 .collect();
161 Ok(PredicateBody::Reference {
162 name: name.clone(),
163 args: new_args,
164 })
165 }
166 PredicateBody::And(bodies) => {
167 let new_bodies: Result<Vec<_>, _> =
168 bodies.iter().map(|b| b.substitute(substitutions)).collect();
169 Ok(PredicateBody::And(new_bodies?))
170 }
171 PredicateBody::Or(bodies) => {
172 let new_bodies: Result<Vec<_>, _> =
173 bodies.iter().map(|b| b.substitute(substitutions)).collect();
174 Ok(PredicateBody::Or(new_bodies?))
175 }
176 PredicateBody::Not(body) => Ok(PredicateBody::Not(Box::new(
177 body.substitute(substitutions)?,
178 ))),
179 }
180 }
181}
182
183impl CompositeRegistry {
184 pub fn new() -> Self {
186 CompositeRegistry::default()
187 }
188
189 pub fn register(&mut self, predicate: CompositePredicate) -> Result<(), AdapterError> {
191 predicate.validate()?;
192 self.predicates.insert(predicate.name.clone(), predicate);
193 Ok(())
194 }
195
196 pub fn get(&self, name: &str) -> Option<&CompositePredicate> {
198 self.predicates.get(name)
199 }
200
201 pub fn contains(&self, name: &str) -> bool {
203 self.predicates.contains_key(name)
204 }
205
206 pub fn expand(&self, name: &str, args: &[String]) -> Result<PredicateBody, AdapterError> {
208 let predicate = self
209 .get(name)
210 .ok_or_else(|| AdapterError::PredicateNotFound(name.to_string()))?;
211
212 predicate.expand(args)
213 }
214
215 pub fn len(&self) -> usize {
217 self.predicates.len()
218 }
219
220 pub fn is_empty(&self) -> bool {
222 self.predicates.is_empty()
223 }
224
225 pub fn list_predicates(&self) -> Vec<String> {
227 self.predicates.keys().cloned().collect()
228 }
229}
230
231#[derive(Clone, Debug, Serialize, Deserialize)]
236pub struct PredicateTemplate {
237 pub name: String,
239 pub type_params: Vec<String>,
241 pub value_params: Vec<String>,
243 pub body: PredicateBody,
245}
246
247impl PredicateTemplate {
248 pub fn new(
250 name: impl Into<String>,
251 type_params: Vec<String>,
252 value_params: Vec<String>,
253 body: PredicateBody,
254 ) -> Self {
255 PredicateTemplate {
256 name: name.into(),
257 type_params,
258 value_params,
259 body,
260 }
261 }
262
263 pub fn instantiate(
265 &self,
266 type_args: &[String],
267 value_args: &[String],
268 ) -> Result<CompositePredicate, AdapterError> {
269 if type_args.len() != self.type_params.len() {
270 return Err(AdapterError::ArityMismatch {
271 name: format!("{}[type params]", self.name),
272 expected: self.type_params.len(),
273 found: type_args.len(),
274 });
275 }
276
277 if value_args.len() != self.value_params.len() {
278 return Err(AdapterError::ArityMismatch {
279 name: format!("{}[value params]", self.name),
280 expected: self.value_params.len(),
281 found: value_args.len(),
282 });
283 }
284
285 let mut substitutions = HashMap::new();
287 for (param, arg) in self.type_params.iter().zip(type_args.iter()) {
288 substitutions.insert(param.clone(), arg.clone());
289 }
290 for (param, arg) in self.value_params.iter().zip(value_args.iter()) {
291 substitutions.insert(param.clone(), arg.clone());
292 }
293
294 let instance_name = format!("{}<{}>", self.name, type_args.join(", "));
296
297 let instance_body = self.body.substitute(&substitutions)?;
299
300 Ok(CompositePredicate {
301 name: instance_name,
302 parameters: value_args.to_vec(),
303 body: instance_body,
304 description: None,
305 })
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_composite_predicate_creation() {
315 let pred = CompositePredicate::new(
316 "friend",
317 vec!["x".to_string(), "y".to_string()],
318 PredicateBody::Reference {
319 name: "knows".to_string(),
320 args: vec!["x".to_string(), "y".to_string()],
321 },
322 );
323
324 assert_eq!(pred.name, "friend");
325 assert_eq!(pred.arity(), 2);
326 }
327
328 #[test]
329 fn test_composite_predicate_validation() {
330 let valid = CompositePredicate::new(
331 "test",
332 vec!["x".to_string(), "y".to_string()],
333 PredicateBody::Reference {
334 name: "knows".to_string(),
335 args: vec!["x".to_string(), "y".to_string()],
336 },
337 );
338 assert!(valid.validate().is_ok());
339
340 let invalid = CompositePredicate::new(
341 "test",
342 vec!["x".to_string(), "x".to_string()], PredicateBody::Reference {
344 name: "knows".to_string(),
345 args: vec!["x".to_string()],
346 },
347 );
348 assert!(invalid.validate().is_err());
349 }
350
351 #[test]
352 fn test_composite_registry() {
353 let mut registry = CompositeRegistry::new();
354
355 let pred = CompositePredicate::new(
356 "friend",
357 vec!["x".to_string(), "y".to_string()],
358 PredicateBody::Reference {
359 name: "knows".to_string(),
360 args: vec!["x".to_string(), "y".to_string()],
361 },
362 );
363
364 registry.register(pred).unwrap();
365 assert!(registry.contains("friend"));
366 assert_eq!(registry.len(), 1);
367 }
368
369 #[test]
370 fn test_predicate_expansion() {
371 let pred = CompositePredicate::new(
372 "friend",
373 vec!["x".to_string(), "y".to_string()],
374 PredicateBody::Reference {
375 name: "knows".to_string(),
376 args: vec!["x".to_string(), "y".to_string()],
377 },
378 );
379
380 let expanded = pred
381 .expand(&["alice".to_string(), "bob".to_string()])
382 .unwrap();
383
384 match expanded {
385 PredicateBody::Reference { name, args } => {
386 assert_eq!(name, "knows");
387 assert_eq!(args, vec!["alice".to_string(), "bob".to_string()]);
388 }
389 _ => panic!("Expected Reference"),
390 }
391 }
392
393 #[test]
394 fn test_predicate_template() {
395 let template = PredicateTemplate::new(
396 "related",
397 vec!["T".to_string()],
398 vec!["x".to_string(), "y".to_string()],
399 PredicateBody::Reference {
400 name: "connected".to_string(),
401 args: vec!["x".to_string(), "y".to_string()],
402 },
403 );
404
405 let instance = template
406 .instantiate(&["Person".to_string()], &["a".to_string(), "b".to_string()])
407 .unwrap();
408
409 assert_eq!(instance.name, "related<Person>");
410 assert_eq!(instance.parameters, vec!["a".to_string(), "b".to_string()]);
411 }
412
413 #[test]
414 fn test_composite_and() {
415 let body = PredicateBody::And(vec![
416 PredicateBody::Reference {
417 name: "knows".to_string(),
418 args: vec!["x".to_string(), "y".to_string()],
419 },
420 PredicateBody::Reference {
421 name: "trusts".to_string(),
422 args: vec!["x".to_string(), "y".to_string()],
423 },
424 ]);
425
426 let pred = CompositePredicate::new("friend", vec!["x".to_string(), "y".to_string()], body);
427
428 assert!(pred.validate().is_ok());
429 }
430
431 #[test]
432 fn test_composite_or() {
433 let body = PredicateBody::Or(vec![
434 PredicateBody::Reference {
435 name: "colleague".to_string(),
436 args: vec!["x".to_string(), "y".to_string()],
437 },
438 PredicateBody::Reference {
439 name: "friend".to_string(),
440 args: vec!["x".to_string(), "y".to_string()],
441 },
442 ]);
443
444 let pred =
445 CompositePredicate::new("connected", vec!["x".to_string(), "y".to_string()], body);
446
447 assert!(pred.validate().is_ok());
448 }
449
450 #[test]
451 fn test_composite_not() {
452 let body = PredicateBody::Not(Box::new(PredicateBody::Reference {
453 name: "enemy".to_string(),
454 args: vec!["x".to_string(), "y".to_string()],
455 }));
456
457 let pred =
458 CompositePredicate::new("not_enemy", vec!["x".to_string(), "y".to_string()], body);
459
460 assert!(pred.validate().is_ok());
461 }
462}