zenu_matrix/operation/
relu.rs

1#[cfg(feature = "nvidia")]
2use std::any::TypeId;
3
4use crate::{
5    device::{cpu::Cpu, Device},
6    dim::{default_stride, DimTrait},
7    index::Index0D,
8    matrix::{Matrix, Ref, Repr},
9    num::Num,
10};
11
12pub trait ReluOps {
13    fn relu<T: Num>(
14        input: *const T,
15        output: *mut T,
16        alpha: T,
17        size: usize,
18        input_stride: usize,
19        output_stride: usize,
20    );
21    fn relu_backward_mask<T: Num>(
22        input: *const T,
23        mask: *mut T,
24        alpha: T,
25        size: usize,
26        input_stride: usize,
27        mask_stride: usize,
28    );
29}
30
31impl ReluOps for Cpu {
32    #[expect(clippy::not_unsafe_ptr_arg_deref)]
33    fn relu<T: Num>(
34        input: *const T,
35        output: *mut T,
36        alpha: T,
37        size: usize,
38        input_stride: usize,
39        output_stride: usize,
40    ) {
41        unsafe {
42            if input_stride == 1 && output_stride == 1 {
43                for i in 0..size {
44                    *output.add(i) = if *input.add(i) > T::zero() {
45                        *input.add(i)
46                    } else {
47                        alpha * *input.add(i)
48                    };
49                }
50            } else {
51                for i in 0..size {
52                    *output.add(i * output_stride) = if *input.add(i * input_stride) > T::zero() {
53                        *input.add(i * input_stride)
54                    } else {
55                        alpha * *input.add(i * input_stride)
56                    };
57                }
58            }
59        }
60    }
61
62    #[expect(clippy::not_unsafe_ptr_arg_deref)]
63    fn relu_backward_mask<T: Num>(
64        input: *const T,
65        mask: *mut T,
66        alpha: T,
67        size: usize,
68        input_stride: usize,
69        mask_stride: usize,
70    ) {
71        unsafe {
72            if input_stride == 1 && mask_stride == 1 {
73                for i in 0..size {
74                    *mask.add(i) = if *input.add(i) > T::zero() {
75                        T::one()
76                    } else {
77                        alpha * T::minus_one()
78                    };
79                }
80            } else {
81                for i in 0..size {
82                    *mask.add(i * mask_stride) = if *input.add(i * input_stride) > T::zero() {
83                        T::one()
84                    } else {
85                        alpha * T::minus_one()
86                    };
87                }
88            }
89        }
90    }
91}
92
93#[cfg(feature = "nvidia")]
94use crate::device::nvidia::Nvidia;
95
96#[cfg(feature = "nvidia")]
97use zenu_cuda::kernel::activation::{relu, relu_backward_mask};
98
99#[cfg(feature = "nvidia")]
100impl ReluOps for Nvidia {
101    fn relu<T: Num>(
102        input: *const T,
103        output: *mut T,
104        alpha: T,
105        size: usize,
106        input_stride: usize,
107        output_stride: usize,
108    ) {
109        if TypeId::of::<T>() == TypeId::of::<f32>() {
110            let alpha = alpha.to_f32().unwrap();
111            relu(
112                input.cast_mut().cast::<f32>(),
113                output.cast(),
114                alpha,
115                size,
116                input_stride,
117                output_stride,
118            );
119        } else if TypeId::of::<T>() == TypeId::of::<f64>() {
120            let alpha = alpha.to_f64().unwrap();
121            relu(
122                input.cast_mut().cast::<f64>(),
123                output.cast(),
124                alpha,
125                size,
126                input_stride,
127                output_stride,
128            );
129        } else {
130            panic!("Unsupported data type");
131        }
132    }
133
134    fn relu_backward_mask<T: Num>(
135        input: *const T,
136        mask: *mut T,
137        alpha: T,
138        size: usize,
139        input_stride: usize,
140        mask_stride: usize,
141    ) {
142        if TypeId::of::<T>() == TypeId::of::<f32>() {
143            // let alpha: f32 = unsafe { *(&alpha as *const T as *const f32) };
144            let alpha = alpha.to_f32().unwrap();
145            relu_backward_mask(
146                input.cast_mut().cast(),
147                mask.cast(),
148                alpha,
149                size,
150                input_stride,
151                mask_stride,
152            );
153        } else if TypeId::of::<T>() == TypeId::of::<f64>() {
154            let alpha = alpha.to_f64().unwrap();
155            relu_backward_mask(
156                input.cast_mut().cast(),
157                mask.cast(),
158                alpha,
159                size,
160                input_stride,
161                mask_stride,
162            );
163        } else {
164            panic!("Unsupported data type");
165        }
166    }
167}
168
169impl<T: Num, S: DimTrait, D: Device> Matrix<Ref<&mut T>, S, D> {
170    #[expect(clippy::missing_panics_doc)]
171    pub fn relu<R: Repr<Item = T>, SO: DimTrait>(&self, other: &Matrix<R, SO, D>, alpha: T) {
172        assert!(
173            self.shape().slice() == other.shape().slice(),
174            "shape mismatch"
175        );
176
177        let len = self.shape().len();
178        let is_self_default_stride = self.shape_stride().is_default_stride();
179        let is_other_default_stride = other.shape_stride().is_default_stride();
180
181        if len == 0 {
182            D::relu(other.as_ptr(), self.as_mut_ptr(), alpha, 1, 1, 1);
183        } else if is_self_default_stride && is_other_default_stride {
184            let num_elm = self.shape().num_elm();
185            let stride_self = self.stride()[len - 1];
186            let stride_other = other.stride()[len - 1];
187            let self_ptr = self.as_mut_ptr();
188            let other_ptr = other.as_ptr();
189            D::relu(
190                other_ptr,
191                self_ptr,
192                alpha,
193                num_elm,
194                stride_other,
195                stride_self,
196            );
197        } else {
198            for i in 0..self.shape()[0] {
199                self.index_axis_mut_dyn(Index0D::new(i))
200                    .relu(&other.index_axis_dyn(Index0D::new(i)), alpha);
201            }
202        }
203    }
204
205    #[expect(clippy::missing_panics_doc)]
206    pub fn relu_backward_mask<R: Repr<Item = T>, SO: DimTrait>(
207        &self,
208        other: &Matrix<R, SO, D>,
209        alpha: T,
210    ) {
211        assert!(
212            self.shape().slice() == other.shape().slice(),
213            "shape mismatch"
214        );
215
216        let len = self.shape().len();
217        let is_self_default_stride = self.stride() == default_stride(self.shape());
218        let is_other_default_stride = other.stride() == default_stride(other.shape());
219
220        if len == 0 {
221            D::relu_backward_mask(other.as_ptr(), self.as_mut_ptr(), alpha, 1, 1, 1);
222        } else if is_self_default_stride && is_other_default_stride {
223            let num_elm = self.shape().num_elm();
224            let stride_self = self.stride()[len - 1];
225            let stride_other = other.stride()[len - 1];
226            let self_ptr = self.as_mut_ptr();
227            let other_ptr = other.as_ptr();
228            D::relu_backward_mask(
229                other_ptr,
230                self_ptr,
231                alpha,
232                num_elm,
233                stride_other,
234                stride_self,
235            );
236        } else {
237            for i in 0..self.shape()[0] {
238                self.index_axis_mut_dyn(Index0D::new(i))
239                    .relu_backward_mask(&other.index_axis_dyn(Index0D::new(i)), alpha);
240            }
241        }
242    }
243}
244
245#[cfg(test)]
246mod relu {
247    use zenu_test::{assert_mat_eq_epsilon, run_mat_test};
248
249    use crate::{
250        device::Device,
251        dim::DimDyn,
252        matrix::{Matrix, Owned},
253    };
254
255    fn relu<D: Device>() {
256        let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, -1.0, 0.0, 2.0], [2, 2]);
257        let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 2]);
258        y.to_ref_mut().relu(&x.to_ref(), 0.0);
259        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 0.0, 0.0, 2.0], [2, 2]);
260        assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
261    }
262    run_mat_test!(relu, relu_cpu, relu_nvidia);
263
264    fn relu_backward_mask<D: Device>() {
265        let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, -1.0, 0.0, 2.0], [2, 2]);
266        let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([2, 2]);
267        y.to_ref_mut().relu_backward_mask(&x.to_ref(), 0.0);
268        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0, 0.0, 0.0, 1.0], [2, 2]);
269        assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
270    }
271    run_mat_test!(
272        relu_backward_mask,
273        relu_backward_mask_cpu,
274        relu_backward_mask_nvidia
275    );
276
277    fn relu_0d<D: Device>() {
278        let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
279        let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([]);
280        y.to_ref_mut().relu(&x.to_ref(), 0.0);
281        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
282        assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
283    }
284    run_mat_test!(relu_0d, relu_0d_cpu, relu_0d_nvidia);
285
286    fn relu_backward_mask_0d<D: Device>() {
287        let x = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
288        let mut y = Matrix::<Owned<f32>, DimDyn, D>::zeros([]);
289        y.to_ref_mut().relu_backward_mask(&x.to_ref(), 0.0);
290        let ans = Matrix::<Owned<f32>, DimDyn, D>::from_vec(vec![1.0], []);
291        assert_mat_eq_epsilon!(y.to_ref(), ans.to_ref(), 1.0e-6);
292    }
293    run_mat_test!(
294        relu_backward_mask_0d,
295        relu_backward_mask_0d_cpu,
296        relu_backward_mask_0d_nvidia
297    );
298}