Skip to main content

wave_compiler/mir/
types.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! MIR type system, lowered from HIR types.
5//!
6//! MIR types are a simpler representation that directly maps to
7//! WAVE register types and memory operations.
8
9use crate::hir;
10
11/// MIR type, a lowered version of HIR types.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum MirType {
14    /// 32-bit integer (signed or unsigned).
15    I32,
16    /// 32-bit float.
17    F32,
18    /// 16-bit float.
19    F16,
20    /// 64-bit float.
21    F64,
22    /// Boolean predicate.
23    Bool,
24    /// Pointer (32-bit address).
25    Ptr,
26}
27
28impl MirType {
29    /// Size in bytes.
30    #[must_use]
31    pub fn size_bytes(self) -> u32 {
32        match self {
33            Self::I32 | Self::F32 | Self::Ptr | Self::Bool => 4,
34            Self::F16 => 2,
35            Self::F64 => 8,
36        }
37    }
38
39    /// Returns true if this is a floating-point type.
40    #[must_use]
41    pub fn is_float(self) -> bool {
42        matches!(self, Self::F32 | Self::F16 | Self::F64)
43    }
44
45    /// Returns true if this is an integer type.
46    #[must_use]
47    pub fn is_integer(self) -> bool {
48        matches!(self, Self::I32 | Self::Ptr)
49    }
50}
51
52impl std::fmt::Display for MirType {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            Self::I32 => write!(f, "i32"),
56            Self::F32 => write!(f, "f32"),
57            Self::F16 => write!(f, "f16"),
58            Self::F64 => write!(f, "f64"),
59            Self::Bool => write!(f, "bool"),
60            Self::Ptr => write!(f, "ptr"),
61        }
62    }
63}
64
65/// Convert an HIR type to a MIR type.
66#[must_use]
67pub fn lower_type(ty: &hir::Type) -> MirType {
68    match ty {
69        hir::Type::U32 | hir::Type::I32 | hir::Type::Void => MirType::I32,
70        hir::Type::F32 => MirType::F32,
71        hir::Type::F16 => MirType::F16,
72        hir::Type::F64 => MirType::F64,
73        hir::Type::Bool => MirType::Bool,
74        hir::Type::Ptr(_) | hir::Type::Array(_, _) => MirType::Ptr,
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::hir::types::AddressSpace;
82
83    #[test]
84    fn test_mir_type_sizes() {
85        assert_eq!(MirType::I32.size_bytes(), 4);
86        assert_eq!(MirType::F32.size_bytes(), 4);
87        assert_eq!(MirType::F16.size_bytes(), 2);
88        assert_eq!(MirType::F64.size_bytes(), 8);
89        assert_eq!(MirType::Bool.size_bytes(), 4);
90        assert_eq!(MirType::Ptr.size_bytes(), 4);
91    }
92
93    #[test]
94    fn test_lower_type() {
95        assert_eq!(lower_type(&hir::Type::U32), MirType::I32);
96        assert_eq!(lower_type(&hir::Type::I32), MirType::I32);
97        assert_eq!(lower_type(&hir::Type::F32), MirType::F32);
98        assert_eq!(lower_type(&hir::Type::F16), MirType::F16);
99        assert_eq!(lower_type(&hir::Type::F64), MirType::F64);
100        assert_eq!(lower_type(&hir::Type::Bool), MirType::Bool);
101        assert_eq!(
102            lower_type(&hir::Type::Ptr(AddressSpace::Device)),
103            MirType::Ptr
104        );
105    }
106
107    #[test]
108    fn test_type_classification() {
109        assert!(MirType::F32.is_float());
110        assert!(MirType::F64.is_float());
111        assert!(!MirType::I32.is_float());
112        assert!(MirType::I32.is_integer());
113        assert!(MirType::Ptr.is_integer());
114        assert!(!MirType::F32.is_integer());
115    }
116}