1use std::mem::MaybeUninit;
2
3use tensorgraph_sys::{ptr::Ref, View, ViewMut, device::Device};
4
5use crate::{
6 dims::{Dimension, RemoveDim},
7 storage::{IntoOwned, Storage, StorageMut},
8};
9
10mod matrix;
11mod vector;
12pub use matrix::*;
13pub use vector::*;
14
15pub type Slice<T, D> = Ref<[T], D>;
17
18pub type ViewOf<S> = Slice<<S as Storage>::T, <S as Storage>::Device>;
20
21pub type TensorView<'a, T, D, Dim> = Tensor<&'a Slice<T, D>, Dim>;
23
24pub type TensorViewMut<'a, T, D, Dim> = Tensor<&'a mut Slice<T, D>, Dim>;
26
27pub type UninitTensor<'a, T, D, Dim> = TensorViewMut<'a, MaybeUninit<T>, D, Dim>;
29
30#[derive(Copy, Clone)]
32pub struct Tensor<S: Storage, Dim: Dimension> {
33 shape: Dim,
34 strides: Dim,
35 data: S,
36}
37
38impl<S: Storage, Dim: Dimension> View for Tensor<S, Dim> {
39 type Ref<'a>
40 where
41 Self: 'a,
42 = Tensor<&'a ViewOf<S>, Dim>;
43
44 fn view(&self) -> TensorView<S::T, S::Device, Dim> {
45 Tensor {
46 shape: self.shape.clone(),
47 strides: self.strides.clone(),
48 data: self.data.as_ref(),
49 }
50 }
51}
52
53impl<S: StorageMut, Dim: Dimension> ViewMut for Tensor<S, Dim> {
54 type Mut<'a>
55 where
56 Self: 'a,
57 = Tensor<&'a mut ViewOf<S>, Dim>;
58
59 fn view_mut(&mut self) -> TensorViewMut<S::T, S::Device, Dim> {
60 Tensor {
61 shape: self.shape.clone(),
62 strides: self.strides.clone(),
63 data: self.data.as_mut(),
64 }
65 }
66}
67
68impl<S: Storage, Dim: Dimension> Tensor<S, Dim> {
69 pub fn from_shape(shape: Dim, data: S) -> Self {
74 assert_eq!(data.as_ref().len(), shape.size());
75 let strides = shape.column_major_strides();
76 Self {
77 shape,
78 strides,
79 data,
80 }
81 }
82
83 pub fn into_inner(self) -> S {
85 self.data
86 }
87
88 pub fn reverse_axes(&mut self) {
90 self.shape.as_mut().reverse();
91 self.strides.as_mut().reverse();
92 }
93
94 pub fn swap_axes(&mut self, i: usize, j: usize) {
95 self.shape.as_mut().swap(i, j);
96 self.strides.as_mut().swap(i, j);
97 }
98
99 pub fn t(&self) -> Tensor<&ViewOf<S>, Dim> {
102 let mut view = self.view();
103 view.reverse_axes();
104 view
105 }
106
107 pub fn into_owned(self) -> Tensor<S::Owned, Dim>
110 where
111 S: IntoOwned,
112 S::Owned: Storage<T = S::T, Device = S::Device>,
113 {
114 Tensor {
115 shape: self.shape,
116 strides: self.strides,
117 data: self.data.into_owned(),
118 }
119 }
120
121 pub fn slice_axis(&self, axis: usize, n: usize) -> Tensor<&ViewOf<S>, Dim::Smaller>
126 where
127 Dim: RemoveDim,
128 {
129 assert!(axis < self.shape.as_ref().len());
130
131 let (shape, m) = self.shape.remove(axis);
132 let (strides, s) = self.strides.remove(axis);
133
134 assert!(n < m);
135
136 Tensor {
137 shape,
138 strides,
139 data: &self.data.as_ref()[s * n..],
140 }
141 }
142}
143
144impl<'a, T, D: Device, Dim: Dimension> UninitTensor<'a, T, D, Dim> {
145 pub unsafe fn assume_init(self) -> TensorViewMut<'a, T, D, Dim> {
148 Tensor {
149 shape: self.shape,
150 strides: self.strides,
151 data: self.data.assume_init_mut(),
152 }
153 }
154}
155
156impl<'a, T, D: Device, Dim: Dimension> TensorView<'a, MaybeUninit<T>, D, Dim> {
157 pub unsafe fn assume_init(self) -> TensorView<'a, T, D, Dim> {
160 Tensor {
161 shape: self.shape,
162 strides: self.strides,
163 data: self.data.assume_init(),
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use tensorgraph_sys::{View, ViewMut};
171
172 use crate::tensor::{gemm, Tensor};
173
174 #[test]
175 fn matmul() {
176 let a = [0., 2., 4., 1., 3., 5.];
185 let b = [0., 2., 1., 3.];
186 let a = Tensor::from_shape([3, 2], a); let b = Tensor::from_shape([2, 2], b); let c = a.matmul(b.view());
194 assert_eq!(c.into_inner().into_std(), [2., 6., 10., 3., 11., 19.]);
195 }
196
197 #[test]
198 fn matmul_t() {
199 let a = [1., 2., 3., 4.];
206 let b = [5., 6., 7., 8.];
207 let a = Tensor::from_shape([2, 2], a);
208 let b = Tensor::from_shape([2, 2], b);
209
210 let c1 = a.t().matmul(b.t());
214 assert_eq!(c1.into_inner().into_std(), [19.0, 43.0, 22.0, 50.0]);
215
216 let c2 = a.matmul(b.t());
220 assert_eq!(c2.into_inner().into_std(), [26.0, 38.0, 30.0, 44.0]);
221
222 let c3 = a.t().matmul(b.view());
226 assert_eq!(c3.into_inner().into_std(), [17.0, 39.0, 23.0, 53.0]);
227 }
228
229 #[test]
230 fn slice() {
231 let a = [0., 1., 2., 3., 4., 5.];
236 let a = Tensor::from_shape([2, 3], a);
237
238 let a00 = a.slice_axis(0, 0);
240 assert_eq!(&**a00.into_inner(), [0., 1., 2., 3., 4., 5.]); assert_eq!(a00.shape, [3]);
242 assert_eq!(a00.strides, [2]);
243
244 let a01 = a.slice_axis(0, 1);
245 assert_eq!(&**a01.into_inner(), [1., 2., 3., 4., 5.]); assert_eq!(a01.shape, [3]);
247 assert_eq!(a01.strides, [2]);
248
249 let a10 = a.slice_axis(1, 0);
251 assert_eq!(&**a10.into_inner(), [0., 1., 2., 3., 4., 5.]); assert_eq!(a10.shape, [2]);
253 assert_eq!(a10.strides, [1]);
254
255 let a11 = a.slice_axis(1, 1);
256 assert_eq!(&**a11.into_inner(), [2., 3., 4., 5.]); assert_eq!(a11.shape, [2]);
258 assert_eq!(a11.strides, [1]);
259
260 let a12 = a.slice_axis(1, 2);
261 assert_eq!(&**a12.into_inner(), [4., 5.]); assert_eq!(a12.shape, [2]);
263 assert_eq!(a12.strides, [1]);
264 }
265
266 #[test]
267 #[cfg(feature = "cublas")]
268 fn matmul_cuda() {
269 use crate::blas::cublas::CublasContext;
270 use tensorgraph_sys::{
271 device::cuda::{Context, Stream},
272 Vec,
273 };
274
275 let ctx = Context::quick_init().unwrap();
276 let cuda = Stream::new(&ctx).unwrap();
277 let cuda = &*cuda;
278
279 let a = Vec::copy_from_host_in(&[0., 2., 4., 1., 3., 5.], cuda);
281 let b = Vec::copy_from_host_in(&[0., 2., 1., 3.], cuda);
282
283 let ctx = CublasContext::new();
284 let ctx = ctx.with_stream(Some(cuda));
285
286 let a = Tensor::from_shape([3, 2], a);
287 let b = Tensor::from_shape([2, 2], b);
288
289 let c = a.matmul_into(b.view(), ctx, cuda);
290
291 let mut out = vec![0.0_f32; 6];
292 c.data.copy_to_host(&mut out);
293
294 assert_eq!(out, vec![2., 6., 10., 3., 11., 19.]);
295 }
296
297 #[test]
298 #[cfg(feature = "cublas")]
299 fn matmul_cuda_global() {
300 use crate::blas::cublas::CublasContext;
301 use tensorgraph_sys::{
302 device::cuda::{Context, Cuda, Stream},
303 DefaultVec,
304 };
305
306 let ctx = Context::quick_init().unwrap();
307 let cuda = Stream::new(&ctx).unwrap();
308 let _handle = cuda.as_global();
309
310 let a = DefaultVec::<f32, Cuda>::copy_from_host(&[0., 2., 4., 1., 3., 5.]);
312 let b = DefaultVec::<f32, Cuda>::copy_from_host(&[0., 2., 1., 3.]);
313
314 let ctx = CublasContext::new();
315 let _handle = ctx.with_stream(Some(&cuda)).as_global();
316
317 let a = Tensor::from_shape([3, 2], a);
318 let b = Tensor::from_shape([2, 2], b);
319
320 let c = a.matmul(b.view());
321
322 let mut out = vec![0.0_f32; 6];
323 c.data.copy_to_host(&mut out);
324
325 assert_eq!(out, vec![2., 6., 10., 3., 11., 19.]);
326 }
327
328 #[test]
329 fn matmul2() {
330 let a = [0.001, 1.0, 1.0, 0.];
332 let b = a;
333 let c = [0.; 4];
334
335 let mut a = Tensor::from_shape([2, 2], a);
336 let b = Tensor::from_shape([2, 2], b);
337 let mut c = Tensor::from_shape([2, 2], c);
338
339 for _ in 0..1000 {
340 gemm(1., a.view(), b.view(), 0., c.view_mut());
341 std::mem::swap(&mut a, &mut c);
342 }
343
344 let out = c.into_inner();
345 let expected = [
346 1.1278865019586632,
347 0.5210952168646452,
348 0.5210952168646452,
349 1.1273654067417986,
350 ];
351
352 assert_eq!(out[0], expected[0]);
353 assert_eq!(out[1], expected[1]);
354 assert_eq!(out[2], expected[2]);
355 assert_eq!(out[3], expected[3]);
356 }
357
358 #[test]
359 #[cfg(feature = "cublas")]
360 fn matmul_cuda2() {
361 use crate::{blas::cublas::CublasContext, tensor::gemm_ctx};
362 use tensorgraph_sys::{
363 device::cuda::{Context, Stream},
364 Vec,
365 };
366
367 let ctx = Context::quick_init().unwrap();
368 let cuda = Stream::new(&ctx).unwrap();
369 let cuda = &*cuda;
370
371 let a = Vec::copy_from_host_in(&[0.001, 1.0, 1.0, 0.], cuda);
373 let b = a.clone();
374 let c = b.clone();
375
376 let ctx = CublasContext::new();
377 let ctx = ctx.with_stream(Some(cuda));
378
379 let mut a = Tensor::from_shape([2, 2], a);
380 let b = Tensor::from_shape([2, 2], b);
381 let mut c = Tensor::from_shape([2, 2], c);
382
383 for _ in 0..1000 {
384 gemm_ctx(ctx, 1., a.view(), b.view(), 0., c.view_mut());
385 std::mem::swap(&mut a, &mut c);
386 }
387
388 let mut out = vec![0.; 4];
389 c.data.copy_to_host(&mut out);
390
391 let expected = [
392 1.1278865019586632,
393 0.5210952168646452,
394 0.5210952168646452,
395 1.1273654067417986,
396 ];
397
398 assert_eq!(out[0], expected[0]);
399 assert_eq!(out[1], expected[1]);
400 assert_eq!(out[2], expected[2]);
401 assert_eq!(out[3], expected[3]);
402 }
403}