Skip to main content

spade_hir/
expression.rs

1use std::borrow::BorrowMut;
2
3use crate::{ConstGenericWithId, Pattern, TypeExpression, TypeParam, UnitKind};
4
5use super::{Block, NameID};
6use num::{BigInt, BigUint};
7use serde::{Deserialize, Serialize};
8use spade_common::{
9    id_tracker::ExprID,
10    location_info::Loc,
11    name::{Identifier, Path},
12    num_ext::InfallibleToBigInt,
13};
14
15#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
16pub enum BinaryOperator {
17    Add,
18    Sub,
19    Mul,
20    Div,
21    Mod,
22    Eq,
23    NotEq,
24    Gt,
25    Lt,
26    Ge,
27    Le,
28    LeftShift,
29    RightShift,
30    ArithmeticRightShift,
31    LogicalAnd,
32    LogicalOr,
33    LogicalXor,
34    BitwiseOr,
35    BitwiseAnd,
36    BitwiseXor,
37}
38
39impl std::fmt::Display for BinaryOperator {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            BinaryOperator::Add => write!(f, "+"),
43            BinaryOperator::Sub => write!(f, "-"),
44            BinaryOperator::Mul => write!(f, "*"),
45            BinaryOperator::Div => write!(f, "/"),
46            BinaryOperator::Mod => write!(f, "%"),
47            BinaryOperator::Eq => write!(f, "=="),
48            BinaryOperator::NotEq => write!(f, "!="),
49            BinaryOperator::Gt => write!(f, ">"),
50            BinaryOperator::Lt => write!(f, "<"),
51            BinaryOperator::Ge => write!(f, ">="),
52            BinaryOperator::Le => write!(f, "<="),
53            BinaryOperator::LeftShift => write!(f, ">>"),
54            BinaryOperator::RightShift => write!(f, "<<"),
55            BinaryOperator::ArithmeticRightShift => write!(f, ">>>"),
56            BinaryOperator::LogicalAnd => write!(f, "&&"),
57            BinaryOperator::LogicalOr => write!(f, "||"),
58            BinaryOperator::LogicalXor => write!(f, "^^"),
59            BinaryOperator::BitwiseOr => write!(f, "|"),
60            BinaryOperator::BitwiseAnd => write!(f, "&"),
61            BinaryOperator::BitwiseXor => write!(f, "^"),
62        }
63    }
64}
65
66#[derive(Clone, Copy, PartialEq, Debug, Serialize, Deserialize)]
67pub enum UnaryOperator {
68    Sub,
69    Not,
70    BitwiseNot,
71    Dereference,
72    Reference,
73}
74
75impl std::fmt::Display for UnaryOperator {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            UnaryOperator::Sub => write!(f, "-"),
79            UnaryOperator::Not => write!(f, "!"),
80            UnaryOperator::BitwiseNot => write!(f, "~"),
81            UnaryOperator::Dereference => write!(f, "*"),
82            UnaryOperator::Reference => write!(f, "&"),
83        }
84    }
85}
86
87// Named arguments are used for both type parameters in turbofishes and in argument lists. T is the
88// right hand side of a binding, i.e. an expression in an argument list
89#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
90pub enum NamedArgument<T> {
91    /// Binds the arguent named LHS in the outer scope to the expression
92    Full(Loc<Identifier>, Loc<T>),
93    /// Binds a local variable to an argument with the same name
94    Short(Loc<Identifier>, Loc<T>),
95}
96
97/// Specifies how an argument is bound. Mainly used for error reporting without
98/// code duplication
99#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
100pub enum ArgumentKind {
101    Positional,
102    Named,
103    ShortNamed,
104}
105
106#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
107pub enum ArgumentList<T> {
108    Named(Vec<NamedArgument<T>>),
109    Positional(Vec<Loc<T>>),
110}
111
112impl<T> ArgumentList<T> {
113    pub fn empty() -> Self {
114        Self::Positional(vec![])
115    }
116
117    pub fn expressions(&self) -> Vec<&Loc<T>> {
118        match self {
119            ArgumentList::Named(n) => n
120                .iter()
121                .map(|arg| match &arg {
122                    NamedArgument::Full(_, expr) => expr,
123                    NamedArgument::Short(_, expr) => expr,
124                })
125                .collect(),
126            ArgumentList::Positional(arg) => arg.iter().collect(),
127        }
128    }
129    pub fn expressions_mut(&mut self) -> Vec<&mut Loc<T>> {
130        match self {
131            ArgumentList::Named(n) => n
132                .iter_mut()
133                .map(|arg| match arg {
134                    NamedArgument::Full(_, expr) => expr,
135                    NamedArgument::Short(_, expr) => expr,
136                })
137                .collect(),
138            ArgumentList::Positional(arg) => arg.iter_mut().collect(),
139        }
140    }
141}
142
143#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
144pub struct Argument<T> {
145    pub target: Loc<Identifier>,
146    pub value: Loc<T>,
147    pub kind: ArgumentKind,
148}
149
150// FIXME: Migrate entity, pipeline and fn instantiation to this
151#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
152pub enum CallKind {
153    Function,
154    Entity(Loc<()>),
155    Pipeline {
156        inst_loc: Loc<()>,
157        depth: Loc<TypeExpression>,
158        /// An expression ID for which the type inferer will infer the depth of the instantiated
159        /// pipeline, i.e. inst(<this>)
160        depth_typeexpr_id: ExprID,
161    },
162}
163
164#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
165pub enum TriLiteral {
166    Low,
167    High,
168    HighImp,
169}
170
171#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
172pub enum IntLiteralKind {
173    Unsized,
174    Signed(BigUint),
175    Unsigned(BigUint),
176}
177
178#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
179pub enum PipelineRefKind {
180    Absolute(Loc<NameID>),
181    Relative(Loc<TypeExpression>),
182}
183
184#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
185pub struct OuterLambdaParam {
186    pub name_in_lambda: NameID,
187    pub name_in_body: Loc<NameID>,
188}
189
190#[derive(PartialEq, Eq, Debug, Clone, Copy, Serialize, Deserialize)]
191pub enum Safety {
192    Default,
193    Unsafe,
194}
195
196#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
197pub struct LambdaTypeParams {
198    /// The parameters that will contain the types of each argument
199    pub arg: Vec<Loc<TypeParam>>,
200    pub output: Loc<TypeParam>,
201    /// The parameters that will contain the types of the captured variables
202    pub captures: Vec<Loc<TypeParam>>,
203    /// The type parameters that are inherited from the unit in which the lambda is defined
204    pub outer: Vec<Loc<TypeParam>>,
205}
206
207impl LambdaTypeParams {
208    pub fn all(&self) -> impl Iterator<Item = &Loc<TypeParam>> {
209        let Self {
210            arg,
211            output,
212            captures,
213            outer,
214        } = self;
215        arg.iter().chain(Some(output)).chain(captures).chain(outer)
216    }
217}
218
219#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
220pub enum ExprKind {
221    Error,
222    Identifier(NameID),
223    IntLiteral(BigInt, IntLiteralKind),
224    BoolLiteral(bool),
225    TriLiteral(TriLiteral),
226    TypeLevelInteger(NameID),
227    CreatePorts,
228    TupleLiteral(Vec<Loc<Expression>>),
229    ArrayLiteral(Vec<Loc<Expression>>),
230    ArrayShorthandLiteral(Box<Loc<Expression>>, Loc<ConstGenericWithId>),
231    Index(Box<Loc<Expression>>, Box<Loc<Expression>>),
232    RangeIndex {
233        target: Box<Loc<Expression>>,
234        start: Loc<ConstGenericWithId>,
235        end: Loc<ConstGenericWithId>,
236    },
237    TupleIndex(Box<Loc<Expression>>, Loc<u128>),
238    FieldAccess(Box<Loc<Expression>>, Loc<Identifier>),
239    MethodCall {
240        target: Box<Loc<Expression>>,
241        name: Loc<Identifier>,
242        args: Loc<ArgumentList<Expression>>,
243        call_kind: CallKind,
244        turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
245        safety: Safety,
246    },
247    Call {
248        kind: CallKind,
249        callee: Loc<NameID>,
250        args: Loc<ArgumentList<Expression>>,
251        turbofish: Option<Loc<ArgumentList<TypeExpression>>>,
252        safety: Safety,
253    },
254    BinaryOperator(
255        Box<Loc<Expression>>,
256        Loc<BinaryOperator>,
257        Box<Loc<Expression>>,
258    ),
259    UnaryOperator(Loc<UnaryOperator>, Box<Loc<Expression>>),
260    Match(
261        Box<Loc<Expression>>,
262        Vec<(Loc<Pattern>, Option<Loc<Expression>>, Loc<Expression>)>,
263    ),
264    Block(Box<Block>),
265    If(
266        Box<Loc<Expression>>,
267        Box<Loc<Expression>>,
268        Box<Loc<Expression>>,
269    ),
270    TypeLevelIf(
271        // FIXME: Having a random u64 is not great, let's make TypeExpressions always have associated ids
272        Loc<ConstGenericWithId>,
273        Box<Loc<Expression>>,
274        Box<Loc<Expression>>,
275    ),
276    PipelineRef {
277        stage: Loc<PipelineRefKind>,
278        name: Loc<NameID>,
279        declares_name: bool,
280        /// An expression ID which after typeinference will contain the absolute depth
281        /// of this referenced value
282        depth_typeexpr_id: ExprID,
283    },
284    LambdaDef {
285        unit_kind: Loc<UnitKind>,
286        /// The type that this lambda definition creates
287        lambda_type: NameID,
288        type_params: LambdaTypeParams,
289        outer_generic_params: Vec<OuterLambdaParam>,
290        /// The unit which is the `call` method on this lambda
291        lambda_unit: NameID,
292        arguments: Vec<Loc<Pattern>>,
293        body: Box<Loc<Expression>>,
294        clock: Option<Loc<NameID>>,
295        captures: Vec<(Loc<Identifier>, Loc<NameID>)>,
296    },
297    StageValid,
298    StageReady,
299    StaticUnreachable(Loc<String>),
300    // This is a special case expression which is never created in user code, but which can be used
301    // in type inference to create virtual expressions with specific IDs
302    Null,
303}
304
305impl ExprKind {
306    pub fn with_id(self, id: ExprID) -> Expression {
307        Expression { kind: self, id }
308    }
309
310    // FIXME: These really should be #[cfg(test)]'d away
311    pub fn idless(self) -> Expression {
312        Expression {
313            kind: self,
314            id: ExprID(0),
315        }
316    }
317
318    pub fn int_literal(val: i32) -> Self {
319        Self::IntLiteral(val.to_bigint(), IntLiteralKind::Unsized)
320    }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct Expression {
325    pub kind: ExprKind,
326    // This ID is used to associate types with the expression
327    pub id: ExprID,
328}
329
330impl Expression {
331    /// Create a new expression referencing an identifier with the specified
332    /// id and name
333    pub fn ident(expr_id: ExprID, name_id: u64, name: &str) -> Expression {
334        ExprKind::Identifier(NameID(name_id, Path::from_strs(&[name]))).with_id(expr_id)
335    }
336
337    /// Returns the block that is this expression if it is a block, an error if it is an Error node, and panics if the expression is not a block or error
338    pub fn assume_block(&self) -> std::result::Result<&Block, ()> {
339        if let ExprKind::Block(ref block) = self.kind {
340            Ok(block)
341        } else if let ExprKind::Error = self.kind {
342            Err(())
343        } else {
344            panic!("Expression is not a block")
345        }
346    }
347
348    /// Returns the block that is this expression. Panics if the expression is not a block
349    pub fn assume_block_mut(&mut self) -> &mut Block {
350        if let ExprKind::Block(block) = &mut self.kind {
351            block.borrow_mut()
352        } else {
353            panic!("Expression is not a block")
354        }
355    }
356}
357
358impl PartialEq for Expression {
359    fn eq(&self, other: &Self) -> bool {
360        self.kind == other.kind
361    }
362}