prism_tensor/activation.rs
1//! `ActivationAxis` declaration + parametric element-wise i8 nonlinearity
2//! reference impls.
3
4#![allow(missing_docs)]
5
6use uor_foundation::enforcement::ShapeViolation;
7use uor_foundation::pipeline::AxisExtension;
8use uor_foundation_sdk::axis;
9
10axis! {
11 /// Wiki ADR-031 element-wise nonlinearity axis.
12 ///
13 /// Reference kernels operate on a fixed-length `N`-element `i8`
14 /// vector. `relu` clamps negative values to zero. `sigmoid_q` is
15 /// the Q1.7 piecewise-linear sigmoid approximation — the canonical
16 /// integer-arithmetic determinism contract per ADR-030.
17 pub trait ActivationAxis: AxisExtension {
18 const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis";
19 /// Per-impl axis output ceiling. The application's
20 /// `HostBounds::AXIS_OUTPUT_BYTES_MAX` (ADR-037) is checked
21 /// against this value at dispatch; the axis impl carries no
22 /// substrate-arbitrary cap of its own.
23 const MAX_OUTPUT_BYTES: usize = 16;
24 /// Apply ReLU elementwise.
25 ///
26 /// # Errors
27 ///
28 /// Returns `ShapeViolation` on input/output length mismatch.
29 fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
30 /// Apply Q1.7 piecewise-linear sigmoid elementwise.
31 ///
32 /// # Errors
33 ///
34 /// Returns `ShapeViolation` on input/output length mismatch.
35 fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
36 }
37}
38
39fn arity_violation(constraint: &'static str) -> ShapeViolation {
40 ShapeViolation {
41 shape_iri: "https://uor.foundation/axis/ActivationAxisShape",
42 constraint_iri: constraint,
43 property_iri: "https://uor.foundation/axis/inputBytes",
44 expected_range: "https://uor.foundation/axis/ActivationInputArity",
45 min_count: 0,
46 max_count: 0,
47 kind: uor_foundation::ViolationKind::ValueCheck,
48 }
49}
50
51fn check_lens(input: &[u8], out: &[u8], n: usize) -> Result<(), ShapeViolation> {
52 if input.len() != n {
53 return Err(arity_violation(
54 "https://uor.foundation/axis/ActivationAxisShape/inputByteLength",
55 ));
56 }
57 if out.len() < n {
58 return Err(arity_violation(
59 "https://uor.foundation/axis/ActivationAxisShape/outputByteLength",
60 ));
61 }
62 Ok(())
63}
64
65/// Parametric element-wise activation kernels over an `N`-element `i8`
66/// vector.
67///
68/// `N` is the vector length. The same kernels (ReLU, Q1.7 sigmoid) are
69/// applied to every element independently; per-element determinism
70/// composes to per-vector determinism per ADR-030.
71///
72/// # `HostBounds` discipline
73///
74/// `N` is unconstrained at the axis level per [Wiki ADR-018][09]. The
75/// application's [`HostBounds`][uor_foundation::HostBounds] selection
76/// declares the ceiling: a `CpuI8VectorActivation<N>` instantiation
77/// requires the application's `B` to satisfy
78/// `N <= B::AXIS_OUTPUT_BYTES_MAX` per ADR-037. Specific `N` values
79/// (16, 32, 64, 128, 256, …) are picked by the application from its
80/// declared bounds, not by this crate.
81///
82/// [09]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
83#[derive(Debug, Clone, Copy)]
84pub struct CpuI8VectorActivation<const N: usize>;
85
86impl<const N: usize> Default for CpuI8VectorActivation<N> {
87 fn default() -> Self {
88 Self
89 }
90}
91
92impl<const N: usize> ActivationAxis for CpuI8VectorActivation<N> {
93 const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis/CpuI8Vector";
94 const MAX_OUTPUT_BYTES: usize = N;
95
96 fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
97 // Structural well-formedness only — a zero-length vector is
98 // not a vector. Capacity ceilings are the application's
99 // `HostBounds::AXIS_OUTPUT_BYTES_MAX` per ADR-037, enforced
100 // structurally at the dispatch layer.
101 if N == 0 {
102 return Err(arity_violation(
103 "https://uor.foundation/axis/ActivationAxisShape/nNonZero",
104 ));
105 }
106 check_lens(input, out, N)?;
107 for i in 0..N {
108 #[allow(clippy::cast_possible_wrap)]
109 let v = input[i] as i8;
110 out[i] = if v > 0 { input[i] } else { 0 };
111 }
112 Ok(N)
113 }
114
115 fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
116 if N == 0 {
117 return Err(arity_violation(
118 "https://uor.foundation/axis/ActivationAxisShape/nNonZero",
119 ));
120 }
121 check_lens(input, out, N)?;
122 for i in 0..N {
123 #[allow(clippy::cast_possible_wrap)]
124 let x = input[i] as i8;
125 let y: i8 = if x <= -64 {
126 0
127 } else if x >= 64 {
128 127
129 } else {
130 #[allow(clippy::cast_possible_truncation)]
131 {
132 64i8 + (x / 2)
133 }
134 };
135 #[allow(clippy::cast_sign_loss)]
136 {
137 out[i] = y as u8;
138 }
139 }
140 Ok(N)
141 }
142}
143
144// ADR-052 generic-form companion.
145axis_extension_impl_for_activation_axis!(@generic CpuI8VectorActivation<N>, [const N: usize]);