ryna/
inference.rs

1use std::collections::HashMap;
2
3use colored::Colorize;
4
5use crate::compilation::RynaError;
6use crate::context::RynaContext;
7use crate::interfaces::InterfaceConstraint;
8use crate::interfaces::ITERABLE_ID;
9use crate::parser::Location;
10use crate::parser::RynaExpr;
11use crate::functions::*;
12use crate::operations::*;
13use crate::types::Type;
14
15impl RynaContext {
16    pub fn get_first_unary_op(&self, id: usize, arg_type: Type, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), RynaError> {
17        if let Operator::Unary{operations, ..} = &self.unary_ops[id] {
18            'outer: for (i, op_ov) in operations.iter().enumerate() {
19                if let (true, subs) = arg_type.bindable_to_subtitutions(&op_ov.args, self) { // Take first that matches
20                    if let Some(call_t) = call_templates {
21                        for (i, t) in call_t.iter().enumerate() {
22                            if let Some(s_t) = subs.get(&i) {
23                                if t != s_t {
24                                    break 'outer;
25                                }   
26                            }
27                        }
28                    }
29                    
30                    let t_args = (0..op_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
31                    return Ok((i, if sub_t { op_ov.ret.sub_templates(&subs) } else { op_ov.ret.clone() }, op_ov.operation.is_some(), t_args));
32                }
33            }
34        }
35
36        if let Operator::Unary{representation, prefix, ..} = &self.unary_ops[id] {
37            if *prefix {
38                Err(RynaError::compiler_error(format!(
39                    "Unable to get unary operator overload for {}({})",
40                    representation,
41                    arg_type.get_name(self)
42                ), l, vec!()))
43
44            } else {
45                Err(RynaError::compiler_error(format!(
46                    "Unable to get unary operator overload for ({}){}",
47                    arg_type.get_name(self),
48                    representation
49                ), l, vec!()))
50            }
51
52        } else {
53            unreachable!()
54        }
55    }
56
57    pub fn is_unary_op_ambiguous(&self, id: usize, arg_type: Type) -> Option<Vec<(Type, Type)>> {
58        if let Operator::Unary{operations, ..} = &self.unary_ops[id] {
59            let overloads = operations.iter()
60                            .map(|op_ov| (op_ov.args.clone(), op_ov.ret.clone()))
61                            .filter(|(a, _)| arg_type.bindable_to(a, self)).collect::<Vec<_>>();
62
63            // Return Some(overloads) if the call is ambiguous, else return None
64            if overloads.len() > 1 {
65                return Some(overloads);
66
67            } else {
68                return None;
69            }
70        }
71
72        unreachable!();
73    }
74
75    pub fn get_first_binary_op(&self, id: usize, a_type: Type, b_type: Type, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), RynaError> {
76        let t = Type::And(vec!(a_type.clone(), b_type.clone()));
77
78        if let Operator::Binary{operations, ..} = &self.binary_ops[id] {
79            'outer: for (i, op_ov) in operations.iter().enumerate() {
80                if let (true, subs) = t.bindable_to_subtitutions(&op_ov.args, self) { // Take first that matches
81                    if let Some(call_t) = call_templates {
82                        for (i, t) in call_t.iter().enumerate() {
83                            if let Some(s_t) = subs.get(&i) {
84                                if t != s_t {
85                                    break 'outer;
86                                }   
87                            }
88                        }
89                    }
90
91                    let t_args = (0..op_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
92                    return Ok((i, if sub_t { op_ov.ret.sub_templates(&subs) } else { op_ov.ret.clone() }, op_ov.operation.is_some(), t_args));
93                }
94            }
95        }
96
97        if let Operator::Binary{representation, ..} = &self.binary_ops[id] {
98            Err(RynaError::compiler_error(format!(
99                "Unable to get binary operator overload for ({}){}({})",
100                a_type.get_name(self),
101                representation,
102                b_type.get_name(self)
103            ), l, vec!()))
104
105        } else {
106            unreachable!()
107        }
108    }
109
110    pub fn is_binary_op_ambiguous(&self, id: usize, a_type: Type, b_type: Type) -> Option<Vec<(Type, Type, Type)>> {
111        let t = Type::And(vec!(a_type, b_type));
112
113        if let Operator::Binary{operations, ..} = &self.binary_ops[id] {
114            let overloads = operations.iter()
115                            .filter(|op_ov| t.bindable_to(&op_ov.args, self))
116                            .map(|op_ov| {
117                                if let Type::And(t) = &op_ov.args {
118                                    (t[0].clone(), t[1].clone(), op_ov.ret.clone())
119
120                                } else {
121                                    unreachable!()
122                                }
123                            })
124                            .collect::<Vec<_>>();
125
126            // Return Some(overloads) if the call is ambiguous, else return None
127            if overloads.len() > 1 {
128                return Some(overloads);
129
130            } else {
131                return None;
132            }
133        }
134
135        unreachable!();
136    }
137
138    pub fn get_first_nary_op(&self, id: usize, a_type: Type, b_type: Vec<Type>, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), RynaError> {
139        let mut arg_types = vec!(a_type.clone());
140        arg_types.extend(b_type.iter().cloned());
141
142        let t = Type::And(arg_types.clone());
143
144        if let Operator::Nary{operations, ..} = &self.nary_ops[id] {
145            'outer: for (i, op_ov) in operations.iter().enumerate() {
146                if let (true, subs) = t.bindable_to_subtitutions(&op_ov.args, self) { // Take first that matches
147                    if let Some(call_t) = call_templates {
148                        for (i, t) in call_t.iter().enumerate() {
149                            if let Some(s_t) = subs.get(&i) {
150                                if t != s_t {
151                                    break 'outer;
152                                }   
153                            }
154                        }
155                    }
156
157                    let t_args = (0..op_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
158                    return Ok((i, if sub_t { op_ov.ret.sub_templates(&subs) } else { op_ov.ret.clone() }, op_ov.operation.is_some(), t_args));
159                }
160            }
161        }
162
163        if let Operator::Nary{open_rep, close_rep, ..} = &self.nary_ops[id] {
164            Err(RynaError::compiler_error(format!(
165                "Unable to get n-ary operator overload for {}{}{}{}",
166                a_type.get_name(self),
167                open_rep,
168                b_type.iter().map(|i| i.get_name(self)).collect::<Vec<_>>().join(", "),
169                close_rep
170            ), l, vec!()))
171
172        } else {
173            unreachable!()
174        }
175    }
176
177    pub fn is_nary_op_ambiguous(&self, id: usize, a_type: Type, b_type: Vec<Type>) -> Option<Vec<(Type, Vec<Type>, Type)>> {
178        let mut arg_types = vec!(a_type.clone());
179        arg_types.extend(b_type.iter().cloned());
180
181        let t = Type::And(arg_types);
182        
183        if let Operator::Nary{operations, ..} = &self.nary_ops[id] {
184            let overloads = operations.iter()
185                            .filter(|op_ov| t.bindable_to(&op_ov.args, self))
186                            .map(|op_ov| {
187                                if let Type::And(t) = &op_ov.args {
188                                    (t[0].clone(), t[1..].to_vec(), op_ov.ret.clone())
189
190                                } else {
191                                    unreachable!()
192                                }
193                            })
194                            .collect::<Vec<_>>();
195
196            // Return Some(overloads) if the call is ambiguous, else return None
197            if overloads.len() > 1 {
198                return Some(overloads);
199
200            } else {
201                return None;
202            }
203        }
204
205        unreachable!();
206    }
207
208    pub fn get_first_function_overload(&self, id: usize, arg_type: Vec<Type>, call_templates: Option<Vec<Type>>, sub_t: bool, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), RynaError> {
209        let t = Type::And(arg_type.clone());
210
211        'outer: for (i, f_ov) in self.functions[id].overloads.iter().enumerate() {
212            if let (true, subs) = t.bindable_to_subtitutions(&f_ov.args, self) { // Take first that matches
213                if let Some(call_t) = &call_templates {
214                    for (i, t) in call_t.iter().enumerate() {
215                        if let Some(s_t) = subs.get(&i) {
216                            if t != s_t {
217                                break 'outer;
218                            }   
219                        }
220                    }
221                }
222                
223                let t_args = (0..f_ov.templates).map(|i| subs.get(&i).cloned().unwrap_or(Type::TemplateParam(i, vec!()))).collect();
224                return Ok((i, if sub_t { f_ov.ret.sub_templates(&subs) } else { f_ov.ret.clone() }, f_ov.function.is_some(), t_args));
225            }
226        }
227
228        Err(RynaError::compiler_error(format!(
229            "Unable to get function overload for {}{}({})",
230            self.functions[id].name.green(),
231            if call_templates.is_none() || call_templates.as_ref().unwrap().is_empty() { 
232                "".into() 
233            } else { 
234                format!("<{}>", call_templates.unwrap().iter().map(|i| i.get_name(self)).collect::<Vec<_>>().join(", ")) 
235            },
236            arg_type.iter().map(|i| i.get_name(self)).collect::<Vec<_>>().join(", ")
237        ), l, vec!()))
238    }
239
240    pub fn is_function_overload_ambiguous(&self, id: usize, arg_type: Vec<Type>) -> Option<Vec<(Type, Type)>> {
241        let t = Type::And(arg_type);
242
243        let overloads = self.functions[id].overloads.iter()
244                            .map(|f_ov| (f_ov.args.clone(), f_ov.ret.clone()))
245                            .filter(|(a, _)| t.bindable_to(a, self)).collect::<Vec<_>>();
246
247        // Return Some(overloads) if the call is ambiguous, else return None
248        if overloads.len() > 1 {
249            Some(overloads)
250
251        } else {
252            None
253        }
254    }
255
256    pub fn implements_iterable(&self, container_type: &Type) -> bool {
257        for i in &self.interface_impls {
258            if i.interface_id == ITERABLE_ID && container_type.bindable_to(&i.interface_type, self) {
259                return true;
260            }
261        }
262
263        false
264    }
265
266    pub fn get_iterator_type(&self, container_type: &Type, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), RynaError> {
267        self.get_first_function_overload(ITERATOR_FUNC_ID, vec!(container_type.clone()), None, true, l)
268    }
269
270    pub fn get_iterator_output_type(&self, iterator_type: &Type, l: &Location) -> Result<(usize, Type, bool, Vec<Type>), RynaError> {
271        let it_mut = Type::MutRef(Box::new(iterator_type.clone()));
272
273        self.get_first_function_overload(NEXT_FUNC_ID, vec!(it_mut.clone()), None, true, l)
274    }
275
276    pub fn implements_destroyable(&self, t: &Type) -> bool {
277        let dint_id = self.get_interface_id("Destroyable".into()).unwrap();
278
279        self.implements_interface(t, &InterfaceConstraint::new(dint_id, vec!()), &mut HashMap::new(), &mut HashMap::new())
280    }
281
282    pub fn infer_type(&self, expr: &RynaExpr) -> Result<Type, RynaError> {
283        return match expr {
284            RynaExpr::Literal(_, obj) => Ok(obj.get_type()),
285
286            RynaExpr::DoBlock(_, _, t) => Ok(t.clone()),
287
288            RynaExpr::AttributeAccess(_, e, att_idx) => {
289                use Type::*;
290
291                let arg_type = self.infer_type(e)?;
292
293                if let Basic(id) | Template(id, _) = arg_type.deref_type() {
294                    let mut att_type = self.type_templates[*id].attributes[*att_idx].1.clone();
295
296                    // Subtitute template parameters if needed
297                    if let Template(_, ts) = arg_type.deref_type() {
298                        att_type = att_type.sub_templates(&ts.iter().cloned().enumerate().collect());
299                    }
300                    
301                    return match (&arg_type, &att_type) {
302                        (MutRef(_), Ref(_) | MutRef(_)) => Ok(att_type.clone()),
303                        (MutRef(_), _) => Ok(MutRef(Box::new(att_type.clone()))),
304
305                        (Ref(_), MutRef(i)) => Ok(Ref(i.clone())),
306                        (Ref(_), Ref(_)) => Ok(att_type.clone()),
307                        (Ref(_), _) => Ok(Ref(Box::new(att_type.clone()))),
308
309                        (_, _) => Ok(att_type.clone())
310                    };
311
312                } else {
313                    unreachable!()
314                }
315            }
316
317            RynaExpr::CompiledLambda(_, _, _, a, r, _) => Ok(
318                if a.len() == 1 && !matches!(a[0].1, Type::And(..)) {
319                    Type::Function(
320                        Box::new(a[0].1.clone()),
321                        Box::new(r.clone())
322                    )
323
324                } else {
325                    Type::Function(
326                        Box::new(Type::And(a.iter().map(|(_, t)| t).cloned().collect())),
327                        Box::new(r.clone())
328                    )
329                }
330            ),
331            
332            RynaExpr::Tuple(_, e) => {
333                let mut args = vec!();
334
335                for i in e {
336                    args.push(self.infer_type(i)?);
337                }
338
339                Ok(Type::And(args))
340            },
341
342            RynaExpr::Variable(_, _, _, t, _) => {
343                match t {
344                    Type::Ref(_) | Type::MutRef(_) => Ok(t.clone()),
345                    t => Ok(Type::MutRef(Box::new(t.clone())))
346                }
347            },
348
349            RynaExpr::UnaryOperation(l, id, t, a) => {
350                let t_sub_call = t.iter().cloned().enumerate().collect();
351                let args_type = self.infer_type(a)?.sub_templates(&t_sub_call);
352
353                let (_, r, _, subs) = self.get_first_unary_op(*id, args_type, None, false, l)?;
354
355                let t_sub_ov = subs.iter().cloned().enumerate().collect();
356
357                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
358            },
359
360            RynaExpr::BinaryOperation(l, id, t, a, b) => {
361                let t_sub_call = t.iter().cloned().enumerate().collect();
362                let a_type = self.infer_type(a)?.sub_templates(&t_sub_call);
363                let b_type = self.infer_type(b)?.sub_templates(&t_sub_call);
364
365                let (_, r, _, subs) = self.get_first_binary_op(*id, a_type, b_type, None, false, l)?;
366
367                let t_sub_ov = subs.iter().cloned().enumerate().collect();
368
369                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
370            },
371
372            RynaExpr::NaryOperation(l, id, t, a, b) => {
373                let t_sub_call = t.iter().cloned().enumerate().collect();
374                let a_type = self.infer_type(a)?.sub_templates(&t_sub_call);
375                let b_type = b.iter().map(|i| self.infer_type(i))
376                                     .collect::<Result<Vec<_>, RynaError>>()?
377                                     .into_iter()
378                                     .map(|i| i.sub_templates(&t_sub_call))
379                                     .collect();
380
381                let (_, r, _, subs) = self.get_first_nary_op(*id, a_type, b_type, None, false, l)?;
382
383                let t_sub_ov = subs.iter().cloned().enumerate().collect();
384
385                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
386            },
387
388            RynaExpr::FunctionCall(l, id, t, args) => {
389                let t_sub_call = t.iter().cloned().enumerate().collect();
390                let arg_types = args.iter().map(|i| self.infer_type(i))
391                                           .collect::<Result<Vec<_>, RynaError>>()?
392                                           .into_iter()
393                                           .map(|i| i.sub_templates(&t_sub_call))
394                                           .collect();
395
396                let (_, r, _, subs) = self.get_first_function_overload(*id, arg_types, None, true, l)?;
397
398                let t_sub_ov = subs.iter().cloned().enumerate().collect();
399
400                return Ok(r.sub_templates(&t_sub_ov).sub_templates(&t_sub_call));
401            }
402
403            RynaExpr::QualifiedName(l, _, Some(id)) => {
404                let func = &self.functions[*id];
405
406                if func.overloads.len() == 1 {
407                    let ov = &func.overloads[0];
408
409                    if ov.templates != 0 {
410                        return Err(RynaError::compiler_error(
411                            format!(
412                                "Implicit lambda for function with name {} cannot be formed from generic overload",
413                                func.name.green()
414                            ), 
415                            l, vec!()
416                        ));
417                    }
418                    
419                    if let Type::And(a) = &ov.args {
420                        if a.len() == 1 {
421                            return Ok(Type::Function(
422                                Box::new(a[0].clone()),
423                                Box::new(ov.ret.clone())
424                            ))
425        
426                        } else {
427                            return Ok(Type::Function(
428                                Box::new(Type::And(a.clone())),
429                                Box::new(ov.ret.clone())
430                            ))
431                        }
432                    }
433
434                    return Ok(Type::Function(
435                        Box::new(ov.args.clone()),
436                        Box::new(ov.ret.clone())
437                    ))
438                }
439
440                return Err(RynaError::compiler_error(
441                    format!(
442                        "Implicit lambda for function with name {} is ambiguous (found {} overloads)",
443                        func.name.green(),
444                        func.overloads.len()
445                    ), 
446                    l, vec!()
447                ));
448            }
449
450            RynaExpr::QualifiedName(l, _, _) |
451            RynaExpr::AttributeAssignment(l, _, _, _) |
452            RynaExpr::CompiledVariableDefinition(l, _, _, _, _, _) |
453            RynaExpr::CompiledVariableAssignment(l, _, _, _, _, _) |
454            RynaExpr::CompiledFor(l, _, _, _, _, _) |
455            RynaExpr::Macro(l, _, _, _, _, _) |
456            RynaExpr::Lambda(l, _, _, _, _) |
457            RynaExpr::NameReference(l, _) |
458            RynaExpr::VariableDefinition(l, _, _, _) |
459            RynaExpr::VariableAssignment(l, _, _) |
460            RynaExpr::FunctionDefinition(l, _, _, _, _, _, _) |
461            RynaExpr::PrefixOperatorDefinition(l, _, _) |
462            RynaExpr::PostfixOperatorDefinition(l, _, _) |
463            RynaExpr::BinaryOperatorDefinition(l, _, _, _) |
464            RynaExpr::NaryOperatorDefinition(l, _, _, _) |
465            RynaExpr::ClassDefinition(l, _, _, _, _, _, _) |
466            RynaExpr::InterfaceDefinition(l, _, _, _, _, _, _, _) |
467            RynaExpr::InterfaceImplementation(l, _, _, _, _) |
468            RynaExpr::PrefixOperationDefinition(l, _, _, _, _, _, _, _) |
469            RynaExpr::PostfixOperationDefinition(l, _, _, _, _, _, _, _) |
470            RynaExpr::BinaryOperationDefinition(l, _, _, _, _, _, _, _) |
471            RynaExpr::NaryOperationDefinition(l, _, _, _, _, _, _, _) |
472            RynaExpr::If(l, _, _, _, _) |
473            RynaExpr::Break(l) |
474            RynaExpr::Continue(l) |
475            RynaExpr::While(l, _, _) |
476            RynaExpr::For(l, _, _, _) |
477            RynaExpr::Return(l, _) => Err(RynaError::compiler_error(
478                "Expression cannot be evaluated to a type".into(), 
479                l, vec!()
480            ))
481        };
482    }
483}