1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
mod bmm;
mod relu;
mod assign;
mod fill_with;
mod make_contiguous;
mod transpose;
mod clone;
mod compare;
mod tensor_tensor_ops;
use crate::tensors::ShapeStrideTrait;
use crate::{GpuAllocated, GpuTensor, GpuTensorView, GpuTensorViewMut, MutShapeStrideTrait};
impl<'a> GpuTensorView<'a> {
pub async fn contiguous(&self) -> GpuTensor {
make_contiguous::make_contiguous(self.get_gpu(), self).await
}
pub async fn eq(&self, other: &Self) -> bool {
compare::eq(self.get_gpu(), self, other).await
}
}
impl<'a> GpuTensorViewMut<'a> {
pub async fn assign_kernel(&mut self, data: f32) {
assign::assign(self.get_gpu(), self, data).await;
}
}
impl GpuTensor {
pub async fn eq(&self, other: &Self) -> bool {
if self.is_empty() || other.is_empty() {
return self.is_empty() && other.is_empty()
}
compare::eq(self.get_gpu(), &self.view(), &other.view()).await
}
pub async fn transpose(&self) -> GpuTensor {
if self.is_empty(){
return self.clone().await;
}
transpose::transpose(self.get_gpu(), &self).await
}
pub async fn clone(&self) -> GpuTensor {
clone::clone(self.get_gpu(), self).await
}
pub async fn leaky_relu(&self, leakage: f32) -> GpuTensor {
if self.is_empty(){
return self.clone().await;
}
relu::leaky_relu(self.get_gpu(), self, leakage).await
}
pub async fn fill_with(&mut self, value: f32) {
if self.is_empty(){
return;
}
fill_with::fill_with(self.get_gpu(), self, value).await;
}
pub async fn matmul<'a>(&'a self, other: &'a Self) -> Self {
if self.is_empty() || other.is_empty(){
panic!("Tried to matmul with at least one empty Tensor")
}
let gpu = self.get_gpu();
assert!(
self.shape().len() >= 2 && other.shape().len() >= 2,
"Input to matmul must be of rank 2 or 3"
);
let (mut input_data_a_view, mut input_data_b_view) =
self.broadcast(other, Some(2)).unwrap();
if input_data_a_view.rank() == 2 {
input_data_a_view.increase_rank();
input_data_b_view.increase_rank();
}
assert_eq!(input_data_a_view.shape().len(), 3);
assert_eq!(input_data_b_view.shape().len(), 3);
assert_eq!(
input_data_a_view.shape()[2],
input_data_b_view.shape()[1],
"Shapes do not match for matrix multiply: {:?} and {:?}",
input_data_a_view.shape(),
input_data_b_view.shape()
);
let mut res = bmm::bmm_kernel(gpu, &input_data_a_view, &input_data_b_view).await;
if self.shape().len() == 2 && other.shape().len() == 2 {
res.decrease_rank();
}
res
}
}