zenu_matrix/operation/
copy_from.rs

1use crate::{
2    device::{cpu::Cpu, DeviceBase},
3    dim::{DimDyn, DimTrait},
4    matrix::{Matrix, Ref, Repr},
5    num::Num,
6    shape_stride::ShapeStride,
7};
8
9pub trait CopyBlas: DeviceBase {
10    fn copy_raw<T: Num>(n: usize, x: *const T, incx: usize, y: *mut T, incy: usize);
11}
12
13impl CopyBlas for Cpu {
14    #[expect(clippy::similar_names)]
15    fn copy_raw<T: Num>(n: usize, x: *const T, incx: usize, y: *mut T, incy: usize) {
16        extern crate openblas_src;
17        use cblas::{dcopy, scopy};
18        if T::is_f32() {
19            let x = unsafe { std::slice::from_raw_parts(x.cast(), n * incx) };
20            let y = unsafe { std::slice::from_raw_parts_mut(y.cast(), n * incy) };
21            unsafe {
22                scopy(
23                    n.try_into().unwrap(),
24                    x,
25                    incx.try_into().unwrap(),
26                    y,
27                    incy.try_into().unwrap(),
28                );
29            }
30        } else {
31            let x = unsafe { std::slice::from_raw_parts(x.cast(), n * incx) };
32            let y = unsafe { std::slice::from_raw_parts_mut(y.cast(), n * incy) };
33            unsafe {
34                dcopy(
35                    n.try_into().unwrap(),
36                    x,
37                    incx.try_into().unwrap(),
38                    y,
39                    incy.try_into().unwrap(),
40                );
41            }
42        }
43    }
44}
45
46#[cfg(feature = "nvidia")]
47use crate::device::nvidia::Nvidia;
48
49#[cfg(feature = "nvidia")]
50impl CopyBlas for Nvidia {
51    #[expect(clippy::similar_names)]
52    fn copy_raw<T: Num>(n: usize, x: *const T, incx: usize, y: *mut T, incy: usize) {
53        zenu_cuda::cublas::cublas_copy(n, x, incx, y, incy).unwrap();
54    }
55}
56
57#[expect(clippy::similar_names, clippy::needless_pass_by_value)]
58pub fn copy_unchecked<T, SA, SB, RB, D>(x: Matrix<Ref<&T>, SA, D>, y: Matrix<Ref<&mut T>, SB, D>)
59where
60    T: Num,
61    SA: DimTrait,
62    SB: DimTrait,
63    D: CopyBlas,
64{
65    let n = x.shape()[0];
66    let incx = x.stride()[0];
67    let incy = y.stride()[0];
68    let x = x.as_ptr();
69    let y = y.as_mut_ptr();
70    D::copy_raw(n, x, incx, y, incy);
71}
72
73fn get_max_shape_idx_of_apply_blas(a: ShapeStride<DimDyn>, b: ShapeStride<DimDyn>) -> usize {
74    let min_len = std::cmp::min(a.shape().len(), b.shape().len());
75    let a_len = a.shape().len();
76    let b_len = b.shape().len();
77
78    match min_len {
79        0 => 0,
80        1 => 1,
81        2 => {
82            let a_stride = a.stride();
83            let b_stride = b.stride();
84            let a_shape = a.shape();
85            let b_shape = b.shape();
86            let a_stride_part: DimDyn = a_stride.slice()[a_len - 2..].into();
87            let b_stride_part: DimDyn = b_stride.slice()[b_len - 2..].into();
88            let a_shape_part: DimDyn = a_shape.slice()[a_len - 2..].into();
89            let b_shape_part: DimDyn = b_shape.slice()[b_len - 2..].into();
90            let a_part = ShapeStride::new(a_shape_part, a_stride_part);
91            let b_part = ShapeStride::new(b_shape_part, b_stride_part);
92            if !(a_part.is_transposed() || b_part.is_transposed())
93                && a_part.is_contiguous()
94                && b_part.is_contiguous()
95            {
96                2
97            } else {
98                1
99            }
100        }
101        _ => {
102            let mut idx = 1;
103            for i in 2..=min_len {
104                let a_shape_part: DimDyn = a.shape().slice()[a_len - i..].into();
105                let b_shape_part: DimDyn = b.shape().slice()[b_len - i..].into();
106                let a_stride_part: DimDyn = a.stride().slice()[a_len - i..].into();
107                let b_stride_part: DimDyn = b.stride().slice()[b_len - i..].into();
108                let a_part = ShapeStride::new(a_shape_part, a_stride_part);
109                let b_part = ShapeStride::new(b_shape_part, b_stride_part);
110                if !a_part.is_transposed()
111                    && !b_part.is_transposed()
112                    && a_part.is_contiguous()
113                    && b_part.is_contiguous()
114                {
115                    idx = i;
116                } else {
117                    break;
118                }
119            }
120            idx
121        }
122    }
123}
124
125struct PointerOffsetIter {
126    max_idx: usize,
127    to_shape_stride: ShapeStride<DimDyn>,
128    source_shape_stride: ShapeStride<DimDyn>,
129    current_idx: usize,
130    num_iter: usize,
131    to_current_idx: DimDyn,
132    source_current_idx: DimDyn,
133}
134
135fn inc_idx(idx: &mut DimDyn, shape: &DimDyn) {
136    let slice = shape.slice();
137    let len = slice.len();
138
139    for i in (0..len).rev() {
140        idx[i] += 1;
141        if idx[i] < slice[i] {
142            return;
143        }
144        idx[i] = 0;
145    }
146}
147
148fn cal_num_ber_of_iter(shape: DimDyn, max_idx: usize) -> usize {
149    shape.slice()[..shape.len() - max_idx].iter().product()
150}
151
152fn cal_offset(stride: DimDyn, idx: DimDyn) -> usize {
153    let stride_slice = stride.slice();
154    let idx_slice = idx.slice();
155    stride_slice
156        .iter()
157        .zip(idx_slice.iter())
158        .fold(0, |acc, (&s, &i)| acc + s * i)
159}
160
161impl PointerOffsetIter {
162    fn new(to_shape_stride: ShapeStride<DimDyn>, source_shape_stride: ShapeStride<DimDyn>) -> Self {
163        let max_idx = get_max_shape_idx_of_apply_blas(to_shape_stride, source_shape_stride);
164        let num_iter = cal_num_ber_of_iter(to_shape_stride.shape(), max_idx);
165        let to_len = to_shape_stride.shape().len();
166        let source_len = source_shape_stride.shape().len();
167        let to_shape_stride = ShapeStride::new(
168            DimDyn::from(&to_shape_stride.shape().slice()[..to_len - max_idx]),
169            DimDyn::from(&to_shape_stride.stride().slice()[..to_len - max_idx]),
170        );
171        let source_shape_stride = ShapeStride::new(
172            DimDyn::from(&source_shape_stride.shape().slice()[..source_len - max_idx]),
173            DimDyn::from(&source_shape_stride.stride().slice()[..source_len - max_idx]),
174        );
175        let current_len = to_shape_stride.shape().len();
176        let source_current_len = source_shape_stride.shape().len();
177        let to_current_idx = DimDyn::from(&vec![0_usize; current_len] as &[usize]);
178        let source_current_idx = DimDyn::from(&vec![0_usize; source_current_len] as &[usize]);
179        Self {
180            max_idx,
181            to_shape_stride,
182            source_shape_stride,
183            current_idx: 0,
184            num_iter,
185            to_current_idx,
186            source_current_idx,
187        }
188    }
189}
190
191impl Iterator for PointerOffsetIter {
192    type Item = (usize, usize);
193
194    fn next(&mut self) -> Option<Self::Item> {
195        if self.current_idx >= self.num_iter {
196            return None;
197        }
198        inc_idx(&mut self.to_current_idx, &self.to_shape_stride.shape());
199        let to_offset = cal_offset(self.to_shape_stride.stride(), self.to_current_idx);
200        inc_idx(
201            &mut self.source_current_idx,
202            &self.source_shape_stride.shape(),
203        );
204        let source_offset = cal_offset(self.source_shape_stride.stride(), self.source_current_idx);
205        self.current_idx += 1;
206        Some((to_offset, source_offset))
207    }
208}
209
210#[expect(clippy::needless_pass_by_value)]
211fn copy<T: Num, D: DeviceBase + CopyBlas>(
212    to: Matrix<Ref<&mut T>, DimDyn, D>,
213    source: Matrix<Ref<&T>, DimDyn, D>,
214) {
215    if to.shape().is_empty() {
216        let source_value = source.index_item([]);
217        to.index_item_assign([], source_value);
218        return;
219    }
220
221    let iter = PointerOffsetIter::new(to.shape_stride(), source.shape_stride());
222    let max_blas_apply_idx = iter.max_idx;
223
224    let to_shape = to.shape();
225    let to_stride = to.stride();
226    let source_stride = source.stride();
227
228    let to_stride_ = *to_stride.slice()[to_stride.len() - max_blas_apply_idx..]
229        .iter()
230        .min()
231        .unwrap();
232    let source_stride_ = *source_stride.slice()[source_stride.len() - max_blas_apply_idx..]
233        .iter()
234        .min()
235        .unwrap();
236
237    let to_blas_num_elm_ =
238        DimDyn::from(&to_shape.slice()[to_shape.len() - max_blas_apply_idx..]).num_elm();
239
240    let to_ptr = to.as_mut_ptr();
241    let source_ptr = source.as_ptr();
242
243    for (to_offset, source_offset) in iter {
244        let to_ptr = unsafe { to_ptr.add(to_offset) };
245        let source_ptr = unsafe { source_ptr.add(source_offset) };
246        D::copy_raw(
247            to_blas_num_elm_,
248            source_ptr,
249            source_stride_,
250            to_ptr.cast(),
251            to_stride_,
252        );
253    }
254}
255
256impl<T, SA, D> Matrix<Ref<&mut T>, SA, D>
257where
258    T: Num,
259    SA: DimTrait,
260    D: DeviceBase + CopyBlas,
261{
262    #[expect(clippy::missing_panics_doc)]
263    pub fn copy_from<R: Repr<Item = T>, SB: DimTrait>(&self, source: &Matrix<R, SB, D>) {
264        assert!(self.shape().slice() == source.shape().slice());
265        copy(self.clone().into_dyn_dim(), source.to_ref().into_dyn_dim());
266    }
267}
268
269#[cfg(test)]
270mod deep_copy {
271    #![expect(clippy::float_cmp)]
272
273    use super::*;
274    use crate::{
275        device::cpu::Cpu,
276        dim::{Dim1, Dim2},
277        matrix::Owned,
278        slice,
279    };
280
281    #[cfg(feature = "nvidia")]
282    use crate::device::nvidia::Nvidia;
283
284    // #[test]
285    fn default_stride_1d<D: CopyBlas>() {
286        let a = vec![0f32; 6];
287        let b = vec![1f32, 2., 3., 4., 5., 6.];
288
289        let mut a: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(a, [6]);
290        let b: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(b, [6]);
291
292        let a_view_mut = a.to_ref_mut();
293
294        a_view_mut.into_dyn_dim().copy_from(&b.into_dyn_dim());
295
296        assert_eq!(a.index_item([0]), 1.);
297        assert_eq!(a.index_item([1]), 2.);
298        assert_eq!(a.index_item([2]), 3.);
299        assert_eq!(a.index_item([3]), 4.);
300        assert_eq!(a.index_item([4]), 5.);
301        assert_eq!(a.index_item([5]), 6.);
302    }
303    #[test]
304    fn default_stride_1d_cpu() {
305        default_stride_1d::<Cpu>();
306    }
307    #[cfg(feature = "nvidia")]
308    #[test]
309    fn default_stride_1d_nvidia() {
310        default_stride_1d::<Nvidia>();
311    }
312
313    fn sliced_1d<D: CopyBlas>() {
314        let a = vec![0f32; 6];
315        let v = vec![0f32, 1., 2., 3., 4., 5.];
316
317        let mut a: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(a.clone(), [6]);
318        let v: Matrix<Owned<f32>, Dim1, D> = Matrix::from_vec(v, [6]);
319
320        let a_sliced = a.to_ref_mut().slice_mut(slice!(..;2));
321        let v_sliced = v.slice(slice!(0..3));
322
323        a_sliced.into_dyn_dim().copy_from(&v_sliced.into_dyn_dim());
324        assert_eq!(a.index_item([0]), 0.);
325        assert_eq!(a.index_item([1]), 0.);
326        assert_eq!(a.index_item([2]), 1.);
327        assert_eq!(a.index_item([3]), 0.);
328        assert_eq!(a.index_item([4]), 2.);
329        assert_eq!(a.index_item([5]), 0.);
330    }
331    #[test]
332    fn sliced_1d_cpu() {
333        sliced_1d::<Cpu>();
334    }
335    #[cfg(feature = "nvidia")]
336    #[test]
337    fn sliced_1d_nvidia() {
338        sliced_1d::<Nvidia>();
339    }
340
341    fn defualt_stride_2d<D: CopyBlas>() {
342        let a = vec![0f32; 6];
343        let b = vec![1f32, 2., 3., 4., 5., 6.];
344
345        let mut a: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(a, [2, 3]);
346        let b: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(b, [2, 3]);
347
348        let a_view_mut = a.to_ref_mut();
349
350        a_view_mut.into_dyn_dim().copy_from(&b.into_dyn_dim());
351
352        assert_eq!(a.index_item([0, 0]), 1.);
353        assert_eq!(a.index_item([0, 1]), 2.);
354        assert_eq!(a.index_item([0, 2]), 3.);
355        assert_eq!(a.index_item([1, 0]), 4.);
356        assert_eq!(a.index_item([1, 1]), 5.);
357        assert_eq!(a.index_item([1, 2]), 6.);
358    }
359    #[test]
360    fn defualt_stride_2d_cpu() {
361        defualt_stride_2d::<Cpu>();
362    }
363    #[cfg(feature = "nvidia")]
364    #[test]
365    fn defualt_stride_2d_nvidia() {
366        defualt_stride_2d::<Nvidia>();
367    }
368
369    fn sliced_2d<D: CopyBlas>() {
370        let a = vec![0f32; 12];
371        let v = vec![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.];
372
373        let mut a: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(a.clone(), [3, 4]);
374        let v: Matrix<Owned<f32>, Dim2, D> = Matrix::from_vec(v, [3, 4]);
375
376        let a_sliced = a.to_ref_mut().slice_mut(slice!(0..2, 0..3));
377        let v_sliced = v.slice(slice!(1..3, 1..4));
378
379        a_sliced.into_dyn_dim().copy_from(&v_sliced.into_dyn_dim());
380        assert_eq!(a.index_item([0, 0]), 5.);
381        assert_eq!(a.index_item([0, 1]), 6.);
382        assert_eq!(a.index_item([0, 2]), 7.);
383        assert_eq!(a.index_item([0, 3]), 0.);
384        assert_eq!(a.index_item([1, 0]), 9.);
385        assert_eq!(a.index_item([1, 1]), 10.);
386        assert_eq!(a.index_item([1, 2]), 11.);
387        assert_eq!(a.index_item([2, 3]), 0.);
388    }
389    #[test]
390    fn sliced_2d_cpu() {
391        sliced_2d::<Cpu>();
392    }
393    #[cfg(feature = "nvidia")]
394    #[test]
395    fn sliced_2d_nvidia() {
396        sliced_2d::<Nvidia>();
397    }
398}