Skip to main content

prism_tensor/
shape.rs

1//! Higher-rank tensor `ConstrainedTypeShape` carriers
2//! ([`Tensor3Shape`], [`Tensor4Shape`]).
3//!
4//! Per [Wiki ADR-031][09]'s `Tensor<Element, Shape>` shape commitment,
5//! the tensor sub-crate declares typed shape carriers per common rank:
6//!
7//! - rank-1 → [`VectorShape`][crate::tensor::VectorShape]
8//! - rank-2 → [`MatrixShape`][crate::tensor::MatrixShape]
9//! - rank-3 → [`Tensor3Shape`] *(this module)*
10//! - rank-4 → [`Tensor4Shape`] *(this module)*
11//!
12//! Higher ranks compose through `partition_product!` per ADR-033/044;
13//! the carriers in this module cover the common GGUF / ONNX tensor
14//! ranks (embeddings, attention heads, batched attention) directly.
15//!
16//! Per [ADR-017][09]'s closure rule each carrier shares the generic
17//! `https://uor.foundation/type/ConstrainedType` IRI and content-
18//! addresses through `(SITE_COUNT, CONSTRAINTS)`; the Rust-type
19//! distinction is the application-level ergonomics surface for
20//! variable-rank tensor shapes.
21//!
22//! [09]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
23
24#![allow(missing_docs)]
25
26use uor_foundation::enforcement::{GroundedShape, ShapeViolation};
27use uor_foundation::pipeline::{ConstrainedTypeShape, ConstraintRef, IntoBindingValue};
28
29/// Parametric `ConstrainedTypeShape` for a row-major rank-3 tensor of
30/// shape `D0 × D1 × D2` carrying `ELEM_BYTES`-byte elements.
31///
32/// Per ADR-031's `Tensor<Element, Shape>` shape commitment for rank-3.
33/// Common GGUF / ONNX usage: per-head attention key / value tensors
34/// (`batch × heads × dim`), 3D image volumes, sequence-of-tokens
35/// embeddings.
36#[derive(Debug, Clone, Copy)]
37pub struct Tensor3Shape<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>;
38
39impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize> Default
40    for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
41{
42    fn default() -> Self {
43        Self
44    }
45}
46
47impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>
48    ConstrainedTypeShape for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
49{
50    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
51    const SITE_COUNT: usize = D0 * D1 * D2 * ELEM_BYTES;
52    const CONSTRAINTS: &'static [ConstraintRef] = &[];
53    #[allow(clippy::cast_possible_truncation)]
54    const CYCLE_SIZE: u64 = 256u64.saturating_pow((D0 * D1 * D2 * ELEM_BYTES) as u32);
55}
56
57impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>
58    uor_foundation::pipeline::__sdk_seal::Sealed for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
59{
60}
61impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize> GroundedShape
62    for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
63{
64}
65impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize> IntoBindingValue
66    for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
67{
68    const MAX_BYTES: usize = D0 * D1 * D2 * ELEM_BYTES;
69
70    fn into_binding_bytes(&self, _out: &mut [u8]) -> Result<usize, ShapeViolation> {
71        Ok(0)
72    }
73}
74
75/// Parametric `ConstrainedTypeShape` for a row-major rank-4 tensor of
76/// shape `D0 × D1 × D2 × D3` carrying `ELEM_BYTES`-byte elements.
77///
78/// Per ADR-031's `Tensor<Element, Shape>` shape commitment for rank-4.
79/// Common GGUF / ONNX usage: batched multi-head attention
80/// (`batch × heads × seq × dim`), 4D conv weight tensors
81/// (`out_channels × in_channels × kernel_h × kernel_w`).
82#[derive(Debug, Clone, Copy)]
83pub struct Tensor4Shape<
84    const D0: usize,
85    const D1: usize,
86    const D2: usize,
87    const D3: usize,
88    const ELEM_BYTES: usize,
89>;
90
91impl<
92        const D0: usize,
93        const D1: usize,
94        const D2: usize,
95        const D3: usize,
96        const ELEM_BYTES: usize,
97    > Default for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
98{
99    fn default() -> Self {
100        Self
101    }
102}
103
104impl<
105        const D0: usize,
106        const D1: usize,
107        const D2: usize,
108        const D3: usize,
109        const ELEM_BYTES: usize,
110    > ConstrainedTypeShape for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
111{
112    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
113    const SITE_COUNT: usize = D0 * D1 * D2 * D3 * ELEM_BYTES;
114    const CONSTRAINTS: &'static [ConstraintRef] = &[];
115    #[allow(clippy::cast_possible_truncation)]
116    const CYCLE_SIZE: u64 = 256u64.saturating_pow((D0 * D1 * D2 * D3 * ELEM_BYTES) as u32);
117}
118
119impl<
120        const D0: usize,
121        const D1: usize,
122        const D2: usize,
123        const D3: usize,
124        const ELEM_BYTES: usize,
125    > uor_foundation::pipeline::__sdk_seal::Sealed for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
126{
127}
128impl<
129        const D0: usize,
130        const D1: usize,
131        const D2: usize,
132        const D3: usize,
133        const ELEM_BYTES: usize,
134    > GroundedShape for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
135{
136}
137impl<
138        const D0: usize,
139        const D1: usize,
140        const D2: usize,
141        const D3: usize,
142        const ELEM_BYTES: usize,
143    > IntoBindingValue for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
144{
145    const MAX_BYTES: usize = D0 * D1 * D2 * D3 * ELEM_BYTES;
146
147    fn into_binding_bytes(&self, _out: &mut [u8]) -> Result<usize, ShapeViolation> {
148        Ok(0)
149    }
150}