Skip to main content

tensorlogic_infer/
eager.rs

1//! Eager mode automatic differentiation.
2//!
3//! This module provides eager execution with automatic differentiation,
4//! similar to PyTorch's autograd or TensorFlow's eager execution.
5//!
6//! Unlike `TlAutodiff` which requires a full `EinsumGraph`, eager mode
7//! computes gradients by building a dynamic computation graph as operations
8//! are executed.
9//!
10//! # Example
11//!
12//! ```ignore
13//! use tensorlogic_infer::eager::{Variable, TlEagerAutodiff};
14//!
15//! // Create variables
16//! let x = Variable::new(tensor_x, true); // requires_grad = true
17//! let y = Variable::new(tensor_y, true);
18//!
19//! // Execute operations eagerly
20//! let z = executor.eager_add(&x, &y)?;
21//! let loss = executor.eager_reduce_sum(&z)?;
22//!
23//! // Compute gradients
24//! let grads = executor.eager_backward(&loss)?;
25//! ```
26
27use crate::ops::{ElemOp, ReduceOp};
28use crate::traits::TlExecutor;
29use std::collections::HashMap;
30
31/// A variable in the eager execution graph.
32///
33/// Wraps a tensor and tracks whether gradients should be computed for it.
34#[derive(Debug, Clone)]
35pub struct Variable<T> {
36    /// The tensor data
37    pub tensor: T,
38    /// Whether this variable requires gradient computation
39    pub requires_grad: bool,
40    /// Unique ID for gradient tracking
41    pub id: usize,
42}
43
44impl<T> Variable<T> {
45    /// Create a new variable.
46    pub fn new(tensor: T, requires_grad: bool) -> Self {
47        use std::sync::atomic::{AtomicUsize, Ordering};
48        static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
49        let id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
50        Variable {
51            tensor,
52            requires_grad,
53            id,
54        }
55    }
56
57    /// Create a constant (no gradient).
58    pub fn constant(tensor: T) -> Self {
59        Self::new(tensor, false)
60    }
61
62    /// Get a reference to the tensor.
63    pub fn tensor(&self) -> &T {
64        &self.tensor
65    }
66
67    /// Check if this variable requires gradients.
68    pub fn requires_grad(&self) -> bool {
69        self.requires_grad
70    }
71}
72
73/// Gradient storage for a variable.
74#[derive(Debug, Clone)]
75pub struct VariableGrad<T> {
76    /// The gradient tensor
77    pub grad: T,
78    /// Whether the gradient has been computed
79    pub computed: bool,
80}
81
82impl<T> VariableGrad<T> {
83    /// Create a new gradient container.
84    pub fn new(grad: T) -> Self {
85        VariableGrad {
86            grad,
87            computed: true,
88        }
89    }
90
91    /// Create an uncomputed gradient placeholder.
92    pub fn placeholder(grad: T) -> Self {
93        VariableGrad {
94            grad,
95            computed: false,
96        }
97    }
98}
99
100/// Tape for recording eager operations and their gradients.
101///
102/// The tape stores the computation graph as operations are executed,
103/// enabling backward pass for gradient computation.
104#[derive(Debug)]
105pub struct EagerTape<T> {
106    /// Map from variable ID to gradient
107    gradients: HashMap<usize, VariableGrad<T>>,
108    /// Operations recorded on the tape
109    operations: Vec<EagerOp<T>>,
110}
111
112impl<T> EagerTape<T> {
113    /// Create a new empty tape.
114    pub fn new() -> Self {
115        EagerTape {
116            gradients: HashMap::new(),
117            operations: Vec::new(),
118        }
119    }
120
121    /// Record an operation on the tape.
122    pub fn record_op(&mut self, op: EagerOp<T>) {
123        self.operations.push(op);
124    }
125
126    /// Set gradient for a variable.
127    pub fn set_gradient(&mut self, var_id: usize, grad: VariableGrad<T>) {
128        self.gradients.insert(var_id, grad);
129    }
130
131    /// Get gradient for a variable.
132    pub fn get_gradient(&self, var_id: usize) -> Option<&VariableGrad<T>> {
133        self.gradients.get(&var_id)
134    }
135
136    /// Get all gradients.
137    pub fn gradients(&self) -> &HashMap<usize, VariableGrad<T>> {
138        &self.gradients
139    }
140
141    /// Get all operations.
142    pub fn operations(&self) -> &[EagerOp<T>] {
143        &self.operations
144    }
145
146    /// Clear the tape.
147    pub fn clear(&mut self) {
148        self.gradients.clear();
149        self.operations.clear();
150    }
151
152    /// Number of operations recorded.
153    pub fn len(&self) -> usize {
154        self.operations.len()
155    }
156
157    /// Check if tape is empty.
158    pub fn is_empty(&self) -> bool {
159        self.operations.is_empty()
160    }
161}
162
163impl<T> Default for EagerTape<T> {
164    fn default() -> Self {
165        Self::new()
166    }
167}
168
169/// An operation recorded in the eager execution tape.
170#[derive(Debug, Clone)]
171pub enum EagerOp<T> {
172    /// Element-wise unary operation
173    ElemUnary {
174        op: ElemOp,
175        input: Variable<T>,
176        output: Variable<T>,
177    },
178    /// Element-wise binary operation
179    ElemBinary {
180        op: ElemOp,
181        left: Variable<T>,
182        right: Variable<T>,
183        output: Variable<T>,
184    },
185    /// Reduction operation
186    Reduce {
187        op: ReduceOp,
188        input: Variable<T>,
189        axes: Vec<usize>,
190        output: Variable<T>,
191    },
192    /// Einsum operation
193    Einsum {
194        spec: String,
195        inputs: Vec<Variable<T>>,
196        output: Variable<T>,
197    },
198}
199
200/// Trait for eager execution with automatic differentiation.
201///
202/// This trait extends `TlExecutor` with eager autodiff capabilities,
203/// allowing gradient computation without building a full graph upfront.
204pub trait TlEagerAutodiff: TlExecutor {
205    /// Execute element-wise unary operation eagerly.
206    ///
207    /// Returns a new variable containing the result. If the input requires
208    /// gradients, the operation is recorded on the tape.
209    fn eager_elem_op(
210        &mut self,
211        op: ElemOp,
212        x: &Variable<Self::Tensor>,
213    ) -> Result<Variable<Self::Tensor>, Self::Error>;
214
215    /// Execute element-wise binary operation eagerly.
216    fn eager_elem_op_binary(
217        &mut self,
218        op: ElemOp,
219        x: &Variable<Self::Tensor>,
220        y: &Variable<Self::Tensor>,
221    ) -> Result<Variable<Self::Tensor>, Self::Error>;
222
223    /// Execute reduction operation eagerly.
224    fn eager_reduce(
225        &mut self,
226        op: ReduceOp,
227        x: &Variable<Self::Tensor>,
228        axes: &[usize],
229    ) -> Result<Variable<Self::Tensor>, Self::Error>;
230
231    /// Execute einsum operation eagerly.
232    fn eager_einsum(
233        &mut self,
234        spec: &str,
235        inputs: &[Variable<Self::Tensor>],
236    ) -> Result<Variable<Self::Tensor>, Self::Error>;
237
238    /// Compute gradients for all variables with respect to the output.
239    ///
240    /// This performs backpropagation through the recorded operations
241    /// to compute gradients.
242    fn eager_backward(
243        &mut self,
244        output: &Variable<Self::Tensor>,
245    ) -> Result<EagerTape<Self::Tensor>, Self::Error>;
246
247    /// Create a new empty tape for recording operations.
248    fn create_tape(&self) -> EagerTape<Self::Tensor> {
249        EagerTape::new()
250    }
251}
252
253/// Convenience methods for common operations.
254pub trait EagerOps: TlEagerAutodiff {
255    /// Add two variables.
256    fn eager_add(
257        &mut self,
258        x: &Variable<Self::Tensor>,
259        y: &Variable<Self::Tensor>,
260    ) -> Result<Variable<Self::Tensor>, Self::Error> {
261        self.eager_elem_op_binary(ElemOp::Add, x, y)
262    }
263
264    /// Multiply two variables.
265    fn eager_mul(
266        &mut self,
267        x: &Variable<Self::Tensor>,
268        y: &Variable<Self::Tensor>,
269    ) -> Result<Variable<Self::Tensor>, Self::Error> {
270        self.eager_elem_op_binary(ElemOp::Multiply, x, y)
271    }
272
273    /// Subtract two variables.
274    fn eager_sub(
275        &mut self,
276        x: &Variable<Self::Tensor>,
277        y: &Variable<Self::Tensor>,
278    ) -> Result<Variable<Self::Tensor>, Self::Error> {
279        self.eager_elem_op_binary(ElemOp::Subtract, x, y)
280    }
281
282    /// Apply Relu activation.
283    fn eager_relu(
284        &mut self,
285        x: &Variable<Self::Tensor>,
286    ) -> Result<Variable<Self::Tensor>, Self::Error> {
287        self.eager_elem_op(ElemOp::Relu, x)
288    }
289
290    /// Apply sigmoid activation.
291    fn eager_sigmoid(
292        &mut self,
293        x: &Variable<Self::Tensor>,
294    ) -> Result<Variable<Self::Tensor>, Self::Error> {
295        self.eager_elem_op(ElemOp::Sigmoid, x)
296    }
297
298    /// Apply one-minus operation (1 - x).
299    fn eager_one_minus(
300        &mut self,
301        x: &Variable<Self::Tensor>,
302    ) -> Result<Variable<Self::Tensor>, Self::Error> {
303        self.eager_elem_op(ElemOp::OneMinus, x)
304    }
305
306    /// Sum reduction along axes.
307    fn eager_sum(
308        &mut self,
309        x: &Variable<Self::Tensor>,
310        axes: &[usize],
311    ) -> Result<Variable<Self::Tensor>, Self::Error> {
312        self.eager_reduce(ReduceOp::Sum, x, axes)
313    }
314
315    /// Mean reduction along axes.
316    fn eager_mean(
317        &mut self,
318        x: &Variable<Self::Tensor>,
319        axes: &[usize],
320    ) -> Result<Variable<Self::Tensor>, Self::Error> {
321        self.eager_reduce(ReduceOp::Mean, x, axes)
322    }
323
324    /// Max reduction along axes.
325    fn eager_max(
326        &mut self,
327        x: &Variable<Self::Tensor>,
328        axes: &[usize],
329    ) -> Result<Variable<Self::Tensor>, Self::Error> {
330        self.eager_reduce(ReduceOp::Max, x, axes)
331    }
332}
333
334/// Automatic implementation of EagerOps for any type implementing TlEagerAutodiff
335impl<T: TlEagerAutodiff> EagerOps for T {}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_variable_creation() {
343        let tensor = vec![1.0, 2.0, 3.0];
344        let var = Variable::new(tensor.clone(), true);
345
346        assert_eq!(var.tensor, tensor);
347        assert!(var.requires_grad);
348        // ID is assigned sequentially starting from 0
349        // Just verify it's a valid ID (any value is fine)
350    }
351
352    #[test]
353    fn test_variable_constant() {
354        let tensor = vec![1.0, 2.0, 3.0];
355        let var = Variable::constant(tensor.clone());
356
357        assert_eq!(var.tensor, tensor);
358        assert!(!var.requires_grad);
359    }
360
361    #[test]
362    fn test_variable_unique_ids() {
363        let var1 = Variable::new(vec![1.0], true);
364        let var2 = Variable::new(vec![2.0], true);
365
366        assert_ne!(var1.id, var2.id);
367    }
368
369    #[test]
370    fn test_eager_tape_creation() {
371        let tape: EagerTape<Vec<f64>> = EagerTape::new();
372
373        assert!(tape.is_empty());
374        assert_eq!(tape.len(), 0);
375        assert_eq!(tape.gradients().len(), 0);
376    }
377
378    #[test]
379    fn test_eager_tape_set_gradient() {
380        let mut tape = EagerTape::new();
381        let grad = VariableGrad::new(vec![1.0, 2.0, 3.0]);
382
383        tape.set_gradient(1, grad);
384
385        assert!(tape.get_gradient(1).is_some());
386        assert!(tape.get_gradient(2).is_none());
387    }
388
389    #[test]
390    fn test_eager_tape_clear() {
391        let mut tape = EagerTape::new();
392        tape.set_gradient(1, VariableGrad::new(vec![1.0]));
393
394        assert!(!tape.is_empty() || !tape.gradients().is_empty());
395
396        tape.clear();
397
398        assert!(tape.is_empty());
399        assert_eq!(tape.gradients().len(), 0);
400    }
401
402    #[test]
403    fn test_variable_grad_creation() {
404        let grad = VariableGrad::new(vec![1.0, 2.0]);
405
406        assert!(grad.computed);
407        assert_eq!(grad.grad, vec![1.0, 2.0]);
408    }
409
410    #[test]
411    fn test_variable_grad_placeholder() {
412        let grad = VariableGrad::placeholder(vec![0.0]);
413
414        assert!(!grad.computed);
415    }
416
417    #[test]
418    fn test_eager_op_variants() {
419        let var1 = Variable::new(vec![1.0], true);
420        let var2 = Variable::new(vec![2.0], true);
421        let var3 = Variable::new(vec![3.0], true);
422
423        // Test ElemUnary variant
424        let _op1 = EagerOp::ElemUnary {
425            op: ElemOp::OneMinus,
426            input: var1.clone(),
427            output: var3.clone(),
428        };
429
430        // Test ElemBinary variant
431        let _op2 = EagerOp::ElemBinary {
432            op: ElemOp::Add,
433            left: var1.clone(),
434            right: var2.clone(),
435            output: var3.clone(),
436        };
437
438        // Test Reduce variant
439        let _op3 = EagerOp::Reduce {
440            op: ReduceOp::Sum,
441            input: var1.clone(),
442            axes: vec![0],
443            output: var3.clone(),
444        };
445
446        // Test Einsum variant
447        let _op4 = EagerOp::Einsum {
448            spec: "ij,jk->ik".to_string(),
449            inputs: vec![var1.clone(), var2.clone()],
450            output: var3.clone(),
451        };
452    }
453
454    #[test]
455    fn test_tape_record_op() {
456        let mut tape = EagerTape::new();
457        let var1 = Variable::new(vec![1.0], true);
458        let var2 = Variable::new(vec![2.0], true);
459
460        let op = EagerOp::ElemBinary {
461            op: ElemOp::Add,
462            left: var1,
463            right: var2.clone(),
464            output: var2,
465        };
466
467        tape.record_op(op);
468
469        assert_eq!(tape.len(), 1);
470        assert!(!tape.is_empty());
471    }
472
473    #[test]
474    fn test_variable_methods() {
475        let tensor = vec![1.0, 2.0, 3.0];
476        let var = Variable::new(tensor.clone(), true);
477
478        assert_eq!(var.tensor(), &tensor);
479        assert!(var.requires_grad());
480    }
481
482    #[test]
483    fn test_tape_default() {
484        let tape: EagerTape<Vec<f64>> = EagerTape::default();
485
486        assert!(tape.is_empty());
487    }
488}