1use crate::ops::{ElemOp, ReduceOp};
28use crate::traits::TlExecutor;
29use std::collections::HashMap;
30
31#[derive(Debug, Clone)]
35pub struct Variable<T> {
36 pub tensor: T,
38 pub requires_grad: bool,
40 pub id: usize,
42}
43
44impl<T> Variable<T> {
45 pub fn new(tensor: T, requires_grad: bool) -> Self {
47 use std::sync::atomic::{AtomicUsize, Ordering};
48 static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
49 let id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
50 Variable {
51 tensor,
52 requires_grad,
53 id,
54 }
55 }
56
57 pub fn constant(tensor: T) -> Self {
59 Self::new(tensor, false)
60 }
61
62 pub fn tensor(&self) -> &T {
64 &self.tensor
65 }
66
67 pub fn requires_grad(&self) -> bool {
69 self.requires_grad
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct VariableGrad<T> {
76 pub grad: T,
78 pub computed: bool,
80}
81
82impl<T> VariableGrad<T> {
83 pub fn new(grad: T) -> Self {
85 VariableGrad {
86 grad,
87 computed: true,
88 }
89 }
90
91 pub fn placeholder(grad: T) -> Self {
93 VariableGrad {
94 grad,
95 computed: false,
96 }
97 }
98}
99
100#[derive(Debug)]
105pub struct EagerTape<T> {
106 gradients: HashMap<usize, VariableGrad<T>>,
108 operations: Vec<EagerOp<T>>,
110}
111
112impl<T> EagerTape<T> {
113 pub fn new() -> Self {
115 EagerTape {
116 gradients: HashMap::new(),
117 operations: Vec::new(),
118 }
119 }
120
121 pub fn record_op(&mut self, op: EagerOp<T>) {
123 self.operations.push(op);
124 }
125
126 pub fn set_gradient(&mut self, var_id: usize, grad: VariableGrad<T>) {
128 self.gradients.insert(var_id, grad);
129 }
130
131 pub fn get_gradient(&self, var_id: usize) -> Option<&VariableGrad<T>> {
133 self.gradients.get(&var_id)
134 }
135
136 pub fn gradients(&self) -> &HashMap<usize, VariableGrad<T>> {
138 &self.gradients
139 }
140
141 pub fn operations(&self) -> &[EagerOp<T>] {
143 &self.operations
144 }
145
146 pub fn clear(&mut self) {
148 self.gradients.clear();
149 self.operations.clear();
150 }
151
152 pub fn len(&self) -> usize {
154 self.operations.len()
155 }
156
157 pub fn is_empty(&self) -> bool {
159 self.operations.is_empty()
160 }
161}
162
163impl<T> Default for EagerTape<T> {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169#[derive(Debug, Clone)]
171pub enum EagerOp<T> {
172 ElemUnary {
174 op: ElemOp,
175 input: Variable<T>,
176 output: Variable<T>,
177 },
178 ElemBinary {
180 op: ElemOp,
181 left: Variable<T>,
182 right: Variable<T>,
183 output: Variable<T>,
184 },
185 Reduce {
187 op: ReduceOp,
188 input: Variable<T>,
189 axes: Vec<usize>,
190 output: Variable<T>,
191 },
192 Einsum {
194 spec: String,
195 inputs: Vec<Variable<T>>,
196 output: Variable<T>,
197 },
198}
199
200pub trait TlEagerAutodiff: TlExecutor {
205 fn eager_elem_op(
210 &mut self,
211 op: ElemOp,
212 x: &Variable<Self::Tensor>,
213 ) -> Result<Variable<Self::Tensor>, Self::Error>;
214
215 fn eager_elem_op_binary(
217 &mut self,
218 op: ElemOp,
219 x: &Variable<Self::Tensor>,
220 y: &Variable<Self::Tensor>,
221 ) -> Result<Variable<Self::Tensor>, Self::Error>;
222
223 fn eager_reduce(
225 &mut self,
226 op: ReduceOp,
227 x: &Variable<Self::Tensor>,
228 axes: &[usize],
229 ) -> Result<Variable<Self::Tensor>, Self::Error>;
230
231 fn eager_einsum(
233 &mut self,
234 spec: &str,
235 inputs: &[Variable<Self::Tensor>],
236 ) -> Result<Variable<Self::Tensor>, Self::Error>;
237
238 fn eager_backward(
243 &mut self,
244 output: &Variable<Self::Tensor>,
245 ) -> Result<EagerTape<Self::Tensor>, Self::Error>;
246
247 fn create_tape(&self) -> EagerTape<Self::Tensor> {
249 EagerTape::new()
250 }
251}
252
253pub trait EagerOps: TlEagerAutodiff {
255 fn eager_add(
257 &mut self,
258 x: &Variable<Self::Tensor>,
259 y: &Variable<Self::Tensor>,
260 ) -> Result<Variable<Self::Tensor>, Self::Error> {
261 self.eager_elem_op_binary(ElemOp::Add, x, y)
262 }
263
264 fn eager_mul(
266 &mut self,
267 x: &Variable<Self::Tensor>,
268 y: &Variable<Self::Tensor>,
269 ) -> Result<Variable<Self::Tensor>, Self::Error> {
270 self.eager_elem_op_binary(ElemOp::Multiply, x, y)
271 }
272
273 fn eager_sub(
275 &mut self,
276 x: &Variable<Self::Tensor>,
277 y: &Variable<Self::Tensor>,
278 ) -> Result<Variable<Self::Tensor>, Self::Error> {
279 self.eager_elem_op_binary(ElemOp::Subtract, x, y)
280 }
281
282 fn eager_relu(
284 &mut self,
285 x: &Variable<Self::Tensor>,
286 ) -> Result<Variable<Self::Tensor>, Self::Error> {
287 self.eager_elem_op(ElemOp::Relu, x)
288 }
289
290 fn eager_sigmoid(
292 &mut self,
293 x: &Variable<Self::Tensor>,
294 ) -> Result<Variable<Self::Tensor>, Self::Error> {
295 self.eager_elem_op(ElemOp::Sigmoid, x)
296 }
297
298 fn eager_one_minus(
300 &mut self,
301 x: &Variable<Self::Tensor>,
302 ) -> Result<Variable<Self::Tensor>, Self::Error> {
303 self.eager_elem_op(ElemOp::OneMinus, x)
304 }
305
306 fn eager_sum(
308 &mut self,
309 x: &Variable<Self::Tensor>,
310 axes: &[usize],
311 ) -> Result<Variable<Self::Tensor>, Self::Error> {
312 self.eager_reduce(ReduceOp::Sum, x, axes)
313 }
314
315 fn eager_mean(
317 &mut self,
318 x: &Variable<Self::Tensor>,
319 axes: &[usize],
320 ) -> Result<Variable<Self::Tensor>, Self::Error> {
321 self.eager_reduce(ReduceOp::Mean, x, axes)
322 }
323
324 fn eager_max(
326 &mut self,
327 x: &Variable<Self::Tensor>,
328 axes: &[usize],
329 ) -> Result<Variable<Self::Tensor>, Self::Error> {
330 self.eager_reduce(ReduceOp::Max, x, axes)
331 }
332}
333
334impl<T: TlEagerAutodiff> EagerOps for T {}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_variable_creation() {
343 let tensor = vec![1.0, 2.0, 3.0];
344 let var = Variable::new(tensor.clone(), true);
345
346 assert_eq!(var.tensor, tensor);
347 assert!(var.requires_grad);
348 }
351
352 #[test]
353 fn test_variable_constant() {
354 let tensor = vec![1.0, 2.0, 3.0];
355 let var = Variable::constant(tensor.clone());
356
357 assert_eq!(var.tensor, tensor);
358 assert!(!var.requires_grad);
359 }
360
361 #[test]
362 fn test_variable_unique_ids() {
363 let var1 = Variable::new(vec![1.0], true);
364 let var2 = Variable::new(vec![2.0], true);
365
366 assert_ne!(var1.id, var2.id);
367 }
368
369 #[test]
370 fn test_eager_tape_creation() {
371 let tape: EagerTape<Vec<f64>> = EagerTape::new();
372
373 assert!(tape.is_empty());
374 assert_eq!(tape.len(), 0);
375 assert_eq!(tape.gradients().len(), 0);
376 }
377
378 #[test]
379 fn test_eager_tape_set_gradient() {
380 let mut tape = EagerTape::new();
381 let grad = VariableGrad::new(vec![1.0, 2.0, 3.0]);
382
383 tape.set_gradient(1, grad);
384
385 assert!(tape.get_gradient(1).is_some());
386 assert!(tape.get_gradient(2).is_none());
387 }
388
389 #[test]
390 fn test_eager_tape_clear() {
391 let mut tape = EagerTape::new();
392 tape.set_gradient(1, VariableGrad::new(vec![1.0]));
393
394 assert!(!tape.is_empty() || !tape.gradients().is_empty());
395
396 tape.clear();
397
398 assert!(tape.is_empty());
399 assert_eq!(tape.gradients().len(), 0);
400 }
401
402 #[test]
403 fn test_variable_grad_creation() {
404 let grad = VariableGrad::new(vec![1.0, 2.0]);
405
406 assert!(grad.computed);
407 assert_eq!(grad.grad, vec![1.0, 2.0]);
408 }
409
410 #[test]
411 fn test_variable_grad_placeholder() {
412 let grad = VariableGrad::placeholder(vec![0.0]);
413
414 assert!(!grad.computed);
415 }
416
417 #[test]
418 fn test_eager_op_variants() {
419 let var1 = Variable::new(vec![1.0], true);
420 let var2 = Variable::new(vec![2.0], true);
421 let var3 = Variable::new(vec![3.0], true);
422
423 let _op1 = EagerOp::ElemUnary {
425 op: ElemOp::OneMinus,
426 input: var1.clone(),
427 output: var3.clone(),
428 };
429
430 let _op2 = EagerOp::ElemBinary {
432 op: ElemOp::Add,
433 left: var1.clone(),
434 right: var2.clone(),
435 output: var3.clone(),
436 };
437
438 let _op3 = EagerOp::Reduce {
440 op: ReduceOp::Sum,
441 input: var1.clone(),
442 axes: vec![0],
443 output: var3.clone(),
444 };
445
446 let _op4 = EagerOp::Einsum {
448 spec: "ij,jk->ik".to_string(),
449 inputs: vec![var1.clone(), var2.clone()],
450 output: var3.clone(),
451 };
452 }
453
454 #[test]
455 fn test_tape_record_op() {
456 let mut tape = EagerTape::new();
457 let var1 = Variable::new(vec![1.0], true);
458 let var2 = Variable::new(vec![2.0], true);
459
460 let op = EagerOp::ElemBinary {
461 op: ElemOp::Add,
462 left: var1,
463 right: var2.clone(),
464 output: var2,
465 };
466
467 tape.record_op(op);
468
469 assert_eq!(tape.len(), 1);
470 assert!(!tape.is_empty());
471 }
472
473 #[test]
474 fn test_variable_methods() {
475 let tensor = vec![1.0, 2.0, 3.0];
476 let var = Variable::new(tensor.clone(), true);
477
478 assert_eq!(var.tensor(), &tensor);
479 assert!(var.requires_grad());
480 }
481
482 #[test]
483 fn test_tape_default() {
484 let tape: EagerTape<Vec<f64>> = EagerTape::default();
485
486 assert!(tape.is_empty());
487 }
488}