Skip to main content

prism_tensor/
activation.rs

1//! `ActivationAxis` declaration + parametric element-wise i8 nonlinearity
2//! reference impls.
3
4#![allow(missing_docs)]
5
6use uor_foundation::enforcement::ShapeViolation;
7use uor_foundation::pipeline::AxisExtension;
8use uor_foundation_sdk::axis;
9
10axis! {
11    /// Wiki ADR-031 element-wise nonlinearity axis.
12    ///
13    /// Reference kernels operate on a fixed-length `N`-element `i8`
14    /// vector. `relu` clamps negative values to zero. `sigmoid_q` is
15    /// the Q1.7 piecewise-linear sigmoid approximation — the canonical
16    /// integer-arithmetic determinism contract per ADR-030.
17    pub trait ActivationAxis: AxisExtension {
18        const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis";
19        /// Vector byte-width (overridden per impl).
20        const MAX_OUTPUT_BYTES: usize = 16;
21        /// Apply ReLU elementwise.
22        ///
23        /// # Errors
24        ///
25        /// Returns `ShapeViolation` on input/output length mismatch.
26        fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
27        /// Apply Q1.7 piecewise-linear sigmoid elementwise.
28        ///
29        /// # Errors
30        ///
31        /// Returns `ShapeViolation` on input/output length mismatch.
32        fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
33    }
34}
35
36/// Maximum vector length any [`CpuI8VectorActivation`] instantiation
37/// supports.
38pub const MAX_ACTIVATION_LEN: usize = 256;
39
40fn arity_violation(constraint: &'static str) -> ShapeViolation {
41    ShapeViolation {
42        shape_iri: "https://uor.foundation/axis/ActivationAxisShape",
43        constraint_iri: constraint,
44        property_iri: "https://uor.foundation/axis/inputBytes",
45        expected_range: "https://uor.foundation/axis/ActivationInputArity",
46        min_count: 0,
47        max_count: 0,
48        kind: uor_foundation::ViolationKind::ValueCheck,
49    }
50}
51
52fn check_lens(input: &[u8], out: &[u8], n: usize) -> Result<(), ShapeViolation> {
53    if input.len() != n {
54        return Err(arity_violation(
55            "https://uor.foundation/axis/ActivationAxisShape/inputByteLength",
56        ));
57    }
58    if out.len() < n {
59        return Err(arity_violation(
60            "https://uor.foundation/axis/ActivationAxisShape/outputByteLength",
61        ));
62    }
63    Ok(())
64}
65
66/// Parametric element-wise activation kernels over an `N`-element `i8`
67/// vector.
68///
69/// `N` is the vector length. The same kernels (ReLU, Q1.7 sigmoid) are
70/// applied to every element independently; per-element determinism
71/// composes to per-vector determinism per ADR-030.
72#[derive(Debug, Clone, Copy)]
73pub struct CpuI8VectorActivation<const N: usize>;
74
75impl<const N: usize> Default for CpuI8VectorActivation<N> {
76    fn default() -> Self {
77        Self
78    }
79}
80
81impl<const N: usize> ActivationAxis for CpuI8VectorActivation<N> {
82    const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis/CpuI8Vector";
83    const MAX_OUTPUT_BYTES: usize = N;
84
85    fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
86        if N == 0 || N > MAX_ACTIVATION_LEN {
87            return Err(arity_violation(
88                "https://uor.foundation/axis/ActivationAxisShape/nInRange",
89            ));
90        }
91        check_lens(input, out, N)?;
92        for i in 0..N {
93            #[allow(clippy::cast_possible_wrap)]
94            let v = input[i] as i8;
95            out[i] = if v > 0 { input[i] } else { 0 };
96        }
97        Ok(N)
98    }
99
100    fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
101        if N == 0 || N > MAX_ACTIVATION_LEN {
102            return Err(arity_violation(
103                "https://uor.foundation/axis/ActivationAxisShape/nInRange",
104            ));
105        }
106        check_lens(input, out, N)?;
107        for i in 0..N {
108            #[allow(clippy::cast_possible_wrap)]
109            let x = input[i] as i8;
110            let y: i8 = if x <= -64 {
111                0
112            } else if x >= 64 {
113                127
114            } else {
115                #[allow(clippy::cast_possible_truncation)]
116                {
117                    64i8 + (x / 2)
118                }
119            };
120            #[allow(clippy::cast_sign_loss)]
121            {
122                out[i] = y as u8;
123            }
124        }
125        Ok(N)
126    }
127}
128
129// ADR-052 generic-form companion.
130axis_extension_impl_for_activation_axis!(@generic CpuI8VectorActivation<N>, [const N: usize]);
131
132/// 16-element `i8` vector activation (the canonical small-vector reference).
133pub type CpuI8VectorActivation16 = CpuI8VectorActivation<16>;
134/// 32-element `i8` vector activation.
135pub type CpuI8VectorActivation32 = CpuI8VectorActivation<32>;
136/// 64-element `i8` vector activation.
137pub type CpuI8VectorActivation64 = CpuI8VectorActivation<64>;
138/// 128-element `i8` vector activation.
139pub type CpuI8VectorActivation128 = CpuI8VectorActivation<128>;
140/// 256-element `i8` vector activation (the `MAX_ACTIVATION_LEN` ceiling).
141pub type CpuI8VectorActivation256 = CpuI8VectorActivation<256>;