Skip to main content

prism_tensor/
tensor.rs

1//! `TensorAxis` declaration + parametric square-matmul impl + shape.
2//!
3//! Per [Wiki ADR-031][09-adr-031] the tensor sub-crate exposes
4//! `TensorAxis` as the canonical Layer-3 surface for tensor compute.
5//! The reference impl [`CpuI8MatmulSquare`] is generic over the square
6//! dimension `DIM`, with `i8` inputs and saturating-`i16` outputs —
7//! the integer-arithmetic determinism contract ADR-030 names as the
8//! axis substitution-determinism baseline.
9//!
10//! Variable-rank tensor compute composes through verbs over
11//! `partition_product!`-declared shapes per ADR-033/044; the axis's
12//! role is the fixed-shape atomic primitive.
13//!
14//! # ADR-055 substrate-Term verb body discipline
15//!
16//! Per [Wiki ADR-055](https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions)
17//! every `AxisExtension` impl satisfies the substrate-Term verb body
18//! discipline; the hand-written kernel below uses the default empty
19//! `body_arena()` emitted by foundation-sdk 0.4.11's `axis!`
20//! companion macro (the primitive-fast-path-equivalent realization).
21//!
22//! Explicit substrate-Term decomposition of
23//! `CpuI8MatmulSquare<DIM>::matmul` — `fold_n(DIM, ...)` over rows ×
24//! `fold_n(DIM, ...)` over columns × `fold_n(DIM, ...)` over
25//! reductions, with a `sign_extend` sub-verb (matching `Ge(operand,
26//! Literal(0x80, W8))` to select between `Concat(0x00, operand)` and
27//! `Concat(0xff, operand)`) plus W16 `Mul` + W16 `Add` accumulation
28//! plus saturation via `Match` over `Ge(acc, Literal(0x7fff, W16))` /
29//! `Lt(acc, Literal(0x8000, W16))` per ADR-054 § Substrate-Term
30//! realization examples — is **syntactically expressible** in
31//! foundation-sdk 0.4.11's verb-body grammar. ADR-056 admits
32//! `le`/`lt`/`ge`/`gt` and `concat` in verb/axis bodies (only the
33//! route body's syntactic surface retains the ψ-residuals rejection);
34//! foundation-sdk 0.4.11's depth-2 const-generic-leaf partition-product
35//! projection covers the fold-n composition over matrix shapes. The
36//! remaining work is **operational composition**: the architectural
37//! witness verbs in [`crate::verbs`] (saturating-xor + concat-bytes)
38//! demonstrate the per-element primitives; the unfolded
39//! fold-over-rows-and-columns matmul body is a published-roster
40//! follow-on.
41//!
42//! The hand-written `for`-loop kernel below is the operational form;
43//! byte-output equivalence with BLAS reference outputs at integer
44//! precision is checked at `tests/conformance.rs`.
45//!
46//! [09-adr-031]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
47//! [09-adr-054]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
48
49#![allow(missing_docs)]
50
51use uor_foundation::enforcement::{GroundedShape, ShapeViolation};
52use uor_foundation::pipeline::{ConstrainedTypeShape, ConstraintRef, IntoBindingValue, TermValue};
53use uor_foundation_sdk::axis;
54
55axis! {
56    /// Wiki ADR-031 tensor-compute axis.
57    ///
58    /// The reference impl `CpuI8MatmulSquare<DIM>` is parametric in
59    /// `DIM` for square `DIM × DIM` `i8` matrices, emitting a `DIM ×
60    /// DIM` `i16` product (saturating) per ADR-030's bit-determinism
61    /// commitment.
62    pub trait TensorAxis: AxisExtension {
63        const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/TensorAxis";
64        /// Per-impl structural output-byte hint
65        /// (`<Impl as TensorAxis>::MAX_OUTPUT_BYTES`). Per ADR-060 the
66        /// foundation derives carrier widths from the application's
67        /// `HostBounds` structural-count primitives; the axis impl
68        /// carries no substrate-arbitrary byte-width cap.
69        const MAX_OUTPUT_BYTES: usize = 32;
70        /// Multiply two row-major `DIM × DIM` `i8` matrices into a
71        /// `DIM × DIM` `i16` product (saturating). Input is `A || B`
72        /// (`2 * DIM * DIM` bytes); output is `2 * DIM * DIM` bytes.
73        ///
74        /// # Errors
75        ///
76        /// Returns `ShapeViolation` on input/output byte-length mismatch.
77        fn matmul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
78    }
79}
80
81fn arity_violation(constraint: &'static str) -> ShapeViolation {
82    ShapeViolation {
83        shape_iri: "https://uor.foundation/axis/TensorAxisShape",
84        constraint_iri: constraint,
85        property_iri: "https://uor.foundation/axis/inputBytes",
86        expected_range: "https://uor.foundation/axis/TensorInputArity",
87        min_count: 0,
88        max_count: 0,
89        kind: uor_foundation::ViolationKind::ValueCheck,
90    }
91}
92
93/// Parametric square `DIM × DIM` `i8` × `i8` → `i16` matmul.
94///
95/// Determinism: per ADR-030's per-axis substitution-determinism note,
96/// the integer-arithmetic CPU impl preserves bit-identity across
97/// targets. `DIM` is the square dimension; for non-square /
98/// non-integer / variable-shape tensor compute the wiki's pattern is
99/// to compose this axis kernel through verbs over `partition_product!`
100/// (per ADR-033/044) — the axis layer fixes the atom shape.
101///
102/// # `HostBounds` discipline
103///
104/// `DIM` is unconstrained at the axis level. Per [Wiki ADR-060][09]
105/// the foundation removed the `AXIS_OUTPUT_BYTES_MAX` cap: a
106/// `CpuI8MatmulSquare<DIM>` kernel's `2 * DIM * DIM`-byte output flows
107/// through the source-polymorphic `TermValue` carrier, whose widths
108/// derive from the application's [`HostBounds`][uor_foundation::HostBounds]
109/// structural-count primitives via foundation `const fn`s — never a
110/// pinned byte-width literal. Specific `DIM` values (4, 8, 16, 32, 64,
111/// …) are picked by the application; this crate imposes no ceiling.
112///
113/// [09]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
114#[derive(Debug, Clone, Copy)]
115pub struct CpuI8MatmulSquare<const DIM: usize>;
116
117impl<const DIM: usize> Default for CpuI8MatmulSquare<DIM> {
118    fn default() -> Self {
119        Self
120    }
121}
122
123impl<const DIM: usize> CpuI8MatmulSquare<DIM> {
124    const fn idx(row: usize, col: usize) -> usize {
125        row * DIM + col
126    }
127}
128
129impl<const DIM: usize> TensorAxis for CpuI8MatmulSquare<DIM> {
130    const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/TensorAxis/CpuI8MatmulSquare";
131    const MAX_OUTPUT_BYTES: usize = 2 * DIM * DIM;
132
133    fn matmul(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
134        // Structural well-formedness only — a 0-dimensional matrix is
135        // not a matrix. Per ADR-060 there is no byte-width cap; the
136        // output flows through the source-polymorphic `TermValue`
137        // carrier sized from the application's `HostBounds` primitives.
138        if DIM == 0 {
139            return Err(arity_violation(
140                "https://uor.foundation/axis/TensorAxisShape/dimNonZero",
141            ));
142        }
143        let mat_bytes = DIM * DIM;
144        let input_bytes = 2 * mat_bytes;
145        let output_bytes = 2 * mat_bytes;
146        if input.len() != input_bytes {
147            return Err(arity_violation(
148                "https://uor.foundation/axis/TensorAxisShape/inputByteLength",
149            ));
150        }
151        if out.len() < output_bytes {
152            return Err(arity_violation(
153                "https://uor.foundation/axis/TensorAxisShape/outputByteLength",
154            ));
155        }
156        let (a_bytes, b_bytes) = input.split_at(mat_bytes);
157        for row in 0..DIM {
158            for col in 0..DIM {
159                let mut acc: i32 = 0;
160                for k in 0..DIM {
161                    #[allow(clippy::cast_possible_wrap)]
162                    let a = i32::from(a_bytes[Self::idx(row, k)] as i8);
163                    #[allow(clippy::cast_possible_wrap)]
164                    let b = i32::from(b_bytes[Self::idx(k, col)] as i8);
165                    acc += a * b;
166                }
167                let saturated: i16 = if acc > i32::from(i16::MAX) {
168                    i16::MAX
169                } else if acc < i32::from(i16::MIN) {
170                    i16::MIN
171                } else {
172                    #[allow(clippy::cast_possible_truncation)]
173                    {
174                        acc as i16
175                    }
176                };
177                let cell = Self::idx(row, col);
178                out[2 * cell..2 * cell + 2].copy_from_slice(&saturated.to_be_bytes());
179            }
180        }
181        Ok(output_bytes)
182    }
183}
184
185// ADR-052 generic-form companion.
186axis_extension_impl_for_tensor_axis!(@generic CpuI8MatmulSquare<DIM>, [const DIM: usize]);
187
188// ---- MatrixShape: ConstrainedTypeShape carrier ----
189
190/// Parametric ConstrainedTypeShape for a row-major `ROWS × COLS`
191/// matrix of `ELEM_BYTES`-byte elements. Per ADR-031's `Tensor<Element,
192/// Shape>` shape commitment, restricted to matrix rank-2 here; higher
193/// ranks compose through `partition_product!` per ADR-033/044.
194#[derive(Debug, Clone, Copy)]
195pub struct MatrixShape<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize>;
196
197impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> Default
198    for MatrixShape<ROWS, COLS, ELEM_BYTES>
199{
200    fn default() -> Self {
201        Self
202    }
203}
204
205impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> ConstrainedTypeShape
206    for MatrixShape<ROWS, COLS, ELEM_BYTES>
207{
208    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
209    const SITE_COUNT: usize = ROWS * COLS * ELEM_BYTES;
210    const CONSTRAINTS: &'static [ConstraintRef] = &[];
211    #[allow(clippy::cast_possible_truncation)]
212    const CYCLE_SIZE: u64 = 256u64.saturating_pow((ROWS * COLS * ELEM_BYTES) as u32);
213}
214
215impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize>
216    uor_foundation::pipeline::__sdk_seal::Sealed for MatrixShape<ROWS, COLS, ELEM_BYTES>
217{
218}
219impl<const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> GroundedShape
220    for MatrixShape<ROWS, COLS, ELEM_BYTES>
221{
222}
223impl<'a, const ROWS: usize, const COLS: usize, const ELEM_BYTES: usize> IntoBindingValue<'a>
224    for MatrixShape<ROWS, COLS, ELEM_BYTES>
225{
226    fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
227        TermValue::empty()
228    }
229}
230
231/// Parametric ConstrainedTypeShape for a length-`N` vector of
232/// `ELEM_BYTES`-byte elements. Per ADR-031's `Tensor<Element, Shape>`
233/// for rank-1.
234#[derive(Debug, Clone, Copy)]
235pub struct VectorShape<const N: usize, const ELEM_BYTES: usize>;
236
237impl<const N: usize, const ELEM_BYTES: usize> Default for VectorShape<N, ELEM_BYTES> {
238    fn default() -> Self {
239        Self
240    }
241}
242
243impl<const N: usize, const ELEM_BYTES: usize> ConstrainedTypeShape for VectorShape<N, ELEM_BYTES> {
244    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
245    const SITE_COUNT: usize = N * ELEM_BYTES;
246    const CONSTRAINTS: &'static [ConstraintRef] = &[];
247    #[allow(clippy::cast_possible_truncation)]
248    const CYCLE_SIZE: u64 = 256u64.saturating_pow((N * ELEM_BYTES) as u32);
249}
250
251impl<const N: usize, const ELEM_BYTES: usize> uor_foundation::pipeline::__sdk_seal::Sealed
252    for VectorShape<N, ELEM_BYTES>
253{
254}
255impl<const N: usize, const ELEM_BYTES: usize> GroundedShape for VectorShape<N, ELEM_BYTES> {}
256impl<'a, const N: usize, const ELEM_BYTES: usize> IntoBindingValue<'a>
257    for VectorShape<N, ELEM_BYTES>
258{
259    fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
260        TermValue::empty()
261    }
262}