zenu_matrix/operation/
clip.rs

1use crate::{
2    device::{cpu::Cpu, Device, DeviceBase},
3    dim::{DimDyn, DimTrait},
4    index::Index0D,
5    matrix::{Matrix, Owned, Ref, Repr},
6    num::Num,
7};
8
9#[cfg(feature = "nvidia")]
10use crate::device::nvidia::Nvidia;
11
12#[cfg(feature = "nvidia")]
13use zenu_cuda::kernel::{clip, clip_assign, clip_backward, clip_backward_assign};
14
15pub trait ClipOps {
16    fn clip<T: Num>(
17        input: *const T,
18        output: *mut T,
19        size: usize,
20        stride_in: usize,
21        stride_out: usize,
22        min: T,
23        max: T,
24    );
25    fn clip_assign<T: Num>(input: *mut T, size: usize, stride: usize, min: T, max: T);
26    fn clip_backward<T: Num>(
27        input: *const T,
28        mask: *mut T,
29        max: T,
30        min: T,
31        size: usize,
32        stride_in: usize,
33        stride_out: usize,
34    );
35    fn clip_backward_assign<T: Num>(mask: *mut T, max: T, min: T, size: usize, stride: usize);
36}
37
38impl ClipOps for Cpu {
39    #[expect(clippy::not_unsafe_ptr_arg_deref)]
40    fn clip<T: Num>(
41        input: *const T,
42        output: *mut T,
43        size: usize,
44        stride_in: usize,
45        stride_out: usize,
46        min: T,
47        max: T,
48    ) {
49        let input = unsafe { std::slice::from_raw_parts(input, size * stride_in) };
50        let output = unsafe { std::slice::from_raw_parts_mut(output, size * stride_out) };
51        for i in 0..size {
52            let mut x = input[i * stride_in];
53            if x < min {
54                x = min;
55            } else if x > max {
56                x = max;
57            }
58            output[i * stride_out] = x;
59        }
60    }
61
62    #[expect(clippy::not_unsafe_ptr_arg_deref)]
63    fn clip_assign<T: Num>(input: *mut T, size: usize, stride: usize, min: T, max: T) {
64        let input = unsafe { std::slice::from_raw_parts_mut(input, size * stride) };
65        for i in 0..size {
66            let mut x = input[i * stride];
67            if x < min {
68                x = min;
69            } else if x > max {
70                x = max;
71            }
72            input[i * stride] = x;
73        }
74    }
75
76    #[expect(clippy::not_unsafe_ptr_arg_deref)]
77    fn clip_backward<T: Num>(
78        input: *const T,
79        mask: *mut T,
80        max: T,
81        min: T,
82        size: usize,
83        stride_in: usize,
84        stride_out: usize,
85    ) {
86        let input = unsafe { std::slice::from_raw_parts(input, size * stride_in) };
87        let mask = unsafe { std::slice::from_raw_parts_mut(mask, size * stride_out) };
88        for i in 0..size {
89            let x = input[i * stride_in];
90            if x < min || x > max {
91                mask[i * stride_in] = T::zero();
92            } else {
93                mask[i * stride_in] = T::one();
94            }
95        }
96    }
97
98    #[expect(clippy::not_unsafe_ptr_arg_deref)]
99    fn clip_backward_assign<T: Num>(mask: *mut T, max: T, min: T, size: usize, stride: usize) {
100        let mask = unsafe { std::slice::from_raw_parts_mut(mask, size * stride) };
101        for i in 0..size {
102            let x = mask[i * stride];
103            if x < min || x > max {
104                mask[i * stride] = T::zero();
105            } else {
106                mask[i * stride] = T::one();
107            }
108        }
109    }
110}
111
112#[cfg(feature = "nvidia")]
113impl ClipOps for Nvidia {
114    fn clip<T: Num>(
115        input: *const T,
116        output: *mut T,
117        size: usize,
118        stride_in: usize,
119        stride_out: usize,
120        min: T,
121        max: T,
122    ) {
123        clip(input, output, size, stride_in, stride_out, min, max);
124    }
125
126    fn clip_assign<T: Num>(input: *mut T, size: usize, stride: usize, min: T, max: T) {
127        clip_assign(input, size, stride, min, max);
128    }
129
130    fn clip_backward<T: Num>(
131        input: *const T,
132        mask: *mut T,
133        max: T,
134        min: T,
135        size: usize,
136        stride_in: usize,
137        stride_out: usize,
138    ) {
139        clip_backward(
140            input.cast_mut(),
141            mask,
142            max,
143            min,
144            size,
145            stride_in,
146            stride_out,
147        );
148    }
149
150    fn clip_backward_assign<T: Num>(mask: *mut T, max: T, min: T, size: usize, stride: usize) {
151        clip_backward_assign(mask, max, min, size, stride);
152    }
153}
154
155fn clip_1d<T: Num, R: Repr<Item = T>, SI: DimTrait, SO: DimTrait, D: DeviceBase + ClipOps>(
156    input: &Matrix<R, SI, D>,
157    output: &Matrix<Ref<&mut T>, SO, D>,
158    min: T,
159    max: T,
160) {
161    let size = input.shape()[0];
162    let stride_in = input.stride()[0];
163    let stride_out = output.stride()[0];
164    let input_ptr = input.as_ptr();
165    let output_ptr = output.as_mut_ptr();
166    D::clip(input_ptr, output_ptr, size, stride_in, stride_out, min, max);
167}
168
169fn clip_assign_1d<T: Num, S: DimTrait, D: DeviceBase + ClipOps>(
170    input: &Matrix<Ref<&mut T>, S, D>,
171    min: T,
172    max: T,
173) {
174    let size = input.shape()[0];
175    let stride = input.stride()[0];
176    let input_ptr = input.as_mut_ptr();
177    D::clip_assign(input_ptr, size, stride, min, max);
178}
179
180fn clip_backward_1d<T: Num, S: DimTrait, D: Device>(
181    input: &Matrix<Ref<&T>, S, D>,
182    mask: &Matrix<Ref<&mut T>, S, D>,
183    min: T,
184    max: T,
185) {
186    let size = input.shape()[0];
187    let stride_in = input.stride()[0];
188    let stride_out = mask.stride()[0];
189    let input_ptr = input.as_ptr();
190    let mask_ptr = mask.as_mut_ptr();
191    D::clip_backward(input_ptr, mask_ptr, max, min, size, stride_in, stride_out);
192}
193
194fn clip_backward_assign_1d<T: Num, S: DimTrait, D: Device>(
195    mask: &Matrix<Ref<&mut T>, S, D>,
196    min: T,
197    max: T,
198) {
199    let size = mask.shape()[0];
200    let stride = mask.stride()[0];
201    let mask_ptr = mask.as_mut_ptr();
202    D::clip_backward_assign(mask_ptr, max, min, size, stride);
203}
204
205fn clip_inner<T: Num, D: DeviceBase + ClipOps>(
206    input: &Matrix<Ref<&T>, DimDyn, D>,
207    output: &Matrix<Ref<&mut T>, DimDyn, D>,
208    min: T,
209    max: T,
210) {
211    if input.shape().len() == 1 {
212        clip_1d(input, output, min, max);
213    } else if input.shape().is_empty() {
214        unimplemented!();
215    } else {
216        for i in 0..(input.shape()[0]) {
217            clip_inner(
218                &input.index_axis_dyn(Index0D::new(i)),
219                &output.index_axis_mut_dyn(Index0D::new(i)),
220                min,
221                max,
222            );
223        }
224    }
225}
226
227fn clip_assign_inner<T: Num, D: DeviceBase + ClipOps>(
228    input: &Matrix<Ref<&mut T>, DimDyn, D>,
229    min: T,
230    max: T,
231) {
232    if input.shape().len() == 1 {
233        clip_assign_1d(input, min, max);
234    } else if input.shape().is_empty() {
235        unimplemented!();
236    } else {
237        for i in 0..(input.shape()[0]) {
238            clip_assign_inner(&input.index_axis_mut(Index0D::new(i)), min, max);
239        }
240    }
241}
242
243fn clip_backward_inner<T: Num, D: Device>(
244    input: &Matrix<Ref<&T>, DimDyn, D>,
245    mask: &Matrix<Ref<&mut T>, DimDyn, D>,
246    min: T,
247    max: T,
248) {
249    if input.shape().len() == 1 {
250        clip_backward_1d(input, mask, min, max);
251    } else if input.shape().is_empty() {
252        unimplemented!();
253    } else {
254        for i in 0..(input.shape()[0]) {
255            clip_backward_inner(
256                &input.index_axis_dyn(Index0D::new(i)),
257                &mask.index_axis_mut_dyn(Index0D::new(i)),
258                min,
259                max,
260            );
261        }
262    }
263}
264
265fn clip_backward_assign_inner<T: Num, D: Device>(
266    mask: &Matrix<Ref<&mut T>, DimDyn, D>,
267    min: T,
268    max: T,
269) {
270    if mask.shape().len() == 1 {
271        clip_backward_assign_1d(mask, min, max);
272    } else if mask.shape().is_empty() {
273        unimplemented!();
274    } else {
275        for i in 0..(mask.shape()[0]) {
276            clip_backward_assign_inner(&mask.index_axis_mut(Index0D::new(i)), min, max);
277        }
278    }
279}
280
281impl<R: Repr, S: DimTrait, D: Device> Matrix<R, S, D> {
282    pub fn clip(&self, min: R::Item, max: R::Item) -> Matrix<Owned<R::Item>, S, D> {
283        let mut output = Matrix::<_, S, D>::alloc_like(self);
284        let s_v = self.to_ref().into_dyn_dim();
285
286        clip_inner(&s_v, &output.to_ref_mut().into_dyn_dim(), min, max);
287
288        output
289    }
290
291    pub fn clip_backward_mask(&self, min: R::Item, max: R::Item) -> Matrix<Owned<R::Item>, S, D> {
292        let mut output = Matrix::<Owned<R::Item>, S, D>::alloc_like(self);
293        let s_v = self.to_ref().into_dyn_dim();
294        {
295            let mask_v = output.to_ref_mut().into_dyn_dim();
296            clip_backward_inner(&s_v, &mask_v, min, max);
297        }
298        output
299    }
300}
301
302impl<T: Num, D: Device> Matrix<Ref<&mut T>, DimDyn, D> {
303    pub fn clip_assign(&self, min: T, max: T) {
304        clip_assign_inner(self, min, max);
305    }
306
307    pub fn clip_backward_assign_mask(&self, min: T, max: T) {
308        clip_backward_assign_inner(self, min, max);
309    }
310}
311
312#[cfg(test)]
313mod clip {
314    #![expect(clippy::float_cmp)]
315
316    use crate::{
317        device::Device,
318        dim::DimDyn,
319        matrix::{Matrix, Owned},
320    };
321
322    fn clip_1d<D: Device>() {
323        let mut a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [4]);
324        let b = a.clip(2.0, 3.0);
325        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![2.0, 2.0, 3.0, 3.0], [4]);
326        let diff = b - ans.to_ref();
327        let diff_asum = diff.asum();
328        assert_eq!(diff_asum, 0.0);
329
330        a.to_ref_mut().clip_assign(2.0, 3.0);
331        let diff = a - ans.to_ref();
332        let diff_asum = diff.asum();
333        assert_eq!(diff_asum, 0.0);
334    }
335    #[test]
336    fn clip_1d_cpu() {
337        clip_1d::<crate::device::cpu::Cpu>();
338    }
339    #[cfg(feature = "nvidia")]
340    #[test]
341    fn clip_1d_nvidia() {
342        clip_1d::<crate::device::nvidia::Nvidia>();
343    }
344
345    fn clip_2d<D: Device>() {
346        let mut a: Matrix<Owned<f32>, DimDyn, D> =
347            Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
348        let b = a.clip(2.0, 3.0);
349        let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![2.0, 2.0, 3.0, 3.0], [2, 2]);
350        let diff = b - ans.to_ref();
351        let diff_asum = diff.asum();
352        assert_eq!(diff_asum, 0.0);
353
354        a.to_ref_mut().clip_assign(2.0, 3.0);
355        let ans: Matrix<_, DimDyn, _> = Matrix::from_vec(vec![2.0, 2.0, 3.0, 3.0], [2, 2]);
356        let diff = a - ans;
357        let diff_asum = diff.asum();
358        assert_eq!(diff_asum, 0.0);
359    }
360    #[test]
361    fn clip_2d_cpu() {
362        clip_2d::<crate::device::cpu::Cpu>();
363    }
364    #[cfg(feature = "nvidia")]
365    #[test]
366    fn clip_2d_nvidia() {
367        clip_2d::<crate::device::nvidia::Nvidia>();
368    }
369}