zen_expression/functions/
defs.rs

1use crate::functions::arguments::Arguments;
2use crate::variable::VariableType;
3use crate::Variable;
4use std::collections::HashSet;
5use std::rc::Rc;
6
7pub trait FunctionDefinition {
8    fn required_parameters(&self) -> usize;
9    fn optional_parameters(&self) -> usize;
10    fn check_types(&self, args: &[VariableType]) -> FunctionTypecheck;
11    fn call(&self, args: Arguments) -> anyhow::Result<Variable>;
12    fn param_type(&self, index: usize) -> Option<VariableType>;
13    fn param_type_str(&self, index: usize) -> String;
14    fn return_type(&self) -> VariableType;
15    fn return_type_str(&self) -> String;
16}
17
18#[derive(Debug, Default)]
19pub struct FunctionTypecheck {
20    pub general: Option<String>,
21    pub arguments: Vec<(usize, String)>,
22    pub return_type: VariableType,
23}
24
25#[derive(Clone)]
26pub struct FunctionSignature {
27    pub parameters: Vec<VariableType>,
28    pub return_type: VariableType,
29}
30
31impl FunctionSignature {
32    pub fn single(parameter: VariableType, return_type: VariableType) -> Self {
33        Self {
34            parameters: vec![parameter],
35            return_type,
36        }
37    }
38}
39
40#[derive(Clone)]
41pub struct StaticFunction {
42    pub signature: FunctionSignature,
43    pub implementation: Rc<dyn Fn(Arguments) -> anyhow::Result<Variable>>,
44}
45
46impl FunctionDefinition for StaticFunction {
47    fn required_parameters(&self) -> usize {
48        self.signature.parameters.len()
49    }
50
51    fn optional_parameters(&self) -> usize {
52        0
53    }
54
55    fn check_types(&self, args: &[VariableType]) -> FunctionTypecheck {
56        let mut typecheck = FunctionTypecheck::default();
57        typecheck.return_type = self.signature.return_type.clone();
58
59        if args.len() != self.required_parameters() {
60            typecheck.general = Some(format!(
61                "Expected `{}` arguments, got `{}`.",
62                self.required_parameters(),
63                args.len()
64            ));
65        }
66
67        // Check each parameter type
68        for (i, (arg, expected_type)) in args
69            .iter()
70            .zip(self.signature.parameters.iter())
71            .enumerate()
72        {
73            if !arg.satisfies(expected_type) {
74                typecheck.arguments.push((
75                    i,
76                    format!(
77                        "Argument of type `{arg}` is not assignable to parameter of type `{expected_type}`.",
78                    ),
79                ));
80            }
81        }
82
83        typecheck
84    }
85
86    fn call(&self, args: Arguments) -> anyhow::Result<Variable> {
87        (&self.implementation)(args)
88    }
89
90    fn param_type(&self, index: usize) -> Option<VariableType> {
91        self.signature.parameters.get(index).cloned()
92    }
93
94    fn param_type_str(&self, index: usize) -> String {
95        self.signature
96            .parameters
97            .get(index)
98            .map(|x| x.to_string())
99            .unwrap_or_else(|| "never".to_string())
100    }
101
102    fn return_type(&self) -> VariableType {
103        self.signature.return_type.clone()
104    }
105
106    fn return_type_str(&self) -> String {
107        self.signature.return_type.to_string()
108    }
109}
110
111#[derive(Clone)]
112pub struct CompositeFunction {
113    pub signatures: Vec<FunctionSignature>,
114    pub implementation: Rc<dyn Fn(Arguments) -> anyhow::Result<Variable>>,
115}
116
117impl FunctionDefinition for CompositeFunction {
118    fn required_parameters(&self) -> usize {
119        self.signatures
120            .iter()
121            .map(|x| x.parameters.len())
122            .min()
123            .unwrap_or(0)
124    }
125
126    fn optional_parameters(&self) -> usize {
127        let required_params = self.required_parameters();
128        let max = self
129            .signatures
130            .iter()
131            .map(|x| x.parameters.len())
132            .max()
133            .unwrap_or(0);
134
135        max - required_params
136    }
137
138    fn check_types(&self, args: &[VariableType]) -> FunctionTypecheck {
139        let mut typecheck = FunctionTypecheck::default();
140        if self.signatures.is_empty() {
141            typecheck.general = Some("No implementation".to_string());
142            return typecheck;
143        }
144
145        let required_params = self.required_parameters();
146        let optional_params = self.optional_parameters();
147        let total_params = required_params + optional_params;
148
149        if args.len() < required_params || args.len() > total_params {
150            typecheck.general = Some(format!(
151                "Expected `{required_params} - {total_params}` arguments, got `{}`.",
152                args.len()
153            ))
154        }
155
156        for signature in &self.signatures {
157            let all_match = args
158                .iter()
159                .zip(signature.parameters.iter())
160                .all(|(arg, param)| arg.satisfies(param));
161            if all_match {
162                typecheck.return_type = signature.return_type.clone();
163                return typecheck;
164            }
165        }
166
167        for (i, arg) in args.iter().enumerate() {
168            let possible_types: Vec<&VariableType> = self
169                .signatures
170                .iter()
171                .filter_map(|sig| sig.parameters.get(i))
172                .collect();
173
174            if !possible_types.iter().any(|param| arg.satisfies(param)) {
175                let type_union = self.param_type_str(i);
176                typecheck.arguments.push((
177                    i,
178                    format!(
179                        "Argument of type `{arg}` is not assignable to parameter of type `{type_union}`.",
180                    ),
181                ))
182            }
183        }
184
185        let available_signatures = self
186            .signatures
187            .iter()
188            .map(|sig| {
189                let param_list = sig
190                    .parameters
191                    .iter()
192                    .map(|x| x.to_string())
193                    .collect::<Vec<_>>()
194                    .join(", ");
195                format!("`({param_list}) -> {}`", sig.return_type)
196            })
197            .collect::<Vec<_>>()
198            .join("\n");
199        typecheck.general = Some(format!("No function overload matches provided arguments. Available overloads:\n{available_signatures}"));
200
201        typecheck
202    }
203
204    fn call(&self, args: Arguments) -> anyhow::Result<Variable> {
205        (&self.implementation)(args)
206    }
207
208    fn param_type(&self, index: usize) -> Option<VariableType> {
209        self.signatures
210            .iter()
211            .filter_map(|sig| sig.parameters.get(index))
212            .cloned()
213            .reduce(|a, b| a.merge(&b))
214    }
215
216    fn param_type_str(&self, index: usize) -> String {
217        let possible_types: Vec<String> = self
218            .signatures
219            .iter()
220            .filter_map(|sig| sig.parameters.get(index))
221            .map(|x| x.to_string())
222            .collect();
223        if possible_types.is_empty() {
224            return String::from("never");
225        }
226
227        let is_optional = possible_types.len() != self.signatures.len();
228        let possible_types: Vec<String> = possible_types
229            .into_iter()
230            .collect::<HashSet<_>>()
231            .into_iter()
232            .collect();
233
234        let type_union = possible_types.join(" | ");
235        if is_optional {
236            return format!("Optional<{type_union}>");
237        }
238
239        type_union
240    }
241
242    fn return_type(&self) -> VariableType {
243        self.signatures
244            .iter()
245            .map(|sig| &sig.return_type)
246            .cloned()
247            .reduce(|a, b| a.merge(&b))
248            .unwrap_or(VariableType::Null)
249    }
250
251    fn return_type_str(&self) -> String {
252        let possible_types: Vec<String> = self
253            .signatures
254            .iter()
255            .map(|sig| sig.return_type.clone())
256            .map(|x| x.to_string())
257            .collect();
258        if possible_types.is_empty() {
259            return String::from("never");
260        }
261
262        possible_types
263            .into_iter()
264            .collect::<HashSet<_>>()
265            .into_iter()
266            .collect::<Vec<_>>()
267            .join(" | ")
268    }
269}