1use std::cell::RefCell;
2use std::collections::HashSet;
3use std::fmt;
4use std::rc::Rc;
5
6use scivex_core::{Float, Tensor};
7
8type GradFn<T> = Box<dyn Fn(&Tensor<T>) -> Vec<Tensor<T>>>;
13
14struct Node<T: Float> {
16 data: Tensor<T>,
17 grad: Option<Tensor<T>>,
18 requires_grad: bool,
19 grad_fn: Option<GradFn<T>>,
20 parents: Vec<Variable<T>>,
21 id: usize,
23}
24
25fn next_id() -> usize {
27 use std::sync::atomic::{AtomicUsize, Ordering};
28 static COUNTER: AtomicUsize = AtomicUsize::new(0);
29 COUNTER.fetch_add(1, Ordering::Relaxed)
30}
31
32pub struct Variable<T: Float> {
39 inner: Rc<RefCell<Node<T>>>,
40}
41
42impl<T: Float> Clone for Variable<T> {
43 fn clone(&self) -> Self {
44 Self {
45 inner: Rc::clone(&self.inner),
46 }
47 }
48}
49
50impl<T: Float> Variable<T> {
51 pub fn new(data: Tensor<T>, requires_grad: bool) -> Self {
66 Self {
67 inner: Rc::new(RefCell::new(Node {
68 data,
69 grad: None,
70 requires_grad,
71 grad_fn: None,
72 parents: Vec::new(),
73 id: next_id(),
74 })),
75 }
76 }
77
78 pub(crate) fn from_op(data: Tensor<T>, parents: Vec<Variable<T>>, grad_fn: GradFn<T>) -> Self {
80 Self {
81 inner: Rc::new(RefCell::new(Node {
82 data,
83 grad: None,
84 requires_grad: true,
85 grad_fn: Some(grad_fn),
86 parents,
87 id: next_id(),
88 })),
89 }
90 }
91
92 pub fn data(&self) -> Tensor<T> {
96 self.inner.borrow().data.clone()
97 }
98
99 pub fn shape(&self) -> Vec<usize> {
101 self.inner.borrow().data.shape().to_vec()
102 }
103
104 pub fn grad(&self) -> Option<Tensor<T>> {
106 self.inner.borrow().grad.clone()
107 }
108
109 pub fn requires_grad(&self) -> bool {
111 self.inner.borrow().requires_grad
112 }
113
114 pub(crate) fn id(&self) -> usize {
116 self.inner.borrow().id
117 }
118
119 pub fn zero_grad(&self) {
123 self.inner.borrow_mut().grad = None;
124 }
125
126 pub fn detach(&self) -> Self {
129 Self::new(self.data(), false)
130 }
131
132 pub fn set_data(&self, data: Tensor<T>) {
134 self.inner.borrow_mut().data = data;
135 }
136
137 pub fn set_grad(&self, grad: Tensor<T>) {
139 self.inner.borrow_mut().grad = Some(grad);
140 }
141
142 pub(crate) fn acc_grad(&self, g: &Tensor<T>) {
144 let mut node = self.inner.borrow_mut();
145 match node.grad.as_mut() {
146 Some(existing) => *existing += g,
147 None => node.grad = Some(g.clone()),
148 }
149 }
150
151 pub fn backward(&self) {
159 let mut order = self.topo_sort();
162 order.reverse();
163
164 {
166 let node = self.inner.borrow();
167 let ones = Tensor::ones(node.data.shape().to_vec());
168 drop(node);
169 self.acc_grad(&ones);
170 }
171
172 for var in &order {
174 let node = var.inner.borrow();
175 let grad_fn = node.grad_fn.as_ref();
176 let parents_clone: Vec<Variable<T>> = node.parents.clone();
177 let grad_val = node.grad.clone();
178
179 if let (Some(gf), Some(g)) = (grad_fn, grad_val) {
180 let parent_grads = gf(&g);
181 drop(node);
183 for (parent, pg) in parents_clone.iter().zip(parent_grads) {
184 if parent.requires_grad() {
185 parent.acc_grad(&pg);
186 }
187 }
188 }
189 }
190 }
191
192 fn topo_sort(&self) -> Vec<Variable<T>> {
194 let mut visited = HashSet::new();
195 let mut order = Vec::new();
196
197 let mut stack: Vec<(Variable<T>, bool)> = vec![(self.clone(), false)];
200
201 while let Some((var, processed)) = stack.pop() {
202 let vid = var.id();
203 if processed {
204 if !visited.contains(&vid) {
205 visited.insert(vid);
206 order.push(var);
207 }
208 continue;
209 }
210 if visited.contains(&vid) {
211 continue;
212 }
213 stack.push((var.clone(), true));
215 let node = var.inner.borrow();
216 for parent in &node.parents {
217 if !visited.contains(&parent.id()) {
218 stack.push((parent.clone(), false));
219 }
220 }
221 }
222
223 order
224 }
225}
226
227impl<T: Float> fmt::Debug for Variable<T> {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 let node = self.inner.borrow();
230 f.debug_struct("Variable")
231 .field("shape", &node.data.shape())
232 .field("requires_grad", &node.requires_grad)
233 .field("has_grad", &node.grad.is_some())
234 .finish()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_leaf_variable() {
244 let t = Tensor::<f64>::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
245 let v = Variable::new(t.clone(), true);
246 assert_eq!(v.data().as_slice(), t.as_slice());
247 assert!(v.requires_grad());
248 assert!(v.grad().is_none());
249 }
250
251 #[test]
252 fn test_detach() {
253 let t = Tensor::<f64>::ones(vec![2, 3]);
254 let v = Variable::new(t, true);
255 let d = v.detach();
256 assert!(!d.requires_grad());
257 }
258
259 #[test]
260 fn test_zero_grad() {
261 let t = Tensor::<f64>::ones(vec![2]);
262 let v = Variable::new(t, true);
263 v.acc_grad(&Tensor::ones(vec![2]));
264 assert!(v.grad().is_some());
265 v.zero_grad();
266 assert!(v.grad().is_none());
267 }
268
269 #[test]
270 fn test_scalar_backward() {
271 let x = Variable::new(Tensor::from_vec(vec![3.0_f64], vec![1]).unwrap(), true);
273 let y = Variable::from_op(
275 x.data(),
276 vec![x.clone()],
277 Box::new(|g: &Tensor<f64>| vec![g.clone()]),
278 );
279 y.backward();
280 let g = x.grad().unwrap();
281 assert_eq!(g.as_slice(), &[1.0]);
282 }
283
284 #[test]
285 fn test_shape_accessor() {
286 let t = Tensor::<f64>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
287 let v = Variable::new(t, false);
288 assert_eq!(v.shape(), vec![2, 3]);
289 }
290
291 #[test]
292 fn test_no_grad_variable_backward_does_not_accumulate() {
293 let x = Variable::new(Tensor::from_vec(vec![2.0_f64], vec![1]).unwrap(), false);
296 let y = Variable::new(Tensor::from_vec(vec![3.0_f64], vec![1]).unwrap(), true);
297 let z = Variable::from_op(
298 &x.data() + &y.data(),
299 vec![x.clone(), y.clone()],
300 Box::new(|g: &Tensor<f64>| vec![g.clone(), g.clone()]),
301 );
302 z.backward();
303 assert!(x.grad().is_none());
305 assert!(y.grad().is_some());
307 assert_eq!(y.grad().unwrap().as_slice(), &[1.0]);
308 }
309
310 #[test]
311 fn test_gradient_accumulation() {
312 let v = Variable::new(Tensor::from_vec(vec![1.0_f64, 2.0], vec![2]).unwrap(), true);
314 v.acc_grad(&Tensor::from_vec(vec![1.0, 1.0], vec![2]).unwrap());
315 v.acc_grad(&Tensor::from_vec(vec![2.0, 3.0], vec![2]).unwrap());
316 let g = v.grad().unwrap();
317 assert_eq!(g.as_slice(), &[3.0, 4.0]);
318 }
319}