1use std::cmp::PartialOrd;
36use std::fmt::{Debug, Display};
37use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign};
38
39pub mod activation;
40pub mod linalg;
41pub mod loss;
42pub mod nn;
43pub mod optim;
44pub mod utils;
45pub trait Num:
53 Add<Output = Self>
54 + Sub<Output = Self>
55 + Mul<Output = Self>
56 + Div<Output = Self>
57 + Div
58 + AddAssign
59 + SubAssign
60 + Neg<Output = Self>
61 + PartialOrd
62 + Copy
63 + Clone
64 + From<u8>
65 + Default
66 + Display
67 + Debug
68 + Sync
69 + Send
70 + PartialOrd
71 + 'static
72{
73}
74
75macro_rules! impl_num_for_types {
76 ($($type:ty),*) => {
77 $(
78 impl Num for $type {}
79 )*
80 };
81}
82impl_num_for_types!(i16, i32, i64, i128, f32, f64);
83
84pub trait Float: Num {
90 fn one() -> Self;
91 fn sign(self) -> Self;
93 fn sqrt(self) -> Self;
94 fn exp(self) -> Self;
95 fn ln(self) -> Self;
96 fn powf(self, n: Self) -> Self;
97 fn abs(self) -> Self;
98 fn neg(self) -> Self;
99
100 fn to_f64(self) -> f64;
101 fn to_f32(self) -> f32;
102
103 fn to_i32(self) -> i32;
104
105 fn selu_lambda(self) -> Self;
106
107 fn selu_alpha(self) -> Self;
108
109 fn from_f64(value: f64) -> Self;
110 fn from_usize(value: usize) -> Self;
111
112 fn from_str(value: &str) -> Self;
113 fn cos(self) -> Self;
114 fn pi() -> Self;
115 fn f32_f64(a: f32, b: f64) -> Self;
116 fn if_f32_f64<T>(a: T, b: T) -> T;
117}
118
119#[warn(dead_code)]
120macro_rules! impl_some_float_for_types {
121 ($($type:ty),*) => {
122 $(
123 fn one() -> Self {1.0}
124 fn pi() -> Self {3.14159}
125 fn sign(self) -> Self {
126 if self > Self::default() {
127 1.0
128 } else if self == Self::default() {
129 0.0
130 } else {
131 -1.0
132 }
133 }
134 fn sqrt(self) -> Self { self.sqrt() }
135 fn cos(self) -> Self {self.cos()}
136 fn exp(self) -> Self {self.exp()}
137 fn ln(self) -> Self { self.ln() }
138 fn abs(self) -> Self { self.abs() }
139 fn powf(self, n: $type) -> Self { self.powf(n) }
140 fn neg(self) -> Self {Neg::neg(self)}
141 fn to_i32(self) -> i32 { self as i32 }
142 )*
143 };
144}
145
146impl Float for f32 {
147 impl_some_float_for_types!(f32);
148
149 fn to_f64(self) -> f64 {
150 self as f64
151 }
152 fn to_f32(self) -> f32 { self }
153
154 fn selu_lambda(self) -> Self {
155 1.0507f32
156 }
157
158 fn selu_alpha(self) -> Self {
159 1.67326f32
160 }
161 fn from_f64(value: f64) -> Self {
162 value as f32
163 }
164
165 fn from_usize(value: usize) -> Self {
166 value as f32
167 }
168
169 fn from_str(value: &str) -> Self {
170 value.parse::<f32>().unwrap()
171 }
172
173 fn f32_f64(a: f32, _: f64) -> Self {
174 a
175 }
176 fn if_f32_f64<T>(a: T, _: T) -> T {a}
177}
178
179impl Float for f64 {
180 impl_some_float_for_types!(f64);
181 fn to_f64(self) -> f64 {
182 self
183 }
184 fn to_f32(self) -> f32 { self as f32 }
185 fn selu_lambda(self) -> Self {
186 1.050700f64
187 }
188
189 fn selu_alpha(self) -> Self {
190 1.673263f64
191 }
192
193 fn from_f64(value: f64) -> Self {
194 value
195 }
196
197 fn from_usize(value: usize) -> Self {
198 value as f64
199 }
200 fn from_str(value: &str) -> Self {
201 value.parse::<f64>().unwrap()
202 }
203 fn f32_f64(_: f32, b: f64) -> Self {
204 b
205 }
206 fn if_f32_f64<T>(_: T, b: T) -> T { b }
207}
208
209pub struct DataType;
211
212impl DataType {
213 pub fn i16() -> i16 {
214 0i16
215 }
216 pub fn i32() -> i32 {
217 0i32
218 }
219 pub fn i64() -> i64 {
220 0i64
221 }
222 pub fn i128() -> i128 {
223 0i128
224 }
225
226 pub fn f32() -> f32 {
227 0f32
228 }
229 pub fn f64() -> f64 {
230 0f64
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use crate::activation::{Function, ReLU};
237 use crate::linalg::Matrix;
238 use crate::nn::{Linear, Sequential};
239 use std::time::Instant;
240 use crate::loss::MSE;
241 use crate::matrix;
242 use crate::optim::SGD;
243 #[test]
246 fn simple_linear() {
247 let fc1: Linear<f64> = Linear::new(16, 64, true);
248 let fc2: Linear<f64> = Linear::new(64, 64, true);
249 let fc3: Linear<f64> = Linear::new(64, 4, true);
250 let act = ReLU::new();
251
252 let data = Matrix::new(vec![1.0; 16], 1, 16);
253
254 let start_time = Instant::now();
255 let mut ans = fc1.call(data);
256 ans = act.call(ans);
257 ans = fc2.call(ans);
258 ans = act.call(ans);
259 ans = fc3.call(ans);
260 let elapsed_time = start_time.elapsed();
261 println!("Time: {} micros", elapsed_time.as_micros());
262 println!("{}", ans)
263 }
264
265 #[test]
266 fn some_shit() {
267 let x_mx = matrix![
268 [3.0,6.0,7.0],
269 [2.0,1.0,8.0],
270 [1.0, 1.0, 1.0],
271 [5.0, 3.0, 3.0]
272 ];
273 let y_mx = matrix![[135.0, 260.0, 220.0, 360.0]].transpose();
274
275
276 let layers: Vec<Box<dyn Function<f64>>> = vec![
277 Box::new(Linear::new(3, 1, false))
278 ];
279 let mut nn = Sequential::new(layers);
280
281 let err = MSE::new(0.0);
282 let mut optim = SGD::new(0.001);
283
284 for _ in 0..100 {
285 let v = nn.train(
286 x_mx.clone(),
287 y_mx.clone(),
288 &mut optim,
289 &err
290 );
291 if v < 0.1 {
292 break
293 }
294 println!("{v}");
295 }
296 println!("{:?}", nn[0].get_data().unwrap());
297 }
298}