1use crate::{
2 device::{cpu::Cpu, Device, DeviceBase},
3 dim::{DimDyn, DimTrait},
4 index::Index0D,
5 matrix::{Matrix, Owned, Ref, Repr},
6 num::Num,
7};
8
9#[cfg(feature = "nvidia")]
10use crate::device::nvidia::Nvidia;
11
12#[cfg(feature = "nvidia")]
13use zenu_cuda::kernel::{clip, clip_assign, clip_backward, clip_backward_assign};
14
15pub trait ClipOps {
16 fn clip<T: Num>(
17 input: *const T,
18 output: *mut T,
19 size: usize,
20 stride_in: usize,
21 stride_out: usize,
22 min: T,
23 max: T,
24 );
25 fn clip_assign<T: Num>(input: *mut T, size: usize, stride: usize, min: T, max: T);
26 fn clip_backward<T: Num>(
27 input: *const T,
28 mask: *mut T,
29 max: T,
30 min: T,
31 size: usize,
32 stride_in: usize,
33 stride_out: usize,
34 );
35 fn clip_backward_assign<T: Num>(mask: *mut T, max: T, min: T, size: usize, stride: usize);
36}
37
38impl ClipOps for Cpu {
39 #[expect(clippy::not_unsafe_ptr_arg_deref)]
40 fn clip<T: Num>(
41 input: *const T,
42 output: *mut T,
43 size: usize,
44 stride_in: usize,
45 stride_out: usize,
46 min: T,
47 max: T,
48 ) {
49 let input = unsafe { std::slice::from_raw_parts(input, size * stride_in) };
50 let output = unsafe { std::slice::from_raw_parts_mut(output, size * stride_out) };
51 for i in 0..size {
52 let mut x = input[i * stride_in];
53 if x < min {
54 x = min;
55 } else if x > max {
56 x = max;
57 }
58 output[i * stride_out] = x;
59 }
60 }
61
62 #[expect(clippy::not_unsafe_ptr_arg_deref)]
63 fn clip_assign<T: Num>(input: *mut T, size: usize, stride: usize, min: T, max: T) {
64 let input = unsafe { std::slice::from_raw_parts_mut(input, size * stride) };
65 for i in 0..size {
66 let mut x = input[i * stride];
67 if x < min {
68 x = min;
69 } else if x > max {
70 x = max;
71 }
72 input[i * stride] = x;
73 }
74 }
75
76 #[expect(clippy::not_unsafe_ptr_arg_deref)]
77 fn clip_backward<T: Num>(
78 input: *const T,
79 mask: *mut T,
80 max: T,
81 min: T,
82 size: usize,
83 stride_in: usize,
84 stride_out: usize,
85 ) {
86 let input = unsafe { std::slice::from_raw_parts(input, size * stride_in) };
87 let mask = unsafe { std::slice::from_raw_parts_mut(mask, size * stride_out) };
88 for i in 0..size {
89 let x = input[i * stride_in];
90 if x < min || x > max {
91 mask[i * stride_in] = T::zero();
92 } else {
93 mask[i * stride_in] = T::one();
94 }
95 }
96 }
97
98 #[expect(clippy::not_unsafe_ptr_arg_deref)]
99 fn clip_backward_assign<T: Num>(mask: *mut T, max: T, min: T, size: usize, stride: usize) {
100 let mask = unsafe { std::slice::from_raw_parts_mut(mask, size * stride) };
101 for i in 0..size {
102 let x = mask[i * stride];
103 if x < min || x > max {
104 mask[i * stride] = T::zero();
105 } else {
106 mask[i * stride] = T::one();
107 }
108 }
109 }
110}
111
112#[cfg(feature = "nvidia")]
113impl ClipOps for Nvidia {
114 fn clip<T: Num>(
115 input: *const T,
116 output: *mut T,
117 size: usize,
118 stride_in: usize,
119 stride_out: usize,
120 min: T,
121 max: T,
122 ) {
123 clip(input, output, size, stride_in, stride_out, min, max);
124 }
125
126 fn clip_assign<T: Num>(input: *mut T, size: usize, stride: usize, min: T, max: T) {
127 clip_assign(input, size, stride, min, max);
128 }
129
130 fn clip_backward<T: Num>(
131 input: *const T,
132 mask: *mut T,
133 max: T,
134 min: T,
135 size: usize,
136 stride_in: usize,
137 stride_out: usize,
138 ) {
139 clip_backward(
140 input.cast_mut(),
141 mask,
142 max,
143 min,
144 size,
145 stride_in,
146 stride_out,
147 );
148 }
149
150 fn clip_backward_assign<T: Num>(mask: *mut T, max: T, min: T, size: usize, stride: usize) {
151 clip_backward_assign(mask, max, min, size, stride);
152 }
153}
154
155fn clip_1d<T: Num, R: Repr<Item = T>, SI: DimTrait, SO: DimTrait, D: DeviceBase + ClipOps>(
156 input: &Matrix<R, SI, D>,
157 output: &Matrix<Ref<&mut T>, SO, D>,
158 min: T,
159 max: T,
160) {
161 let size = input.shape()[0];
162 let stride_in = input.stride()[0];
163 let stride_out = output.stride()[0];
164 let input_ptr = input.as_ptr();
165 let output_ptr = output.as_mut_ptr();
166 D::clip(input_ptr, output_ptr, size, stride_in, stride_out, min, max);
167}
168
169fn clip_assign_1d<T: Num, S: DimTrait, D: DeviceBase + ClipOps>(
170 input: &Matrix<Ref<&mut T>, S, D>,
171 min: T,
172 max: T,
173) {
174 let size = input.shape()[0];
175 let stride = input.stride()[0];
176 let input_ptr = input.as_mut_ptr();
177 D::clip_assign(input_ptr, size, stride, min, max);
178}
179
180fn clip_backward_1d<T: Num, S: DimTrait, D: Device>(
181 input: &Matrix<Ref<&T>, S, D>,
182 mask: &Matrix<Ref<&mut T>, S, D>,
183 min: T,
184 max: T,
185) {
186 let size = input.shape()[0];
187 let stride_in = input.stride()[0];
188 let stride_out = mask.stride()[0];
189 let input_ptr = input.as_ptr();
190 let mask_ptr = mask.as_mut_ptr();
191 D::clip_backward(input_ptr, mask_ptr, max, min, size, stride_in, stride_out);
192}
193
194fn clip_backward_assign_1d<T: Num, S: DimTrait, D: Device>(
195 mask: &Matrix<Ref<&mut T>, S, D>,
196 min: T,
197 max: T,
198) {
199 let size = mask.shape()[0];
200 let stride = mask.stride()[0];
201 let mask_ptr = mask.as_mut_ptr();
202 D::clip_backward_assign(mask_ptr, max, min, size, stride);
203}
204
205fn clip_inner<T: Num, D: DeviceBase + ClipOps>(
206 input: &Matrix<Ref<&T>, DimDyn, D>,
207 output: &Matrix<Ref<&mut T>, DimDyn, D>,
208 min: T,
209 max: T,
210) {
211 if input.shape().len() == 1 {
212 clip_1d(input, output, min, max);
213 } else if input.shape().is_empty() {
214 unimplemented!();
215 } else {
216 for i in 0..(input.shape()[0]) {
217 clip_inner(
218 &input.index_axis_dyn(Index0D::new(i)),
219 &output.index_axis_mut_dyn(Index0D::new(i)),
220 min,
221 max,
222 );
223 }
224 }
225}
226
227fn clip_assign_inner<T: Num, D: DeviceBase + ClipOps>(
228 input: &Matrix<Ref<&mut T>, DimDyn, D>,
229 min: T,
230 max: T,
231) {
232 if input.shape().len() == 1 {
233 clip_assign_1d(input, min, max);
234 } else if input.shape().is_empty() {
235 unimplemented!();
236 } else {
237 for i in 0..(input.shape()[0]) {
238 clip_assign_inner(&input.index_axis_mut(Index0D::new(i)), min, max);
239 }
240 }
241}
242
243fn clip_backward_inner<T: Num, D: Device>(
244 input: &Matrix<Ref<&T>, DimDyn, D>,
245 mask: &Matrix<Ref<&mut T>, DimDyn, D>,
246 min: T,
247 max: T,
248) {
249 if input.shape().len() == 1 {
250 clip_backward_1d(input, mask, min, max);
251 } else if input.shape().is_empty() {
252 unimplemented!();
253 } else {
254 for i in 0..(input.shape()[0]) {
255 clip_backward_inner(
256 &input.index_axis_dyn(Index0D::new(i)),
257 &mask.index_axis_mut_dyn(Index0D::new(i)),
258 min,
259 max,
260 );
261 }
262 }
263}
264
265fn clip_backward_assign_inner<T: Num, D: Device>(
266 mask: &Matrix<Ref<&mut T>, DimDyn, D>,
267 min: T,
268 max: T,
269) {
270 if mask.shape().len() == 1 {
271 clip_backward_assign_1d(mask, min, max);
272 } else if mask.shape().is_empty() {
273 unimplemented!();
274 } else {
275 for i in 0..(mask.shape()[0]) {
276 clip_backward_assign_inner(&mask.index_axis_mut(Index0D::new(i)), min, max);
277 }
278 }
279}
280
281impl<R: Repr, S: DimTrait, D: Device> Matrix<R, S, D> {
282 pub fn clip(&self, min: R::Item, max: R::Item) -> Matrix<Owned<R::Item>, S, D> {
283 let mut output = Matrix::<_, S, D>::alloc_like(self);
284 let s_v = self.to_ref().into_dyn_dim();
285
286 clip_inner(&s_v, &output.to_ref_mut().into_dyn_dim(), min, max);
287
288 output
289 }
290
291 pub fn clip_backward_mask(&self, min: R::Item, max: R::Item) -> Matrix<Owned<R::Item>, S, D> {
292 let mut output = Matrix::<Owned<R::Item>, S, D>::alloc_like(self);
293 let s_v = self.to_ref().into_dyn_dim();
294 {
295 let mask_v = output.to_ref_mut().into_dyn_dim();
296 clip_backward_inner(&s_v, &mask_v, min, max);
297 }
298 output
299 }
300}
301
302impl<T: Num, D: Device> Matrix<Ref<&mut T>, DimDyn, D> {
303 pub fn clip_assign(&self, min: T, max: T) {
304 clip_assign_inner(self, min, max);
305 }
306
307 pub fn clip_backward_assign_mask(&self, min: T, max: T) {
308 clip_backward_assign_inner(self, min, max);
309 }
310}
311
312#[cfg(test)]
313mod clip {
314 #![expect(clippy::float_cmp)]
315
316 use crate::{
317 device::Device,
318 dim::DimDyn,
319 matrix::{Matrix, Owned},
320 };
321
322 fn clip_1d<D: Device>() {
323 let mut a: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [4]);
324 let b = a.clip(2.0, 3.0);
325 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![2.0, 2.0, 3.0, 3.0], [4]);
326 let diff = b - ans.to_ref();
327 let diff_asum = diff.asum();
328 assert_eq!(diff_asum, 0.0);
329
330 a.to_ref_mut().clip_assign(2.0, 3.0);
331 let diff = a - ans.to_ref();
332 let diff_asum = diff.asum();
333 assert_eq!(diff_asum, 0.0);
334 }
335 #[test]
336 fn clip_1d_cpu() {
337 clip_1d::<crate::device::cpu::Cpu>();
338 }
339 #[cfg(feature = "nvidia")]
340 #[test]
341 fn clip_1d_nvidia() {
342 clip_1d::<crate::device::nvidia::Nvidia>();
343 }
344
345 fn clip_2d<D: Device>() {
346 let mut a: Matrix<Owned<f32>, DimDyn, D> =
347 Matrix::from_vec(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
348 let b = a.clip(2.0, 3.0);
349 let ans: Matrix<Owned<f32>, DimDyn, D> = Matrix::from_vec(vec![2.0, 2.0, 3.0, 3.0], [2, 2]);
350 let diff = b - ans.to_ref();
351 let diff_asum = diff.asum();
352 assert_eq!(diff_asum, 0.0);
353
354 a.to_ref_mut().clip_assign(2.0, 3.0);
355 let ans: Matrix<_, DimDyn, _> = Matrix::from_vec(vec![2.0, 2.0, 3.0, 3.0], [2, 2]);
356 let diff = a - ans;
357 let diff_asum = diff.asum();
358 assert_eq!(diff_asum, 0.0);
359 }
360 #[test]
361 fn clip_2d_cpu() {
362 clip_2d::<crate::device::cpu::Cpu>();
363 }
364 #[cfg(feature = "nvidia")]
365 #[test]
366 fn clip_2d_nvidia() {
367 clip_2d::<crate::device::nvidia::Nvidia>();
368 }
369}