1use serde::{Deserialize, Serialize};
4
5use crate::parametric_types::{unify, ParametricType, TypeSubstitution};
6use crate::{IrError, TypeAnnotation};
7
8#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
10pub struct PredicateSignature {
11 pub name: String,
12 pub arg_types: Vec<TypeAnnotation>,
13 pub arity: usize,
14 #[serde(skip_serializing_if = "Option::is_none")]
16 pub parametric_types: Option<Vec<ParametricType>>,
17}
18
19impl PredicateSignature {
20 pub fn new(name: impl Into<String>, arg_types: Vec<TypeAnnotation>) -> Self {
21 let arity = arg_types.len();
22 PredicateSignature {
23 name: name.into(),
24 arg_types,
25 arity,
26 parametric_types: None,
27 }
28 }
29
30 pub fn parametric(name: impl Into<String>, parametric_types: Vec<ParametricType>) -> Self {
32 let arity = parametric_types.len();
33 PredicateSignature {
34 name: name.into(),
35 arg_types: Vec::new(), arity,
37 parametric_types: Some(parametric_types),
38 }
39 }
40
41 pub fn untyped(name: impl Into<String>, arity: usize) -> Self {
43 PredicateSignature {
44 name: name.into(),
45 arg_types: Vec::new(),
46 arity,
47 parametric_types: None,
48 }
49 }
50
51 pub fn matches_arity(&self, arg_count: usize) -> bool {
53 self.arity == arg_count
54 }
55
56 pub fn matches_types(&self, arg_types: &[Option<&TypeAnnotation>]) -> bool {
58 if arg_types.len() != self.arity {
59 return false;
60 }
61
62 if self.arg_types.is_empty() && self.parametric_types.is_none() {
64 return true;
65 }
66
67 for (i, expected) in self.arg_types.iter().enumerate() {
69 if let Some(actual) = arg_types[i] {
70 if expected != actual {
71 return false;
72 }
73 }
74 }
76
77 true
78 }
79
80 pub fn unify_parametric(
82 &self,
83 arg_types: &[ParametricType],
84 ) -> Result<TypeSubstitution, IrError> {
85 if arg_types.len() != self.arity {
86 return Err(IrError::ArityMismatch {
87 name: self.name.clone(),
88 expected: self.arity,
89 actual: arg_types.len(),
90 });
91 }
92
93 let Some(ref param_types) = self.parametric_types else {
94 return Ok(TypeSubstitution::new());
96 };
97
98 let mut subst = TypeSubstitution::new();
100 for (expected, actual) in param_types.iter().zip(arg_types.iter()) {
101 let new_subst = unify(expected, actual)?;
102 subst = crate::parametric_types::compose_substitutions(&subst, &new_subst);
104 }
105
106 Ok(subst)
107 }
108
109 pub fn is_parametric(&self) -> bool {
111 self.parametric_types.is_some()
112 }
113
114 pub fn get_parametric_types(&self) -> Option<&[ParametricType]> {
116 self.parametric_types.as_deref()
117 }
118
119 pub fn instantiate(&self, subst: &TypeSubstitution) -> PredicateSignature {
121 let parametric_types = self
122 .parametric_types
123 .as_ref()
124 .map(|types| types.iter().map(|ty| ty.substitute(subst)).collect());
125
126 PredicateSignature {
127 name: self.name.clone(),
128 arg_types: self.arg_types.clone(),
129 arity: self.arity,
130 parametric_types,
131 }
132 }
133}
134
135#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
137pub struct SignatureRegistry {
138 signatures: Vec<PredicateSignature>,
139}
140
141impl SignatureRegistry {
142 pub fn new() -> Self {
143 Self::default()
144 }
145
146 pub fn register(&mut self, signature: PredicateSignature) {
148 self.signatures.push(signature);
149 }
150
151 pub fn get(&self, name: &str) -> Option<&PredicateSignature> {
153 self.signatures.iter().find(|sig| sig.name == name)
154 }
155
156 pub fn all(&self) -> &[PredicateSignature] {
158 &self.signatures
159 }
160
161 pub fn contains(&self, name: &str) -> bool {
163 self.get(name).is_some()
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_signature_creation() {
173 let sig = PredicateSignature::new(
174 "knows",
175 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
176 );
177 assert_eq!(sig.name, "knows");
178 assert_eq!(sig.arity, 2);
179 assert_eq!(sig.arg_types.len(), 2);
180 }
181
182 #[test]
183 fn test_signature_arity_matching() {
184 let sig = PredicateSignature::new(
185 "knows",
186 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
187 );
188 assert!(sig.matches_arity(2));
189 assert!(!sig.matches_arity(1));
190 assert!(!sig.matches_arity(3));
191 }
192
193 #[test]
194 fn test_signature_type_matching() {
195 let sig = PredicateSignature::new(
196 "knows",
197 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
198 );
199
200 let person_type = TypeAnnotation::new("Person");
201 let thing_type = TypeAnnotation::new("Thing");
202
203 assert!(sig.matches_types(&[Some(&person_type), Some(&person_type)]));
205
206 assert!(!sig.matches_types(&[Some(&person_type), Some(&thing_type)]));
208
209 assert!(sig.matches_types(&[None, Some(&person_type)]));
211 assert!(sig.matches_types(&[None, None]));
212 }
213
214 #[test]
215 fn test_signature_registry() {
216 let mut registry = SignatureRegistry::new();
217
218 let knows_sig = PredicateSignature::new(
219 "knows",
220 vec![TypeAnnotation::new("Person"), TypeAnnotation::new("Person")],
221 );
222 registry.register(knows_sig);
223
224 assert!(registry.contains("knows"));
225 assert!(!registry.contains("likes"));
226
227 let retrieved = registry.get("knows").unwrap();
228 assert_eq!(retrieved.arity, 2);
229 }
230
231 #[test]
232 fn test_untyped_signature() {
233 let sig = PredicateSignature::untyped("pred", 3);
234 assert_eq!(sig.arity, 3);
235 assert!(sig.arg_types.is_empty());
236
237 let any_type = TypeAnnotation::new("AnyType");
239 assert!(sig.matches_types(&[Some(&any_type), None, Some(&any_type)]));
240 }
241
242 #[test]
243 fn test_parametric_signature_creation() {
244 let t_var = ParametricType::variable("T");
245 let sig = PredicateSignature::parametric(
246 "contains",
247 vec![ParametricType::list(t_var.clone()), t_var.clone()],
248 );
249
250 assert_eq!(sig.name, "contains");
251 assert_eq!(sig.arity, 2);
252 assert!(sig.is_parametric());
253 assert_eq!(sig.get_parametric_types().unwrap().len(), 2);
254 }
255
256 #[test]
257 fn test_parametric_signature_unification() {
258 let t_var = ParametricType::variable("T");
259 let sig = PredicateSignature::parametric(
260 "contains",
261 vec![ParametricType::list(t_var.clone()), t_var.clone()],
262 );
263
264 let int_type = ParametricType::concrete("Int");
265 let list_int = ParametricType::list(int_type.clone());
266
267 let subst = sig.unify_parametric(&[list_int, int_type.clone()]).unwrap();
269 assert_eq!(subst.get("T").unwrap(), &int_type);
270 }
271
272 #[test]
273 fn test_parametric_signature_instantiation() {
274 let t_var = ParametricType::variable("T");
275 let sig = PredicateSignature::parametric("identity", vec![t_var.clone(), t_var.clone()]);
276
277 let int_type = ParametricType::concrete("Int");
278 let mut subst = TypeSubstitution::new();
279 subst.insert("T".to_string(), int_type.clone());
280
281 let instantiated = sig.instantiate(&subst);
282 assert!(instantiated.is_parametric());
283 let param_types = instantiated.get_parametric_types().unwrap();
284 assert_eq!(param_types[0], int_type);
285 assert_eq!(param_types[1], int_type);
286 }
287
288 #[test]
289 fn test_parametric_signature_arity_mismatch() {
290 let t_var = ParametricType::variable("T");
291 let sig = PredicateSignature::parametric("pred", vec![t_var.clone()]);
292
293 let int_type = ParametricType::concrete("Int");
294 let result = sig.unify_parametric(&[int_type.clone(), int_type]);
296 assert!(result.is_err());
297 }
298
299 #[test]
300 fn test_parametric_signature_complex_types() {
301 let t_var = ParametricType::variable("T");
302 let u_var = ParametricType::variable("U");
303
304 let sig = PredicateSignature::parametric(
306 "map_over",
307 vec![
308 ParametricType::function(t_var.clone(), u_var.clone()),
309 ParametricType::list(t_var.clone()),
310 ParametricType::list(u_var.clone()),
311 ],
312 );
313
314 let int_type = ParametricType::concrete("Int");
315 let string_type = ParametricType::concrete("String");
316
317 let subst = sig
319 .unify_parametric(&[
320 ParametricType::function(int_type.clone(), string_type.clone()),
321 ParametricType::list(int_type.clone()),
322 ParametricType::list(string_type.clone()),
323 ])
324 .unwrap();
325
326 assert_eq!(subst.get("T").unwrap(), &int_type);
327 assert_eq!(subst.get("U").unwrap(), &string_type);
328 }
329
330 #[test]
331 fn test_type_annotation_parametric_conversion() {
332 let type_ann = TypeAnnotation::new("Int");
333 let param_type = type_ann.to_parametric();
334 assert_eq!(param_type, ParametricType::concrete("Int"));
335
336 let converted_back = TypeAnnotation::from_parametric(¶m_type);
338 assert_eq!(converted_back, Some(type_ann));
339
340 let list_int = ParametricType::list(ParametricType::concrete("Int"));
342 assert!(TypeAnnotation::from_parametric(&list_int).is_none());
343 }
344}