Skip to main content

scivex_nn/
variable.rs

1use std::cell::RefCell;
2use std::collections::HashSet;
3use std::fmt;
4use std::rc::Rc;
5
6use scivex_core::{Float, Tensor};
7
8/// Closure that computes parent gradients from the output gradient.
9///
10/// Given the gradient of the output, returns a `Vec` of gradients for each parent
11/// in the same order as `parents`.
12type GradFn<T> = Box<dyn Fn(&Tensor<T>) -> Vec<Tensor<T>>>;
13
14/// Internal node of the autograd computation graph.
15struct Node<T: Float> {
16    data: Tensor<T>,
17    grad: Option<Tensor<T>>,
18    requires_grad: bool,
19    grad_fn: Option<GradFn<T>>,
20    parents: Vec<Variable<T>>,
21    /// Unique id for topological-sort visited tracking.
22    id: usize,
23}
24
25/// Global counter for node ids.
26fn next_id() -> usize {
27    use std::sync::atomic::{AtomicUsize, Ordering};
28    static COUNTER: AtomicUsize = AtomicUsize::new(0);
29    COUNTER.fetch_add(1, Ordering::Relaxed)
30}
31
32/// A variable in the computation graph that wraps a [`Tensor`] and supports
33/// reverse-mode automatic differentiation.
34///
35/// `Variable<T>` uses shared ownership (`Rc<RefCell<...>>`) so that the same
36/// node can appear as a parent in multiple downstream operations. Cloning a
37/// `Variable` is cheap — it just increments the reference count.
38pub struct Variable<T: Float> {
39    inner: Rc<RefCell<Node<T>>>,
40}
41
42impl<T: Float> Clone for Variable<T> {
43    fn clone(&self) -> Self {
44        Self {
45            inner: Rc::clone(&self.inner),
46        }
47    }
48}
49
50impl<T: Float> Variable<T> {
51    // ── Constructors ────────────────────────────────────────────────
52
53    /// Create a new leaf variable.
54    ///
55    /// # Examples
56    ///
57    /// ```
58    /// # use scivex_core::Tensor;
59    /// # use scivex_nn::Variable;
60    /// let t = Tensor::<f64>::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
61    /// let v = Variable::new(t, true);
62    /// assert!(v.requires_grad());
63    /// assert_eq!(v.shape(), vec![3]);
64    /// ```
65    pub fn new(data: Tensor<T>, requires_grad: bool) -> Self {
66        Self {
67            inner: Rc::new(RefCell::new(Node {
68                data,
69                grad: None,
70                requires_grad,
71                grad_fn: None,
72                parents: Vec::new(),
73                id: next_id(),
74            })),
75        }
76    }
77
78    /// Create an internal (non-leaf) variable produced by an operation.
79    pub(crate) fn from_op(data: Tensor<T>, parents: Vec<Variable<T>>, grad_fn: GradFn<T>) -> Self {
80        Self {
81            inner: Rc::new(RefCell::new(Node {
82                data,
83                grad: None,
84                requires_grad: true,
85                grad_fn: Some(grad_fn),
86                parents,
87                id: next_id(),
88            })),
89        }
90    }
91
92    // ── Accessors ───────────────────────────────────────────────────
93
94    /// Return a clone of the underlying tensor data.
95    pub fn data(&self) -> Tensor<T> {
96        self.inner.borrow().data.clone()
97    }
98
99    /// Return the shape of the underlying tensor.
100    pub fn shape(&self) -> Vec<usize> {
101        self.inner.borrow().data.shape().to_vec()
102    }
103
104    /// Return the accumulated gradient, if any.
105    pub fn grad(&self) -> Option<Tensor<T>> {
106        self.inner.borrow().grad.clone()
107    }
108
109    /// Whether this variable tracks gradients.
110    pub fn requires_grad(&self) -> bool {
111        self.inner.borrow().requires_grad
112    }
113
114    /// Unique node id (used internally for graph traversal).
115    pub(crate) fn id(&self) -> usize {
116        self.inner.borrow().id
117    }
118
119    // ── Gradient helpers ────────────────────────────────────────────
120
121    /// Reset the gradient to `None`.
122    pub fn zero_grad(&self) {
123        self.inner.borrow_mut().grad = None;
124    }
125
126    /// Detach from the computation graph, returning a new leaf variable
127    /// with the same data but no graph history.
128    pub fn detach(&self) -> Self {
129        Self::new(self.data(), false)
130    }
131
132    /// Replace the data tensor (used by optimizers and weight loading).
133    pub fn set_data(&self, data: Tensor<T>) {
134        self.inner.borrow_mut().data = data;
135    }
136
137    /// Replace the gradient tensor (used by `GradScaler` for unscaling).
138    pub fn set_grad(&self, grad: Tensor<T>) {
139        self.inner.borrow_mut().grad = Some(grad);
140    }
141
142    /// Accumulate `g` into this node's gradient (summing if one already exists).
143    pub(crate) fn acc_grad(&self, g: &Tensor<T>) {
144        let mut node = self.inner.borrow_mut();
145        match node.grad.as_mut() {
146            Some(existing) => *existing += g,
147            None => node.grad = Some(g.clone()),
148        }
149    }
150
151    // ── Backward pass ───────────────────────────────────────────────
152
153    /// Run reverse-mode automatic differentiation starting from this variable.
154    ///
155    /// This variable is expected to be a scalar (single-element tensor). Its
156    /// gradient is seeded with `1.0`. After `backward()`, each ancestor with
157    /// `requires_grad == true` will have its `.grad()` populated.
158    pub fn backward(&self) {
159        // Topological sort produces post-order (leaves first).
160        // Reverse to get output-first order for backward pass.
161        let mut order = self.topo_sort();
162        order.reverse();
163
164        // Seed gradient.
165        {
166            let node = self.inner.borrow();
167            let ones = Tensor::ones(node.data.shape().to_vec());
168            drop(node);
169            self.acc_grad(&ones);
170        }
171
172        // Reverse walk.
173        for var in &order {
174            let node = var.inner.borrow();
175            let grad_fn = node.grad_fn.as_ref();
176            let parents_clone: Vec<Variable<T>> = node.parents.clone();
177            let grad_val = node.grad.clone();
178
179            if let (Some(gf), Some(g)) = (grad_fn, grad_val) {
180                let parent_grads = gf(&g);
181                // Drop the borrow before touching parents.
182                drop(node);
183                for (parent, pg) in parents_clone.iter().zip(parent_grads) {
184                    if parent.requires_grad() {
185                        parent.acc_grad(&pg);
186                    }
187                }
188            }
189        }
190    }
191
192    /// Topological sort via iterative DFS (output-first order).
193    fn topo_sort(&self) -> Vec<Variable<T>> {
194        let mut visited = HashSet::new();
195        let mut order = Vec::new();
196
197        // Iterative DFS with explicit stack.
198        // Each entry is (variable, processed_flag).
199        let mut stack: Vec<(Variable<T>, bool)> = vec![(self.clone(), false)];
200
201        while let Some((var, processed)) = stack.pop() {
202            let vid = var.id();
203            if processed {
204                if !visited.contains(&vid) {
205                    visited.insert(vid);
206                    order.push(var);
207                }
208                continue;
209            }
210            if visited.contains(&vid) {
211                continue;
212            }
213            // Push this node again with processed=true so it gets added after children.
214            stack.push((var.clone(), true));
215            let node = var.inner.borrow();
216            for parent in &node.parents {
217                if !visited.contains(&parent.id()) {
218                    stack.push((parent.clone(), false));
219                }
220            }
221        }
222
223        order
224    }
225}
226
227impl<T: Float> fmt::Debug for Variable<T> {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        let node = self.inner.borrow();
230        f.debug_struct("Variable")
231            .field("shape", &node.data.shape())
232            .field("requires_grad", &node.requires_grad)
233            .field("has_grad", &node.grad.is_some())
234            .finish()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_leaf_variable() {
244        let t = Tensor::<f64>::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
245        let v = Variable::new(t.clone(), true);
246        assert_eq!(v.data().as_slice(), t.as_slice());
247        assert!(v.requires_grad());
248        assert!(v.grad().is_none());
249    }
250
251    #[test]
252    fn test_detach() {
253        let t = Tensor::<f64>::ones(vec![2, 3]);
254        let v = Variable::new(t, true);
255        let d = v.detach();
256        assert!(!d.requires_grad());
257    }
258
259    #[test]
260    fn test_zero_grad() {
261        let t = Tensor::<f64>::ones(vec![2]);
262        let v = Variable::new(t, true);
263        v.acc_grad(&Tensor::ones(vec![2]));
264        assert!(v.grad().is_some());
265        v.zero_grad();
266        assert!(v.grad().is_none());
267    }
268
269    #[test]
270    fn test_scalar_backward() {
271        // f(x) = x, x is scalar => grad = 1
272        let x = Variable::new(Tensor::from_vec(vec![3.0_f64], vec![1]).unwrap(), true);
273        // Identity op
274        let y = Variable::from_op(
275            x.data(),
276            vec![x.clone()],
277            Box::new(|g: &Tensor<f64>| vec![g.clone()]),
278        );
279        y.backward();
280        let g = x.grad().unwrap();
281        assert_eq!(g.as_slice(), &[1.0]);
282    }
283
284    #[test]
285    fn test_shape_accessor() {
286        let t = Tensor::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
287        let v = Variable::new(t, false);
288        assert_eq!(v.shape(), vec![2, 3]);
289    }
290
291    #[test]
292    fn test_no_grad_variable_backward_does_not_accumulate() {
293        // A variable with requires_grad=false should not accumulate gradients
294        // even when it appears as a parent.
295        let x = Variable::new(Tensor::from_vec(vec![2.0_f64], vec![1]).unwrap(), false);
296        let y = Variable::new(Tensor::from_vec(vec![3.0_f64], vec![1]).unwrap(), true);
297        let z = Variable::from_op(
298            &x.data() + &y.data(),
299            vec![x.clone(), y.clone()],
300            Box::new(|g: &Tensor<f64>| vec![g.clone(), g.clone()]),
301        );
302        z.backward();
303        // x does not require grad, so it should have no accumulated gradient.
304        assert!(x.grad().is_none());
305        // y requires grad, so it should have accumulated gradient.
306        assert!(y.grad().is_some());
307        assert_eq!(y.grad().unwrap().as_slice(), &[1.0]);
308    }
309
310    #[test]
311    fn test_gradient_accumulation() {
312        // When acc_grad is called twice, gradients should sum.
313        let v = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap(), true);
314        v.acc_grad(&Tensor::from_vec(vec![1.0, 1.0], vec![2]).unwrap());
315        v.acc_grad(&Tensor::from_vec(vec![2.0, 3.0], vec![2]).unwrap());
316        let g = v.grad().unwrap();
317        assert_eq!(g.as_slice(), &[3.0, 4.0]);
318    }
319}