zenu_autograd/
lib.rs

1#![expect(clippy::module_name_repetitions)]
2
3pub mod activation;
4pub mod concat;
5pub mod creator;
6pub mod functions;
7pub mod loss;
8pub mod nn;
9
10use std::{
11    cell::{Ref, RefCell, RefMut},
12    collections::{BinaryHeap, HashSet},
13    fmt::{Debug, Display},
14    ops::Deref,
15    rc::{Rc, Weak},
16    sync::Mutex,
17};
18
19use creator::{ones::ones, zeros::zeros_like};
20use functions::sum_to::sum_to;
21use lazy_static::lazy_static;
22use serde::{Deserialize, Serialize};
23use zenu_matrix::{
24    device::Device,
25    dim::{larger_shape, DimDyn, DimTrait},
26    matrix::{Matrix, Owned, Ref as MRef},
27    num::Num,
28};
29
30pub(crate) struct ZenuAutogradState {
31    pub(crate) is_drop_name_show: bool,
32}
33
34impl Default for ZenuAutogradState {
35    fn default() -> Self {
36        let is_drop_name_show =
37            std::env::var("ZENU_DROP_NAME_SHOW").unwrap_or("0".to_string()) == "1";
38        if is_drop_name_show {
39            println!("Drop name show");
40        }
41        ZenuAutogradState { is_drop_name_show }
42    }
43}
44
45pub(crate) static ZENU_AUTOGRAD_STATE: once_cell::sync::Lazy<ZenuAutogradState> =
46    once_cell::sync::Lazy::new(ZenuAutogradState::default);
47
48pub trait Function<T: Num, D: Device> {
49    fn forward(&self);
50    fn backward(&self);
51    fn get_inputs(&self) -> Vec<Variable<T, D>>;
52    fn get_gen(&self) -> usize {
53        let inputs = self.get_inputs();
54        inputs.iter().map(Variable::get_gen).max().unwrap()
55    }
56}
57
58lazy_static! {
59    static ref IS_TRAIN: Mutex<bool> = Mutex::new(true);
60}
61
62#[expect(clippy::missing_panics_doc)]
63pub fn no_train() {
64    let mut is_train = IS_TRAIN.lock().unwrap();
65    *is_train = false;
66}
67
68#[expect(clippy::missing_panics_doc)]
69#[must_use]
70pub fn is_train() -> bool {
71    let is_train = IS_TRAIN.lock().unwrap();
72    *is_train
73}
74
75#[expect(clippy::missing_panics_doc)]
76pub fn set_train() {
77    let mut is_train = IS_TRAIN.lock().unwrap();
78    *is_train = true;
79}
80
81#[derive(Clone)]
82pub(crate) struct FunctionQueueItem<T: Num, D: Device> {
83    pub(crate) func: Rc<RefCell<Box<dyn Function<T, D>>>>,
84    pub(crate) gen: usize,
85}
86
87impl<T: Num, D: Device> From<Rc<RefCell<Box<dyn Function<T, D>>>>> for FunctionQueueItem<T, D> {
88    fn from(func: Rc<RefCell<Box<dyn Function<T, D>>>>) -> Self {
89        Self {
90            func: func.clone(),
91            gen: func.borrow().get_gen(),
92        }
93    }
94}
95
96impl<T: Num, D: Device> PartialEq for FunctionQueueItem<T, D> {
97    fn eq(&self, other: &Self) -> bool {
98        self.gen == other.gen
99    }
100}
101
102impl<T: Num, D: Device> Eq for FunctionQueueItem<T, D> {
103    fn assert_receiver_is_total_eq(&self) {}
104}
105
106impl<T: Num, D: Device> PartialOrd for FunctionQueueItem<T, D> {
107    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
108        Some(self.gen.cmp(&other.gen))
109    }
110}
111
112impl<T: Num, D: Device> Ord for FunctionQueueItem<T, D> {
113    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
114        self.gen.cmp(&other.gen)
115    }
116}
117
118impl<T: Num, D: Device> Deref for FunctionQueueItem<T, D> {
119    type Target = Rc<RefCell<Box<dyn Function<T, D>>>>;
120
121    fn deref(&self) -> &Self::Target {
122        &self.func
123    }
124}
125
126#[derive(Clone)]
127pub struct VariableInner<T: Num, D: Device> {
128    data: Matrix<Owned<T>, DimDyn, D>,
129    #[expect(clippy::type_complexity)]
130    creator: Option<Rc<RefCell<Box<dyn Function<T, D>>>>>,
131    grad: Option<Variable<T, D>>,
132    gen: usize,
133    name: Option<String>,
134    is_train: bool,
135}
136
137impl<T: Num, D: Device> Drop for VariableInner<T, D> {
138    fn drop(&mut self) {
139        if ZENU_AUTOGRAD_STATE.is_drop_name_show {
140            if let Some(name) = self.name.clone() {
141                println!("Drop Variable: {name}");
142            } else {
143                println!("Drop Variable");
144            }
145        }
146    }
147}
148
149impl<T, D> Serialize for VariableInner<T, D>
150where
151    T: Num,
152    D: Device,
153{
154    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
155    where
156        S: serde::Serializer,
157    {
158        self.data.serialize(serializer)
159    }
160}
161
162impl<'de, T, D> Deserialize<'de> for VariableInner<T, D>
163where
164    T: Num + Deserialize<'de>,
165    D: Device,
166{
167    fn deserialize<Ds>(deserializer: Ds) -> Result<Self, Ds::Error>
168    where
169        Ds: serde::Deserializer<'de>,
170    {
171        let data = Matrix::<Owned<T>, DimDyn, D>::deserialize(deserializer)?;
172        Ok(VariableInner {
173            data,
174            creator: None,
175            grad: None,
176            gen: 0,
177            name: None,
178            is_train: false,
179        })
180    }
181}
182
183impl<T: Num, D: Device> VariableInner<T, D> {
184    #[must_use]
185    pub fn new(data: Matrix<Owned<T>, DimDyn, D>) -> Self {
186        VariableInner {
187            data,
188            creator: None,
189            grad: None,
190            gen: 0,
191            name: None,
192            is_train: false,
193        }
194    }
195
196    #[expect(clippy::type_complexity)]
197    fn get_creator(&self) -> Option<Rc<RefCell<Box<dyn Function<T, D>>>>> {
198        self.creator.clone()
199    }
200
201    fn set_creator(&mut self, creator: Rc<RefCell<Box<dyn Function<T, D>>>>) {
202        self.creator = Some(creator);
203        let gen = self.creator.as_ref().unwrap().borrow().get_gen();
204        self.gen = gen + 1;
205    }
206
207    fn get_gen(&self) -> usize {
208        self.gen
209    }
210
211    fn get_name(&self) -> Option<String> {
212        self.name.clone()
213    }
214
215    fn set_name(&mut self, name: String) {
216        self.name = Some(name);
217    }
218
219    #[expect(clippy::missing_panics_doc)]
220    pub fn backward(&self) {
221        let mut funcs: BinaryHeap<FunctionQueueItem<T, D>> = BinaryHeap::new();
222        let mut seen_rc = HashSet::new();
223
224        funcs.push(self.creator.clone().unwrap().into());
225
226        while let Some(FunctionQueueItem { func, .. }) = funcs.pop() {
227            func.borrow().backward();
228            func.borrow().get_inputs().iter().for_each(|input| {
229                if let Some(creator) = input.get_creator() {
230                    if !seen_rc.contains(&creator.as_ptr()) {
231                        funcs.push(creator.clone().into());
232                        seen_rc.insert(creator.as_ptr());
233                    }
234                }
235            });
236        }
237    }
238
239    fn clear_grad(&mut self) {
240        if let Some(ref mut grad) = self.grad {
241            grad.inner.borrow_mut().clear_grad();
242        }
243        self.grad = None;
244    }
245
246    fn get_is_train(&self) -> bool {
247        self.is_train
248    }
249
250    fn set_is_train(&mut self, is_train: bool) {
251        self.is_train = is_train;
252    }
253
254    fn get_all_variable(&self) -> Vec<Variable<T, D>> {
255        let mut variables = Vec::new();
256        let mut seen_rc = HashSet::new();
257        let mut funcs: BinaryHeap<FunctionQueueItem<T, D>> = BinaryHeap::new();
258
259        funcs.push(self.creator.clone().unwrap().into());
260
261        while let Some(FunctionQueueItem { func, .. }) = funcs.pop() {
262            let inputs = func.borrow().get_inputs();
263            for input in inputs {
264                if let Some(creator) = input.get_creator() {
265                    if !seen_rc.contains(&creator.as_ptr()) {
266                        funcs.push(creator.clone().into());
267                        seen_rc.insert(creator.as_ptr());
268                    }
269                }
270            }
271            let inputs = func.borrow().get_inputs();
272            for input in inputs {
273                variables.push(input);
274            }
275        }
276
277        variables.dedup_by(|a, b| a.get_data().as_ptr() == b.get_data().as_ptr());
278        variables
279    }
280
281    fn get_all_trainable_variables(&self) -> Vec<Variable<T, D>> {
282        let variables = self.get_all_variable();
283        variables
284            .into_iter()
285            .filter(Variable::get_is_train)
286            .collect()
287    }
288
289    fn to<DO: Device>(&self) -> VariableInner<T, DO> {
290        assert!(self.grad.is_none(), "grad must be None");
291        VariableInner {
292            data: self.data.new_matrix().to(),
293            creator: None,
294            grad: None,
295            gen: 0,
296            name: self.name.clone(),
297            is_train: self.is_train,
298        }
299    }
300}
301
302#[derive(Clone)]
303pub struct Variable<T: Num, D: Device> {
304    inner: Rc<RefCell<VariableInner<T, D>>>,
305}
306
307impl<T: Num, D: Device> From<T> for Variable<T, D> {
308    fn from(data: T) -> Self {
309        let data = Matrix::from_vec(vec![data], DimDyn::new(&[]));
310        Variable::new(data)
311    }
312}
313
314impl<T: Num, D: Device> From<Matrix<Owned<T>, DimDyn, D>> for Variable<T, D> {
315    fn from(data: Matrix<Owned<T>, DimDyn, D>) -> Self {
316        Variable::new(data)
317    }
318}
319
320impl<T, D> Serialize for Variable<T, D>
321where
322    T: Num,
323    D: Device,
324{
325    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
326    where
327        S: serde::Serializer,
328    {
329        self.inner.borrow().clone().serialize(serializer)
330    }
331}
332
333impl<'de, T, D> Deserialize<'de> for Variable<T, D>
334where
335    T: Num + Deserialize<'de>,
336    D: Device,
337{
338    fn deserialize<Ds>(deserializer: Ds) -> Result<Self, Ds::Error>
339    where
340        Ds: serde::Deserializer<'de>,
341    {
342        let inner = VariableInner::<T, D>::deserialize(deserializer)?;
343        Ok(Variable {
344            inner: Rc::new(RefCell::new(inner)),
345        })
346    }
347}
348
349impl<T: Num, D: Device> Variable<T, D> {
350    #[must_use]
351    pub fn new(data: Matrix<Owned<T>, DimDyn, D>) -> Self {
352        Variable {
353            inner: Rc::new(RefCell::new(VariableInner::new(data))),
354        }
355    }
356
357    #[expect(clippy::missing_panics_doc)]
358    pub fn swap_inner(&self, inner: Matrix<Owned<T>, DimDyn, D>) {
359        assert_eq!(
360            self.get_shape(),
361            inner.shape(),
362            "`Variable::swap_inner`, shape must be same"
363        );
364        self.inner.borrow_mut().data = inner;
365    }
366
367    #[must_use]
368    pub fn get_data<'a>(&'a self) -> Ref<'a, Matrix<Owned<T>, DimDyn, D>> {
369        let reference: Ref<'a, VariableInner<T, D>> = self.inner.borrow();
370        Ref::map(reference, |r| &r.data)
371    }
372
373    #[must_use]
374    pub fn get_as_ref<'a>(&self) -> Matrix<MRef<&'a T>, DimDyn, D> {
375        let data = self.get_data();
376        data.to_ref()
377    }
378
379    #[must_use]
380    pub fn get_as_mut<'a>(&self) -> Matrix<MRef<&'a mut T>, DimDyn, D> {
381        let mut data = self.get_data_mut();
382        data.to_ref_mut()
383    }
384
385    #[must_use]
386    pub fn get_data_mut<'a>(&'a self) -> RefMut<'a, Matrix<Owned<T>, DimDyn, D>> {
387        let reference: RefMut<'a, VariableInner<T, D>> = self.inner.borrow_mut();
388        RefMut::map(reference, |r| &mut r.data)
389    }
390
391    pub fn set_creator(&self, creator: Rc<RefCell<Box<dyn Function<T, D>>>>) {
392        self.inner.borrow_mut().set_creator(creator);
393    }
394
395    #[expect(clippy::type_complexity)]
396    #[must_use]
397    pub fn get_creator(&self) -> Option<Rc<RefCell<Box<dyn Function<T, D>>>>> {
398        self.inner.borrow().get_creator().clone()
399    }
400
401    #[must_use]
402    pub fn get_grad<'a>(&'a self) -> Option<Variable<T, D>> {
403        let reference: Ref<'a, VariableInner<T, D>> = self.inner.borrow();
404        let ref_option = Ref::map(reference, |r| &r.grad);
405        ref_option.clone()
406    }
407
408    fn get_grad_mut<'a>(&'a self) -> RefMut<'a, Option<Variable<T, D>>> {
409        let reference: RefMut<'a, VariableInner<T, D>> = self.inner.borrow_mut();
410        RefMut::map(reference, |r| &mut r.grad)
411    }
412
413    pub fn backward(&self) {
414        if self.inner.borrow().grad.is_none() {
415            let ones = ones(self.get_data().shape());
416            ones.set_name(&format!("{:?}_grad", self.get_name().unwrap_or_default()));
417            self.inner.borrow_mut().grad = Some(ones);
418        }
419        self.inner.borrow().backward();
420    }
421
422    #[must_use]
423    pub fn downgrade(self) -> VariableWeak<T, D> {
424        VariableWeak {
425            inner: Rc::downgrade(&self.inner),
426        }
427    }
428
429    #[must_use]
430    pub fn get_gen(&self) -> usize {
431        self.inner.borrow().get_gen()
432    }
433
434    pub fn clear_grad(&self) {
435        self.inner.borrow_mut().clear_grad();
436        let all_val = self.inner.borrow().get_all_variable();
437        for val in all_val {
438            val.inner.borrow_mut().clear_grad();
439        }
440    }
441
442    pub fn set_name(&self, name: &str) {
443        self.inner.borrow_mut().set_name(name.to_string());
444    }
445
446    #[must_use]
447    pub fn get_name(&self) -> Option<String> {
448        self.inner.borrow().get_name().clone()
449    }
450
451    #[expect(clippy::missing_panics_doc)]
452    pub fn with_grad_data<F>(&self, mut f: F)
453    where
454        F: FnMut(&Matrix<Owned<T>, DimDyn, D>),
455    {
456        let inner = self.inner.borrow();
457        if let Some(grad_variable) = &inner.grad {
458            let grad_inner = grad_variable.inner.borrow();
459            f(&grad_inner.data);
460        } else {
461            panic!("grad is None");
462        }
463    }
464
465    #[expect(clippy::missing_panics_doc)]
466    pub fn set_grad(&self, mut grad: Variable<T, D>) {
467        let self_shape = self.get_shape();
468        let grad_shape = grad.get_shape();
469        let larger_shape_ = larger_shape(self_shape, grad_shape);
470        if self_shape.slice() == grad_shape.slice() {
471        } else if self_shape.slice() == larger_shape_.slice() {
472            grad = zeros_like(self) + grad;
473        } else if grad_shape.slice() == larger_shape_.slice() {
474            grad = sum_to(grad, self_shape);
475        } else {
476            panic!("shape of grad and data must be same");
477        }
478        let name = self.get_name().clone().unwrap_or_default();
479        let mut grad_mut = self.get_grad_mut();
480        if let Some(ref mut grad_variable) = *grad_mut {
481            *grad_variable = grad + grad_variable.clone();
482        } else {
483            grad.set_name(&format!("{name}_grad"));
484            *grad_mut = Some(grad);
485        }
486    }
487
488    #[must_use]
489    pub fn get_is_train(&self) -> bool {
490        self.inner.borrow().get_is_train()
491    }
492
493    pub fn set_is_train(&self, is_train: bool) {
494        self.inner.borrow_mut().set_is_train(is_train);
495    }
496
497    #[must_use]
498    pub fn get_all_trainable_variables(&self) -> Vec<Variable<T, D>> {
499        self.inner.borrow().get_all_trainable_variables()
500    }
501
502    #[must_use]
503    pub fn get_shape(&self) -> DimDyn {
504        self.get_data().shape()
505    }
506
507    #[must_use]
508    pub fn to<DO: Device>(&self) -> Variable<T, DO> {
509        if std::any::TypeId::of::<D>() == std::any::TypeId::of::<DO>() {
510            return unsafe { std::mem::transmute::<Variable<T, D>, Variable<T, DO>>(self.clone()) };
511        }
512        Variable {
513            inner: Rc::new(RefCell::new(self.inner.borrow().to())),
514        }
515    }
516}
517
518#[derive(Debug, Clone)]
519pub struct VariableWeak<T: Num, D: Device> {
520    inner: Weak<RefCell<VariableInner<T, D>>>,
521}
522
523impl<T: Num, D: Device> VariableWeak<T, D> {
524    #[must_use]
525    pub fn upgrade(&self) -> Option<Variable<T, D>> {
526        self.inner.upgrade().map(|inner| Variable { inner })
527    }
528}
529
530impl<T: Num, D: Device> Debug for Variable<T, D> {
531    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
532        let inner = self.get_data();
533        write!(f, "Variable {{ data: \n{inner:?} }}")?;
534        Ok(())
535    }
536}
537
538impl<T: Num, D: Device> Display for Variable<T, D> {
539    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
540        let inner = self.get_data();
541        write!(f, "Variable {{ data: \n{inner:?} }}")?;
542        Ok(())
543    }
544}
545
546impl<T: Num, D: Device> Debug for VariableInner<T, D> {
547    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548        write!(f, "VariableInner {{ data: \n{:?} }}", self.data)?;
549        Ok(())
550    }
551}
552
553impl<T: Num, D: Device> Display for VariableInner<T, D> {
554    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555        write!(f, "VariableInner {{ data: \n{:?} }}", self.data)?;
556        Ok(())
557    }
558}