parametrized_function/
parametrized_function.rs

1use rusttyc::{
2    types::Arity, Constructable, Partial, TcErr, TcKey, TcVar, TypeChecker, Variant as TcVariant, VarlessTypeChecker,
3};
4use std::cmp::max;
5use std::convert::TryInto;
6use std::hash::Hash;
7
8#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
9enum Type {
10    Int128,
11    FixedPointI64F64,
12    Bool,
13}
14
15#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
16enum Variant {
17    Any,
18    Fixed(u8, u8),
19    Integer(u8),
20    Numeric,
21    Bool,
22}
23
24#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
25struct Variable(usize);
26
27// ************ IMPLEMENTATION OF REQUIRED TRAITS ************ //
28
29impl TcVar for Variable {}
30
31impl TcVariant for Variant {
32    type Err = String;
33
34    fn top() -> Self {
35        Self::Any
36    }
37
38    fn meet(lhs: Partial<Self>, rhs: Partial<Self>) -> Result<Partial<Self>, Self::Err> {
39        use Variant::*;
40        assert_eq!(lhs.least_arity, 0, "spurious child");
41        assert_eq!(rhs.least_arity, 0, "spurious child");
42        let variant = match (lhs.variant, rhs.variant) {
43            (Any, other) | (other, Any) => Ok(other),
44            (Integer(l), Integer(r)) => Ok(Integer(max(r, l))),
45            (Fixed(li, lf), Fixed(ri, rf)) => Ok(Fixed(max(li, ri), max(lf, rf))),
46            (Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) if f == 0 => Ok(Integer(max(i, u))),
47            (Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) => Ok(Fixed(max(i, u), f)),
48            (Bool, Bool) => Ok(Bool),
49            (Bool, _) | (_, Bool) => Err("bool can only be combined with bool"),
50            (Numeric, Integer(w)) | (Integer(w), Numeric) => Ok(Integer(w)),
51            (Numeric, Fixed(i, f)) | (Fixed(i, f), Numeric) => Ok(Fixed(i, f)),
52            (Numeric, Numeric) => Ok(Numeric),
53        }?;
54        Ok(Partial { variant, least_arity: 0 })
55    }
56
57    fn arity(&self) -> Arity {
58        Arity::Fixed(0)
59    }
60}
61
62/// Represents a type in a parametrized function; either refers to a type parameter or is an abstract type.
63#[derive(Clone, Copy, Debug)]
64enum ParamType {
65    ParamId(usize),
66    Abstract(Variant),
67}
68
69#[derive(Clone, Debug)]
70enum Expression {
71    /// Conditional expression.
72    /// Requirement: `cond` more concrete than bool, `cons` and `alt` compatible.
73    /// Returns: meet of `cons` and `alt`.
74    Conditional {
75        cond: Box<Expression>,
76        cons: Box<Expression>,
77        alt: Box<Expression>,
78    },
79    /// Polymorphic function f<T_1: C_1, ..., T_n: C_n>(p_1: C_{n+1}, ..., p_m: C_{n+m}) -> T
80    /// Arguments:
81    ///     `name` for illustration,
82    ///     `param_constraints`: vector of parameters and their constraints if any, in this example vec![C_1, ..., C_n].
83    ///     `arg_types`: vector of type constraints on the arguments, in this example vec![C_{n+1}, ..., C_{n+m}]. May refer to parameters.
84    ///     `args`: vector containing all argument expressions.
85    ///     `returns`: return type. May refer to a parameter.
86    /// Requirement: each argument needs to comply with its constraint. If several arguments refer to the same parametric type, they need to be compliant.
87    /// Returns: `returns`.
88    PolyFn {
89        name: &'static str,
90        param_constraints: Vec<Option<Variant>>,
91        args: Vec<(ParamType, Expression)>,
92        returns: ParamType,
93    },
94    ConstInt(i128),
95    ConstBool(bool),
96    ConstFixed(i64, u64),
97}
98
99impl Constructable for Variant {
100    type Type = Type;
101
102    fn construct(&self, children: &[Self::Type]) -> Result<Self::Type, Self::Err> {
103        assert!(children.is_empty(), "spurious children");
104        use Variant::*;
105        match self {
106            Any => Err("Cannot reify `Any`.".to_string()),
107            Integer(w) if *w <= 128 => Ok(Type::Int128),
108            Integer(w) => Err(format!("Integer too wide, {}-bit not supported.", w)),
109            Fixed(i, f) if *i <= 64 && *f <= 64 => Ok(Type::FixedPointI64F64),
110            Fixed(i, f) => Err(format!("Fixed point number too wide, I{}F{} not supported.", i, f)),
111            Numeric => {
112                Err("Cannot reify a numeric value. Either define a default (int/fixed) or restrict type.".to_string())
113            }
114            Bool => Ok(Type::Bool),
115        }
116    }
117}
118
119/// This function traverses the expression tree.
120/// It creates keys on the fly.  This is not possible for many kinds of type systems, in which case the functions
121/// requires a context with a mapping of e.g. Variable -> Key.  The context can be built during a first pass over the
122/// tree.
123fn tc_expr(tc: &mut VarlessTypeChecker<Variant>, expr: &Expression) -> Result<TcKey, TcErr<Variant>> {
124    use Expression::*;
125    let key_result = tc.new_term_key(); // will be returned
126    match expr {
127        ConstInt(c) => {
128            let width = (128 - c.leading_zeros()).try_into().unwrap();
129            tc.impose(key_result.concretizes_explicit(Variant::Integer(width)))?;
130        }
131        ConstFixed(i, f) => {
132            let int_width = (64 - i.leading_zeros()).try_into().unwrap();
133            let frac_width = (64 - f.leading_zeros()).try_into().unwrap();
134            tc.impose(key_result.concretizes_explicit(Variant::Fixed(int_width, frac_width)))?;
135        }
136        ConstBool(_) => tc.impose(key_result.concretizes_explicit(Variant::Bool))?,
137        Conditional { cond, cons, alt } => {
138            let key_cond = tc_expr(tc, cond)?;
139            let key_cons = tc_expr(tc, cons)?;
140            let key_alt = tc_expr(tc, alt)?;
141            tc.impose(key_cond.concretizes_explicit(Variant::Bool))?;
142            tc.impose(key_result.is_meet_of(key_cons, key_alt))?;
143        }
144        PolyFn { name: _, param_constraints, args, returns } => {
145            // Note: The following line cannot be replaced by `vec![param_constraints.len(); tc.new_key()]` as this
146            // would copy the keys rather than creating new ones.
147            let params: Vec<(Option<Variant>, TcKey)> =
148                param_constraints.iter().map(|p| (*p, tc.new_term_key())).collect();
149            &params;
150            for (arg_ty, arg_expr) in args {
151                let arg_key = tc_expr(tc, arg_expr)?;
152                match arg_ty {
153                    ParamType::ParamId(id) => {
154                        let (p_constr, p_key) = params[*id];
155                        // We need to enforce that the parameter is more concrete than the passed argument and that the
156                        // passed argument satisfies the constraints imposed on the parametric type.
157                        tc.impose(p_key.concretizes(arg_key))?;
158                        if let Some(c) = p_constr {
159                            tc.impose(arg_key.concretizes_explicit(c))?;
160                        }
161                    }
162                    ParamType::Abstract(at) => tc.impose(arg_key.concretizes_explicit(*at))?,
163                };
164            }
165            match returns {
166                ParamType::Abstract(at) => tc.impose(key_result.concretizes_explicit(*at))?,
167                ParamType::ParamId(id) => {
168                    let (constr, key) = params[*id];
169                    if let Some(c) = constr {
170                        tc.impose(key_result.concretizes_explicit(c))?;
171                    }
172                    tc.impose(key_result.equate_with(key))?;
173                }
174            }
175        }
176    }
177    Ok(key_result)
178}
179
180fn main() {
181    // Build an expression to type-check.
182    let expr = build_complex_expression_type_checks();
183    // Create an empty type checker.
184    let mut tc: VarlessTypeChecker<Variant> = TypeChecker::without_vars();
185    // Type check the expression.
186    let res = tc_expr(&mut tc, &expr).and_then(|key| tc.type_check().map(|tt| (key, tt)));
187    match res {
188        Ok((key, tt)) => {
189            let res_type = tt[&key];
190            // Expression `if true then 2.7^3 + 4.3 else 3` should yield type Fixed(3, 3) because the addition requires a
191            // Fixed(2,3) and a Fixed(3,3), which results in a Fixed(3, 3).
192            // Constructing an actual type yields Type::FixedPointI64F64.
193            assert_eq!(res_type, Type::FixedPointI64F64);
194        }
195        Err(_) => panic!("Unexpected type error!"),
196    }
197}
198
199fn build_complex_expression_type_checks() -> Expression {
200    use Expression::*;
201    // Build expression `if true then 2.7^3 + 4.3 else 3`
202    let const27 = ConstFixed(2, 7);
203    let const3 = ConstInt(3);
204    let const43 = ConstFixed(4, 3);
205    let const_true = ConstBool(true);
206    let exponentiation = PolyFn {
207        name: "exponentiation", // Signature: +<T: Numeric>(T, Int<1>) -> T
208        param_constraints: vec![Some(Variant::Numeric)],
209        args: vec![(ParamType::ParamId(0), const27), (ParamType::Abstract(Variant::Integer(1)), const3.clone())],
210        returns: ParamType::ParamId(0),
211    };
212    let addition = create_addition(exponentiation, const43);
213    Conditional { cond: Box::new(const_true), cons: Box::new(addition), alt: Box::new(const3) }
214}
215
216fn create_addition(lhs: Expression, rhs: Expression) -> Expression {
217    Expression::PolyFn {
218        name: "addition", // Signature: +<T: Numeric>(T, T) -> T
219        param_constraints: vec![Some(Variant::Numeric)],
220        args: vec![(ParamType::ParamId(0), lhs), (ParamType::ParamId(0), rhs)],
221        returns: ParamType::ParamId(0),
222    }
223}