tensorgraph_math/
tensor.rs

1use std::mem::MaybeUninit;
2
3use tensorgraph_sys::{ptr::Ref, View, ViewMut, device::Device};
4
5use crate::{
6    dims::{Dimension, RemoveDim},
7    storage::{IntoOwned, Storage, StorageMut},
8};
9
10mod matrix;
11mod vector;
12pub use matrix::*;
13pub use vector::*;
14
15/// A representation of a slice
16pub type Slice<T, D> = Ref<[T], D>;
17
18/// Gets the view repr of the provided storage
19pub type ViewOf<S> = Slice<<S as Storage>::T, <S as Storage>::Device>;
20
21/// A 'view' of a tensor, Like `&[T]` is to `Vec<T>`
22pub type TensorView<'a, T, D, Dim> = Tensor<&'a Slice<T, D>, Dim>;
23
24/// A 'mut view' of a tensor, Like `&mut [T]` is to `Vec<T>`
25pub type TensorViewMut<'a, T, D, Dim> = Tensor<&'a mut Slice<T, D>, Dim>;
26
27/// An uninit tensor. Contents are mutable and specified as [`MaybeUninit`].
28pub type UninitTensor<'a, T, D, Dim> = TensorViewMut<'a, MaybeUninit<T>, D, Dim>;
29
30/// A multidimensional data structure not unlike [`ndarray::ArrayBase`](https://docs.rs/ndarray/0.15.4/ndarray/struct.ArrayBase.html).
31#[derive(Copy, Clone)]
32pub struct Tensor<S: Storage, Dim: Dimension> {
33    shape: Dim,
34    strides: Dim,
35    data: S,
36}
37
38impl<S: Storage, Dim: Dimension> View for Tensor<S, Dim> {
39    type Ref<'a>
40    where
41        Self: 'a,
42    = Tensor<&'a ViewOf<S>, Dim>;
43
44    fn view(&self) -> TensorView<S::T, S::Device, Dim> {
45        Tensor {
46            shape: self.shape.clone(),
47            strides: self.strides.clone(),
48            data: self.data.as_ref(),
49        }
50    }
51}
52
53impl<S: StorageMut, Dim: Dimension> ViewMut for Tensor<S, Dim> {
54    type Mut<'a>
55    where
56        Self: 'a,
57    = Tensor<&'a mut ViewOf<S>, Dim>;
58
59    fn view_mut(&mut self) -> TensorViewMut<S::T, S::Device, Dim> {
60        Tensor {
61            shape: self.shape.clone(),
62            strides: self.strides.clone(),
63            data: self.data.as_mut(),
64        }
65    }
66}
67
68impl<S: Storage, Dim: Dimension> Tensor<S, Dim> {
69    /// Creates a new tensor using the shape and the raw data.
70    ///
71    /// # Panics
72    /// The length of the data structure must match the size of the dimensions
73    pub fn from_shape(shape: Dim, data: S) -> Self {
74        assert_eq!(data.as_ref().len(), shape.size());
75        let strides = shape.column_major_strides();
76        Self {
77            shape,
78            strides,
79            data,
80        }
81    }
82
83    /// Consumes the tensor, returning the underlying data
84    pub fn into_inner(self) -> S {
85        self.data
86    }
87
88    /// Reverses the axes of the tensor. An inplace transpose
89    pub fn reverse_axes(&mut self) {
90        self.shape.as_mut().reverse();
91        self.strides.as_mut().reverse();
92    }
93
94    pub fn swap_axes(&mut self, i: usize, j: usize) {
95        self.shape.as_mut().swap(i, j);
96        self.strides.as_mut().swap(i, j);
97    }
98
99    /// Returns a view of the tensor with the contents transposed.
100    /// This operation happens without mutating or cloning any data
101    pub fn t(&self) -> Tensor<&ViewOf<S>, Dim> {
102        let mut view = self.view();
103        view.reverse_axes();
104        view
105    }
106
107    /// Creates a new owned version of the tensor.
108    /// Will only clone the contents if needed
109    pub fn into_owned(self) -> Tensor<S::Owned, Dim>
110    where
111        S: IntoOwned,
112        S::Owned: Storage<T = S::T, Device = S::Device>,
113    {
114        Tensor {
115            shape: self.shape,
116            strides: self.strides,
117            data: self.data.into_owned(),
118        }
119    }
120
121    /// Slices the tensor over a specific axis. The resulting tensor will be a dimension smaller
122    ///
123    /// # Panics
124    /// If the axis is outside of the length of the dimensions
125    pub fn slice_axis(&self, axis: usize, n: usize) -> Tensor<&ViewOf<S>, Dim::Smaller>
126    where
127        Dim: RemoveDim,
128    {
129        assert!(axis < self.shape.as_ref().len());
130
131        let (shape, m) = self.shape.remove(axis);
132        let (strides, s) = self.strides.remove(axis);
133
134        assert!(n < m);
135
136        Tensor {
137            shape,
138            strides,
139            data: &self.data.as_ref()[s * n..],
140        }
141    }
142}
143
144impl<'a, T, D: Device, Dim: Dimension> UninitTensor<'a, T, D, Dim> {
145    /// # Safety
146    /// Contents must be initialised
147    pub unsafe fn assume_init(self) -> TensorViewMut<'a, T, D, Dim> {
148        Tensor {
149            shape: self.shape,
150            strides: self.strides,
151            data: self.data.assume_init_mut(),
152        }
153    }
154}
155
156impl<'a, T, D: Device, Dim: Dimension> TensorView<'a, MaybeUninit<T>, D, Dim> {
157    /// # Safety
158    /// Contents must be initialised
159    pub unsafe fn assume_init(self) -> TensorView<'a, T, D, Dim> {
160        Tensor {
161            shape: self.shape,
162            strides: self.strides,
163            data: self.data.assume_init(),
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use tensorgraph_sys::{View, ViewMut};
171
172    use crate::tensor::{gemm, Tensor};
173
174    #[test]
175    fn matmul() {
176        //     0 1
177        // A = 2 3
178        //     4 5
179
180        // B = 0 1
181        //     2 3
182
183        // column major (read each column first)
184        let a = [0., 2., 4., 1., 3., 5.];
185        let b = [0., 2., 1., 3.];
186        let a = Tensor::from_shape([3, 2], a); // 3 rows x 2 cols
187        let b = Tensor::from_shape([2, 2], b); // 2 rows x 2 cols
188
189        //           2  3
190        // C = AB =  6 11
191        //          10 19
192
193        let c = a.matmul(b.view());
194        assert_eq!(c.into_inner().into_std(), [2., 6., 10., 3., 11., 19.]);
195    }
196
197    #[test]
198    fn matmul_t() {
199        // A = 1 3
200        //     2 4
201
202        // B = 5 7
203        //     6 8
204
205        let a = [1., 2., 3., 4.];
206        let b = [5., 6., 7., 8.];
207        let a = Tensor::from_shape([2, 2], a);
208        let b = Tensor::from_shape([2, 2], b);
209
210        // C1 = A^B^ = 19 22
211        //             43 50
212
213        let c1 = a.t().matmul(b.t());
214        assert_eq!(c1.into_inner().into_std(), [19.0, 43.0, 22.0, 50.0]);
215
216        // C2 = AB^ = 26 30
217        //            38 44
218
219        let c2 = a.matmul(b.t());
220        assert_eq!(c2.into_inner().into_std(), [26.0, 38.0, 30.0, 44.0]);
221
222        // C3 = A^B = 17 23
223        //            39 53
224
225        let c3 = a.t().matmul(b.view());
226        assert_eq!(c3.into_inner().into_std(), [17.0, 39.0, 23.0, 53.0]);
227    }
228
229    #[test]
230    fn slice() {
231        //     0 2 4
232        // A = 1 3 5
233
234        // column major
235        let a = [0., 1., 2., 3., 4., 5.];
236        let a = Tensor::from_shape([2, 3], a);
237
238        // axis 0 (columns)
239        let a00 = a.slice_axis(0, 0);
240        assert_eq!(&**a00.into_inner(), [0., 1., 2., 3., 4., 5.]); // represents 0, 2, 4
241        assert_eq!(a00.shape, [3]);
242        assert_eq!(a00.strides, [2]);
243
244        let a01 = a.slice_axis(0, 1);
245        assert_eq!(&**a01.into_inner(), [1., 2., 3., 4., 5.]); // represents 1, 3, 5
246        assert_eq!(a01.shape, [3]);
247        assert_eq!(a01.strides, [2]);
248
249        // axis 1 (rows)
250        let a10 = a.slice_axis(1, 0);
251        assert_eq!(&**a10.into_inner(), [0., 1., 2., 3., 4., 5.]); // represents 0, 1
252        assert_eq!(a10.shape, [2]);
253        assert_eq!(a10.strides, [1]);
254
255        let a11 = a.slice_axis(1, 1);
256        assert_eq!(&**a11.into_inner(), [2., 3., 4., 5.]); // represents 2, 3
257        assert_eq!(a11.shape, [2]);
258        assert_eq!(a11.strides, [1]);
259
260        let a12 = a.slice_axis(1, 2);
261        assert_eq!(&**a12.into_inner(), [4., 5.]); // represents 4, 5
262        assert_eq!(a12.shape, [2]);
263        assert_eq!(a12.strides, [1]);
264    }
265
266    #[test]
267    #[cfg(feature = "cublas")]
268    fn matmul_cuda() {
269        use crate::blas::cublas::CublasContext;
270        use tensorgraph_sys::{
271            device::cuda::{Context, Stream},
272            Vec,
273        };
274
275        let ctx = Context::quick_init().unwrap();
276        let cuda = Stream::new(&ctx).unwrap();
277        let cuda = &*cuda;
278
279        // column major
280        let a = Vec::copy_from_host_in(&[0., 2., 4., 1., 3., 5.], cuda);
281        let b = Vec::copy_from_host_in(&[0., 2., 1., 3.], cuda);
282
283        let ctx = CublasContext::new();
284        let ctx = ctx.with_stream(Some(cuda));
285
286        let a = Tensor::from_shape([3, 2], a);
287        let b = Tensor::from_shape([2, 2], b);
288
289        let c = a.matmul_into(b.view(), ctx, cuda);
290
291        let mut out = vec![0.0_f32; 6];
292        c.data.copy_to_host(&mut out);
293
294        assert_eq!(out, vec![2., 6., 10., 3., 11., 19.]);
295    }
296
297    #[test]
298    #[cfg(feature = "cublas")]
299    fn matmul_cuda_global() {
300        use crate::blas::cublas::CublasContext;
301        use tensorgraph_sys::{
302            device::cuda::{Context, Cuda, Stream},
303            DefaultVec,
304        };
305
306        let ctx = Context::quick_init().unwrap();
307        let cuda = Stream::new(&ctx).unwrap();
308        let _handle = cuda.as_global();
309
310        // column major
311        let a = DefaultVec::<f32, Cuda>::copy_from_host(&[0., 2., 4., 1., 3., 5.]);
312        let b = DefaultVec::<f32, Cuda>::copy_from_host(&[0., 2., 1., 3.]);
313
314        let ctx = CublasContext::new();
315        let _handle = ctx.with_stream(Some(&cuda)).as_global();
316
317        let a = Tensor::from_shape([3, 2], a);
318        let b = Tensor::from_shape([2, 2], b);
319
320        let c = a.matmul(b.view());
321
322        let mut out = vec![0.0_f32; 6];
323        c.data.copy_to_host(&mut out);
324
325        assert_eq!(out, vec![2., 6., 10., 3., 11., 19.]);
326    }
327
328    #[test]
329    fn matmul2() {
330        // column major
331        let a = [0.001, 1.0, 1.0, 0.];
332        let b = a;
333        let c = [0.; 4];
334
335        let mut a = Tensor::from_shape([2, 2], a);
336        let b = Tensor::from_shape([2, 2], b);
337        let mut c = Tensor::from_shape([2, 2], c);
338
339        for _ in 0..1000 {
340            gemm(1., a.view(), b.view(), 0., c.view_mut());
341            std::mem::swap(&mut a, &mut c);
342        }
343
344        let out = c.into_inner();
345        let expected = [
346            1.1278865019586632,
347            0.5210952168646452,
348            0.5210952168646452,
349            1.1273654067417986,
350        ];
351
352        assert_eq!(out[0], expected[0]);
353        assert_eq!(out[1], expected[1]);
354        assert_eq!(out[2], expected[2]);
355        assert_eq!(out[3], expected[3]);
356    }
357
358    #[test]
359    #[cfg(feature = "cublas")]
360    fn matmul_cuda2() {
361        use crate::{blas::cublas::CublasContext, tensor::gemm_ctx};
362        use tensorgraph_sys::{
363            device::cuda::{Context, Stream},
364            Vec,
365        };
366
367        let ctx = Context::quick_init().unwrap();
368        let cuda = Stream::new(&ctx).unwrap();
369        let cuda = &*cuda;
370
371        // column major
372        let a = Vec::copy_from_host_in(&[0.001, 1.0, 1.0, 0.], cuda);
373        let b = a.clone();
374        let c = b.clone();
375
376        let ctx = CublasContext::new();
377        let ctx = ctx.with_stream(Some(cuda));
378
379        let mut a = Tensor::from_shape([2, 2], a);
380        let b = Tensor::from_shape([2, 2], b);
381        let mut c = Tensor::from_shape([2, 2], c);
382
383        for _ in 0..1000 {
384            gemm_ctx(ctx, 1., a.view(), b.view(), 0., c.view_mut());
385            std::mem::swap(&mut a, &mut c);
386        }
387
388        let mut out = vec![0.; 4];
389        c.data.copy_to_host(&mut out);
390
391        let expected = [
392            1.1278865019586632,
393            0.5210952168646452,
394            0.5210952168646452,
395            1.1273654067417986,
396        ];
397
398        assert_eq!(out[0], expected[0]);
399        assert_eq!(out[1], expected[1]);
400        assert_eq!(out[2], expected[2]);
401        assert_eq!(out[3], expected[3]);
402    }
403}