Skip to main content

prism_tensor/
shape.rs

1//! Higher-rank tensor `ConstrainedTypeShape` carriers
2//! ([`Tensor3Shape`], [`Tensor4Shape`]).
3//!
4//! Per [Wiki ADR-031][09]'s `Tensor<Element, Shape>` shape commitment,
5//! the tensor sub-crate declares typed shape carriers per common rank:
6//!
7//! - rank-1 → [`VectorShape`][crate::tensor::VectorShape]
8//! - rank-2 → [`MatrixShape`][crate::tensor::MatrixShape]
9//! - rank-3 → [`Tensor3Shape`] *(this module)*
10//! - rank-4 → [`Tensor4Shape`] *(this module)*
11//!
12//! Higher ranks compose through `partition_product!` per ADR-033/044;
13//! the carriers in this module cover the common GGUF / ONNX tensor
14//! ranks (embeddings, attention heads, batched attention) directly.
15//!
16//! Per [ADR-017][09]'s closure rule each carrier shares the generic
17//! `https://uor.foundation/type/ConstrainedType` IRI and content-
18//! addresses through `(SITE_COUNT, CONSTRAINTS)`; the Rust-type
19//! distinction is the application-level ergonomics surface for
20//! variable-rank tensor shapes.
21//!
22//! [09]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
23
24#![allow(missing_docs)]
25
26use uor_foundation::enforcement::GroundedShape;
27use uor_foundation::pipeline::{ConstrainedTypeShape, ConstraintRef, IntoBindingValue, TermValue};
28
29/// Parametric `ConstrainedTypeShape` for a row-major rank-3 tensor of
30/// shape `D0 × D1 × D2` carrying `ELEM_BYTES`-byte elements.
31///
32/// Per ADR-031's `Tensor<Element, Shape>` shape commitment for rank-3.
33/// Common GGUF / ONNX usage: per-head attention key / value tensors
34/// (`batch × heads × dim`), 3D image volumes, sequence-of-tokens
35/// embeddings.
36#[derive(Debug, Clone, Copy)]
37pub struct Tensor3Shape<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>;
38
39impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize> Default
40    for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
41{
42    fn default() -> Self {
43        Self
44    }
45}
46
47impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>
48    ConstrainedTypeShape for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
49{
50    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
51    const SITE_COUNT: usize = D0 * D1 * D2 * ELEM_BYTES;
52    const CONSTRAINTS: &'static [ConstraintRef] = &[];
53    #[allow(clippy::cast_possible_truncation)]
54    const CYCLE_SIZE: u64 = 256u64.saturating_pow((D0 * D1 * D2 * ELEM_BYTES) as u32);
55}
56
57impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>
58    uor_foundation::pipeline::__sdk_seal::Sealed for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
59{
60}
61impl<const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize> GroundedShape
62    for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
63{
64}
65impl<'a, const D0: usize, const D1: usize, const D2: usize, const ELEM_BYTES: usize>
66    IntoBindingValue<'a> for Tensor3Shape<D0, D1, D2, ELEM_BYTES>
67{
68    fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
69        TermValue::empty()
70    }
71}
72
73/// Parametric `ConstrainedTypeShape` for a row-major rank-4 tensor of
74/// shape `D0 × D1 × D2 × D3` carrying `ELEM_BYTES`-byte elements.
75///
76/// Per ADR-031's `Tensor<Element, Shape>` shape commitment for rank-4.
77/// Common GGUF / ONNX usage: batched multi-head attention
78/// (`batch × heads × seq × dim`), 4D conv weight tensors
79/// (`out_channels × in_channels × kernel_h × kernel_w`).
80#[derive(Debug, Clone, Copy)]
81pub struct Tensor4Shape<
82    const D0: usize,
83    const D1: usize,
84    const D2: usize,
85    const D3: usize,
86    const ELEM_BYTES: usize,
87>;
88
89impl<
90        const D0: usize,
91        const D1: usize,
92        const D2: usize,
93        const D3: usize,
94        const ELEM_BYTES: usize,
95    > Default for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
96{
97    fn default() -> Self {
98        Self
99    }
100}
101
102impl<
103        const D0: usize,
104        const D1: usize,
105        const D2: usize,
106        const D3: usize,
107        const ELEM_BYTES: usize,
108    > ConstrainedTypeShape for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
109{
110    const IRI: &'static str = "https://uor.foundation/type/ConstrainedType";
111    const SITE_COUNT: usize = D0 * D1 * D2 * D3 * ELEM_BYTES;
112    const CONSTRAINTS: &'static [ConstraintRef] = &[];
113    #[allow(clippy::cast_possible_truncation)]
114    const CYCLE_SIZE: u64 = 256u64.saturating_pow((D0 * D1 * D2 * D3 * ELEM_BYTES) as u32);
115}
116
117impl<
118        const D0: usize,
119        const D1: usize,
120        const D2: usize,
121        const D3: usize,
122        const ELEM_BYTES: usize,
123    > uor_foundation::pipeline::__sdk_seal::Sealed for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
124{
125}
126impl<
127        const D0: usize,
128        const D1: usize,
129        const D2: usize,
130        const D3: usize,
131        const ELEM_BYTES: usize,
132    > GroundedShape for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
133{
134}
135impl<
136        'a,
137        const D0: usize,
138        const D1: usize,
139        const D2: usize,
140        const D3: usize,
141        const ELEM_BYTES: usize,
142    > IntoBindingValue<'a> for Tensor4Shape<D0, D1, D2, D3, ELEM_BYTES>
143{
144    fn as_binding_value<const INLINE_BYTES: usize>(&self) -> TermValue<'a, INLINE_BYTES> {
145        TermValue::empty()
146    }
147}