prism_tensor/tensor.rs
1//! `TensorAxis` declaration + parametric square-matmul impl + shape.
2//!
3//! Per [Wiki ADR-031][09-adr-031] the tensor sub-crate exposes
4//! `TensorAxis` as the canonical Layer-3 surface for tensor compute.
5//! The reference impl [`CpuI8MatmulSquare`] is generic over the square
6//! dimension `DIM`, with `i8` inputs and saturating-`i16` outputs —
7//! the integer-arithmetic determinism contract ADR-030 names as the
8//! axis substitution-determinism baseline.
9//!
10//! Variable-rank tensor compute composes through verbs over
11//! `partition_product!`-declared shapes per ADR-033/044; the axis's
12//! role is the fixed-shape atomic primitive.
13//!
14//! # ADR-055 substrate-Term verb body discipline
15//!
16//! Per [Wiki ADR-055](https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions)
17//! every `AxisExtension` impl satisfies the substrate-Term verb body
18//! discipline; the hand-written kernel below uses the default empty
19//! `body_arena()` emitted by foundation-sdk 0.4.11's `axis!`
20//! companion macro (the primitive-fast-path-equivalent realization).
21//!
22//! Explicit substrate-Term decomposition of
23//! `CpuI8MatmulSquare<DIM>::matmul` — `fold_n(DIM, ...)` over rows ×
24//! `fold_n(DIM, ...)` over columns × `fold_n(DIM, ...)` over
25//! reductions, with a `sign_extend` sub-verb (matching `Ge(operand,
26//! Literal(0x80, W8))` to select between `Concat(0x00, operand)` and
27//! `Concat(0xff, operand)`) plus W16 `Mul` + W16 `Add` accumulation
28//! plus saturation via `Match` over `Ge(acc, Literal(0x7fff, W16))` /
29//! `Lt(acc, Literal(0x8000, W16))` per ADR-054 § Substrate-Term
30//! realization examples — is **syntactically expressible** in
31//! foundation-sdk 0.4.11's verb-body grammar. ADR-056 admits
32//! `le`/`lt`/`ge`/`gt` and `concat` in verb/axis bodies (only the
33//! route body's syntactic surface retains the ψ-residuals rejection);
34//! foundation-sdk 0.4.11's depth-2 const-generic-leaf partition-product
35//! projection covers the fold-n composition over matrix shapes. The
36//! remaining work is **operational composition**: the architectural
37//! witness verbs in [`crate::verbs`] (saturating-xor + concat-bytes)
38//! demonstrate the per-element primitives; the unfolded
39//! fold-over-rows-and-columns matmul body is a published-roster
40//! follow-on.
41//!
42//! The hand-written `for`-loop kernel below is the operational form;
43//! byte-output equivalence with BLAS reference outputs at integer
44//! precision is checked at `tests/conformance.rs`.
45//!
46//! [09-adr-031]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
47//! [09-adr-054]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
48
49#![allow(missing_docs)]
50
51use uor_foundation::enforcement::{GroundedShape, ShapeViolation};
52use uor_foundation::pipeline::{
53 AxisExtension, ConstrainedTypeShape, ConstraintRef, IntoBindingValue,
54};
55use uor_foundation_sdk::axis;
56
57axis! {
58 /// Wiki ADR-031 tensor-compute axis.
59 ///
60 /// The reference impl `CpuI8MatmulSquare<DIM>` is parametric in
61 /// `DIM` for square `DIM × DIM` `i8` matrices, emitting a `DIM ×
62 /// DIM` `i16` product (saturating) per ADR-030's bit-determinism
63 /// commitment.
64 pub trait TensorAxis: AxisExtension {
65 const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/TensorAxis";
66 /// Per-impl axis output ceiling. The application's
67 /// `HostBounds::AXIS_OUTPUT_BYTES_MAX` (ADR-037) is checked
68 /// against this value at dispatch; the axis impl carries no
69 /// substrate-arbitrary cap of its own.
70 const MAX_OUTPUT_BYTES: usize = 32;
71 /// Multiply two row-major `DIM × DIM` `i8` matrices into a
72 /// `DIM × DIM` `i16` product (saturating). Input is `A || B`
73 /// (`2 * DIM * DIM` bytes); output is `2 * DIM * DIM` bytes.
74 ///
75 /// # Errors
76 ///
77 /// Returns `ShapeViolation` on input/output byte-length mismatch.
78 fn matmul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
79 }
80}
81
82/// Per-impl `MAX_OUTPUT_BYTES` default for `TensorAxis`: the framework
83/// uses `<Impl as TensorAxis>::MAX_OUTPUT_BYTES` together with the
84/// application's [`HostBounds::AXIS_OUTPUT_BYTES_MAX`][hb] to validate
85/// that the application's substrate selection is wide enough for every
86/// axis impl it composes. The dispatch layer enforces the relation
87/// `<Impl as TensorAxis>::MAX_OUTPUT_BYTES <= B::AXIS_OUTPUT_BYTES_MAX`
88/// structurally; the axis impl carries no per-substrate ceiling of
89/// its own.
90///
91/// [hb]: uor_foundation::HostBounds::AXIS_OUTPUT_BYTES_MAX
92fn arity_violation(constraint: &'static str) -> ShapeViolation {
93 ShapeViolation {
94 shape_iri: "https://uor.foundation/axis/TensorAxisShape",
95 constraint_iri: constraint,
96 property_iri: "https://uor.foundation/axis/inputBytes",
97 expected_range: "https://uor.foundation/axis/TensorInputArity",
98 min_count: 0,
99 max_count: 0,
100 kind: uor_foundation::ViolationKind::ValueCheck,
101 }
102}
103
104/// Parametric square `DIM × DIM` `i8` × `i8` → `i16` matmul.
105///
106/// Determinism: per ADR-030's per-axis substitution-determinism note,
107/// the integer-arithmetic CPU impl preserves bit-identity across
108/// targets. `DIM` is the square dimension; for non-square /
109/// non-integer / variable-shape tensor compute the wiki's pattern is
110/// to compose this axis kernel through verbs over `partition_product!`
111/// (per ADR-033/044) — the axis layer fixes the atom shape.
112///
113/// # `HostBounds` discipline
114///
115/// `DIM` is unconstrained at the axis level per [Wiki ADR-018][09].
116/// The application's [`HostBounds`][uor_foundation::HostBounds]
117/// selection declares the ceiling: a `CpuI8MatmulSquare<DIM>`
118/// instantiation requires the application's `B` to satisfy
119/// `2 * DIM * DIM <= B::AXIS_OUTPUT_BYTES_MAX` per ADR-037. Specific
120/// `DIM` values (4, 8, 16, 32, 64, …) are picked by the application
121/// from its declared bounds, not by this crate.
122///
123/// [09]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
124#[derive(Debug, Clone, Copy)]
125pub struct CpuI8MatmulSquare<const DIM: usize>;
126
127impl<const DIM: usize> Default for CpuI8MatmulSquare<DIM> {
128 fn default() -> Self {
129 Self
130 }
131}
132
133impl<const DIM: usize> CpuI8MatmulSquare<DIM> {
134 const fn idx(row: usize, col: usize) -> usize {
135 row * DIM + col
136 }
137}
138
139impl<const DIM: usize> TensorAxis for CpuI8MatmulSquare<DIM> {
140 const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/TensorAxis/CpuI8MatmulSquare";
141 const MAX_OUTPUT_BYTES: usize = 2 * DIM * DIM;
142
143 fn matmul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
144 // Structural well-formedness only — a 0-dimensional matrix is
145 // not a matrix. Capacity ceilings are the application's
146 // `HostBounds::AXIS_OUTPUT_BYTES_MAX` per ADR-037, enforced
147 // structurally at the dispatch layer; no axis-internal cap.
148 if DIM == 0 {
149 return Err(arity_violation(
150 "https://uor.foundation/axis/TensorAxisShape/dimNonZero",
151 ));
152 }
153 let mat_bytes = DIM * DIM;
154 let input_bytes = 2 * mat_bytes;
155 let output_bytes = 2 * mat_bytes;
156 if input.len() != input_bytes {
157 return Err(arity_violation(
158 "https://uor.foundation/axis/TensorAxisShape/inputByteLength",
159 ));
160 }
161 if out.len() < output_bytes {
162 return Err(arity_violation(
163 "https://uor.foundation/axis/TensorAxisShape/outputByteLength",
164 ));
165 }
166 let (a_bytes, b_bytes) = input.split_at(mat_bytes);
167 for row in 0..DIM {
168 for col in 0..DIM {
169 let mut acc: i32 = 0;
170 for k in 0..DIM {
171 #[allow(clippy::cast_possible_wrap)]
172 let a = i32::from(a_bytes[Self::idx(row, k)] as i8);
173 #[allow(clippy::cast_possible_wrap)]
174 let b = i32::from(b_bytes[Self::idx(k, col)] as i8);
175 acc += a * b;
176 }
177 let saturated: i16 = if acc > i32::from(i16::MAX) {
178 i16::MAX
179 } else if acc < i32::from(i16::MIN) {
180 i16::MIN
181 } else {
182 #[allow(clippy::cast_possible_truncation)]
183 {
184 acc as i16
185 }
186 };
187 let cell = Self::idx(row, col);
188 out[2 * cell..2 * cell + 2].copy_from_slice(&saturated.to_be_bytes());
189 }
190 }
191 Ok(output_bytes)
192 }
193}
194
195// ADR-052 generic-form companion.
196axis_extension_impl_for_tensor_axis!(@generic CpuI8MatmulSquare<DIM>, [const DIM: usize]);
197
198// ---- MatrixShape: ConstrainedTypeShape carrier ----
199
200/// Parametric ConstrainedTypeShape for a row-major `ROWS × COLS`
201/// matrix of `ELEM_BYTES`-byte elements. Per ADR-031's `Tensor<Element,
202/// Shape>` shape commitment, restricted to matrix rank-2 here; higher
203/// ranks compose through `partition_product!` per ADR-033/044.
204#[derive(Debug, Clone, Copy)]
205pub struct MatrixShape<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize>;
206
207impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> Default
208 for MatrixShape<ROWS, COLS, ELEM_BYTES>
209{
210 fn default() -> Self {
211 Self
212 }
213}
214
215impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> ConstrainedTypeShape
216 for MatrixShape<ROWS, COLS, ELEM_BYTES>
217{
218 const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
219 const SITE_COUNT: usize = ROWS * COLS * ELEM_BYTES;
220 const CONSTRAINTS: &'static [ConstraintRef] = &[];
221 #[allow(clippy::cast_possible_truncation)]
222 const CYCLE_SIZE: u64 = 256u64.saturating_pow((ROWS * COLS * ELEM_BYTES) as u32);
223}
224
225impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize>
226 uor_foundation::pipeline::__sdk_seal::Sealed for MatrixShape<ROWS, COLS, ELEM_BYTES>
227{
228}
229impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> GroundedShape
230 for MatrixShape<ROWS, COLS, ELEM_BYTES>
231{
232}
233impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> IntoBindingValue
234 for MatrixShape<ROWS, COLS, ELEM_BYTES>
235{
236 const MAX_BYTES: usize = ROWS * COLS * ELEM_BYTES;
237
238 fn into_binding_bytes(&self, _out: &mut [u8]) -> Result<usize, ShapeViolation> {
239 Ok(0)
240 }
241}
242
243/// Parametric ConstrainedTypeShape for a length-`N` vector of
244/// `ELEM_BYTES`-byte elements. Per ADR-031's `Tensor<Element, Shape>`
245/// for rank-1.
246#[derive(Debug, Clone, Copy)]
247pub struct VectorShape<const N: usize, const ELEM_BYTES: usize>;
248
249impl<const N: usize, const ELEM_BYTES: usize> Default for VectorShape<N, ELEM_BYTES> {
250 fn default() -> Self {
251 Self
252 }
253}
254
255impl<const N: usize, const ELEM_BYTES: usize> ConstrainedTypeShape for VectorShape<N, ELEM_BYTES> {
256 const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
257 const SITE_COUNT: usize = N * ELEM_BYTES;
258 const CONSTRAINTS: &'static [ConstraintRef] = &[];
259 #[allow(clippy::cast_possible_truncation)]
260 const CYCLE_SIZE: u64 = 256u64.saturating_pow((N * ELEM_BYTES) as u32);
261}
262
263impl<const N: usize, const ELEM_BYTES: usize> uor_foundation::pipeline::__sdk_seal::Sealed
264 for VectorShape<N, ELEM_BYTES>
265{
266}
267impl<const N: usize, const ELEM_BYTES: usize> GroundedShape for VectorShape<N, ELEM_BYTES> {}
268impl<const N: usize, const ELEM_BYTES: usize> IntoBindingValue for VectorShape<N, ELEM_BYTES> {
269 const MAX_BYTES: usize = N * ELEM_BYTES;
270
271 fn into_binding_bytes(&self, _out: &mut [u8]) -> Result<usize, ShapeViolation> {
272 Ok(0)
273 }
274}