1use crate::{
2 device::{cpu::Cpu, DeviceBase},
3 dim::{DimDyn, DimTrait},
4 matrix::{Matrix, Ref, Repr},
5 num::Num,
6 shape_stride::ShapeStride,
7};
8
9pub trait CopyBlas: DeviceBase {
10 fn copy_raw<T: Num>(n: usize, x: *const T, incx: usize, y: *mut T, incy: usize);
11}
12
13impl CopyBlas for Cpu {
14 #[expect(clippy::similar_names)]
15 fn copy_raw<T: Num>(n: usize, x: *const T, incx: usize, y: *mut T, incy: usize) {
16 extern crate openblas_src;
17 use cblas::{dcopy, scopy};
18 if T::is_f32() {
19 let x = unsafe { std::slice::from_raw_parts(x.cast(), n * incx) };
20 let y = unsafe { std::slice::from_raw_parts_mut(y.cast(), n * incy) };
21 unsafe {
22 scopy(
23 n.try_into().unwrap(),
24 x,
25 incx.try_into().unwrap(),
26 y,
27 incy.try_into().unwrap(),
28 );
29 }
30 } else {
31 let x = unsafe { std::slice::from_raw_parts(x.cast(), n * incx) };
32 let y = unsafe { std::slice::from_raw_parts_mut(y.cast(), n * incy) };
33 unsafe {
34 dcopy(
35 n.try_into().unwrap(),
36 x,
37 incx.try_into().unwrap(),
38 y,
39 incy.try_into().unwrap(),
40 );
41 }
42 }
43 }
44}
45
46#[cfg(feature = "nvidia")]
47use crate::device::nvidia::Nvidia;
48
49#[cfg(feature = "nvidia")]
50impl CopyBlas for Nvidia {
51 #[expect(clippy::similar_names)]
52 fn copy_raw<T: Num>(n: usize, x: *const T, incx: usize, y: *mut T, incy: usize) {
53 zenu_cuda::cublas::cublas_copy(n, x, incx, y, incy).unwrap();
54 }
55}
56
57#[expect(clippy::similar_names, clippy::needless_pass_by_value)]
58pub fn copy_unchecked<T, SA, SB, RB, D>(x: Matrix<Ref<&T>, SA, D>, y: Matrix<Ref<&mut T>, SB, D>)
59where
60 T: Num,
61 SA: DimTrait,
62 SB: DimTrait,
63 D: CopyBlas,
64{
65 let n = x.shape()[0];
66 let incx = x.stride()[0];
67 let incy = y.stride()[0];
68 let x = x.as_ptr();
69 let y = y.as_mut_ptr();
70 D::copy_raw(n, x, incx, y, incy);
71}
72
73fn get_max_shape_idx_of_apply_blas(a: ShapeStride<DimDyn>, b: ShapeStride<DimDyn>) -> usize {
74 let min_len = std::cmp::min(a.shape().len(), b.shape().len());
75 let a_len = a.shape().len();
76 let b_len = b.shape().len();
77
78 match min_len {
79 0 => 0,
80 1 => 1,
81 2 => {
82 let a_stride = a.stride();
83 let b_stride = b.stride();
84 let a_shape = a.shape();
85 let b_shape = b.shape();
86 let a_stride_part: DimDyn = a_stride.slice()[a_len - 2..].into();
87 let b_stride_part: DimDyn = b_stride.slice()[b_len - 2..].into();
88 let a_shape_part: DimDyn = a_shape.slice()[a_len - 2..].into();
89 let b_shape_part: DimDyn = b_shape.slice()[b_len - 2..].into();
90 let a_part = ShapeStride::new(a_shape_part, a_stride_part);
91 let b_part = ShapeStride::new(b_shape_part, b_stride_part);
92 if !(a_part.is_transposed() || b_part.is_transposed())
93 && a_part.is_contiguous()
94 && b_part.is_contiguous()
95 {
96 2
97 } else {
98 1
99 }
100 }
101 _ => {
102 let mut idx = 1;
103 for i in 2..=min_len {
104 let a_shape_part: DimDyn = a.shape().slice()[a_len - i..].into();
105 let b_shape_part: DimDyn = b.shape().slice()[b_len - i..].into();
106 let a_stride_part: DimDyn = a.stride().slice()[a_len - i..].into();
107 let b_stride_part: DimDyn = b.stride().slice()[b_len - i..].into();
108 let a_part = ShapeStride::new(a_shape_part, a_stride_part);
109 let b_part = ShapeStride::new(b_shape_part, b_stride_part);
110 if !a_part.is_transposed()
111 && !b_part.is_transposed()
112 && a_part.is_contiguous()
113 && b_part.is_contiguous()
114 {
115 idx = i;
116 } else {
117 break;
118 }
119 }
120 idx
121 }
122 }
123}
124
125struct PointerOffsetIter {
126 max_idx: usize,
127 to_shape_stride: ShapeStride<DimDyn>,
128 source_shape_stride: ShapeStride<DimDyn>,
129 current_idx: usize,
130 num_iter: usize,
131 to_current_idx: DimDyn,
132 source_current_idx: DimDyn,
133}
134
135fn inc_idx(idx: &mut DimDyn, shape: &DimDyn) {
136 let slice = shape.slice();
137 let len = slice.len();
138
139 for i in (0..len).rev() {
140 idx[i] += 1;
141 if idx[i] < slice[i] {
142 return;
143 }
144 idx[i] = 0;
145 }
146}
147
148fn cal_num_ber_of_iter(shape: DimDyn, max_idx: usize) -> usize {
149 shape.slice()[..shape.len() - max_idx].iter().product()
150}
151
152fn cal_offset(stride: DimDyn, idx: DimDyn) -> usize {
153 let stride_slice = stride.slice();
154 let idx_slice = idx.slice();
155 stride_slice
156 .iter()
157 .zip(idx_slice.iter())
158 .fold(0, |acc, (&s, &i)| acc + s * i)
159}
160
161impl PointerOffsetIter {
162 fn new(to_shape_stride: ShapeStride<DimDyn>, source_shape_stride: ShapeStride<DimDyn>) -> Self {
163 let max_idx = get_max_shape_idx_of_apply_blas(to_shape_stride, source_shape_stride);
164 let num_iter = cal_num_ber_of_iter(to_shape_stride.shape(), max_idx);
165 let to_len = to_shape_stride.shape().len();
166 let source_len = source_shape_stride.shape().len();
167 let to_shape_stride = ShapeStride::new(
168 DimDyn::from(&to_shape_stride.shape().slice()[..to_len - max_idx]),
169 DimDyn::from(&to_shape_stride.stride().slice()[..to_len - max_idx]),
170 );
171 let source_shape_stride = ShapeStride::new(
172 DimDyn::from(&source_shape_stride.shape().slice()[..source_len - max_idx]),
173 DimDyn::from(&source_shape_stride.stride().slice()[..source_len - max_idx]),
174 );
175 let current_len = to_shape_stride.shape().len();
176 let source_current_len = source_shape_stride.shape().len();
177 let to_current_idx = DimDyn::from(&vec![0_usize; current_len] as &[usize]);
178 let source_current_idx = DimDyn::from(&vec![0_usize; source_current_len] as &[usize]);
179 Self {
180 max_idx,
181 to_shape_stride,
182 source_shape_stride,
183 current_idx: 0,
184 num_iter,
185 to_current_idx,
186 source_current_idx,
187 }
188 }
189}
190
191impl Iterator for PointerOffsetIter {
192 type Item = (usize, usize);
193
194 fn next(&mut self) -> Option<Self::Item> {
195 if self.current_idx >= self.num_iter {
196 return None;
197 }
198 inc_idx(&mut self.to_current_idx, &self.to_shape_stride.shape());
199 let to_offset = cal_offset(self.to_shape_stride.stride(), self.to_current_idx);
200 inc_idx(
201 &mut self.source_current_idx,
202 &self.source_shape_stride.shape(),
203 );
204 let source_offset = cal_offset(self.source_shape_stride.stride(), self.source_current_idx);
205 self.current_idx += 1;
206 Some((to_offset, source_offset))
207 }
208}
209
210#[expect(clippy::needless_pass_by_value)]
211fn copy<T: Num, D: DeviceBase + CopyBlas>(
212 to: Matrix<Ref<&mut T>, DimDyn, D>,
213 source: Matrix<Ref<&T>, DimDyn, D>,
214) {
215 if to.shape().is_empty() {
216 let source_value = source.index_item([]);
217 to.index_item_assign([], source_value);
218 return;
219 }
220
221 let iter = PointerOffsetIter::new(to.shape_stride(), source.shape_stride());
222 let max_blas_apply_idx = iter.max_idx;
223
224 let to_shape = to.shape();
225 let to_stride = to.stride();
226 let source_stride = source.stride();
227
228 let to_stride_ = *to_stride.slice()[to_stride.len() - max_blas_apply_idx..]
229 .iter()
230 .min()
231 .unwrap();
232 let source_stride_ = *source_stride.slice()[source_stride.len() - max_blas_apply_idx..]
233 .iter()
234 .min()
235 .unwrap();
236
237 let to_blas_num_elm_ =
238 DimDyn::from(&to_shape.slice()[to_shape.len() - max_blas_apply_idx..]).num_elm();
239
240 let to_ptr = to.as_mut_ptr();
241 let source_ptr = source.as_ptr();
242
243 for (to_offset, source_offset) in iter {
244 let to_ptr = unsafe { to_ptr.add(to_offset) };
245 let source_ptr = unsafe { source_ptr.add(source_offset) };
246 D::copy_raw(
247 to_blas_num_elm_,
248 source_ptr,
249 source_stride_,
250 to_ptr.cast(),
251 to_stride_,
252 );
253 }
254}
255
256impl<T, SA, D> Matrix<Ref<&mut T>, SA, D>
257where
258 T: Num,
259 SA: DimTrait,
260 D: DeviceBase + CopyBlas,
261{
262 #[expect(clippy::missing_panics_doc)]
263 pub fn copy_from<R: Repr<Item = T>, SB: DimTrait>(&self, source: &Matrix<R, SB, D>) {
264 assert!(self.shape().slice() == source.shape().slice());
265 copy(self.clone().into_dyn_dim(), source.to_ref().into_dyn_dim());
266 }
267}
268
269#[cfg(test)]
270mod deep_copy {
271 #![expect(clippy::float_cmp)]
272
273 use super::*;
274 use crate::{
275 device::cpu::Cpu,
276 dim::{Dim1, Dim2},
277 matrix::Owned,
278 slice,
279 };
280
281 #[cfg(feature = "nvidia")]
282 use crate::device::nvidia::Nvidia;
283
284 fn default_stride_1d<D: CopyBlas>() {
286 let a = vec![0f32; 6];
287 let b = vec![1f32, 2., 3., 4., 5., 6.];
288
289 let mut a: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(a, [6]);
290 let b: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(b, [6]);
291
292 let a_view_mut = a.to_ref_mut();
293
294 a_view_mut.into_dyn_dim().copy_from(&b.into_dyn_dim());
295
296 assert_eq!(a.index_item([0]), 1.);
297 assert_eq!(a.index_item([1]), 2.);
298 assert_eq!(a.index_item([2]), 3.);
299 assert_eq!(a.index_item([3]), 4.);
300 assert_eq!(a.index_item([4]), 5.);
301 assert_eq!(a.index_item([5]), 6.);
302 }
303 #[test]
304 fn default_stride_1d_cpu() {
305 default_stride_1d::<Cpu>();
306 }
307 #[cfg(feature = "nvidia")]
308 #[test]
309 fn default_stride_1d_nvidia() {
310 default_stride_1d::<Nvidia>();
311 }
312
313 fn sliced_1d<D: CopyBlas>() {
314 let a = vec![0f32; 6];
315 let v = vec![0f32, 1., 2., 3., 4., 5.];
316
317 let mut a: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(a.clone(), [6]);
318 let v: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(v, [6]);
319
320 let a_sliced = a.to_ref_mut().slice_mut(slice!(..;2));
321 let v_sliced = v.slice(slice!(0..3));
322
323 a_sliced.into_dyn_dim().copy_from(&v_sliced.into_dyn_dim());
324 assert_eq!(a.index_item([0]), 0.);
325 assert_eq!(a.index_item([1]), 0.);
326 assert_eq!(a.index_item([2]), 1.);
327 assert_eq!(a.index_item([3]), 0.);
328 assert_eq!(a.index_item([4]), 2.);
329 assert_eq!(a.index_item([5]), 0.);
330 }
331 #[test]
332 fn sliced_1d_cpu() {
333 sliced_1d::<Cpu>();
334 }
335 #[cfg(feature = "nvidia")]
336 #[test]
337 fn sliced_1d_nvidia() {
338 sliced_1d::<Nvidia>();
339 }
340
341 fn defualt_stride_2d<D: CopyBlas>() {
342 let a = vec![0f32; 6];
343 let b = vec![1f32, 2., 3., 4., 5., 6.];
344
345 let mut a: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(a, [2, 3]);
346 let b: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(b, [2, 3]);
347
348 let a_view_mut = a.to_ref_mut();
349
350 a_view_mut.into_dyn_dim().copy_from(&b.into_dyn_dim());
351
352 assert_eq!(a.index_item([0, 0]), 1.);
353 assert_eq!(a.index_item([0, 1]), 2.);
354 assert_eq!(a.index_item([0, 2]), 3.);
355 assert_eq!(a.index_item([1, 0]), 4.);
356 assert_eq!(a.index_item([1, 1]), 5.);
357 assert_eq!(a.index_item([1, 2]), 6.);
358 }
359 #[test]
360 fn defualt_stride_2d_cpu() {
361 defualt_stride_2d::<Cpu>();
362 }
363 #[cfg(feature = "nvidia")]
364 #[test]
365 fn defualt_stride_2d_nvidia() {
366 defualt_stride_2d::<Nvidia>();
367 }
368
369 fn sliced_2d<D: CopyBlas>() {
370 let a = vec![0f32; 12];
371 let v = vec![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.];
372
373 let mut a: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(a.clone(), [3, 4]);
374 let v: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(v, [3, 4]);
375
376 let a_sliced = a.to_ref_mut().slice_mut(slice!(0..2, 0..3));
377 let v_sliced = v.slice(slice!(1..3, 1..4));
378
379 a_sliced.into_dyn_dim().copy_from(&v_sliced.into_dyn_dim());
380 assert_eq!(a.index_item([0, 0]), 5.);
381 assert_eq!(a.index_item([0, 1]), 6.);
382 assert_eq!(a.index_item([0, 2]), 7.);
383 assert_eq!(a.index_item([0, 3]), 0.);
384 assert_eq!(a.index_item([1, 0]), 9.);
385 assert_eq!(a.index_item([1, 1]), 10.);
386 assert_eq!(a.index_item([1, 2]), 11.);
387 assert_eq!(a.index_item([2, 3]), 0.);
388 }
389 #[test]
390 fn sliced_2d_cpu() {
391 sliced_2d::<Cpu>();
392 }
393 #[cfg(feature = "nvidia")]
394 #[test]
395 fn sliced_2d_nvidia() {
396 sliced_2d::<Nvidia>();
397 }
398}