1use std::sync::Arc;
21use vyre::ir::{DataType, Ident};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
30#[non_exhaustive]
31pub struct TensorRef {
32 pub name: Ident,
34 pub dtype: DataType,
36 pub shape: Arc<[u32]>,
40}
41
42impl TensorRef {
43 #[must_use]
46 pub fn new(name: impl Into<Ident>, dtype: DataType, shape: Vec<u32>) -> Self {
47 Self {
48 name: name.into(),
49 dtype,
50 shape: Arc::from(shape),
51 }
52 }
53
54 #[must_use]
56 pub fn u32_1d(name: impl Into<Ident>, len: u32) -> Self {
57 Self::new(name, DataType::U32, vec![len])
58 }
59
60 #[must_use]
62 pub fn f32_1d(name: impl Into<Ident>, len: u32) -> Self {
63 Self::new(name, DataType::F32, vec![len])
64 }
65
66 #[must_use]
68 pub fn u32_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
69 Self::new(name, DataType::U32, vec![rows, cols])
70 }
71
72 #[must_use]
74 pub fn f16_1d(name: impl Into<Ident>, len: u32) -> Self {
75 Self::new(name, DataType::F16, vec![len])
76 }
77
78 #[must_use]
80 pub fn f16_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
81 Self::new(name, DataType::F16, vec![rows, cols])
82 }
83
84 #[must_use]
86 pub fn f32_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
87 Self::new(name, DataType::F32, vec![rows, cols])
88 }
89
90 #[must_use]
93 pub fn element_count(&self) -> Option<u32> {
94 self.shape
95 .iter()
96 .try_fold(1u32, |acc, &dim| acc.checked_mul(dim))
97 }
98
99 #[must_use]
103 pub fn name_str(&self) -> &str {
104 self.name.as_str()
105 }
106}
107
108#[derive(Debug, Clone, thiserror::Error)]
111#[non_exhaustive]
112pub enum TensorRefError {
113 #[error(
115 "TensorRef `{name}` has dtype `{found:?}`; op `{op}` expects `{expected:?}`. Fix: pass a buffer of the correct dtype or cast."
116 )]
117 DtypeMismatch {
118 name: String,
120 found: DataType,
122 expected: DataType,
124 op: &'static str,
126 },
127 #[error(
129 "TensorRef `{name}` has shape {found:?}; op `{op}` expects {expected:?}. Fix: reshape or pick a compatible op variant."
130 )]
131 ShapeMismatch {
132 name: String,
134 found: Vec<u32>,
136 expected: Vec<u32>,
138 op: &'static str,
140 },
141 #[error(
143 "TensorRef name collision in op `{op}`: `{name}` appears on multiple arguments. Fix: use distinct buffer names per argument."
144 )]
145 NameCollision {
146 name: String,
148 op: &'static str,
150 },
151 #[error(
153 "TensorRef `{name}` element-count overflows u32 for shape {shape:?}. Fix: reduce dimensions below the u32 boundary."
154 )]
155 ElementCountOverflow {
156 name: String,
158 shape: Vec<u32>,
160 },
161}
162
163pub fn check_unique_names(refs: &[&TensorRef], op: &'static str) -> Result<(), TensorRefError> {
166 for (idx, t) in refs.iter().enumerate() {
167 if refs[..idx]
168 .iter()
169 .any(|previous| previous.name_str() == t.name_str())
170 {
171 return Err(TensorRefError::NameCollision {
172 name: t.name.as_str().to_string(),
173 op,
174 });
175 }
176 }
177 Ok(())
178}
179
180pub fn check_dtype(
183 r: &TensorRef,
184 expected: DataType,
185 op: &'static str,
186) -> Result<(), TensorRefError> {
187 if r.dtype != expected {
188 return Err(TensorRefError::DtypeMismatch {
189 name: r.name.as_str().to_string(),
190 found: r.dtype.clone(),
191 expected,
192 op,
193 });
194 }
195 Ok(())
196}
197
198pub fn check_shape(
201 r: &TensorRef,
202 expected: &[u32],
203 op: &'static str,
204) -> Result<(), TensorRefError> {
205 if r.shape.as_ref() != expected {
206 return Err(TensorRefError::ShapeMismatch {
207 name: r.name.as_str().to_string(),
208 found: r.shape.to_vec(),
209 expected: expected.to_vec(),
210 op,
211 });
212 }
213 Ok(())
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn u32_1d_builder_produces_expected_fields() {
222 let t = TensorRef::u32_1d("x", 64);
223 assert_eq!(t.name.as_str(), "x");
224 assert_eq!(t.dtype, DataType::U32);
225 assert_eq!(t.shape.as_ref(), [64]);
226 assert_eq!(t.element_count(), Some(64));
227 }
228
229 #[test]
230 fn element_count_detects_overflow() {
231 let t = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
232 assert_eq!(t.element_count(), None);
233 }
234
235 #[test]
236 fn check_unique_names_catches_collision() {
237 let a = TensorRef::u32_1d("x", 4);
238 let b = TensorRef::u32_1d("x", 4);
239 let err = check_unique_names(&[&a, &b], "test").unwrap_err();
240 assert!(matches!(err, TensorRefError::NameCollision { .. }));
241 }
242
243 #[test]
244 fn check_dtype_passes_on_match() {
245 let t = TensorRef::f32_1d("y", 8);
246 assert!(matches!(check_dtype(&t, DataType::F32, "op"), Ok(())));
247 }
248
249 #[test]
250 fn check_dtype_fails_on_mismatch() {
251 let t = TensorRef::u32_1d("y", 8);
252 let err = check_dtype(&t, DataType::F32, "op").unwrap_err();
253 assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
254 }
255
256 #[test]
257 fn check_shape_passes_and_fails() {
258 let t = TensorRef::u32_2d("m", 4, 8);
259 assert!(check_shape(&t, &[4, 8], "op").is_ok());
260 let err = check_shape(&t, &[4, 16], "op").unwrap_err();
261 assert!(matches!(err, TensorRefError::ShapeMismatch { .. }));
262 }
263}