1use crate::SmeltError;
2
3pub trait Tensor: Clone {
5 fn shape(&self) -> &[usize];
7 fn zeros(shape: Vec<usize>) -> Self;
9}
10
11pub trait TensorOps<T>:
13 TensorMatmul<T>
14 + TensorMatmulT<T>
15 + TensorAdd<T>
16 + TensorMul<T>
17 + TensorNormalize<T>
18 + TensorSelect<T>
19 + TensorGelu<T>
20 + TensorTanh<T>
21 + TensorSoftmax<T>
22{
23}
24
25pub trait TensorMatmul<T> {
27 fn matmul(a: &T, b: &T, c: &mut T) -> Result<(), SmeltError>;
29}
30
31pub trait TensorMatmulT<T> {
33 fn matmul_t(a: &T, b: &T, c: &mut T) -> Result<(), SmeltError>;
35}
36
37pub trait TensorAdd<T> {
39 fn add(a: &T, b: &mut T) -> Result<(), SmeltError>;
41}
42
43pub trait TensorMul<T> {
45 fn mul(a: &T, b: &mut T) -> Result<(), SmeltError>;
47}
48
49pub trait TensorNormalize<T> {
51 fn normalize(x: &mut T, epsilon: f32) -> Result<(), SmeltError>;
53}
54
55pub trait TensorSelect<T> {
57 fn select(x: &[usize], weight: &T, out: &mut T) -> Result<(), SmeltError>;
59}
60
61pub trait TensorGelu<T> {
63 fn gelu(x: &mut T) -> Result<(), SmeltError>;
65}
66
67pub trait TensorTanh<T> {
69 fn tanh(x: &mut T) -> Result<(), SmeltError>;
71}
72
73pub trait TensorSoftmax<T> {
75 fn softmax(x: &mut T) -> Result<(), SmeltError>;
77}