1use std::convert::TryFrom;
3
4use half::f16;
5use ordered_float::OrderedFloat;
6
7use crate::{
8 error::{anyhow, Result},
9 ty::{ScalarType, Type},
10 var::SpecId,
11};
12
13#[non_exhaustive]
15#[derive(PartialEq, Eq, Hash, Clone, Debug)]
16pub enum ConstantValue {
17 Typeless(Box<[u8]>),
18 Bool(bool),
19 S8(i8),
20 S16(i16),
21 S32(i32),
22 S64(i64),
23 U8(u8),
24 U16(u16),
25 U32(u32),
26 U64(u64),
27 F16(OrderedFloat<f16>),
28 F32(OrderedFloat<f32>),
29 F64(OrderedFloat<f64>),
30}
31impl From<&[u32]> for ConstantValue {
32 fn from(x: &[u32]) -> Self {
33 let bytes = x.iter().flat_map(|x| x.to_le_bytes()).collect();
34 ConstantValue::Typeless(bytes)
35 }
36}
37impl From<&[u8]> for ConstantValue {
38 fn from(x: &[u8]) -> Self {
39 let bytes = x.to_owned().into_boxed_slice();
40 ConstantValue::Typeless(bytes)
41 }
42}
43impl From<[u8; 4]> for ConstantValue {
44 fn from(x: [u8; 4]) -> Self {
45 ConstantValue::try_from(&x as &[u8]).unwrap()
46 }
47}
48impl From<[u8; 8]> for ConstantValue {
49 fn from(x: [u8; 8]) -> Self {
50 ConstantValue::try_from(&x as &[u8]).unwrap()
51 }
52}
53impl From<bool> for ConstantValue {
54 fn from(x: bool) -> Self {
55 Self::Bool(x)
56 }
57}
58impl From<u32> for ConstantValue {
59 fn from(x: u32) -> Self {
60 Self::U32(x)
61 }
62}
63impl From<i32> for ConstantValue {
64 fn from(x: i32) -> Self {
65 Self::S32(x)
66 }
67}
68impl From<f32> for ConstantValue {
69 fn from(x: f32) -> Self {
70 Self::F32(OrderedFloat(x))
71 }
72}
73impl ConstantValue {
74 pub fn to_typed(&self, ty: &Type) -> Result<Self> {
75 let x = match self {
76 Self::Typeless(x) => x,
77 _ => return Err(anyhow!("{self:?} is already typed")),
78 };
79
80 if let Some(scalar_ty) = ty.as_scalar() {
81 match scalar_ty {
82 ScalarType::Boolean => Ok(ConstantValue::Bool(x.iter().any(|x| x != &0))),
83 ScalarType::Integer {
84 bits: 8,
85 is_signed: true,
86 } if x.len() == 4 => {
87 let x = i8::from_le_bytes([x[0]]);
88 Ok(ConstantValue::S8(x))
89 }
90 ScalarType::Integer {
91 bits: 8,
92 is_signed: false,
93 } if x.len() == 4 => {
94 let x = u8::from_le_bytes([x[0]]);
95 Ok(ConstantValue::U8(x))
96 }
97 ScalarType::Integer {
98 bits: 16,
99 is_signed: true,
100 } if x.len() == 4 => {
101 let x = i16::from_le_bytes([x[0], x[1]]);
102 Ok(ConstantValue::S16(x))
103 }
104 ScalarType::Integer {
105 bits: 16,
106 is_signed: false,
107 } if x.len() == 4 => {
108 let x = u16::from_le_bytes([x[0], x[1]]);
109 Ok(ConstantValue::U16(x))
110 }
111 ScalarType::Integer {
112 bits: 32,
113 is_signed: true,
114 } if x.len() == 4 => {
115 let x = i32::from_le_bytes([x[0], x[1], x[2], x[3]]);
116 Ok(ConstantValue::S32(x))
117 }
118 ScalarType::Integer {
119 bits: 32,
120 is_signed: false,
121 } if x.len() == 4 => {
122 let x = u32::from_le_bytes([x[0], x[1], x[2], x[3]]);
123 Ok(ConstantValue::U32(x))
124 }
125 ScalarType::Integer {
126 bits: 64,
127 is_signed: true,
128 } if x.len() == 8 => {
129 let x = i64::from_le_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
130 Ok(ConstantValue::S64(x))
131 }
132 ScalarType::Integer {
133 bits: 64,
134 is_signed: false,
135 } if x.len() == 8 => {
136 let x = u64::from_le_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
137 Ok(ConstantValue::U64(x))
138 }
139 ScalarType::Float { bits: 16 } if x.len() == 4 => {
140 let x = f16::from_le_bytes([x[0], x[1]]);
141 Ok(ConstantValue::F16(OrderedFloat(x)))
142 }
143 ScalarType::Float { bits: 32 } if x.len() == 4 => {
144 let x = f32::from_le_bytes([x[0], x[1], x[2], x[3]]);
145 Ok(ConstantValue::F32(OrderedFloat(x)))
146 }
147 ScalarType::Float { bits: 64 } if x.len() == 8 => {
148 let x = f64::from_le_bytes([x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]]);
149 Ok(ConstantValue::F64(OrderedFloat(x)))
150 }
151 _ => Err(anyhow!(
152 "cannot parse {:?} from {} bytes",
153 scalar_ty,
154 x.len()
155 )),
156 }
157 } else {
158 Err(anyhow!("cannot parse {:?} as a constant value", ty))
159 }
160 }
161
162 pub fn to_bool(&self) -> Option<bool> {
163 match self {
164 Self::Bool(x) => Some(*x),
165 _ => None,
166 }
167 }
168 pub fn to_s32(&self) -> Option<i32> {
169 match self {
170 Self::S32(x) => Some(*x),
171 _ => None,
172 }
173 }
174 pub fn to_u32(&self) -> Option<i32> {
175 match self {
176 Self::S32(x) => Some(*x),
177 _ => None,
178 }
179 }
180 pub fn to_f32(&self) -> Option<f32> {
181 match self {
182 Self::F32(x) => Some((*x).into()),
183 _ => None,
184 }
185 }
186
187 pub fn to_typeless(&self) -> Option<Box<[u8]>> {
188 match self {
189 Self::Typeless(x) => Some(x.clone()),
190 Self::S8(x) => Some(Box::new(x.to_le_bytes())),
191 Self::S16(x) => Some(Box::new(x.to_le_bytes())),
192 Self::S32(x) => Some(Box::new(x.to_le_bytes())),
193 Self::S64(x) => Some(Box::new(x.to_le_bytes())),
194 Self::U8(x) => Some(Box::new(x.to_le_bytes())),
195 Self::U16(x) => Some(Box::new(x.to_le_bytes())),
196 Self::U32(x) => Some(Box::new(x.to_le_bytes())),
197 Self::U64(x) => Some(Box::new(x.to_le_bytes())),
198 Self::F16(x) => Some(Box::new(x.to_le_bytes())),
199 Self::F32(x) => Some(Box::new(x.to_le_bytes())),
200 Self::F64(x) => Some(Box::new(x.to_le_bytes())),
201 Self::Bool(x) => Some(Box::new([*x as u8])),
202 }
203 }
204}
205
206#[derive(PartialEq, Eq, Hash, Clone, Debug)]
208pub struct Constant {
209 pub name: Option<String>,
210 pub ty: Type,
212 pub value: ConstantValue,
214 pub spec_id: Option<SpecId>,
217}
218impl Constant {
219 pub fn new(name: Option<String>, ty: Type, value: ConstantValue) -> Self {
222 Self {
223 name,
224 ty,
225 value,
226 spec_id: None,
227 }
228 }
229 pub fn new_itm(ty: Type, value: ConstantValue) -> Self {
233 Self {
234 name: None,
235 ty,
236 value,
237 spec_id: None,
238 }
239 }
240 pub fn new_spec(name: Option<String>, ty: Type, value: ConstantValue, spec_id: SpecId) -> Self {
243 Self {
244 name,
245 ty,
246 value: value,
247 spec_id: Some(spec_id),
248 }
249 }
250}