tensorlogic_ir/
signature.rs1use serde::{Deserialize, Serialize};
4
5use crate::TypeAnnotation;
6
7#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
9pub struct PredicateSignature {
10 pub name: String,
11 pub arg_types: Vec<TypeAnnotation>,
12 pub arity: usize,
13}
14
15impl PredicateSignature {
16 pub fn new(name: impl Into<String>, arg_types: Vec<TypeAnnotation>) -> Self {
17 let arity = arg_types.len();
18 PredicateSignature {
19 name: name.into(),
20 arg_types,
21 arity,
22 }
23 }
24
25 pub fn untyped(name: impl Into<String>, arity: usize) -> Self {
27 PredicateSignature {
28 name: name.into(),
29 arg_types: Vec::new(),
30 arity,
31 }
32 }
33
34 pub fn matches_arity(&self, arg_count: usize) -> bool {
36 self.arity == arg_count
37 }
38
39 pub fn matches_types(&self, arg_types: &[Option<&TypeAnnotation>]) -> bool {
41 if arg_types.len() != self.arity {
42 return false;
43 }
44
45 if self.arg_types.is_empty() {
47 return true;
48 }
49
50 for (i, expected) in self.arg_types.iter().enumerate() {
52 if let Some(actual) = arg_types[i] {
53 if expected != actual {
54 return false;
55 }
56 }
57 }
59
60 true
61 }
62}
63
64#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
66pub struct SignatureRegistry {
67 signatures: Vec<PredicateSignature>,
68}
69
70impl SignatureRegistry {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn register(&mut self, signature: PredicateSignature) {
77 self.signatures.push(signature);
78 }
79
80 pub fn get(&self, name: &str) -> Option<&PredicateSignature> {
82 self.signatures.iter().find(|sig| sig.name == name)
83 }
84
85 pub fn all(&self) -> &[PredicateSignature] {
87 &self.signatures
88 }
89
90 pub fn contains(&self, name: &str) -> bool {
92 self.get(name).is_some()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_signature_creation() {
102 let sig = PredicateSignature::new(
103 "knows",
104 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
105 );
106 assert_eq!(sig.name, "knows");
107 assert_eq!(sig.arity, 2);
108 assert_eq!(sig.arg_types.len(), 2);
109 }
110
111 #[test]
112 fn test_signature_arity_matching() {
113 let sig = PredicateSignature::new(
114 "knows",
115 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
116 );
117 assert!(sig.matches_arity(2));
118 assert!(!sig.matches_arity(1));
119 assert!(!sig.matches_arity(3));
120 }
121
122 #[test]
123 fn test_signature_type_matching() {
124 let sig = PredicateSignature::new(
125 "knows",
126 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
127 );
128
129 let person_type = TypeAnnotation::new("Person");
130 let thing_type = TypeAnnotation::new("Thing");
131
132 assert!(sig.matches_types(&[Some(&person_type), Some(&person_type)]));
134
135 assert!(!sig.matches_types(&[Some(&person_type), Some(&thing_type)]));
137
138 assert!(sig.matches_types(&[None, Some(&person_type)]));
140 assert!(sig.matches_types(&[None, None]));
141 }
142
143 #[test]
144 fn test_signature_registry() {
145 let mut registry = SignatureRegistry::new();
146
147 let knows_sig = PredicateSignature::new(
148 "knows",
149 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
150 );
151 registry.register(knows_sig);
152
153 assert!(registry.contains("knows"));
154 assert!(!registry.contains("likes"));
155
156 let retrieved = registry.get("knows").unwrap();
157 assert_eq!(retrieved.arity, 2);
158 }
159
160 #[test]
161 fn test_untyped_signature() {
162 let sig = PredicateSignature::untyped("pred", 3);
163 assert_eq!(sig.arity, 3);
164 assert!(sig.arg_types.is_empty());
165
166 let any_type = TypeAnnotation::new("AnyType");
168 assert!(sig.matches_types(&[Some(&any_type), None, Some(&any_type)]));
169 }
170}