Skip to main content

prism_tensor/
activation.rs

1//! `ActivationAxis` declaration + parametric element-wise i8 nonlinearity
2//! reference impls.
3
4#![allow(missing_docs)]
5
6use uor_foundation::enforcement::ShapeViolation;
7use uor_foundation_sdk::axis;
8
9axis! {
10    /// Wiki ADR-031 element-wise nonlinearity axis.
11    ///
12    /// Reference kernels operate on a fixed-length `N`-element `i8`
13    /// vector. `relu` clamps negative values to zero. `sigmoid_q` is
14    /// the Q1.7 piecewise-linear sigmoid approximation — the canonical
15    /// integer-arithmetic determinism contract per ADR-030.
16    pub trait ActivationAxis: AxisExtension {
17        const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis";
18        /// Per-impl structural output-byte hint. Per ADR-060 the
19        /// foundation derives carrier widths from the application's
20        /// `HostBounds` structural-count primitives; the axis impl
21        /// carries no substrate-arbitrary byte-width cap.
22        const MAX_OUTPUT_BYTES: usize = 16;
23        /// Apply ReLU elementwise.
24        ///
25        /// # Errors
26        ///
27        /// Returns `ShapeViolation` on input/output length mismatch.
28        fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
29        /// Apply Q1.7 piecewise-linear sigmoid elementwise.
30        ///
31        /// # Errors
32        ///
33        /// Returns `ShapeViolation` on input/output length mismatch.
34        fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
35    }
36}
37
38fn arity_violation(constraint: &'static str) -> ShapeViolation {
39    ShapeViolation {
40        shape_iri: "https://uor.foundation/axis/ActivationAxisShape",
41        constraint_iri: constraint,
42        property_iri: "https://uor.foundation/axis/inputBytes",
43        expected_range: "https://uor.foundation/axis/ActivationInputArity",
44        min_count: 0,
45        max_count: 0,
46        kind: uor_foundation::ViolationKind::ValueCheck,
47    }
48}
49
50fn check_lens(input: &[u8], out: &[u8], n: usize) -> Result<(), ShapeViolation> {
51    if input.len() != n {
52        return Err(arity_violation(
53            "https://uor.foundation/axis/ActivationAxisShape/inputByteLength",
54        ));
55    }
56    if out.len() < n {
57        return Err(arity_violation(
58            "https://uor.foundation/axis/ActivationAxisShape/outputByteLength",
59        ));
60    }
61    Ok(())
62}
63
64/// Parametric element-wise activation kernels over an `N`-element `i8`
65/// vector.
66///
67/// `N` is the vector length. The same kernels (ReLU, Q1.7 sigmoid) are
68/// applied to every element independently; per-element determinism
69/// composes to per-vector determinism per ADR-030.
70///
71/// # `HostBounds` discipline
72///
73/// `N` is unconstrained at the axis level. Per [Wiki ADR-060][09] the
74/// foundation removed the `AXIS_OUTPUT_BYTES_MAX` cap: a length-`N`
75/// kernel's output flows through the source-polymorphic `TermValue`
76/// carrier, whose widths derive from the application's
77/// [`HostBounds`][uor_foundation::HostBounds] structural-count
78/// primitives via foundation `const fn`s — never a pinned byte-width
79/// literal. Specific `N` values (16, 32, 64, 128, 256, …) are picked
80/// by the application; this crate imposes no ceiling.
81///
82/// [09]: https://github.com/UOR-Foundation/UOR-Framework/wiki/09-Architecture-Decisions
83#[derive(Debug, Clone, Copy)]
84pub struct CpuI8VectorActivation<const N: usize>;
85
86impl<const N: usize> Default for CpuI8VectorActivation<N> {
87    fn default() -> Self {
88        Self
89    }
90}
91
92impl<const N: usize> ActivationAxis for CpuI8VectorActivation<N> {
93    const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis/CpuI8Vector";
94    const MAX_OUTPUT_BYTES: usize = N;
95
96    fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
97        // Structural well-formedness only — a zero-length vector is
98        // not a vector. Per ADR-060 there is no byte-width cap; the
99        // output flows through the source-polymorphic `TermValue`
100        // carrier sized from the application's `HostBounds` primitives.
101        if N == 0 {
102            return Err(arity_violation(
103                "https://uor.foundation/axis/ActivationAxisShape/nNonZero",
104            ));
105        }
106        check_lens(input, out, N)?;
107        for i in 0..N {
108            #[allow(clippy::cast_possible_wrap)]
109            let v = input[i] as i8;
110            out[i] = if v > 0 { input[i] } else { 0 };
111        }
112        Ok(N)
113    }
114
115    fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
116        if N == 0 {
117            return Err(arity_violation(
118                "https://uor.foundation/axis/ActivationAxisShape/nNonZero",
119            ));
120        }
121        check_lens(input, out, N)?;
122        for i in 0..N {
123            #[allow(clippy::cast_possible_wrap)]
124            let x = input[i] as i8;
125            let y: i8 = if x <= -64 {
126                0
127            } else if x >= 64 {
128                127
129            } else {
130                #[allow(clippy::cast_possible_truncation)]
131                {
132                    64i8 + (x / 2)
133                }
134            };
135            #[allow(clippy::cast_sign_loss)]
136            {
137                out[i] = y as u8;
138            }
139        }
140        Ok(N)
141    }
142}
143
144// ADR-052 generic-form companion.
145axis_extension_impl_for_activation_axis!(@generic CpuI8VectorActivation<N>, [const N: usize]);