1use rusttyc::{
2 types::Arity, Constructable, Partial, TcErr, TcKey, TcVar, TypeChecker, Variant as TcVariant, VarlessTypeChecker,
3};
4use std::cmp::max;
5use std::convert::TryInto;
6use std::hash::Hash;
7
8#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
9enum Type {
10 Int128,
11 FixedPointI64F64,
12 Bool,
13}
14
15#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
16enum Variant {
17 Any,
18 Fixed(u8, u8),
19 Integer(u8),
20 Numeric,
21 Bool,
22}
23
24#[derive(Debug, Clone, Copy, PartialOrd, PartialEq, Ord, Eq, Hash)]
25struct Variable(usize);
26
27impl TcVar for Variable {}
30
31impl TcVariant for Variant {
32 type Err = String;
33
34 fn top() -> Self {
35 Self::Any
36 }
37
38 fn meet(lhs: Partial<Self>, rhs: Partial<Self>) -> Result<Partial<Self>, Self::Err> {
39 use Variant::*;
40 assert_eq!(lhs.least_arity, 0, "spurious child");
41 assert_eq!(rhs.least_arity, 0, "spurious child");
42 let variant = match (lhs.variant, rhs.variant) {
43 (Any, other) | (other, Any) => Ok(other),
44 (Integer(l), Integer(r)) => Ok(Integer(max(r, l))),
45 (Fixed(li, lf), Fixed(ri, rf)) => Ok(Fixed(max(li, ri), max(lf, rf))),
46 (Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) if f == 0 => Ok(Integer(max(i, u))),
47 (Fixed(i, f), Integer(u)) | (Integer(u), Fixed(i, f)) => Ok(Fixed(max(i, u), f)),
48 (Bool, Bool) => Ok(Bool),
49 (Bool, _) | (_, Bool) => Err("bool can only be combined with bool"),
50 (Numeric, Integer(w)) | (Integer(w), Numeric) => Ok(Integer(w)),
51 (Numeric, Fixed(i, f)) | (Fixed(i, f), Numeric) => Ok(Fixed(i, f)),
52 (Numeric, Numeric) => Ok(Numeric),
53 }?;
54 Ok(Partial { variant, least_arity: 0 })
55 }
56
57 fn arity(&self) -> Arity {
58 Arity::Fixed(0)
59 }
60}
61
62#[derive(Clone, Copy, Debug)]
64enum ParamType {
65 ParamId(usize),
66 Abstract(Variant),
67}
68
69#[derive(Clone, Debug)]
70enum Expression {
71 Conditional {
75 cond: Box<Expression>,
76 cons: Box<Expression>,
77 alt: Box<Expression>,
78 },
79 PolyFn {
89 name: &'static str,
90 param_constraints: Vec<Option<Variant>>,
91 args: Vec<(ParamType, Expression)>,
92 returns: ParamType,
93 },
94 ConstInt(i128),
95 ConstBool(bool),
96 ConstFixed(i64, u64),
97}
98
99impl Constructable for Variant {
100 type Type = Type;
101
102 fn construct(&self, children: &[Self::Type]) -> Result<Self::Type, Self::Err> {
103 assert!(children.is_empty(), "spurious children");
104 use Variant::*;
105 match self {
106 Any => Err("Cannot reify `Any`.".to_string()),
107 Integer(w) if *w <= 128 => Ok(Type::Int128),
108 Integer(w) => Err(format!("Integer too wide, {}-bit not supported.", w)),
109 Fixed(i, f) if *i <= 64 && *f <= 64 => Ok(Type::FixedPointI64F64),
110 Fixed(i, f) => Err(format!("Fixed point number too wide, I{}F{} not supported.", i, f)),
111 Numeric => {
112 Err("Cannot reify a numeric value. Either define a default (int/fixed) or restrict type.".to_string())
113 }
114 Bool => Ok(Type::Bool),
115 }
116 }
117}
118
119fn tc_expr(tc: &mut VarlessTypeChecker<Variant>, expr: &Expression) -> Result<TcKey, TcErr<Variant>> {
124 use Expression::*;
125 let key_result = tc.new_term_key(); match expr {
127 ConstInt(c) => {
128 let width = (128 - c.leading_zeros()).try_into().unwrap();
129 tc.impose(key_result.concretizes_explicit(Variant::Integer(width)))?;
130 }
131 ConstFixed(i, f) => {
132 let int_width = (64 - i.leading_zeros()).try_into().unwrap();
133 let frac_width = (64 - f.leading_zeros()).try_into().unwrap();
134 tc.impose(key_result.concretizes_explicit(Variant::Fixed(int_width, frac_width)))?;
135 }
136 ConstBool(_) => tc.impose(key_result.concretizes_explicit(Variant::Bool))?,
137 Conditional { cond, cons, alt } => {
138 let key_cond = tc_expr(tc, cond)?;
139 let key_cons = tc_expr(tc, cons)?;
140 let key_alt = tc_expr(tc, alt)?;
141 tc.impose(key_cond.concretizes_explicit(Variant::Bool))?;
142 tc.impose(key_result.is_meet_of(key_cons, key_alt))?;
143 }
144 PolyFn { name: _, param_constraints, args, returns } => {
145 let params: Vec<(Option<Variant>, TcKey)> =
148 param_constraints.iter().map(|p| (*p, tc.new_term_key())).collect();
149 ¶ms;
150 for (arg_ty, arg_expr) in args {
151 let arg_key = tc_expr(tc, arg_expr)?;
152 match arg_ty {
153 ParamType::ParamId(id) => {
154 let (p_constr, p_key) = params[*id];
155 tc.impose(p_key.concretizes(arg_key))?;
158 if let Some(c) = p_constr {
159 tc.impose(arg_key.concretizes_explicit(c))?;
160 }
161 }
162 ParamType::Abstract(at) => tc.impose(arg_key.concretizes_explicit(*at))?,
163 };
164 }
165 match returns {
166 ParamType::Abstract(at) => tc.impose(key_result.concretizes_explicit(*at))?,
167 ParamType::ParamId(id) => {
168 let (constr, key) = params[*id];
169 if let Some(c) = constr {
170 tc.impose(key_result.concretizes_explicit(c))?;
171 }
172 tc.impose(key_result.equate_with(key))?;
173 }
174 }
175 }
176 }
177 Ok(key_result)
178}
179
180fn main() {
181 let expr = build_complex_expression_type_checks();
183 let mut tc: VarlessTypeChecker<Variant> = TypeChecker::without_vars();
185 let res = tc_expr(&mut tc, &expr).and_then(|key| tc.type_check().map(|tt| (key, tt)));
187 match res {
188 Ok((key, tt)) => {
189 let res_type = tt[&key];
190 assert_eq!(res_type, Type::FixedPointI64F64);
194 }
195 Err(_) => panic!("Unexpected type error!"),
196 }
197}
198
199fn build_complex_expression_type_checks() -> Expression {
200 use Expression::*;
201 let const27 = ConstFixed(2, 7);
203 let const3 = ConstInt(3);
204 let const43 = ConstFixed(4, 3);
205 let const_true = ConstBool(true);
206 let exponentiation = PolyFn {
207 name: "exponentiation", param_constraints: vec![Some(Variant::Numeric)],
209 args: vec![(ParamType::ParamId(0), const27), (ParamType::Abstract(Variant::Integer(1)), const3.clone())],
210 returns: ParamType::ParamId(0),
211 };
212 let addition = create_addition(exponentiation, const43);
213 Conditional { cond: Box::new(const_true), cons: Box::new(addition), alt: Box::new(const3) }
214}
215
216fn create_addition(lhs: Expression, rhs: Expression) -> Expression {
217 Expression::PolyFn {
218 name: "addition", param_constraints: vec![Some(Variant::Numeric)],
220 args: vec![(ParamType::ParamId(0), lhs), (ParamType::ParamId(0), rhs)],
221 returns: ParamType::ParamId(0),
222 }
223}