Skip to main content

tensorlogic_adapters/
composition.rs

1//! Predicate composition system for defining predicates in terms of others.
2//!
3//! This module provides a system for composing predicates from other predicates,
4//! enabling:
5//! - Macro-like predicate expansion
6//! - Predicate templates with parameters
7//! - Derived predicates based on existing ones
8//! - Complex predicate definitions through composition
9
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use tensorlogic_ir::TLExpr;
13
14use crate::error::AdapterError;
15
16/// A composable predicate definition that can be expanded.
17///
18/// Composite predicates are defined in terms of other predicates and can
19/// include parameters that are substituted during expansion.
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct CompositePredicate {
22    /// Name of this composite predicate
23    pub name: String,
24    /// Parameter names (e.g., ["x", "y"])
25    pub parameters: Vec<String>,
26    /// The body expression defining this predicate
27    pub body: PredicateBody,
28    /// Optional description
29    pub description: Option<String>,
30}
31
32/// The body of a composite predicate.
33#[derive(Clone, Debug, Serialize, Deserialize)]
34pub enum PredicateBody {
35    /// A TensorLogic expression
36    Expression(Box<TLExpr>),
37    /// Reference to another composite predicate
38    Reference { name: String, args: Vec<String> },
39    /// Conjunction of multiple predicates
40    And(Vec<PredicateBody>),
41    /// Disjunction of multiple predicates
42    Or(Vec<PredicateBody>),
43    /// Negation
44    Not(Box<PredicateBody>),
45}
46
47/// A registry of composite predicates for lookup and expansion.
48#[derive(Clone, Debug, Default, Serialize, Deserialize)]
49pub struct CompositeRegistry {
50    predicates: HashMap<String, CompositePredicate>,
51}
52
53impl CompositePredicate {
54    /// Creates a new composite predicate.
55    pub fn new(name: impl Into<String>, parameters: Vec<String>, body: PredicateBody) -> Self {
56        CompositePredicate {
57            name: name.into(),
58            parameters,
59            body,
60            description: None,
61        }
62    }
63
64    /// Sets the description for this composite predicate.
65    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
66        self.description = Some(desc.into());
67        self
68    }
69
70    /// Returns the arity (number of parameters) of this predicate.
71    pub fn arity(&self) -> usize {
72        self.parameters.len()
73    }
74
75    /// Validates that this composite predicate is well-formed.
76    pub fn validate(&self) -> Result<(), AdapterError> {
77        // Check that all parameters are unique
78        let mut seen = std::collections::HashSet::new();
79        for param in &self.parameters {
80            if !seen.insert(param) {
81                return Err(AdapterError::InvalidParametricType(format!(
82                    "Duplicate parameter '{}' in predicate '{}'",
83                    param, self.name
84                )));
85            }
86        }
87
88        // Validate the body
89        self.body.validate(&self.parameters)?;
90
91        Ok(())
92    }
93
94    /// Expands this composite predicate with the given arguments.
95    ///
96    /// Substitutes all parameter occurrences in the body with the provided arguments.
97    pub fn expand(&self, args: &[String]) -> Result<PredicateBody, AdapterError> {
98        if args.len() != self.parameters.len() {
99            return Err(AdapterError::ArityMismatch {
100                name: self.name.clone(),
101                expected: self.parameters.len(),
102                found: args.len(),
103            });
104        }
105
106        // Create substitution map
107        let mut substitutions = HashMap::new();
108        for (param, arg) in self.parameters.iter().zip(args.iter()) {
109            substitutions.insert(param.clone(), arg.clone());
110        }
111
112        self.body.substitute(&substitutions)
113    }
114}
115
116impl PredicateBody {
117    /// Validates that this predicate body is well-formed.
118    fn validate(&self, parameters: &[String]) -> Result<(), AdapterError> {
119        match self {
120            PredicateBody::Expression(_) => Ok(()), // TLExpr validation handled elsewhere
121            PredicateBody::Reference { args, .. } => {
122                // Check that all args reference valid parameters
123                for arg in args {
124                    if !parameters.contains(arg) && !arg.starts_with('_') {
125                        return Err(AdapterError::UnboundVariable(arg.clone()));
126                    }
127                }
128                Ok(())
129            }
130            PredicateBody::And(bodies) | PredicateBody::Or(bodies) => {
131                for body in bodies {
132                    body.validate(parameters)?;
133                }
134                Ok(())
135            }
136            PredicateBody::Not(body) => body.validate(parameters),
137        }
138    }
139
140    /// Substitutes parameters with concrete arguments.
141    fn substitute(
142        &self,
143        substitutions: &HashMap<String, String>,
144    ) -> Result<PredicateBody, AdapterError> {
145        match self {
146            PredicateBody::Expression(expr) => {
147                // For now, return as-is. Full expression substitution would require
148                // walking the TLExpr tree and replacing variable names.
149                Ok(PredicateBody::Expression(expr.clone()))
150            }
151            PredicateBody::Reference { name, args } => {
152                let new_args = args
153                    .iter()
154                    .map(|arg| {
155                        substitutions
156                            .get(arg)
157                            .cloned()
158                            .unwrap_or_else(|| arg.clone())
159                    })
160                    .collect();
161                Ok(PredicateBody::Reference {
162                    name: name.clone(),
163                    args: new_args,
164                })
165            }
166            PredicateBody::And(bodies) => {
167                let new_bodies: Result<Vec<_>, _> =
168                    bodies.iter().map(|b| b.substitute(substitutions)).collect();
169                Ok(PredicateBody::And(new_bodies?))
170            }
171            PredicateBody::Or(bodies) => {
172                let new_bodies: Result<Vec<_>, _> =
173                    bodies.iter().map(|b| b.substitute(substitutions)).collect();
174                Ok(PredicateBody::Or(new_bodies?))
175            }
176            PredicateBody::Not(body) => Ok(PredicateBody::Not(Box::new(
177                body.substitute(substitutions)?,
178            ))),
179        }
180    }
181}
182
183impl CompositeRegistry {
184    /// Creates a new empty composite registry.
185    pub fn new() -> Self {
186        CompositeRegistry::default()
187    }
188
189    /// Registers a composite predicate.
190    pub fn register(&mut self, predicate: CompositePredicate) -> Result<(), AdapterError> {
191        predicate.validate()?;
192        self.predicates.insert(predicate.name.clone(), predicate);
193        Ok(())
194    }
195
196    /// Gets a composite predicate by name.
197    pub fn get(&self, name: &str) -> Option<&CompositePredicate> {
198        self.predicates.get(name)
199    }
200
201    /// Checks if a predicate is registered.
202    pub fn contains(&self, name: &str) -> bool {
203        self.predicates.contains_key(name)
204    }
205
206    /// Expands a composite predicate with the given arguments.
207    pub fn expand(&self, name: &str, args: &[String]) -> Result<PredicateBody, AdapterError> {
208        let predicate = self
209            .get(name)
210            .ok_or_else(|| AdapterError::PredicateNotFound(name.to_string()))?;
211
212        predicate.expand(args)
213    }
214
215    /// Returns the number of registered composite predicates.
216    pub fn len(&self) -> usize {
217        self.predicates.len()
218    }
219
220    /// Checks if the registry is empty.
221    pub fn is_empty(&self) -> bool {
222        self.predicates.is_empty()
223    }
224
225    /// Lists all registered predicate names.
226    pub fn list_predicates(&self) -> Vec<String> {
227        self.predicates.keys().cloned().collect()
228    }
229}
230
231/// A template for creating multiple similar predicates.
232///
233/// Templates allow defining patterns for predicates that can be instantiated
234/// with different domains or properties.
235#[derive(Clone, Debug, Serialize, Deserialize)]
236pub struct PredicateTemplate {
237    /// Template name
238    pub name: String,
239    /// Type parameters (e.g., ["T", "U"])
240    pub type_params: Vec<String>,
241    /// Value parameters (e.g., ["relation"])
242    pub value_params: Vec<String>,
243    /// The body defining how to construct the predicate
244    pub body: PredicateBody,
245}
246
247impl PredicateTemplate {
248    /// Creates a new predicate template.
249    pub fn new(
250        name: impl Into<String>,
251        type_params: Vec<String>,
252        value_params: Vec<String>,
253        body: PredicateBody,
254    ) -> Self {
255        PredicateTemplate {
256            name: name.into(),
257            type_params,
258            value_params,
259            body,
260        }
261    }
262
263    /// Instantiates this template with concrete types and values.
264    pub fn instantiate(
265        &self,
266        type_args: &[String],
267        value_args: &[String],
268    ) -> Result<CompositePredicate, AdapterError> {
269        if type_args.len() != self.type_params.len() {
270            return Err(AdapterError::ArityMismatch {
271                name: format!("{}[type params]", self.name),
272                expected: self.type_params.len(),
273                found: type_args.len(),
274            });
275        }
276
277        if value_args.len() != self.value_params.len() {
278            return Err(AdapterError::ArityMismatch {
279                name: format!("{}[value params]", self.name),
280                expected: self.value_params.len(),
281                found: value_args.len(),
282            });
283        }
284
285        // Create substitution map
286        let mut substitutions = HashMap::new();
287        for (param, arg) in self.type_params.iter().zip(type_args.iter()) {
288            substitutions.insert(param.clone(), arg.clone());
289        }
290        for (param, arg) in self.value_params.iter().zip(value_args.iter()) {
291            substitutions.insert(param.clone(), arg.clone());
292        }
293
294        // Generate instance name
295        let instance_name = format!("{}<{}>", self.name, type_args.join(", "));
296
297        // Substitute in body
298        let instance_body = self.body.substitute(&substitutions)?;
299
300        Ok(CompositePredicate {
301            name: instance_name,
302            parameters: value_args.to_vec(),
303            body: instance_body,
304            description: None,
305        })
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_composite_predicate_creation() {
315        let pred = CompositePredicate::new(
316            "friend",
317            vec!["x".to_string(), "y".to_string()],
318            PredicateBody::Reference {
319                name: "knows".to_string(),
320                args: vec!["x".to_string(), "y".to_string()],
321            },
322        );
323
324        assert_eq!(pred.name, "friend");
325        assert_eq!(pred.arity(), 2);
326    }
327
328    #[test]
329    fn test_composite_predicate_validation() {
330        let valid = CompositePredicate::new(
331            "test",
332            vec!["x".to_string(), "y".to_string()],
333            PredicateBody::Reference {
334                name: "knows".to_string(),
335                args: vec!["x".to_string(), "y".to_string()],
336            },
337        );
338        assert!(valid.validate().is_ok());
339
340        let invalid = CompositePredicate::new(
341            "test",
342            vec!["x".to_string(), "x".to_string()], // Duplicate parameter
343            PredicateBody::Reference {
344                name: "knows".to_string(),
345                args: vec!["x".to_string()],
346            },
347        );
348        assert!(invalid.validate().is_err());
349    }
350
351    #[test]
352    fn test_composite_registry() {
353        let mut registry = CompositeRegistry::new();
354
355        let pred = CompositePredicate::new(
356            "friend",
357            vec!["x".to_string(), "y".to_string()],
358            PredicateBody::Reference {
359                name: "knows".to_string(),
360                args: vec!["x".to_string(), "y".to_string()],
361            },
362        );
363
364        registry.register(pred).unwrap();
365        assert!(registry.contains("friend"));
366        assert_eq!(registry.len(), 1);
367    }
368
369    #[test]
370    fn test_predicate_expansion() {
371        let pred = CompositePredicate::new(
372            "friend",
373            vec!["x".to_string(), "y".to_string()],
374            PredicateBody::Reference {
375                name: "knows".to_string(),
376                args: vec!["x".to_string(), "y".to_string()],
377            },
378        );
379
380        let expanded = pred
381            .expand(&["alice".to_string(), "bob".to_string()])
382            .unwrap();
383
384        match expanded {
385            PredicateBody::Reference { name, args } => {
386                assert_eq!(name, "knows");
387                assert_eq!(args, vec!["alice".to_string(), "bob".to_string()]);
388            }
389            _ => panic!("Expected Reference"),
390        }
391    }
392
393    #[test]
394    fn test_predicate_template() {
395        let template = PredicateTemplate::new(
396            "related",
397            vec!["T".to_string()],
398            vec!["x".to_string(), "y".to_string()],
399            PredicateBody::Reference {
400                name: "connected".to_string(),
401                args: vec!["x".to_string(), "y".to_string()],
402            },
403        );
404
405        let instance = template
406            .instantiate(&["Person".to_string()], &["a".to_string(), "b".to_string()])
407            .unwrap();
408
409        assert_eq!(instance.name, "related<Person>");
410        assert_eq!(instance.parameters, vec!["a".to_string(), "b".to_string()]);
411    }
412
413    #[test]
414    fn test_composite_and() {
415        let body = PredicateBody::And(vec![
416            PredicateBody::Reference {
417                name: "knows".to_string(),
418                args: vec!["x".to_string(), "y".to_string()],
419            },
420            PredicateBody::Reference {
421                name: "trusts".to_string(),
422                args: vec!["x".to_string(), "y".to_string()],
423            },
424        ]);
425
426        let pred = CompositePredicate::new("friend", vec!["x".to_string(), "y".to_string()], body);
427
428        assert!(pred.validate().is_ok());
429    }
430
431    #[test]
432    fn test_composite_or() {
433        let body = PredicateBody::Or(vec![
434            PredicateBody::Reference {
435                name: "colleague".to_string(),
436                args: vec!["x".to_string(), "y".to_string()],
437            },
438            PredicateBody::Reference {
439                name: "friend".to_string(),
440                args: vec!["x".to_string(), "y".to_string()],
441            },
442        ]);
443
444        let pred =
445            CompositePredicate::new("connected", vec!["x".to_string(), "y".to_string()], body);
446
447        assert!(pred.validate().is_ok());
448    }
449
450    #[test]
451    fn test_composite_not() {
452        let body = PredicateBody::Not(Box::new(PredicateBody::Reference {
453            name: "enemy".to_string(),
454            args: vec!["x".to_string(), "y".to_string()],
455        }));
456
457        let pred =
458            CompositePredicate::new("not_enemy", vec!["x".to_string(), "y".to_string()], body);
459
460        assert!(pred.validate().is_ok());
461    }
462}