spade_mir/
eval.rs

1use itertools::Itertools;
2use num::range;
3use num::BigInt;
4use num::BigUint;
5use num::ToPrimitive;
6use num::Zero;
7use rustc_hash::FxHashMap as HashMap;
8use spade_common::num_ext::InfallibleToBigUint;
9
10use crate::{enum_util, types::Type, Binding, Operator, Statement, ValueName};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum Value {
14    Bit(bool),
15    Int {
16        size: BigUint,
17        val: BigInt,
18    },
19    UInt {
20        size: BigUint,
21        val: BigUint,
22    },
23    // Concatenated values are placed left-to-right, i.e. `Concat(vec![0b11,0b00])` is
24    // 0b1100
25    Concat(Vec<Value>),
26    /// A number of undefined bits
27    Undef(BigUint),
28}
29
30impl Value {
31    pub fn assume_int(&self) -> BigInt {
32        match self {
33            Value::Int { val, .. } => val.clone(),
34            other => panic!("Assumed value to be int, was {other:?}"),
35        }
36    }
37    pub fn assume_uint(&self) -> BigUint {
38        match self {
39            Value::UInt { val, .. } => val.clone(),
40            other => panic!("Assumed value to be int, was {other:?}"),
41        }
42    }
43
44    pub fn as_string(&self) -> String {
45        match self {
46            Value::Bit(val) => format!("{}", if *val { 1 } else { 0 }),
47            Value::Int { size, val } => {
48                if *val >= 0i64.into() {
49                    let val_str = format!("{val:b}");
50                    let needed_0s: BigUint = size - val_str.len();
51                    let extra_0 = range(0u64.into(), needed_0s).map(|_| "0").join("");
52                    assert!(*size >= val_str.len().into());
53                    format!("{extra_0}{val_str}")
54                } else {
55                    // Negative numbers in 2's complement are represented by all leadingdigits
56                    // being 1. That is a bit difficult to achieve when our numbers have infinite
57                    // bits (as is the case for BigInt). To fix that, we mask out the bits we do
58                    // want which gives a positive number with the correct binary representation.
59                    // https://stackoverflow.com/questions/12946116/twos-complement-binary-in-python
60                    let size_usize = size.to_usize().unwrap_or_else(|| {
61                        panic!("Variable size {size} is too large to fit a 'usize'")
62                    });
63                    let mask = (BigInt::from(1) << size_usize) - 1;
64                    format!("{:b}", val & mask)
65                }
66            }
67            Value::UInt { size, val } => {
68                let val_str = format!("{val:b}");
69                let needed_0s: BigUint = size - val_str.len();
70                let extra_0 = range(0u64.into(), needed_0s).map(|_| "0").join("");
71
72                assert!(*size >= val_str.len().into());
73
74                format!("{extra_0}{val_str}")
75            }
76            Value::Concat(inner) => inner.iter().map(|i| i.as_string()).join(""),
77            Value::Undef(size) => range(0u64.into(), size.clone()).map(|_| "X").join(""),
78        }
79    }
80
81    pub fn width(&self) -> BigUint {
82        match self {
83            Value::Bit(_) => 1u32.to_biguint(),
84            Value::Int { size, val: _ } => size.clone(),
85            Value::UInt { size, val: _ } => size.clone(),
86            Value::Concat(inner) => inner.iter().map(|i| i.width()).sum(),
87            Value::Undef(size) => size.clone(),
88        }
89    }
90
91    /// Computes the value as a 64. If the type this value represents is wider than 64 bits,
92    /// the behaviour is undefined. If the value is value::Undef, 0 is returned
93    pub fn as_u64(&self) -> u64 {
94        match self {
95            Value::Bit(val) => {
96                if *val {
97                    1
98                } else {
99                    0
100                }
101            }
102            Value::Int { size, val } => {
103                if *val >= 0i64.into() {
104                    val.to_u64().unwrap()
105                } else {
106                    // Negative numbers in 2's complement are represented by all leadingdigits
107                    // being 1. That is a bit difficult to achieve when our numbers have infinite
108                    // bits (as is the case for BigInt). To fix that, we mask out the bits we do
109                    // want which gives a positive number with the correct binary representation.
110                    // https://stackoverflow.com/questions/12946116/twos-complement-binary-in-python
111                    let size_usize = size.to_usize().unwrap_or_else(|| {
112                        panic!("Variable size {size} is too large to fit in a 'usize'")
113                    });
114                    let mask: BigInt = (BigInt::from(1) << size_usize) - 1;
115                    (val & mask).to_u64().unwrap()
116                }
117            }
118            Value::UInt { size: _, val } => val.to_u64().unwrap(),
119            Value::Concat(inner) => {
120                let mut current = 0;
121
122                for next in inner {
123                    current = (current << next.width().to_u64().unwrap()) + next.as_u64()
124                }
125                current
126            }
127            Value::Undef(_) => 0,
128        }
129    }
130
131    pub fn as_u32_chunks(&self) -> BigUint {
132        match self {
133            Value::Bit(val) => {
134                if *val {
135                    1u32.to_biguint()
136                } else {
137                    0u32.to_biguint()
138                }
139            }
140            Value::Int { size, val } => {
141                if *val >= 0i64.into() {
142                    val.to_biguint().unwrap()
143                } else {
144                    // Negative numbers in 2's complement are represented by all leadingdigits
145                    // being 1. That is a bit difficult to achieve when our numbers have infinite
146                    // bits (as is the case for BigInt). To fix that, we mask out the bits we do
147                    // want which gives a positive number with the correct binary representation.
148                    // https://stackoverflow.com/questions/12946116/twos-complement-binary-in-python
149                    let size_usize = size.to_usize().unwrap_or_else(|| {
150                        panic!("Variable size {size} is too large to fit in a usize")
151                    });
152                    let mask: BigInt = (BigInt::from(1) << size_usize) - 1;
153                    (val & mask).to_biguint().unwrap()
154                }
155            }
156            Value::UInt { size: _, val } => val.clone(),
157            Value::Concat(inner) => {
158                let mut current = 0u32.to_biguint();
159
160                for next in inner {
161                    current = (current << next.width().to_u64().unwrap()) + next.as_u64()
162                }
163                current
164            }
165            Value::Undef(_) => 0u32.to_biguint(),
166        }
167    }
168}
169
170impl std::fmt::Display for Value {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        match self {
173            Value::Bit(val) => write!(f, "bit({})", if *val { "1" } else { "0" }),
174            Value::Int { size, val } => write!(f, "int<{size}>({val})"),
175            Value::UInt { size, val } => write!(f, "uint<{size}>({val})"),
176            Value::Concat(inner) => write!(
177                f,
178                "concat([{}])",
179                inner.iter().map(|v| format!("{v}")).join(", ")
180            ),
181            Value::Undef(_) => write!(f, "X"),
182        }
183    }
184}
185
186#[cfg(test)]
187impl Value {
188    fn uint(size: u64, val: u64) -> Self {
189        Self::UInt {
190            size: size.into(),
191            val: val.into(),
192        }
193    }
194
195    fn int(size: u64, val: i64) -> Self {
196        Self::Int {
197            size: size.into(),
198            val: val.into(),
199        }
200    }
201
202    fn undef(size: u64) -> Self {
203        Self::Undef(size.into())
204    }
205}
206
207/// Evaluates a list of statements, returning the value of the final statement in the
208/// list. Panics if the list of statements is empty
209pub fn eval_statements(statements: &[Statement]) -> Value {
210    let mut name_vals: HashMap<ValueName, Value> = HashMap::default();
211    let mut name_types: HashMap<ValueName, Type> = HashMap::default();
212
213    let mut last_value = None;
214    for stmt in statements {
215        let (n, v) = match stmt {
216            Statement::Binding(b) => {
217                let Binding {
218                    name,
219                    operator,
220                    operands: ops,
221                    ty,
222                    loc: _,
223                } = b;
224
225                name_types.insert(name.clone(), ty.clone());
226
227                let val = match operator {
228                    Operator::Add => Value::Int {
229                        size: ty.size(),
230                        val: name_vals[&ops[0]].assume_int() + name_vals[&ops[1]].assume_int(),
231                    },
232                    Operator::UnsignedAdd => Value::UInt {
233                        size: ty.size(),
234                        val: name_vals[&ops[0]].assume_uint() + name_vals[&ops[1]].assume_uint(),
235                    },
236                    Operator::Sub => Value::Int {
237                        size: ty.size(),
238                        val: name_vals[&ops[0]].assume_int() - name_vals[&ops[1]].assume_int(),
239                    },
240                    Operator::UnsignedSub => Value::UInt {
241                        size: ty.size(),
242                        val: name_vals[&ops[0]].assume_uint() - name_vals[&ops[1]].assume_uint(),
243                    },
244                    Operator::Mul => todo!(),
245                    Operator::UnsignedMul => todo!(),
246                    Operator::Div => todo!(),
247                    Operator::UnsignedDiv => todo!(),
248                    Operator::Mod => todo!(),
249                    Operator::UnsignedMod => todo!(),
250                    Operator::Eq => todo!(),
251                    Operator::NotEq => todo!(),
252                    Operator::Gt => todo!(),
253                    Operator::Lt => todo!(),
254                    Operator::Ge => todo!(),
255                    Operator::Le => todo!(),
256                    Operator::UnsignedGt => todo!(),
257                    Operator::UnsignedLt => todo!(),
258                    Operator::UnsignedGe => todo!(),
259                    Operator::UnsignedLe => todo!(),
260                    Operator::LeftShift => todo!(),
261                    Operator::RightShift => todo!(),
262                    Operator::ArithmeticRightShift => todo!(),
263                    Operator::LogicalAnd => todo!(),
264                    Operator::LogicalOr => todo!(),
265                    Operator::LogicalXor => todo!(),
266                    Operator::LogicalNot => todo!(),
267                    Operator::BitwiseAnd => todo!(),
268                    Operator::BitwiseOr => todo!(),
269                    Operator::BitwiseXor => todo!(),
270                    Operator::USub => Value::Int {
271                        size: ty.size(),
272                        val: -name_vals[&ops[0]].assume_int(),
273                    },
274                    Operator::Not => todo!(),
275                    Operator::BitwiseNot => todo!(),
276                    Operator::DivPow2 => todo!(),
277                    Operator::ReduceAnd => todo!(),
278                    Operator::ReduceOr => todo!(),
279                    Operator::ReduceXor => todo!(),
280                    Operator::SignExtend => todo!(),
281                    Operator::ZeroExtend => todo!(),
282                    Operator::Truncate => todo!(),
283                    Operator::Concat => todo!(),
284                    Operator::Select => todo!(),
285                    Operator::Match => todo!(),
286                    Operator::ConstructArray => {
287                        Value::Concat(ops.iter().rev().map(|op| name_vals[op].clone()).collect())
288                    }
289                    Operator::DeclClockedMemory { .. } => todo!(),
290                    Operator::IndexArray => todo!(),
291                    Operator::IndexMemory => todo!(),
292                    Operator::RangeIndexArray { .. } => todo!(),
293                    Operator::RangeIndexBits { .. } => todo!(),
294                    Operator::ConstructTuple => {
295                        Value::Concat(ops.iter().map(|op| name_vals[op].clone()).collect())
296                    }
297                    Operator::ConstructEnum { variant } => {
298                        let Type::Enum(options) = ty else {
299                            panic!("Attempted enum construction of non-enum");
300                        };
301
302                        let tag_size = BigUint::from(enum_util::tag_size(options.len()));
303
304                        let mut to_concat = if options.len() <= 1 {
305                            vec![]
306                        } else {
307                            vec![Value::UInt {
308                                size: tag_size.clone(),
309                                val: (*variant).into(),
310                            }]
311                        };
312                        to_concat.append(
313                            &mut ops
314                                .iter()
315                                .map(|op| {
316                                    let val = &name_vals[op];
317
318                                    val.clone()
319                                })
320                                .collect(),
321                        );
322                        let variant_member_size =
323                            ops.iter().map(|op| name_types[op].size()).sum::<BigUint>();
324                        let padding_size = ty.size() - tag_size - variant_member_size;
325                        if padding_size != BigUint::zero() {
326                            to_concat.push(Value::Undef(padding_size))
327                        }
328
329                        Value::Concat(to_concat)
330                    }
331                    Operator::IsEnumVariant { .. } => todo!(),
332                    Operator::EnumMember { .. } => todo!(),
333                    Operator::IndexTuple(_) => todo!(),
334                    Operator::ReadPort => todo!(),
335                    Operator::ReadWriteInOut => todo!(),
336                    Operator::FlipPort => todo!(),
337                    Operator::ReadMutWires => todo!(),
338                    Operator::Instance { .. } => todo!(),
339                    Operator::Alias => name_vals[&ops[0]].clone(),
340                    Operator::Nop => todo!(),
341                };
342
343                (name.clone(), val)
344            }
345            Statement::Register(_) => panic!("trying to evaluate a register"),
346            Statement::Constant(id, ty, val) => {
347                let val = match val {
348                    crate::ConstantValue::Int(i) => Value::Int {
349                        size: ty.size(),
350                        val: i.clone(),
351                    },
352                    crate::ConstantValue::Bool(v) => Value::Bit(*v),
353                    crate::ConstantValue::String(_) => todo!(),
354                    crate::ConstantValue::HighImp => todo!(),
355                };
356                let name = ValueName::Expr(*id);
357                name_types.insert(name.clone(), ty.clone());
358                (name, val)
359            }
360            Statement::Assert(_) => panic!("trying to evaluate an assert statement"),
361            Statement::Set { .. } => panic!("trying to evaluate a `set` statement"),
362            Statement::WalTrace { .. } => panic!("trying to evaluate a `wal_trace`"),
363            Statement::Error => panic!("Trying to evaluate an Error statement"),
364        };
365
366        name_vals.insert(n, v.clone());
367        last_value = Some(v);
368    }
369    last_value.expect("Trying to evaluate empty statement list")
370}
371
372#[cfg(test)]
373mod string_value_tests {
374    use super::*;
375
376    #[test]
377    fn positive_integer_works() {
378        let value = Value::int(8, 8);
379
380        let expected = "00001000";
381
382        assert_eq!(value.as_string(), expected)
383    }
384
385    #[test]
386    fn negative_integer_works() {
387        let value = Value::int(8, -8);
388
389        let expected = "11111000";
390
391        assert_eq!(value.as_string(), expected)
392    }
393
394    #[test]
395    fn minus_10_works() {
396        let value = Value::int(8, -10);
397
398        let expected = "11110110";
399
400        assert_eq!(value.as_string(), expected)
401    }
402
403    #[test]
404    fn zero_integer_works() {
405        let value = Value::int(8, 0);
406
407        let expected = "00000000";
408
409        assert_eq!(value.as_string(), expected)
410    }
411
412    #[test]
413    fn positive_uinteger_works() {
414        let value = Value::uint(8, 8);
415
416        let expected = "00001000";
417
418        assert_eq!(value.as_string(), expected)
419    }
420
421    #[test]
422    fn zero_uinteger_works() {
423        let value = Value::uint(8, 0);
424
425        let expected = "00000000";
426
427        assert_eq!(value.as_string(), expected)
428    }
429}
430
431#[cfg(test)]
432mod test {
433    use crate as spade_mir;
434    use crate::{statement, types::Type, ConstantValue};
435    use pretty_assertions::assert_eq;
436    use spade_common::num_ext::InfallibleToBigInt;
437
438    use super::*;
439
440    #[test]
441    fn addition_works() {
442        let mir = vec![
443            statement!(const 0; Type::int(16); ConstantValue::int(5)),
444            statement!(const 1; Type::int(16); ConstantValue::int(10)),
445            statement!(e(2); Type::int(16); Add; e(0), e(1)),
446        ];
447
448        let result = eval_statements(&mir);
449
450        assert_eq!(result, Value::int(16, 15));
451    }
452
453    #[test]
454    fn enum_construction_works() {
455        let enum_t = Type::Enum(vec![vec![], vec![Type::int(16)], vec![]]);
456
457        let mir = vec![
458            statement!(const 0; Type::int(16); ConstantValue::int(5)),
459            statement!(e(1); enum_t; ConstructEnum({variant: 1}); e(0)),
460        ];
461
462        let result = eval_statements(&mir);
463
464        assert_eq!(
465            result,
466            Value::Concat(vec![Value::uint(2, 1), Value::int(16, 5),])
467        )
468    }
469
470    #[test]
471    fn enum_construction_with_padding_works() {
472        let enum_t = Type::Enum(vec![vec![Type::int(2)], vec![Type::int(16)], vec![]]);
473
474        let mir = vec![
475            statement!(const 0; Type::int(3); ConstantValue::int(5)),
476            statement!(e(1); enum_t; ConstructEnum({variant: 1}); e(0)),
477        ];
478
479        let result = eval_statements(&mir);
480
481        assert_eq!(
482            result,
483            Value::Concat(vec![Value::uint(2, 1), Value::int(3, 5), Value::undef(13)])
484        )
485    }
486
487    #[test]
488    fn enum_construction_to_string_works() {
489        let enum_t = Type::Enum(vec![vec![Type::int(8)], vec![]]);
490
491        let mir = vec![
492            statement!(const 0; Type::int(8); ConstantValue::int(0b1010)),
493            statement!(e(1); enum_t; ConstructEnum({variant: 0}); e(0)),
494        ];
495
496        assert_eq!("000001010", eval_statements(&mir).as_string())
497    }
498
499    #[test]
500    fn as_u64_works_for_bits() {
501        assert_eq!(Value::Bit(false).as_u64(), 0);
502        assert_eq!(Value::Bit(true).as_u64(), 1);
503    }
504
505    #[test]
506    fn as_u32_chunks_works_for_bits() {
507        assert_eq!(Value::Bit(false).as_u32_chunks(), 0u32.to_biguint());
508        assert_eq!(Value::Bit(true).as_u32_chunks(), 1u32.to_biguint());
509    }
510
511    #[test]
512    fn as_u64_works_for_ints() {
513        assert_eq!(
514            Value::Int {
515                size: 64u64.to_biguint(),
516                val: i64::MAX.to_bigint()
517            }
518            .as_u64(),
519            i64::MAX as u64
520        );
521        assert_eq!(
522            Value::Int {
523                size: 64u64.to_biguint(),
524                val: i64::MIN.to_bigint()
525            }
526            .as_u64(),
527            0x8000_0000_0000_0000u64
528        );
529        assert_eq!(
530            Value::Int {
531                size: 64u64.to_biguint(),
532                val: -1.to_bigint()
533            }
534            .as_u64(),
535            0xffff_ffff_ffff_ffffu64
536        );
537
538        assert_eq!(
539            Value::Int {
540                size: 8u64.to_biguint(),
541                val: -1.to_bigint()
542            }
543            .as_u64(),
544            0xffu64
545        );
546        assert_eq!(
547            Value::Int {
548                size: 8u64.to_biguint(),
549                val: -128.to_bigint()
550            }
551            .as_u64(),
552            0x80u64
553        )
554    }
555
556    #[test]
557    fn as_u32_chunks_works_for_ints() {
558        assert_eq!(
559            Value::Int {
560                size: 64u64.to_biguint(),
561                val: i64::MAX.to_bigint()
562            }
563            .as_u32_chunks(),
564            num::bigint::ToBigUint::to_biguint(&i64::MAX).unwrap()
565        );
566        assert_eq!(
567            Value::Int {
568                size: 64u64.to_biguint(),
569                val: i64::MIN.to_bigint()
570            }
571            .as_u32_chunks(),
572            0x8000_0000_0000_0000u64.to_biguint()
573        );
574        assert_eq!(
575            Value::Int {
576                size: 64u64.to_biguint(),
577                val: -1.to_bigint()
578            }
579            .as_u32_chunks(),
580            0xffff_ffff_ffff_ffffu64.to_biguint()
581        );
582        assert_eq!(
583            Value::Int {
584                size: 8u64.to_biguint(),
585                val: -1.to_bigint()
586            }
587            .as_u32_chunks(),
588            0xffu64.to_biguint()
589        );
590        assert_eq!(
591            Value::Int {
592                size: 8u64.to_biguint(),
593                val: -128.to_bigint()
594            }
595            .as_u32_chunks(),
596            0x80u64.to_biguint()
597        )
598    }
599
600    macro_rules! test_conversion {
601        ($name:ident, $value:expr, $expected_u64:expr) => {
602            #[test]
603            fn $name() {
604                assert_eq!($value.as_u64(), $expected_u64);
605                assert_eq!($value.as_u32_chunks(), $expected_u64.to_biguint());
606            }
607        };
608    }
609
610    test_conversion! {
611        concat_works,
612        Value::Concat (
613            vec![
614                Value::Int {
615                    size: 8u64.to_biguint(),
616                    val: -1.to_bigint(),
617                },
618                Value::Bit(true),
619                Value::Undef(7u64.to_biguint()),
620                Value::UInt {
621                    size: 8u64.to_biguint(),
622                    val: 3u32.to_biguint()
623                },
624                Value::Bit(true),
625            ]
626        ),
627        0b1111_1111_1000_0000_0000_0011_1u64
628    }
629}