prism_tensor/
activation.rs1#![allow(missing_docs)]
5
6use uor_foundation::enforcement::ShapeViolation;
7use uor_foundation::pipeline::AxisExtension;
8use uor_foundation_sdk::axis;
9
10axis! {
11 pub trait ActivationAxis: AxisExtension {
18 const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis";
19 const MAX_OUTPUT_BYTES: usize = 16;
21 fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
27 fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation>;
33 }
34}
35
36pub const MAX_ACTIVATION_LEN: usize = 256;
39
40fn arity_violation(constraint: &'static str) -> ShapeViolation {
41 ShapeViolation {
42 shape_iri: "https://uor.foundation/axis/ActivationAxisShape",
43 constraint_iri: constraint,
44 property_iri: "https://uor.foundation/axis/inputBytes",
45 expected_range: "https://uor.foundation/axis/ActivationInputArity",
46 min_count: 0,
47 max_count: 0,
48 kind: uor_foundation::ViolationKind::ValueCheck,
49 }
50}
51
52fn check_lens(input: &[u8], out: &[u8], n: usize) -> Result<(), ShapeViolation> {
53 if input.len() != n {
54 return Err(arity_violation(
55 "https://uor.foundation/axis/ActivationAxisShape/inputByteLength",
56 ));
57 }
58 if out.len() < n {
59 return Err(arity_violation(
60 "https://uor.foundation/axis/ActivationAxisShape/outputByteLength",
61 ));
62 }
63 Ok(())
64}
65
66#[derive(Debug, Clone, Copy)]
73pub struct CpuI8VectorActivation<const N: usize>;
74
75impl<const N: usize> Default for CpuI8VectorActivation<N> {
76 fn default() -> Self {
77 Self
78 }
79}
80
81impl<const N: usize> ActivationAxis for CpuI8VectorActivation<N> {
82 const AXIS_ADDRESS: &'static str = "https://uor.foundation/axis/ActivationAxis/CpuI8Vector";
83 const MAX_OUTPUT_BYTES: usize = N;
84
85 fn relu(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
86 if N == 0 || N > MAX_ACTIVATION_LEN {
87 return Err(arity_violation(
88 "https://uor.foundation/axis/ActivationAxisShape/nInRange",
89 ));
90 }
91 check_lens(input, out, N)?;
92 for i in 0..N {
93 #[allow(clippy::cast_possible_wrap)]
94 let v = input[i] as i8;
95 out[i] = if v > 0 { input[i] } else { 0 };
96 }
97 Ok(N)
98 }
99
100 fn sigmoid_q(input: &[u8], out: &mut [u8]) -> Result<usize, ShapeViolation> {
101 if N == 0 || N > MAX_ACTIVATION_LEN {
102 return Err(arity_violation(
103 "https://uor.foundation/axis/ActivationAxisShape/nInRange",
104 ));
105 }
106 check_lens(input, out, N)?;
107 for i in 0..N {
108 #[allow(clippy::cast_possible_wrap)]
109 let x = input[i] as i8;
110 let y: i8 = if x <= -64 {
111 0
112 } else if x >= 64 {
113 127
114 } else {
115 #[allow(clippy::cast_possible_truncation)]
116 {
117 64i8 + (x / 2)
118 }
119 };
120 #[allow(clippy::cast_sign_loss)]
121 {
122 out[i] = y as u8;
123 }
124 }
125 Ok(N)
126 }
127}
128
129axis_extension_impl_for_activation_axis!(@generic CpuI8VectorActivation<N>, [const N: usize]);
131
132pub type CpuI8VectorActivation16 = CpuI8VectorActivation<16>;
134pub type CpuI8VectorActivation32 = CpuI8VectorActivation<32>;
136pub type CpuI8VectorActivation64 = CpuI8VectorActivation<64>;
138pub type CpuI8VectorActivation128 = CpuI8VectorActivation<128>;
140pub type CpuI8VectorActivation256 = CpuI8VectorActivation<256>;