smelte_rs/
traits.rs

1use crate::SmeltError;
2
3/// TODO
4pub trait Tensor: Clone {
5    /// TODO
6    fn shape(&self) -> &[usize];
7    /// TODO
8    fn zeros(shape: Vec<usize>) -> Self;
9}
10
11/// All common tensor operations
12pub trait TensorOps<T>:
13    TensorMatmul<T>
14    + TensorMatmulT<T>
15    + TensorAdd<T>
16    + TensorMul<T>
17    + TensorNormalize<T>
18    + TensorSelect<T>
19    + TensorGelu<T>
20    + TensorTanh<T>
21    + TensorSoftmax<T>
22{
23}
24
25/// TODO
26pub trait TensorMatmul<T> {
27    /// TODO
28    fn matmul(a: &T, b: &T, c: &mut T) -> Result<(), SmeltError>;
29}
30
31/// TODO
32pub trait TensorMatmulT<T> {
33    /// TODO
34    fn matmul_t(a: &T, b: &T, c: &mut T) -> Result<(), SmeltError>;
35}
36
37/// TODO
38pub trait TensorAdd<T> {
39    /// TODO
40    fn add(a: &T, b: &mut T) -> Result<(), SmeltError>;
41}
42
43/// TODO
44pub trait TensorMul<T> {
45    /// TODO
46    fn mul(a: &T, b: &mut T) -> Result<(), SmeltError>;
47}
48
49/// TODO
50pub trait TensorNormalize<T> {
51    /// TODO
52    fn normalize(x: &mut T, epsilon: f32) -> Result<(), SmeltError>;
53}
54
55/// TODO
56pub trait TensorSelect<T> {
57    /// TODO
58    fn select(x: &[usize], weight: &T, out: &mut T) -> Result<(), SmeltError>;
59}
60
61/// TODO
62pub trait TensorGelu<T> {
63    /// TODO
64    fn gelu(x: &mut T) -> Result<(), SmeltError>;
65}
66
67/// TODO
68pub trait TensorTanh<T> {
69    /// TODO
70    fn tanh(x: &mut T) -> Result<(), SmeltError>;
71}
72
73/// TODO
74pub trait TensorSoftmax<T> {
75    /// TODO
76    fn softmax(x: &mut T) -> Result<(), SmeltError>;
77}