tensorrs/
lib.rs

1//!
2//! # Tensors
3//!
4//! Tensors is a lightweight machine learning library in Rust. It provides a simple and efficient way to create and train machine learning models with minimal dependencies.
5//! ## Dependencies
6//! The library uses the following dependencies:
7//! - [rayon](https://crates.io/crates/rayon) - for parallel computations on CPU.
8//! - [rand](https://crates.io/crates/rand) - for random number generation.
9//! - [serde](https://crates.io/crates/serde) - for saving models.
10//! - [serde_json](https://crates.io/crates/serde_json) - for loading models.
11//!
12//! ## Example
13//! ```rust
14//! use tensorrs::activation::Function;
15//! use tensorrs::DataType;
16//! use tensorrs::linalg::{Matrix, Vector};
17//! use tensorrs::nn::{Linear, Sequential};
18//! use tensorrs::optim::Adam;
19//! use tensorrs::loss::MSE;
20//! use tensorrs::loss::Loss;
21//!
22//! let x = Matrix::from(Vector::range(-1.0, 1.0, 0.125).unwrap());
23//! let y:Matrix<f32> = 8.0 * &x - 10.0;
24//!
25//! let layers: Vec<Box< dyn Function<f32>>> = vec![Box::new(Linear::new(1, 1, true))];
26//! let mut optim = Adam::new(0.001, &layers);
27//! let mut model = Sequential::new(layers);
28//! let loss = MSE::new(DataType::f32());
29//!
30//! for _ in 0..1000 {
31//!     model.train(x.transpose(), y.transpose(), &mut optim, &loss);
32//! }
33//! ```
34//! Thanks for using Tensors!!!
35use std::cmp::PartialOrd;
36use std::fmt::{Debug, Display};
37use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
38
39pub mod activation;
40pub mod linalg;
41pub mod loss;
42pub mod nn;
43pub mod optim;
44pub mod utils;
45//pub(crate) mod onnx_pb;
46
47/// Numeric type
48///
49/// Special Trait
50///
51/// For most of the numbers like (i16, i32, i64, i128, f32, f64)
52pub trait Num:
53    Add<Output = Self>
54    + Sub<Output = Self>
55    + Mul<Output = Self>
56    + Div<Output = Self>
57    + Div
58    + AddAssign
59    + SubAssign
60    + Neg<Output = Self>
61    + PartialOrd
62    + Copy
63    + Clone
64    + From<u8>
65    + Default
66    + Display
67    + Debug
68    + Sync
69    + Send
70    + PartialOrd
71    + 'static
72{
73}
74
75macro_rules! impl_num_for_types {
76    ($($type:ty),*) => {
77        $(
78        impl Num for $type {}
79        )*
80    };
81}
82impl_num_for_types!(i16, i32, i64, i128, f32, f64);
83
84/// Float type
85///
86/// Special Numeric Trait for all floating points numbers
87///
88/// For all float numbers (f32, f64)
89pub trait Float: Num {
90    fn one() -> Self;
91    /// 1 for positive 0 for 0 and -1 for negative
92    fn sign(self) -> Self;
93    fn sqrt(self) -> Self;
94    fn exp(self) -> Self;
95    fn ln(self) -> Self;
96    fn powf(self, n: Self) -> Self;
97    fn abs(self) -> Self;
98    fn neg(self) -> Self;
99
100    fn to_f64(self) -> f64;
101    fn to_f32(self) -> f32;
102
103    fn to_i32(self) -> i32;
104
105    fn selu_lambda(self) -> Self;
106
107    fn selu_alpha(self) -> Self;
108
109    fn from_f64(value: f64) -> Self;
110    fn from_usize(value: usize) -> Self;
111
112    fn from_str(value: &str) -> Self;
113    fn cos(self) -> Self;
114    fn pi() -> Self;
115    fn f32_f64(a: f32, b: f64) -> Self;
116    fn if_f32_f64<T>(a: T, b: T) -> T;
117}
118
119#[warn(dead_code)]
120macro_rules! impl_some_float_for_types {
121    ($($type:ty),*) => {
122        $(
123            fn one() -> Self {1.0}
124            fn pi() -> Self {3.14159}
125            fn sign(self) -> Self {
126                if self > Self::default() {
127                    1.0
128                } else if self == Self::default() {
129                    0.0
130                } else {
131                    -1.0
132                }
133            }
134            fn sqrt(self) -> Self { self.sqrt() }
135            fn cos(self) -> Self {self.cos()}
136            fn exp(self) -> Self {self.exp()}
137            fn ln(self) -> Self { self.ln() }
138            fn abs(self) -> Self { self.abs() }
139            fn powf(self, n: $type) -> Self { self.powf(n) }
140            fn neg(self) -> Self {Neg::neg(self)}
141            fn to_i32(self) -> i32 { self as i32 }
142        )*
143    };
144}
145
146impl Float for f32 {
147    impl_some_float_for_types!(f32);
148
149    fn to_f64(self) -> f64 {
150        self as f64
151    }
152    fn to_f32(self) -> f32 { self }
153
154    fn selu_lambda(self) -> Self {
155        1.0507f32
156    }
157
158    fn selu_alpha(self) -> Self {
159        1.67326f32
160    }
161    fn from_f64(value: f64) -> Self {
162        value as f32
163    }
164
165    fn from_usize(value: usize) -> Self {
166        value as f32
167    }
168
169    fn from_str(value: &str) -> Self {
170        value.parse::<f32>().unwrap()
171    }
172
173    fn f32_f64(a: f32, _: f64) -> Self {
174        a
175    }
176    fn if_f32_f64<T>(a: T, _: T) -> T {a}
177}
178
179impl Float for f64 {
180    impl_some_float_for_types!(f64);
181    fn to_f64(self) -> f64 {
182        self
183    }
184    fn to_f32(self) -> f32 { self as f32 }
185    fn selu_lambda(self) -> Self {
186        1.050700f64
187    }
188
189    fn selu_alpha(self) -> Self {
190        1.673263f64
191    }
192
193    fn from_f64(value: f64) -> Self {
194        value
195    }
196
197    fn from_usize(value: usize) -> Self {
198        value as f64
199    }
200    fn from_str(value: &str) -> Self {
201        value.parse::<f64>().unwrap()
202    }
203    fn f32_f64(_: f32, b: f64) -> Self {
204        b
205    }
206    fn if_f32_f64<T>(_: T, b: T) -> T { b }
207}
208
209///Structure to improve readability
210pub struct DataType;
211
212impl DataType {
213    pub fn i16() -> i16 {
214        0i16
215    }
216    pub fn i32() -> i32 {
217        0i32
218    }
219    pub fn i64() -> i64 {
220        0i64
221    }
222    pub fn i128() -> i128 {
223        0i128
224    }
225
226    pub fn f32() -> f32 {
227        0f32
228    }
229    pub fn f64() -> f64 {
230        0f64
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use crate::activation::{Function, ReLU};
237    use crate::linalg::Matrix;
238    use crate::nn::{Linear, Sequential};
239    use std::time::Instant;
240    use crate::loss::MSE;
241    use crate::matrix;
242    use crate::optim::SGD;
243    //use prost_build::*;
244
245    #[test]
246    fn simple_linear() {
247        let fc1: Linear<f64> = Linear::new(16, 64, true);
248        let fc2: Linear<f64> = Linear::new(64, 64, true);
249        let fc3: Linear<f64> = Linear::new(64, 4, true);
250        let act = ReLU::new();
251
252        let data = Matrix::new(vec![1.0; 16], 1, 16);
253
254        let start_time = Instant::now();
255        let mut ans = fc1.call(data);
256        ans = act.call(ans);
257        ans = fc2.call(ans);
258        ans = act.call(ans);
259        ans = fc3.call(ans);
260        let elapsed_time = start_time.elapsed();
261        println!("Time: {} micros", elapsed_time.as_micros());
262        println!("{}", ans)
263    }
264
265    #[test]
266    fn some_shit() {
267        let x_mx = matrix![
268            [3.0,6.0,7.0],
269            [2.0,1.0,8.0],
270            [1.0, 1.0, 1.0],
271            [5.0, 3.0, 3.0]
272        ];
273        let y_mx = matrix![[135.0, 260.0, 220.0, 360.0]].transpose();
274
275
276        let layers: Vec<Box<dyn Function<f64>>> = vec![
277            Box::new(Linear::new(3, 1, false))
278        ];
279        let mut nn = Sequential::new(layers);
280
281        let err = MSE::new(0.0);
282        let mut optim = SGD::new(0.001);
283
284        for _ in 0..100 {
285            let v = nn.train(
286                x_mx.clone(),
287                y_mx.clone(),
288                &mut optim,
289                &err
290            );
291            if v < 0.1 {
292                break
293            }
294            println!("{v}");
295        }
296        println!("{:?}", nn[0].get_data().unwrap());
297    }
298}