1#![expect(clippy::module_name_repetitions)]
2
3pub mod activation;
4pub mod concat;
5pub mod creator;
6pub mod functions;
7pub mod loss;
8pub mod nn;
9
10use std::{
11 cell::{Ref, RefCell, RefMut},
12 collections::{BinaryHeap, HashSet},
13 fmt::{Debug, Display},
14 ops::Deref,
15 rc::{Rc, Weak},
16 sync::Mutex,
17};
18
19use creator::{ones::ones, zeros::zeros_like};
20use functions::sum_to::sum_to;
21use lazy_static::lazy_static;
22use serde::{Deserialize, Serialize};
23use zenu_matrix::{
24 device::Device,
25 dim::{larger_shape, DimDyn, DimTrait},
26 matrix::{Matrix, Owned, Ref as MRef},
27 num::Num,
28};
29
30pub(crate) struct ZenuAutogradState {
31 pub(crate) is_drop_name_show: bool,
32}
33
34impl Default for ZenuAutogradState {
35 fn default() -> Self {
36 let is_drop_name_show =
37 std::env::var("ZENU_DROP_NAME_SHOW").unwrap_or("0".to_string()) == "1";
38 if is_drop_name_show {
39 println!("Drop name show");
40 }
41 ZenuAutogradState { is_drop_name_show }
42 }
43}
44
45pub(crate) static ZENU_AUTOGRAD_STATE: once_cell::sync::Lazy<ZenuAutogradState> =
46 once_cell::sync::Lazy::new(ZenuAutogradState::default);
47
48pub trait Function<T: Num, D: Device> {
49 fn forward(&self);
50 fn backward(&self);
51 fn get_inputs(&self) -> Vec<Variable<T, D>>;
52 fn get_gen(&self) -> usize {
53 let inputs = self.get_inputs();
54 inputs.iter().map(Variable::get_gen).max().unwrap()
55 }
56}
57
58lazy_static! {
59 static ref IS_TRAIN: Mutex<bool> = Mutex::new(true);
60}
61
62#[expect(clippy::missing_panics_doc)]
63pub fn no_train() {
64 let mut is_train = IS_TRAIN.lock().unwrap();
65 *is_train = false;
66}
67
68#[expect(clippy::missing_panics_doc)]
69#[must_use]
70pub fn is_train() -> bool {
71 let is_train = IS_TRAIN.lock().unwrap();
72 *is_train
73}
74
75#[expect(clippy::missing_panics_doc)]
76pub fn set_train() {
77 let mut is_train = IS_TRAIN.lock().unwrap();
78 *is_train = true;
79}
80
81#[derive(Clone)]
82pub(crate) struct FunctionQueueItem<T: Num, D: Device> {
83 pub(crate) func: Rc<RefCell<Box<dyn Function<T, D>>>>,
84 pub(crate) gen: usize,
85}
86
87impl<T: Num, D: Device> From<Rc<RefCell<Box<dyn Function<T, D>>>>> for FunctionQueueItem<T, D> {
88 fn from(func: Rc<RefCell<Box<dyn Function<T, D>>>>) -> Self {
89 Self {
90 func: func.clone(),
91 gen: func.borrow().get_gen(),
92 }
93 }
94}
95
96impl<T: Num, D: Device> PartialEq for FunctionQueueItem<T, D> {
97 fn eq(&self, other: &Self) -> bool {
98 self.gen == other.gen
99 }
100}
101
102impl<T: Num, D: Device> Eq for FunctionQueueItem<T, D> {
103 fn assert_receiver_is_total_eq(&self) {}
104}
105
106impl<T: Num, D: Device> PartialOrd for FunctionQueueItem<T, D> {
107 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
108 Some(self.gen.cmp(&other.gen))
109 }
110}
111
112impl<T: Num, D: Device> Ord for FunctionQueueItem<T, D> {
113 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
114 self.gen.cmp(&other.gen)
115 }
116}
117
118impl<T: Num, D: Device> Deref for FunctionQueueItem<T, D> {
119 type Target = Rc<RefCell<Box<dyn Function<T, D>>>>;
120
121 fn deref(&self) -> &Self::Target {
122 &self.func
123 }
124}
125
126#[derive(Clone)]
127pub struct VariableInner<T: Num, D: Device> {
128 data: Matrix<Owned<T>, DimDyn, D>,
129 #[expect(clippy::type_complexity)]
130 creator: Option<Rc<RefCell<Box<dyn Function<T, D>>>>>,
131 grad: Option<Variable<T, D>>,
132 gen: usize,
133 name: Option<String>,
134 is_train: bool,
135}
136
137impl<T: Num, D: Device> Drop for VariableInner<T, D> {
138 fn drop(&mut self) {
139 if ZENU_AUTOGRAD_STATE.is_drop_name_show {
140 if let Some(name) = self.name.clone() {
141 println!("Drop Variable: {name}");
142 } else {
143 println!("Drop Variable");
144 }
145 }
146 }
147}
148
149impl<T, D> Serialize for VariableInner<T, D>
150where
151 T: Num,
152 D: Device,
153{
154 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
155 where
156 S: serde::Serializer,
157 {
158 self.data.serialize(serializer)
159 }
160}
161
162impl<'de, T, D> Deserialize<'de> for VariableInner<T, D>
163where
164 T: Num + Deserialize<'de>,
165 D: Device,
166{
167 fn deserialize<Ds>(deserializer: Ds) -> Result<Self, Ds::Error>
168 where
169 Ds: serde::Deserializer<'de>,
170 {
171 let data = Matrix::<Owned<T>, DimDyn, D>::deserialize(deserializer)?;
172 Ok(VariableInner {
173 data,
174 creator: None,
175 grad: None,
176 gen: 0,
177 name: None,
178 is_train: false,
179 })
180 }
181}
182
183impl<T: Num, D: Device> VariableInner<T, D> {
184 #[must_use]
185 pub fn new(data: Matrix<Owned<T>, DimDyn, D>) -> Self {
186 VariableInner {
187 data,
188 creator: None,
189 grad: None,
190 gen: 0,
191 name: None,
192 is_train: false,
193 }
194 }
195
196 #[expect(clippy::type_complexity)]
197 fn get_creator(&self) -> Option<Rc<RefCell<Box<dyn Function<T, D>>>>> {
198 self.creator.clone()
199 }
200
201 fn set_creator(&mut self, creator: Rc<RefCell<Box<dyn Function<T, D>>>>) {
202 self.creator = Some(creator);
203 let gen = self.creator.as_ref().unwrap().borrow().get_gen();
204 self.gen = gen + 1;
205 }
206
207 fn get_gen(&self) -> usize {
208 self.gen
209 }
210
211 fn get_name(&self) -> Option<String> {
212 self.name.clone()
213 }
214
215 fn set_name(&mut self, name: String) {
216 self.name = Some(name);
217 }
218
219 #[expect(clippy::missing_panics_doc)]
220 pub fn backward(&self) {
221 let mut funcs: BinaryHeap<FunctionQueueItem<T, D>> = BinaryHeap::new();
222 let mut seen_rc = HashSet::new();
223
224 funcs.push(self.creator.clone().unwrap().into());
225
226 while let Some(FunctionQueueItem { func, .. }) = funcs.pop() {
227 func.borrow().backward();
228 func.borrow().get_inputs().iter().for_each(|input| {
229 if let Some(creator) = input.get_creator() {
230 if !seen_rc.contains(&creator.as_ptr()) {
231 funcs.push(creator.clone().into());
232 seen_rc.insert(creator.as_ptr());
233 }
234 }
235 });
236 }
237 }
238
239 fn clear_grad(&mut self) {
240 if let Some(ref mut grad) = self.grad {
241 grad.inner.borrow_mut().clear_grad();
242 }
243 self.grad = None;
244 }
245
246 fn get_is_train(&self) -> bool {
247 self.is_train
248 }
249
250 fn set_is_train(&mut self, is_train: bool) {
251 self.is_train = is_train;
252 }
253
254 fn get_all_variable(&self) -> Vec<Variable<T, D>> {
255 let mut variables = Vec::new();
256 let mut seen_rc = HashSet::new();
257 let mut funcs: BinaryHeap<FunctionQueueItem<T, D>> = BinaryHeap::new();
258
259 funcs.push(self.creator.clone().unwrap().into());
260
261 while let Some(FunctionQueueItem { func, .. }) = funcs.pop() {
262 let inputs = func.borrow().get_inputs();
263 for input in inputs {
264 if let Some(creator) = input.get_creator() {
265 if !seen_rc.contains(&creator.as_ptr()) {
266 funcs.push(creator.clone().into());
267 seen_rc.insert(creator.as_ptr());
268 }
269 }
270 }
271 let inputs = func.borrow().get_inputs();
272 for input in inputs {
273 variables.push(input);
274 }
275 }
276
277 variables.dedup_by(|a, b| a.get_data().as_ptr() == b.get_data().as_ptr());
278 variables
279 }
280
281 fn get_all_trainable_variables(&self) -> Vec<Variable<T, D>> {
282 let variables = self.get_all_variable();
283 variables
284 .into_iter()
285 .filter(Variable::get_is_train)
286 .collect()
287 }
288
289 fn to<DO: Device>(&self) -> VariableInner<T, DO> {
290 assert!(self.grad.is_none(), "grad must be None");
291 VariableInner {
292 data: self.data.new_matrix().to(),
293 creator: None,
294 grad: None,
295 gen: 0,
296 name: self.name.clone(),
297 is_train: self.is_train,
298 }
299 }
300}
301
302#[derive(Clone)]
303pub struct Variable<T: Num, D: Device> {
304 inner: Rc<RefCell<VariableInner<T, D>>>,
305}
306
307impl<T: Num, D: Device> From<T> for Variable<T, D> {
308 fn from(data: T) -> Self {
309 let data = Matrix::from_vec(vec![data], DimDyn::new(&[]));
310 Variable::new(data)
311 }
312}
313
314impl<T: Num, D: Device> From<Matrix<Owned<T>, DimDyn, D>> for Variable<T, D> {
315 fn from(data: Matrix<Owned<T>, DimDyn, D>) -> Self {
316 Variable::new(data)
317 }
318}
319
320impl<T, D> Serialize for Variable<T, D>
321where
322 T: Num,
323 D: Device,
324{
325 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
326 where
327 S: serde::Serializer,
328 {
329 self.inner.borrow().clone().serialize(serializer)
330 }
331}
332
333impl<'de, T, D> Deserialize<'de> for Variable<T, D>
334where
335 T: Num + Deserialize<'de>,
336 D: Device,
337{
338 fn deserialize<Ds>(deserializer: Ds) -> Result<Self, Ds::Error>
339 where
340 Ds: serde::Deserializer<'de>,
341 {
342 let inner = VariableInner::<T, D>::deserialize(deserializer)?;
343 Ok(Variable {
344 inner: Rc::new(RefCell::new(inner)),
345 })
346 }
347}
348
349impl<T: Num, D: Device> Variable<T, D> {
350 #[must_use]
351 pub fn new(data: Matrix<Owned<T>, DimDyn, D>) -> Self {
352 Variable {
353 inner: Rc::new(RefCell::new(VariableInner::new(data))),
354 }
355 }
356
357 #[expect(clippy::missing_panics_doc)]
358 pub fn swap_inner(&self, inner: Matrix<Owned<T>, DimDyn, D>) {
359 assert_eq!(
360 self.get_shape(),
361 inner.shape(),
362 "`Variable::swap_inner`, shape must be same"
363 );
364 self.inner.borrow_mut().data = inner;
365 }
366
367 #[must_use]
368 pub fn get_data<'a>(&'a self) -> Ref<'a, Matrix<Owned<T>, DimDyn, D>> {
369 let reference: Ref<'a, VariableInner<T, D>> = self.inner.borrow();
370 Ref::map(reference, |r| &r.data)
371 }
372
373 #[must_use]
374 pub fn get_as_ref<'a>(&self) -> Matrix<MRef<&'a T>, DimDyn, D> {
375 let data = self.get_data();
376 data.to_ref()
377 }
378
379 #[must_use]
380 pub fn get_as_mut<'a>(&self) -> Matrix<MRef<&'a mut T>, DimDyn, D> {
381 let mut data = self.get_data_mut();
382 data.to_ref_mut()
383 }
384
385 #[must_use]
386 pub fn get_data_mut<'a>(&'a self) -> RefMut<'a, Matrix<Owned<T>, DimDyn, D>> {
387 let reference: RefMut<'a, VariableInner<T, D>> = self.inner.borrow_mut();
388 RefMut::map(reference, |r| &mut r.data)
389 }
390
391 pub fn set_creator(&self, creator: Rc<RefCell<Box<dyn Function<T, D>>>>) {
392 self.inner.borrow_mut().set_creator(creator);
393 }
394
395 #[expect(clippy::type_complexity)]
396 #[must_use]
397 pub fn get_creator(&self) -> Option<Rc<RefCell<Box<dyn Function<T, D>>>>> {
398 self.inner.borrow().get_creator().clone()
399 }
400
401 #[must_use]
402 pub fn get_grad<'a>(&'a self) -> Option<Variable<T, D>> {
403 let reference: Ref<'a, VariableInner<T, D>> = self.inner.borrow();
404 let ref_option = Ref::map(reference, |r| &r.grad);
405 ref_option.clone()
406 }
407
408 fn get_grad_mut<'a>(&'a self) -> RefMut<'a, Option<Variable<T, D>>> {
409 let reference: RefMut<'a, VariableInner<T, D>> = self.inner.borrow_mut();
410 RefMut::map(reference, |r| &mut r.grad)
411 }
412
413 pub fn backward(&self) {
414 if self.inner.borrow().grad.is_none() {
415 let ones = ones(self.get_data().shape());
416 ones.set_name(&format!("{:?}_grad", self.get_name().unwrap_or_default()));
417 self.inner.borrow_mut().grad = Some(ones);
418 }
419 self.inner.borrow().backward();
420 }
421
422 #[must_use]
423 pub fn downgrade(self) -> VariableWeak<T, D> {
424 VariableWeak {
425 inner: Rc::downgrade(&self.inner),
426 }
427 }
428
429 #[must_use]
430 pub fn get_gen(&self) -> usize {
431 self.inner.borrow().get_gen()
432 }
433
434 pub fn clear_grad(&self) {
435 self.inner.borrow_mut().clear_grad();
436 let all_val = self.inner.borrow().get_all_variable();
437 for val in all_val {
438 val.inner.borrow_mut().clear_grad();
439 }
440 }
441
442 pub fn set_name(&self, name: &str) {
443 self.inner.borrow_mut().set_name(name.to_string());
444 }
445
446 #[must_use]
447 pub fn get_name(&self) -> Option<String> {
448 self.inner.borrow().get_name().clone()
449 }
450
451 #[expect(clippy::missing_panics_doc)]
452 pub fn with_grad_data<F>(&self, mut f: F)
453 where
454 F: FnMut(&Matrix<Owned<T>, DimDyn, D>),
455 {
456 let inner = self.inner.borrow();
457 if let Some(grad_variable) = &inner.grad {
458 let grad_inner = grad_variable.inner.borrow();
459 f(&grad_inner.data);
460 } else {
461 panic!("grad is None");
462 }
463 }
464
465 #[expect(clippy::missing_panics_doc)]
466 pub fn set_grad(&self, mut grad: Variable<T, D>) {
467 let self_shape = self.get_shape();
468 let grad_shape = grad.get_shape();
469 let larger_shape_ = larger_shape(self_shape, grad_shape);
470 if self_shape.slice() == grad_shape.slice() {
471 } else if self_shape.slice() == larger_shape_.slice() {
472 grad = zeros_like(self) + grad;
473 } else if grad_shape.slice() == larger_shape_.slice() {
474 grad = sum_to(grad, self_shape);
475 } else {
476 panic!("shape of grad and data must be same");
477 }
478 let name = self.get_name().clone().unwrap_or_default();
479 let mut grad_mut = self.get_grad_mut();
480 if let Some(ref mut grad_variable) = *grad_mut {
481 *grad_variable = grad + grad_variable.clone();
482 } else {
483 grad.set_name(&format!("{name}_grad"));
484 *grad_mut = Some(grad);
485 }
486 }
487
488 #[must_use]
489 pub fn get_is_train(&self) -> bool {
490 self.inner.borrow().get_is_train()
491 }
492
493 pub fn set_is_train(&self, is_train: bool) {
494 self.inner.borrow_mut().set_is_train(is_train);
495 }
496
497 #[must_use]
498 pub fn get_all_trainable_variables(&self) -> Vec<Variable<T, D>> {
499 self.inner.borrow().get_all_trainable_variables()
500 }
501
502 #[must_use]
503 pub fn get_shape(&self) -> DimDyn {
504 self.get_data().shape()
505 }
506
507 #[must_use]
508 pub fn to<DO: Device>(&self) -> Variable<T, DO> {
509 if std::any::TypeId::of::<D>() == std::any::TypeId::of::<DO>() {
510 return unsafe { std::mem::transmute::<Variable<T, D>, Variable<T, DO>>(self.clone()) };
511 }
512 Variable {
513 inner: Rc::new(RefCell::new(self.inner.borrow().to())),
514 }
515 }
516}
517
518#[derive(Debug, Clone)]
519pub struct VariableWeak<T: Num, D: Device> {
520 inner: Weak<RefCell<VariableInner<T, D>>>,
521}
522
523impl<T: Num, D: Device> VariableWeak<T, D> {
524 #[must_use]
525 pub fn upgrade(&self) -> Option<Variable<T, D>> {
526 self.inner.upgrade().map(|inner| Variable { inner })
527 }
528}
529
530impl<T: Num, D: Device> Debug for Variable<T, D> {
531 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532 let inner = self.get_data();
533 write!(f, "Variable {{ data: \n{inner:?} }}")?;
534 Ok(())
535 }
536}
537
538impl<T: Num, D: Device> Display for Variable<T, D> {
539 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540 let inner = self.get_data();
541 write!(f, "Variable {{ data: \n{inner:?} }}")?;
542 Ok(())
543 }
544}
545
546impl<T: Num, D: Device> Debug for VariableInner<T, D> {
547 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 write!(f, "VariableInner {{ data: \n{:?} }}", self.data)?;
549 Ok(())
550 }
551}
552
553impl<T: Num, D: Device> Display for VariableInner<T, D> {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 write!(f, "VariableInner {{ data: \n{:?} }}", self.data)?;
556 Ok(())
557 }
558}