spirq_core/
constant.rs

1//! Constant and specialization constant representations.
2use std::convert::TryFrom;
3
4use half::f16;
5use ordered_float::OrderedFloat;
6
7use crate::{
8    error::{anyhow, Result},
9    ty::{ScalarType, Type},
10    var::SpecId,
11};
12
13/// Typed constant value.
14#[non_exhaustive]
15#[derive(PartialEq, Eq, Hash, Clone, Debug)]
16pub enum ConstantValue {
17    Typeless(Box<[u8]>),
18    Bool(bool),
19    S8(i8),
20    S16(i16),
21    S32(i32),
22    S64(i64),
23    U8(u8),
24    U16(u16),
25    U32(u32),
26    U64(u64),
27    F16(OrderedFloat<f16>),
28    F32(OrderedFloat<f32>),
29    F64(OrderedFloat<f64>),
30}
31impl From<&[u32]> for ConstantValue {
32    fn from(x: &[u32]) -> Self {
33        let bytes = x.iter().flat_map(|x| x.to_le_bytes()).collect();
34        ConstantValue::Typeless(bytes)
35    }
36}
37impl From<&[u8]> for ConstantValue {
38    fn from(x: &[u8]) -> Self {
39        let bytes = x.to_owned().into_boxed_slice();
40        ConstantValue::Typeless(bytes)
41    }
42}
43impl From<[u8; 4]> for ConstantValue {
44    fn from(x: [u8; 4]) -> Self {
45        ConstantValue::try_from(&x as &[u8]).unwrap()
46    }
47}
48impl From<[u8; 8]> for ConstantValue {
49    fn from(x: [u8; 8]) -> Self {
50        ConstantValue::try_from(&x as &[u8]).unwrap()
51    }
52}
53impl From<bool> for ConstantValue {
54    fn from(x: bool) -> Self {
55        Self::Bool(x)
56    }
57}
58impl From<u32> for ConstantValue {
59    fn from(x: u32) -> Self {
60        Self::U32(x)
61    }
62}
63impl From<i32> for ConstantValue {
64    fn from(x: i32) -> Self {
65        Self::S32(x)
66    }
67}
68impl From<f32> for ConstantValue {
69    fn from(x: f32) -> Self {
70        Self::F32(OrderedFloat(x))
71    }
72}
73impl ConstantValue {
74    pub fn to_typed(&self, ty: &Type) -> Result<Self> {
75        let x = match self {
76            Self::Typeless(x) => x,
77            _ => return Err(anyhow!("{self:?} is already typed")),
78        };
79
80        if let Some(scalar_ty) = ty.as_scalar() {
81            match scalar_ty {
82                ScalarType::Boolean => Ok(ConstantValue::Bool(x.iter().any(|x| x != &0))),
83                ScalarType::Integer {
84                    bits: 8,
85                    is_signed: true,
86                } if x.len() == 4 => {
87                    let x = i8::from_le_bytes([x[0]]);
88                    Ok(ConstantValue::S8(x))
89                }
90                ScalarType::Integer {
91                    bits: 8,
92                    is_signed: false,
93                } if x.len() == 4 => {
94                    let x = u8::from_le_bytes([x[0]]);
95                    Ok(ConstantValue::U8(x))
96                }
97                ScalarType::Integer {
98                    bits: 16,
99                    is_signed: true,
100                } if x.len() == 4 => {
101                    let x = i16::from_le_bytes([x[0], x[1]]);
102                    Ok(ConstantValue::S16(x))
103                }
104                ScalarType::Integer {
105                    bits: 16,
106                    is_signed: false,
107                } if x.len() == 4 => {
108                    let x = u16::from_le_bytes([x[0], x[1]]);
109                    Ok(ConstantValue::U16(x))
110                }
111                ScalarType::Integer {
112                    bits: 32,
113                    is_signed: true,
114                } if x.len() == 4 => {
115                    let x = i32::from_le_bytes([x[0], x[1], x[2], x[3]]);
116                    Ok(ConstantValue::S32(x))
117                }
118                ScalarType::Integer {
119                    bits: 32,
120                    is_signed: false,
121                } if x.len() == 4 => {
122                    let x = u32::from_le_bytes([x[0], x[1], x[2], x[3]]);
123                    Ok(ConstantValue::U32(x))
124                }
125                ScalarType::Integer {
126                    bits: 64,
127                    is_signed: true,
128                } if x.len() == 8 => {
129                    let x = i64::from_le_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
130                    Ok(ConstantValue::S64(x))
131                }
132                ScalarType::Integer {
133                    bits: 64,
134                    is_signed: false,
135                } if x.len() == 8 => {
136                    let x = u64::from_le_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
137                    Ok(ConstantValue::U64(x))
138                }
139                ScalarType::Float { bits: 16 } if x.len() == 4 => {
140                    let x = f16::from_le_bytes([x[0], x[1]]);
141                    Ok(ConstantValue::F16(OrderedFloat(x)))
142                }
143                ScalarType::Float { bits: 32 } if x.len() == 4 => {
144                    let x = f32::from_le_bytes([x[0], x[1], x[2], x[3]]);
145                    Ok(ConstantValue::F32(OrderedFloat(x)))
146                }
147                ScalarType::Float { bits: 64 } if x.len() == 8 => {
148                    let x = f64::from_le_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
149                    Ok(ConstantValue::F64(OrderedFloat(x)))
150                }
151                _ => Err(anyhow!(
152                    "cannot parse {:?} from {} bytes",
153                    scalar_ty,
154                    x.len()
155                )),
156            }
157        } else {
158            Err(anyhow!("cannot parse {:?} as a constant value", ty))
159        }
160    }
161
162    pub fn to_bool(&self) -> Option<bool> {
163        match self {
164            Self::Bool(x) => Some(*x),
165            _ => None,
166        }
167    }
168    pub fn to_s32(&self) -> Option<i32> {
169        match self {
170            Self::S32(x) => Some(*x),
171            _ => None,
172        }
173    }
174    pub fn to_u32(&self) -> Option<i32> {
175        match self {
176            Self::S32(x) => Some(*x),
177            _ => None,
178        }
179    }
180    pub fn to_f32(&self) -> Option<f32> {
181        match self {
182            Self::F32(x) => Some((*x).into()),
183            _ => None,
184        }
185    }
186
187    pub fn to_typeless(&self) -> Option<Box<[u8]>> {
188        match self {
189            Self::Typeless(x) => Some(x.clone()),
190            Self::S8(x) => Some(Box::new(x.to_le_bytes())),
191            Self::S16(x) => Some(Box::new(x.to_le_bytes())),
192            Self::S32(x) => Some(Box::new(x.to_le_bytes())),
193            Self::S64(x) => Some(Box::new(x.to_le_bytes())),
194            Self::U8(x) => Some(Box::new(x.to_le_bytes())),
195            Self::U16(x) => Some(Box::new(x.to_le_bytes())),
196            Self::U32(x) => Some(Box::new(x.to_le_bytes())),
197            Self::U64(x) => Some(Box::new(x.to_le_bytes())),
198            Self::F16(x) => Some(Box::new(x.to_le_bytes())),
199            Self::F32(x) => Some(Box::new(x.to_le_bytes())),
200            Self::F64(x) => Some(Box::new(x.to_le_bytes())),
201            Self::Bool(x) => Some(Box::new([*x as u8])),
202        }
203    }
204}
205
206/// Constant or specialization constant record.
207#[derive(PartialEq, Eq, Hash, Clone, Debug)]
208pub struct Constant {
209    pub name: Option<String>,
210    /// Type of constant.
211    pub ty: Type,
212    /// Defined value of constant, or default value of specialization constant.
213    pub value: ConstantValue,
214    /// Specialization constant ID, notice that this is NOT an instruction ID.
215    /// It is used to identify specialization constants for graphics libraries.
216    pub spec_id: Option<SpecId>,
217}
218impl Constant {
219    /// Create a constant record with name, type and value. `ty` must be a
220    /// `ScalarType`.
221    pub fn new(name: Option<String>, ty: Type, value: ConstantValue) -> Self {
222        Self {
223            name,
224            ty,
225            value,
226            spec_id: None,
227        }
228    }
229    /// Create an intermediate constant record with type and value. Intermediate
230    /// constants don't have names because they contribute to subexpressions in
231    /// arithmetic.
232    pub fn new_itm(ty: Type, value: ConstantValue) -> Self {
233        Self {
234            name: None,
235            ty,
236            value,
237            spec_id: None,
238        }
239    }
240    /// Create a specialization constant record with name, type, default value
241    /// and a `SpecId`.
242    pub fn new_spec(name: Option<String>, ty: Type, value: ConstantValue, spec_id: SpecId) -> Self {
243        Self {
244            name,
245            ty,
246            value: value,
247            spec_id: Some(spec_id),
248        }
249    }
250}