tensorgraph_math/tensor/
vector.rs

1use std::{
2    mem::MaybeUninit,
3    ops::{AddAssign, Mul, MulAssign},
4};
5
6use num_traits::One;
7use tensorgraph_sys::{
8    device::{DefaultDeviceAllocator, Device},
9    ViewMut,
10};
11
12use crate::{
13    blas::{BLASContext, DefaultBLASContext, BLAS1},
14    storage::{IntoOwned, Storage, StorageMut},
15};
16
17use super::{Slice, Tensor, ViewOf};
18
19/// A 1-dimensional tensor
20pub type Vector<S> = Tensor<S, [usize; 1]>;
21
22/// A 'view' of a vector, Like `&[T]` is to `Vec<T>`
23pub type VectorView<'a, T, D> = Vector<&'a Slice<T, D>>;
24
25/// A 'mut view' of a vector, Like `&mut [T]` is to `Vec<T>`
26pub type VectorViewMut<'a, T, D> = Vector<&'a mut Slice<T, D>>;
27
28/// An uninit vector. Contents are mutable and specified as [`MaybeUninit`].
29pub type UninitVector<'a, T, D> = VectorViewMut<'a, MaybeUninit<T>, D>;
30
31impl<S: Storage> Vector<S> {
32    /// Vector dot product
33    pub fn dot(&self, rhs: Vector<&ViewOf<S>>) -> S::T
34    where
35        S::Device: DefaultBLASContext,
36        S::T: BLAS1<<S::Device as DefaultBLASContext>::Context>,
37    {
38        self.dot_using(rhs, Default::default())
39    }
40
41    /// Vector dot product, using the specified [`BLASContext`]
42    ///
43    /// # Panics
44    /// If the vectors do not have the same length
45    #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46    pub fn dot_using<C: BLASContext<Device = S::Device>>(
47        &self,
48        rhs: Vector<&ViewOf<S>>,
49        ctx: C,
50    ) -> S::T
51    where
52        S::T: BLAS1<C>,
53    {
54        let x = self;
55        let y = rhs;
56        let [n] = x.shape;
57        let [m] = y.shape;
58        assert_eq!(n, m);
59
60        let incx = x.strides[0] as i32;
61        let incy = y.strides[0] as i32;
62
63        unsafe {
64            <S::T as BLAS1<C>>::dot(
65                ctx,
66                n as i32,
67                x.data.as_ref().as_ptr(),
68                incx,
69                y.data.as_ptr(),
70                incy,
71            )
72        }
73    }
74}
75
76impl<'a, S: StorageMut> AddAssign<Vector<&'a ViewOf<S>>> for Vector<S>
77where
78    S::Device: DefaultBLASContext,
79    S::T: One + BLAS1<<S::Device as DefaultBLASContext>::Context>,
80{
81    fn add_assign(&mut self, rhs: Vector<&'a ViewOf<S>>) {
82        axpy_ctx(Default::default(), One::one(), rhs, self.view_mut());
83    }
84}
85
86/// Performs the basic vector operation.
87/// > y = alpha * x + y.
88///
89/// # Panics
90/// If the vectors do not have the same length
91#[allow(
92    clippy::cast_possible_wrap,
93    clippy::cast_possible_truncation,
94    clippy::needless_pass_by_value
95)]
96pub fn axpy_ctx<F: BLAS1<C>, C: BLASContext<Device = D>, D: Device>(
97    ctx: C,
98    alpha: F,
99    x: VectorView<F, D>,
100    y: VectorViewMut<F, D>,
101) {
102    let [n] = x.shape;
103    let [m] = y.shape;
104    assert_eq!(n, m);
105
106    let incx = x.strides[0] as i32;
107    let incy = y.strides[0] as i32;
108
109    unsafe {
110        F::axpy(
111            ctx,
112            n as i32,
113            alpha,
114            x.data.as_ref().as_ptr(),
115            incx,
116            y.data.as_ptr(),
117            incy,
118        );
119    }
120}
121
122impl<S: StorageMut> MulAssign<S::T> for Vector<S>
123where
124    S::Device: DefaultBLASContext,
125    S::T: BLAS1<<S::Device as DefaultBLASContext>::Context>,
126{
127    fn mul_assign(&mut self, rhs: S::T) {
128        self.scale_using(rhs, Default::default());
129    }
130}
131
132impl<S: Storage + IntoOwned> Mul<S::T> for Vector<S>
133where
134    S::Device: DefaultBLASContext + DefaultDeviceAllocator,
135    S::T: BLAS1<<S::Device as DefaultBLASContext>::Context>,
136    S::Owned: Storage<T = S::T, Device = S::Device> + StorageMut,
137{
138    type Output = Vector<S::Owned>;
139    fn mul(self, rhs: S::T) -> Self::Output {
140        let mut x = self.into_owned();
141        x *= rhs;
142        x
143    }
144}
145
146impl<S: StorageMut> Vector<S> {
147    /// Vector scaling, using the specified [`BLASContext`]
148    #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
149    pub fn scale_using<C: BLASContext<Device = S::Device>>(&mut self, alpha: S::T, ctx: C)
150    where
151        S::T: BLAS1<C>,
152    {
153        scal_ctx(ctx, alpha, self.view_mut());
154    }
155}
156
157/// Performs the basic vector operation.
158/// > x = alpha * x.
159///
160/// # Panics
161/// If the vectors do not have the same length
162#[allow(
163    clippy::cast_possible_wrap,
164    clippy::cast_possible_truncation,
165    clippy::needless_pass_by_value
166)]
167pub fn scal_ctx<F: BLAS1<C>, C: BLASContext<Device = D>, D: Device>(
168    ctx: C,
169    alpha: F,
170    x: VectorViewMut<F, D>,
171) {
172    let [n] = x.shape;
173    let incx = x.strides[0] as i32;
174
175    unsafe {
176        F::scal(ctx, n as i32, alpha, x.data.as_ref().as_ptr(), incx);
177    }
178}