Skip to main content

vyre_libs/
tensor_ref.rs

1//! `TensorRef`  -  typed buffer-argument wrapper for Cat-A ops.
2//!
3//! Every Cat-A composition that takes a buffer name as `&str` is a
4//! landmine: nothing type-checks `attention(q, k, v, out)` when the
5//! caller swaps `q` and `k`. `TensorRef` fixes that by pairing the
6//! buffer name with shape + dtype metadata so builders can validate
7//! at construction time.
8//!
9//! The type is intentionally shallow: it carries just enough metadata
10//! to catch the most common mistakes (dtype mismatch, shape mismatch,
11//! name collision). Full tensor-semantic analysis  -  broadcasting,
12//! stride inference, view lifetimes  -  belongs in a future
13//! `vyre-libs-tensor` layer, but the name + shape + dtype trio here
14//! is the frozen API every consumer pins to.
15//!
16//! **Future-proofing:** `TensorRef` is `#[non_exhaustive]` and its
17//! constructor takes `impl Into<…>` so we can add fields without
18//! breaking existing call sites.
19
20use std::sync::Arc;
21use vyre::ir::{DataType, Ident};
22
23/// A named, typed, shaped buffer argument passed into a Cat-A op.
24///
25/// Construct with [`TensorRef::new`] or the convenience helpers
26/// (`u32_1d`, `f32_1d`, `u32_2d`, `f32_2d`). Downstream ops consume
27/// `TensorRef`s instead of raw `&str` buffer names so type + shape
28/// checks happen at `build()` time.
29#[derive(Debug, Clone, PartialEq, Eq)]
30#[non_exhaustive]
31pub struct TensorRef {
32    /// Name the buffer is registered under. Matches `BufferDecl::name`.
33    pub name: Ident,
34    /// Element dtype. Enforced against each op's expected dtype set.
35    pub dtype: DataType,
36    /// Logical shape in elements (not bytes). An empty vec is a scalar;
37    /// a 2-element vec is a matrix; etc. Used for shape-mismatch
38    /// detection at build-time.
39    pub shape: Arc<[u32]>,
40}
41
42impl TensorRef {
43    /// Construct an explicit `TensorRef`. Callers prefer the shape
44    /// helpers below unless their shape is computed.
45    #[must_use]
46    pub fn new(name: impl Into<Ident>, dtype: DataType, shape: Vec<u32>) -> Self {
47        Self {
48            name: name.into(),
49            dtype,
50            shape: Arc::from(shape),
51        }
52    }
53
54    /// U32 1-D tensor convenience constructor.
55    #[must_use]
56    pub fn u32_1d(name: impl Into<Ident>, len: u32) -> Self {
57        Self::new(name, DataType::U32, vec![len])
58    }
59
60    /// F32 1-D tensor convenience constructor.
61    #[must_use]
62    pub fn f32_1d(name: impl Into<Ident>, len: u32) -> Self {
63        Self::new(name, DataType::F32, vec![len])
64    }
65
66    /// U32 2-D tensor convenience constructor (rows × cols).
67    #[must_use]
68    pub fn u32_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
69        Self::new(name, DataType::U32, vec![rows, cols])
70    }
71
72    /// F16 1-D tensor convenience constructor.
73    #[must_use]
74    pub fn f16_1d(name: impl Into<Ident>, len: u32) -> Self {
75        Self::new(name, DataType::F16, vec![len])
76    }
77
78    /// F16 2-D tensor convenience constructor (rows × cols).
79    #[must_use]
80    pub fn f16_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
81        Self::new(name, DataType::F16, vec![rows, cols])
82    }
83
84    /// F32 2-D tensor convenience constructor (rows × cols).
85    #[must_use]
86    pub fn f32_2d(name: impl Into<Ident>, rows: u32, cols: u32) -> Self {
87        Self::new(name, DataType::F32, vec![rows, cols])
88    }
89
90    /// Total element count. Returns `None` on overflow so builders
91    /// can surface a structured error rather than silent wraparound.
92    #[must_use]
93    pub fn element_count(&self) -> Option<u32> {
94        self.shape
95            .iter()
96            .try_fold(1u32, |acc, &dim| acc.checked_mul(dim))
97    }
98
99    /// Borrow the buffer name as `&str`  -  the form every IR builder
100    /// still accepts. Lets Cat-A ops forward to underlying primitives
101    /// while keeping the typed surface on the boundary.
102    #[must_use]
103    pub fn name_str(&self) -> &str {
104        self.name.as_str()
105    }
106}
107
108/// Error returned when [`TensorRef`] arguments fail a builder's
109/// dtype / shape / name-uniqueness check.
110#[derive(Debug, Clone, thiserror::Error)]
111#[non_exhaustive]
112pub enum TensorRefError {
113    /// Dtype doesn't match what the op expects.
114    #[error(
115        "TensorRef `{name}` has dtype `{found:?}`; op `{op}` expects `{expected:?}`. Fix: pass a buffer of the correct dtype or cast."
116    )]
117    DtypeMismatch {
118        /// Tensor name that failed.
119        name: String,
120        /// Caller-provided dtype.
121        found: DataType,
122        /// Dtype the op requires.
123        expected: DataType,
124        /// Op id for the failing builder.
125        op: &'static str,
126    },
127    /// Shape doesn't match what the op expects.
128    #[error(
129        "TensorRef `{name}` has shape {found:?}; op `{op}` expects {expected:?}. Fix: reshape or pick a compatible op variant."
130    )]
131    ShapeMismatch {
132        /// Tensor name that failed.
133        name: String,
134        /// Caller-provided shape.
135        found: Vec<u32>,
136        /// Shape the op requires.
137        expected: Vec<u32>,
138        /// Op id for the failing builder.
139        op: &'static str,
140    },
141    /// Two TensorRef args resolve to the same buffer name.
142    #[error(
143        "TensorRef name collision in op `{op}`: `{name}` appears on multiple arguments. Fix: use distinct buffer names per argument."
144    )]
145    NameCollision {
146        /// The duplicated buffer name.
147        name: String,
148        /// Op id for the failing builder.
149        op: &'static str,
150    },
151    /// Total element count overflows u32.
152    #[error(
153        "TensorRef `{name}` element-count overflows u32 for shape {shape:?}. Fix: reduce dimensions below the u32 boundary."
154    )]
155    ElementCountOverflow {
156        /// Tensor name.
157        name: String,
158        /// Shape that overflowed.
159        shape: Vec<u32>,
160    },
161}
162
163/// Verify that every name in `refs` is unique. Returns
164/// [`TensorRefError::NameCollision`] on the first duplicate.
165pub fn check_unique_names(refs: &[&TensorRef], op: &'static str) -> Result<(), TensorRefError> {
166    for (idx, t) in refs.iter().enumerate() {
167        if refs[..idx]
168            .iter()
169            .any(|previous| previous.name_str() == t.name_str())
170        {
171            return Err(TensorRefError::NameCollision {
172                name: t.name.as_str().to_string(),
173                op,
174            });
175        }
176    }
177    Ok(())
178}
179
180/// Verify a TensorRef matches the expected dtype; returns
181/// [`TensorRefError::DtypeMismatch`] on mismatch.
182pub fn check_dtype(
183    r: &TensorRef,
184    expected: DataType,
185    op: &'static str,
186) -> Result<(), TensorRefError> {
187    if r.dtype != expected {
188        return Err(TensorRefError::DtypeMismatch {
189            name: r.name.as_str().to_string(),
190            found: r.dtype.clone(),
191            expected,
192            op,
193        });
194    }
195    Ok(())
196}
197
198/// Verify a TensorRef matches the expected shape; returns
199/// [`TensorRefError::ShapeMismatch`] on mismatch.
200pub fn check_shape(
201    r: &TensorRef,
202    expected: &[u32],
203    op: &'static str,
204) -> Result<(), TensorRefError> {
205    if r.shape.as_ref() != expected {
206        return Err(TensorRefError::ShapeMismatch {
207            name: r.name.as_str().to_string(),
208            found: r.shape.to_vec(),
209            expected: expected.to_vec(),
210            op,
211        });
212    }
213    Ok(())
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn u32_1d_builder_produces_expected_fields() {
222        let t = TensorRef::u32_1d("x", 64);
223        assert_eq!(t.name.as_str(), "x");
224        assert_eq!(t.dtype, DataType::U32);
225        assert_eq!(t.shape.as_ref(), [64]);
226        assert_eq!(t.element_count(), Some(64));
227    }
228
229    #[test]
230    fn element_count_detects_overflow() {
231        let t = TensorRef::new("big", DataType::U32, vec![1u32 << 20, 1u32 << 20]);
232        assert_eq!(t.element_count(), None);
233    }
234
235    #[test]
236    fn check_unique_names_catches_collision() {
237        let a = TensorRef::u32_1d("x", 4);
238        let b = TensorRef::u32_1d("x", 4);
239        let err = check_unique_names(&[&a, &b], "test").unwrap_err();
240        assert!(matches!(err, TensorRefError::NameCollision { .. }));
241    }
242
243    #[test]
244    fn check_dtype_passes_on_match() {
245        let t = TensorRef::f32_1d("y", 8);
246        assert!(matches!(check_dtype(&t, DataType::F32, "op"), Ok(())));
247    }
248
249    #[test]
250    fn check_dtype_fails_on_mismatch() {
251        let t = TensorRef::u32_1d("y", 8);
252        let err = check_dtype(&t, DataType::F32, "op").unwrap_err();
253        assert!(matches!(err, TensorRefError::DtypeMismatch { .. }));
254    }
255
256    #[test]
257    fn check_shape_passes_and_fails() {
258        let t = TensorRef::u32_2d("m", 4, 8);
259        assert!(check_shape(&t, &[4, 8], "op").is_ok());
260        let err = check_shape(&t, &[4, 16], "op").unwrap_err();
261        assert!(matches!(err, TensorRefError::ShapeMismatch { .. }));
262    }
263}