1use std::{
2 mem::MaybeUninit,
3 ops::{AddAssign, Mul, MulAssign},
4};
5
6use num_traits::One;
7use tensorgraph_sys::{
8 device::{DefaultDeviceAllocator, Device},
9 ViewMut,
10};
11
12use crate::{
13 blas::{BLASContext, DefaultBLASContext, BLAS1},
14 storage::{IntoOwned, Storage, StorageMut},
15};
16
17use super::{Slice, Tensor, ViewOf};
18
19pub type Vector<S> = Tensor<S, [usize; 1]>;
21
22pub type VectorView<'a, T, D> = Vector<&'a Slice<T, D>>;
24
25pub type VectorViewMut<'a, T, D> = Vector<&'a mut Slice<T, D>>;
27
28pub type UninitVector<'a, T, D> = VectorViewMut<'a, MaybeUninit<T>, D>;
30
31impl<S: Storage> Vector<S> {
32 pub fn dot(&self, rhs: Vector<&ViewOf<S>>) -> S::T
34 where
35 S::Device: DefaultBLASContext,
36 S::T: BLAS1<<S::Device as DefaultBLASContext>::Context>,
37 {
38 self.dot_using(rhs, Default::default())
39 }
40
41 #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
46 pub fn dot_using<C: BLASContext<Device = S::Device>>(
47 &self,
48 rhs: Vector<&ViewOf<S>>,
49 ctx: C,
50 ) -> S::T
51 where
52 S::T: BLAS1<C>,
53 {
54 let x = self;
55 let y = rhs;
56 let [n] = x.shape;
57 let [m] = y.shape;
58 assert_eq!(n, m);
59
60 let incx = x.strides[0] as i32;
61 let incy = y.strides[0] as i32;
62
63 unsafe {
64 <S::T as BLAS1<C>>::dot(
65 ctx,
66 n as i32,
67 x.data.as_ref().as_ptr(),
68 incx,
69 y.data.as_ptr(),
70 incy,
71 )
72 }
73 }
74}
75
76impl<'a, S: StorageMut> AddAssign<Vector<&'a ViewOf<S>>> for Vector<S>
77where
78 S::Device: DefaultBLASContext,
79 S::T: One + BLAS1<<S::Device as DefaultBLASContext>::Context>,
80{
81 fn add_assign(&mut self, rhs: Vector<&'a ViewOf<S>>) {
82 axpy_ctx(Default::default(), One::one(), rhs, self.view_mut());
83 }
84}
85
86#[allow(
92 clippy::cast_possible_wrap,
93 clippy::cast_possible_truncation,
94 clippy::needless_pass_by_value
95)]
96pub fn axpy_ctx<F: BLAS1<C>, C: BLASContext<Device = D>, D: Device>(
97 ctx: C,
98 alpha: F,
99 x: VectorView<F, D>,
100 y: VectorViewMut<F, D>,
101) {
102 let [n] = x.shape;
103 let [m] = y.shape;
104 assert_eq!(n, m);
105
106 let incx = x.strides[0] as i32;
107 let incy = y.strides[0] as i32;
108
109 unsafe {
110 F::axpy(
111 ctx,
112 n as i32,
113 alpha,
114 x.data.as_ref().as_ptr(),
115 incx,
116 y.data.as_ptr(),
117 incy,
118 );
119 }
120}
121
122impl<S: StorageMut> MulAssign<S::T> for Vector<S>
123where
124 S::Device: DefaultBLASContext,
125 S::T: BLAS1<<S::Device as DefaultBLASContext>::Context>,
126{
127 fn mul_assign(&mut self, rhs: S::T) {
128 self.scale_using(rhs, Default::default());
129 }
130}
131
132impl<S: Storage + IntoOwned> Mul<S::T> for Vector<S>
133where
134 S::Device: DefaultBLASContext + DefaultDeviceAllocator,
135 S::T: BLAS1<<S::Device as DefaultBLASContext>::Context>,
136 S::Owned: Storage<T = S::T, Device = S::Device> + StorageMut,
137{
138 type Output = Vector<S::Owned>;
139 fn mul(self, rhs: S::T) -> Self::Output {
140 let mut x = self.into_owned();
141 x *= rhs;
142 x
143 }
144}
145
146impl<S: StorageMut> Vector<S> {
147 #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
149 pub fn scale_using<C: BLASContext<Device = S::Device>>(&mut self, alpha: S::T, ctx: C)
150 where
151 S::T: BLAS1<C>,
152 {
153 scal_ctx(ctx, alpha, self.view_mut());
154 }
155}
156
157#[allow(
163 clippy::cast_possible_wrap,
164 clippy::cast_possible_truncation,
165 clippy::needless_pass_by_value
166)]
167pub fn scal_ctx<F: BLAS1<C>, C: BLASContext<Device = D>, D: Device>(
168 ctx: C,
169 alpha: F,
170 x: VectorViewMut<F, D>,
171) {
172 let [n] = x.shape;
173 let incx = x.strides[0] as i32;
174
175 unsafe {
176 F::scal(ctx, n as i32, alpha, x.data.as_ref().as_ptr(), incx);
177 }
178}