1#[cfg(feature = "nvidia")]
2use std::any::TypeId;
3
4use crate::{
5 device::{cpu::Cpu, Device},
6 dim::{default_stride, DimTrait},
7 index::Index0D,
8 matrix::{Matrix, Ref, Repr},
9 num::Num,
10};
11
12pub trait ReluOps {
13 fn relu<T: Num>(
14 input: *const T,
15 output: *mut T,
16 alpha: T,
17 size: usize,
18 input_stride: usize,
19 output_stride: usize,
20 );
21 fn relu_backward_mask<T: Num>(
22 input: *const T,
23 mask: *mut T,
24 alpha: T,
25 size: usize,
26 input_stride: usize,
27 mask_stride: usize,
28 );
29}
30
31impl ReluOps for Cpu {
32 #[expect(clippy::not_unsafe_ptr_arg_deref)]
33 fn relu<T: Num>(
34 input: *const T,
35 output: *mut T,
36 alpha: T,
37 size: usize,
38 input_stride: usize,
39 output_stride: usize,
40 ) {
41 unsafe {
42 if input_stride == 1 && output_stride == 1 {
43 for i in 0..size {
44 *output.add(i) = if *input.add(i) > T::zero() {
45 *input.add(i)
46 } else {
47 alpha * *input.add(i)
48 };
49 }
50 } else {
51 for i in 0..size {
52 *output.add(i * output_stride) = if *input.add(i * input_stride) > T::zero() {
53 *input.add(i * input_stride)
54 } else {
55 alpha * *input.add(i * input_stride)
56 };
57 }
58 }
59 }
60 }
61
62 #[expect(clippy::not_unsafe_ptr_arg_deref)]
63 fn relu_backward_mask<T: Num>(
64 input: *const T,
65 mask: *mut T,
66 alpha: T,
67 size: usize,
68 input_stride: usize,
69 mask_stride: usize,
70 ) {
71 unsafe {
72 if input_stride == 1 && mask_stride == 1 {
73 for i in 0..size {
74 *mask.add(i) = if *input.add(i) > T::zero() {
75 T::one()
76 } else {
77 alpha * T::minus_one()
78 };
79 }
80 } else {
81 for i in 0..size {
82 *mask.add(i * mask_stride) = if *input.add(i * input_stride) > T::zero() {
83 T::one()
84 } else {
85 alpha * T::minus_one()
86 };
87 }
88 }
89 }
90 }
91}
92
93#[cfg(feature = "nvidia")]
94use crate::device::nvidia::Nvidia;
95
96#[cfg(feature = "nvidia")]
97use zenu_cuda::kernel::activation::{relu, relu_backward_mask};
98
99#[cfg(feature = "nvidia")]
100impl ReluOps for Nvidia {
101 fn relu<T: Num>(
102 input: *const T,
103 output: *mut T,
104 alpha: T,
105 size: usize,
106 input_stride: usize,
107 output_stride: usize,
108 ) {
109 if TypeId::of::<T>() == TypeId::of::<f32>() {
110 let alpha = alpha.to_f32().unwrap();
111 relu(
112 input.cast_mut().cast::<f32>(),
113 output.cast(),
114 alpha,
115 size,
116 input_stride,
117 output_stride,
118 );
119 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
120 let alpha = alpha.to_f64().unwrap();
121 relu(
122 input.cast_mut().cast::<f64>(),
123 output.cast(),
124 alpha,
125 size,
126 input_stride,
127 output_stride,
128 );
129 } else {
130 panic!("Unsupported data type");
131 }
132 }
133
134 fn relu_backward_mask<T: Num>(
135 input: *const T,
136 mask: *mut T,
137 alpha: T,
138 size: usize,
139 input_stride: usize,
140 mask_stride: usize,
141 ) {
142 if TypeId::of::<T>() == TypeId::of::<f32>() {
143 let alpha = alpha.to_f32().unwrap();
145 relu_backward_mask(
146 input.cast_mut().cast(),
147 mask.cast(),
148 alpha,
149 size,
150 input_stride,
151 mask_stride,
152 );
153 } else if TypeId::of::<T>() == TypeId::of::<f64>() {
154 let alpha = alpha.to_f64().unwrap();
155 relu_backward_mask(
156 input.cast_mut().cast(),
157 mask.cast(),
158 alpha,
159 size,
160 input_stride,
161 mask_stride,
162 );
163 } else {
164 panic!("Unsupported data type");
165 }
166 }
167}
168
169impl<T: Num, S: DimTrait, D: Device> Matrix<Ref<&mut T>, S, D> {
170 #[expect(clippy::missing_panics_doc)]
171 pub fn relu<R: Repr<Item = T>, SO: DimTrait>(&self, other: &Matrix<R, SO, D>, alpha: T) {
172 assert!(
173 self.shape().slice() == other.shape().slice(),
174 "shape mismatch"
175 );
176
177 let len = self.shape().len();
178 let is_self_default_stride = self.shape_stride().is_default_stride();
179 let is_other_default_stride = other.shape_stride().is_default_stride();
180
181 if len == 0 {
182 D::relu(other.as_ptr(), self.as_mut_ptr(), alpha, 1, 1, 1);
183 } else if is_self_default_stride && is_other_default_stride {
184 let num_elm = self.shape().num_elm();
185 let stride_self = self.stride()[len - 1];
186 let stride_other = other.stride()[len - 1];
187 let self_ptr = self.as_mut_ptr();
188 let other_ptr = other.as_ptr();
189 D::relu(
190 other_ptr,
191 self_ptr,
192 alpha,
193 num_elm,
194 stride_other,
195 stride_self,
196 );
197 } else {
198 for i in 0..self.shape()[0] {
199 self.index_axis_mut_dyn(Index0D::new(i))
200 .relu(&other.index_axis_dyn(Index0D::new(i)), alpha);
201 }
202 }
203 }
204
205 #[expect(clippy::missing_panics_doc)]
206 pub fn relu_backward_mask<R: Repr<Item = T>, SO: DimTrait>(
207 &self,
208 other: &Matrix<R, SO, D>,
209 alpha: T,
210 ) {
211 assert!(
212 self.shape().slice() == other.shape().slice(),
213 "shape mismatch"
214 );
215
216 let len = self.shape().len();
217 let is_self_default_stride = self.stride() == default_stride(self.shape());
218 let is_other_default_stride = other.stride() == default_stride(other.shape());
219
220 if len == 0 {
221 D::relu_backward_mask(other.as_ptr(), self.as_mut_ptr(), alpha, 1, 1, 1);
222 } else if is_self_default_stride && is_other_default_stride {
223 let num_elm = self.shape().num_elm();
224 let stride_self = self.stride()[len - 1];
225 let stride_other = other.stride()[len - 1];
226 let self_ptr = self.as_mut_ptr();
227 let other_ptr = other.as_ptr();
228 D::relu_backward_mask(
229 other_ptr,
230 self_ptr,
231 alpha,
232 num_elm,
233 stride_other,
234 stride_self,
235 );
236 } else {
237 for i in 0..self.shape()[0] {
238 self.index_axis_mut_dyn(Index0D::new(i))
239 .relu_backward_mask(&other.index_axis_dyn(Index0D::new(i)), alpha);
240 }
241 }
242 }
243}
244
245#[cfg(test)]
246mod relu {
247 use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
248
249 use crate::{
250 device::Device,
251 dim::DimDyn,
252 matrix::{Matrix, Owned},
253 };
254
255 fn relu<D: Device>() {
256 let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, -1.0, 0.0, 2.0], [2, 2]);
257 let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 2]);
258 y.to_ref_mut().relu(&x.to_ref(), 0.0);
259 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 0.0, 0.0, 2.0], [2, 2]);
260 assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
261 }
262 run_mat_test!(relu, relu_cpu, relu_nvidia);
263
264 fn relu_backward_mask<D: Device>() {
265 let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, -1.0, 0.0, 2.0], [2, 2]);
266 let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 2]);
267 y.to_ref_mut().relu_backward_mask(&x.to_ref(), 0.0);
268 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 0.0, 0.0, 1.0], [2, 2]);
269 assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
270 }
271 run_mat_test!(
272 relu_backward_mask,
273 relu_backward_mask_cpu,
274 relu_backward_mask_nvidia
275 );
276
277 fn relu_0d<D: Device>() {
278 let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
279 let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([]);
280 y.to_ref_mut().relu(&x.to_ref(), 0.0);
281 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
282 assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
283 }
284 run_mat_test!(relu_0d, relu_0d_cpu, relu_0d_nvidia);
285
286 fn relu_backward_mask_0d<D: Device>() {
287 let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
288 let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([]);
289 y.to_ref_mut().relu_backward_mask(&x.to_ref(), 0.0);
290 let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
291 assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
292 }
293 run_mat_test!(
294 relu_backward_mask_0d,
295 relu_backward_mask_0d_cpu,
296 relu_backward_mask_0d_nvidia
297 );
298}