Skip to main content

tensorlogic_ir/
signature.rs

1//! Predicate signatures and metadata.
2
3use serde::{Deserialize, Serialize};
4
5use crate::parametric_types::{unify, ParametricType, TypeSubstitution};
6use crate::{IrError, TypeAnnotation};
7
8/// Signature for a predicate: defines expected argument types
9#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
10pub struct PredicateSignature {
11    pub name: String,
12    pub arg_types: Vec<TypeAnnotation>,
13    pub arity: usize,
14    /// Optional parametric type signature for generic predicates
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub parametric_types: Option<Vec<ParametricType>>,
17}
18
19impl PredicateSignature {
20    pub fn new(name: impl Into<String>, arg_types: Vec<TypeAnnotation>) -> Self {
21        let arity = arg_types.len();
22        PredicateSignature {
23            name: name.into(),
24            arg_types,
25            arity,
26            parametric_types: None,
27        }
28    }
29
30    /// Create a parametric signature
31    pub fn parametric(name: impl Into<String>, parametric_types: Vec<ParametricType>) -> Self {
32        let arity = parametric_types.len();
33        PredicateSignature {
34            name: name.into(),
35            arg_types: Vec::new(), // Populated from parametric types if needed
36            arity,
37            parametric_types: Some(parametric_types),
38        }
39    }
40
41    /// Create an untyped signature (for backward compatibility)
42    pub fn untyped(name: impl Into<String>, arity: usize) -> Self {
43        PredicateSignature {
44            name: name.into(),
45            arg_types: Vec::new(),
46            arity,
47            parametric_types: None,
48        }
49    }
50
51    /// Check if this signature matches the given number of arguments
52    pub fn matches_arity(&self, arg_count: usize) -> bool {
53        self.arity == arg_count
54    }
55
56    /// Check if the given argument types match this signature
57    pub fn matches_types(&self, arg_types: &[Option<&TypeAnnotation>]) -> bool {
58        if arg_types.len() != self.arity {
59            return false;
60        }
61
62        // If signature has no type annotations, accept any types
63        if self.arg_types.is_empty() && self.parametric_types.is_none() {
64            return true;
65        }
66
67        // Check each argument type
68        for (i, expected) in self.arg_types.iter().enumerate() {
69            if let Some(actual) = arg_types[i] {
70                if expected != actual {
71                    return false;
72                }
73            }
74            // If actual type is None, we accept it (untyped argument)
75        }
76
77        true
78    }
79
80    /// Unify parametric signature with concrete argument types
81    pub fn unify_parametric(
82        &self,
83        arg_types: &[ParametricType],
84    ) -> Result<TypeSubstitution, IrError> {
85        if arg_types.len() != self.arity {
86            return Err(IrError::ArityMismatch {
87                name: self.name.clone(),
88                expected: self.arity,
89                actual: arg_types.len(),
90            });
91        }
92
93        let Some(ref param_types) = self.parametric_types else {
94            // No parametric types, fall back to simple matching
95            return Ok(TypeSubstitution::new());
96        };
97
98        // Unify each argument type with the parametric signature
99        let mut subst = TypeSubstitution::new();
100        for (expected, actual) in param_types.iter().zip(arg_types.iter()) {
101            let new_subst = unify(expected, actual)?;
102            // Compose substitutions
103            subst = crate::parametric_types::compose_substitutions(&subst, &new_subst);
104        }
105
106        Ok(subst)
107    }
108
109    /// Check if this is a parametric signature
110    pub fn is_parametric(&self) -> bool {
111        self.parametric_types.is_some()
112    }
113
114    /// Get the parametric types if present
115    pub fn get_parametric_types(&self) -> Option<&[ParametricType]> {
116        self.parametric_types.as_deref()
117    }
118
119    /// Instantiate a parametric signature with a substitution
120    pub fn instantiate(&self, subst: &TypeSubstitution) -> PredicateSignature {
121        let parametric_types = self
122            .parametric_types
123            .as_ref()
124            .map(|types| types.iter().map(|ty| ty.substitute(subst)).collect());
125
126        PredicateSignature {
127            name: self.name.clone(),
128            arg_types: self.arg_types.clone(),
129            arity: self.arity,
130            parametric_types,
131        }
132    }
133}
134
135/// Registry of predicate signatures
136#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
137pub struct SignatureRegistry {
138    signatures: Vec<PredicateSignature>,
139}
140
141impl SignatureRegistry {
142    pub fn new() -> Self {
143        Self::default()
144    }
145
146    /// Register a new predicate signature
147    pub fn register(&mut self, signature: PredicateSignature) {
148        self.signatures.push(signature);
149    }
150
151    /// Look up a signature by predicate name
152    pub fn get(&self, name: &str) -> Option<&PredicateSignature> {
153        self.signatures.iter().find(|sig| sig.name == name)
154    }
155
156    /// Get all registered signatures
157    pub fn all(&self) -> &[PredicateSignature] {
158        &self.signatures
159    }
160
161    /// Check if a predicate is registered
162    pub fn contains(&self, name: &str) -> bool {
163        self.get(name).is_some()
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_signature_creation() {
173        let sig = PredicateSignature::new(
174            "knows",
175            vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
176        );
177        assert_eq!(sig.name, "knows");
178        assert_eq!(sig.arity, 2);
179        assert_eq!(sig.arg_types.len(), 2);
180    }
181
182    #[test]
183    fn test_signature_arity_matching() {
184        let sig = PredicateSignature::new(
185            "knows",
186            vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
187        );
188        assert!(sig.matches_arity(2));
189        assert!(!sig.matches_arity(1));
190        assert!(!sig.matches_arity(3));
191    }
192
193    #[test]
194    fn test_signature_type_matching() {
195        let sig = PredicateSignature::new(
196            "knows",
197            vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
198        );
199
200        let person_type = TypeAnnotation::new("Person");
201        let thing_type = TypeAnnotation::new("Thing");
202
203        // Matching types
204        assert!(sig.matches_types(&[Some(&person_type), Some(&person_type)]));
205
206        // Mismatched types
207        assert!(!sig.matches_types(&[Some(&person_type), Some(&thing_type)]));
208
209        // Untyped arguments (should accept)
210        assert!(sig.matches_types(&[None, Some(&person_type)]));
211        assert!(sig.matches_types(&[None, None]));
212    }
213
214    #[test]
215    fn test_signature_registry() {
216        let mut registry = SignatureRegistry::new();
217
218        let knows_sig = PredicateSignature::new(
219            "knows",
220            vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
221        );
222        registry.register(knows_sig);
223
224        assert!(registry.contains("knows"));
225        assert!(!registry.contains("likes"));
226
227        let retrieved = registry.get("knows").unwrap();
228        assert_eq!(retrieved.arity, 2);
229    }
230
231    #[test]
232    fn test_untyped_signature() {
233        let sig = PredicateSignature::untyped("pred", 3);
234        assert_eq!(sig.arity, 3);
235        assert!(sig.arg_types.is_empty());
236
237        // Untyped signature should accept any types
238        let any_type = TypeAnnotation::new("AnyType");
239        assert!(sig.matches_types(&[Some(&any_type), None, Some(&any_type)]));
240    }
241
242    #[test]
243    fn test_parametric_signature_creation() {
244        let t_var = ParametricType::variable("T");
245        let sig = PredicateSignature::parametric(
246            "contains",
247            vec![ParametricType::list(t_var.clone()), t_var.clone()],
248        );
249
250        assert_eq!(sig.name, "contains");
251        assert_eq!(sig.arity, 2);
252        assert!(sig.is_parametric());
253        assert_eq!(sig.get_parametric_types().unwrap().len(), 2);
254    }
255
256    #[test]
257    fn test_parametric_signature_unification() {
258        let t_var = ParametricType::variable("T");
259        let sig = PredicateSignature::parametric(
260            "contains",
261            vec![ParametricType::list(t_var.clone()), t_var.clone()],
262        );
263
264        let int_type = ParametricType::concrete("Int");
265        let list_int = ParametricType::list(int_type.clone());
266
267        // Unify List<T>, T with List<Int>, Int
268        let subst = sig.unify_parametric(&[list_int, int_type.clone()]).unwrap();
269        assert_eq!(subst.get("T").unwrap(), &int_type);
270    }
271
272    #[test]
273    fn test_parametric_signature_instantiation() {
274        let t_var = ParametricType::variable("T");
275        let sig = PredicateSignature::parametric("identity", vec![t_var.clone(), t_var.clone()]);
276
277        let int_type = ParametricType::concrete("Int");
278        let mut subst = TypeSubstitution::new();
279        subst.insert("T".to_string(), int_type.clone());
280
281        let instantiated = sig.instantiate(&subst);
282        assert!(instantiated.is_parametric());
283        let param_types = instantiated.get_parametric_types().unwrap();
284        assert_eq!(param_types[0], int_type);
285        assert_eq!(param_types[1], int_type);
286    }
287
288    #[test]
289    fn test_parametric_signature_arity_mismatch() {
290        let t_var = ParametricType::variable("T");
291        let sig = PredicateSignature::parametric("pred", vec![t_var.clone()]);
292
293        let int_type = ParametricType::concrete("Int");
294        // Provide 2 arguments when signature expects 1
295        let result = sig.unify_parametric(&[int_type.clone(), int_type]);
296        assert!(result.is_err());
297    }
298
299    #[test]
300    fn test_parametric_signature_complex_types() {
301        let t_var = ParametricType::variable("T");
302        let u_var = ParametricType::variable("U");
303
304        // map_over: (T -> U, List<T>) -> List<U>
305        let sig = PredicateSignature::parametric(
306            "map_over",
307            vec![
308                ParametricType::function(t_var.clone(), u_var.clone()),
309                ParametricType::list(t_var.clone()),
310                ParametricType::list(u_var.clone()),
311            ],
312        );
313
314        let int_type = ParametricType::concrete("Int");
315        let string_type = ParametricType::concrete("String");
316
317        // Unify with (Int -> String, List<Int>, List<String>)
318        let subst = sig
319            .unify_parametric(&[
320                ParametricType::function(int_type.clone(), string_type.clone()),
321                ParametricType::list(int_type.clone()),
322                ParametricType::list(string_type.clone()),
323            ])
324            .unwrap();
325
326        assert_eq!(subst.get("T").unwrap(), &int_type);
327        assert_eq!(subst.get("U").unwrap(), &string_type);
328    }
329
330    #[test]
331    fn test_type_annotation_parametric_conversion() {
332        let type_ann = TypeAnnotation::new("Int");
333        let param_type = type_ann.to_parametric();
334        assert_eq!(param_type, ParametricType::concrete("Int"));
335
336        // Convert back
337        let converted_back = TypeAnnotation::from_parametric(&param_type);
338        assert_eq!(converted_back, Some(type_ann));
339
340        // Can't convert parametric types back
341        let list_int = ParametricType::list(ParametricType::concrete("Int"));
342        assert!(TypeAnnotation::from_parametric(&list_int).is_none());
343    }
344}