Skip to main content

trustformers_core/autodiff/
variable.rs

1//! Variable implementation for automatic differentiation.
2//!
3//! This module provides the Variable type, which wraps tensors and enables
4//! automatic gradient computation through the computational graph.
5
6use super::graph::{ComputationGraph, NodeId, OperationType};
7use crate::errors::{Result, TrustformersError};
8use crate::tensor::Tensor;
9use std::sync::{Arc, Mutex};
10
11/// Reference to a shared computation graph
12pub type GraphRef = Arc<Mutex<ComputationGraph>>;
13
14/// Variable that participates in automatic differentiation
15#[derive(Debug, Clone)]
16pub struct Variable {
17    /// Reference to the computation graph
18    graph: GraphRef,
19    /// Node ID in the computation graph
20    node_id: NodeId,
21    /// Whether this variable requires gradients
22    requires_grad: bool,
23}
24
25/// Shared reference to a variable
26pub type VariableRef = Arc<Variable>;
27
28impl Variable {
29    /// Create a new variable from a tensor
30    pub fn new(tensor: Tensor, requires_grad: bool) -> Self {
31        let graph = Arc::new(Mutex::new(ComputationGraph::new()));
32        let node_id = {
33            let mut graph_guard = graph.lock().expect("lock should not be poisoned");
34            graph_guard.add_node(tensor, requires_grad, None)
35        };
36
37        Self {
38            graph,
39            node_id,
40            requires_grad,
41        }
42    }
43
44    /// Create a new variable with a name
45    pub fn new_with_name(tensor: Tensor, requires_grad: bool, name: String) -> Self {
46        let graph = Arc::new(Mutex::new(ComputationGraph::new()));
47        let node_id = {
48            let mut graph_guard = graph.lock().expect("lock should not be poisoned");
49            graph_guard.add_node(tensor, requires_grad, Some(name))
50        };
51
52        Self {
53            graph,
54            node_id,
55            requires_grad,
56        }
57    }
58
59    /// Create a new variable from an existing graph
60    pub fn from_graph(graph: GraphRef, node_id: NodeId, requires_grad: bool) -> Self {
61        Self {
62            graph,
63            node_id,
64            requires_grad,
65        }
66    }
67
68    /// Get the tensor data
69    pub fn data(&self) -> Result<Tensor> {
70        let graph = self.graph.lock().expect("lock should not be poisoned");
71        graph.get_value(self.node_id).cloned().ok_or_else(|| {
72            TrustformersError::tensor_op_error(
73                &format!("Node {} not found in graph", self.node_id),
74                "Variable::data",
75            )
76        })
77    }
78
79    /// Get the gradient
80    pub fn grad(&self) -> Result<Option<Tensor>> {
81        let graph = self.graph.lock().expect("lock should not be poisoned");
82        Ok(graph.get_gradient(self.node_id).cloned())
83    }
84
85    /// Get the node ID
86    pub fn node_id(&self) -> NodeId {
87        self.node_id
88    }
89
90    /// Check if this variable requires gradients
91    pub fn requires_grad(&self) -> bool {
92        self.requires_grad
93    }
94
95    /// Get the graph reference
96    pub fn graph(&self) -> GraphRef {
97        self.graph.clone()
98    }
99
100    /// Get the shape of the tensor
101    pub fn shape(&self) -> Result<Vec<usize>> {
102        let graph = self.graph.lock().expect("lock should not be poisoned");
103        graph.get_value(self.node_id).map(|tensor| tensor.shape()).ok_or_else(|| {
104            TrustformersError::tensor_op_error(
105                &format!("Node {} not found in graph", self.node_id),
106                "Variable::shape",
107            )
108        })
109    }
110
111    /// Convert to a scalar value
112    pub fn item(&self) -> Result<f32> {
113        let tensor = self.data()?;
114        tensor.to_scalar()
115    }
116
117    /// Compute backward pass for this variable
118    pub fn backward(&self) -> Result<()> {
119        let mut graph = self.graph.lock().expect("lock should not be poisoned");
120        graph.backward(self.node_id, None)
121    }
122
123    /// Compute backward pass with custom gradient
124    pub fn backward_with_grad(&self, grad: Tensor) -> Result<()> {
125        let mut graph = self.graph.lock().expect("lock should not be poisoned");
126        graph.backward(self.node_id, Some(grad))
127    }
128
129    /// Zero the gradients
130    pub fn zero_grad(&self) {
131        let mut graph = self.graph.lock().expect("lock should not be poisoned");
132        graph.zero_grad();
133    }
134
135    /// Detach this variable from the computation graph
136    pub fn detach(&self) -> Result<Variable> {
137        let tensor = self.data()?;
138        Ok(Variable::new(tensor, false))
139    }
140
141    /// Create a copy of this variable that requires gradients
142    pub fn requires_grad_(&self) -> Result<Variable> {
143        let tensor = self.data()?;
144        Ok(Variable::new(tensor, true))
145    }
146
147    /// Update the value of this variable
148    pub fn set_data(&self, tensor: Tensor) -> Result<()> {
149        let mut graph = self.graph.lock().expect("lock should not be poisoned");
150        graph.update_value(self.node_id, tensor)
151    }
152
153    // Arithmetic operations
154
155    /// Add another variable
156    pub fn add(&self, other: &Variable) -> Result<Variable> {
157        self.binary_op(other, OperationType::Add)
158    }
159
160    /// Subtract another variable
161    pub fn sub(&self, other: &Variable) -> Result<Variable> {
162        self.binary_op(other, OperationType::Subtract)
163    }
164
165    /// Multiply by another variable
166    pub fn mul(&self, other: &Variable) -> Result<Variable> {
167        self.binary_op(other, OperationType::Multiply)
168    }
169
170    /// Divide by another variable
171    pub fn div(&self, other: &Variable) -> Result<Variable> {
172        self.binary_op(other, OperationType::Divide)
173    }
174
175    /// Matrix multiplication
176    pub fn matmul(&self, other: &Variable) -> Result<Variable> {
177        self.binary_op(other, OperationType::MatrixMultiply)
178    }
179
180    /// Negation
181    pub fn neg(&self) -> Result<Variable> {
182        self.unary_op(OperationType::Negate)
183    }
184
185    /// Square
186    pub fn square(&self) -> Result<Variable> {
187        self.unary_op(OperationType::Square)
188    }
189
190    /// Square root
191    pub fn sqrt(&self) -> Result<Variable> {
192        self.unary_op(OperationType::Sqrt)
193    }
194
195    /// Natural logarithm
196    pub fn log(&self) -> Result<Variable> {
197        self.unary_op(OperationType::Log)
198    }
199
200    /// Exponential
201    pub fn exp(&self) -> Result<Variable> {
202        self.unary_op(OperationType::Exp)
203    }
204
205    // Activation functions
206
207    /// Sigmoid activation
208    pub fn sigmoid(&self) -> Result<Variable> {
209        self.unary_op(OperationType::Sigmoid)
210    }
211
212    /// Tanh activation
213    pub fn tanh(&self) -> Result<Variable> {
214        self.unary_op(OperationType::Tanh)
215    }
216
217    /// ReLU activation
218    pub fn relu(&self) -> Result<Variable> {
219        self.unary_op(OperationType::ReLU)
220    }
221
222    /// Leaky ReLU activation
223    pub fn leaky_relu(&self, alpha: f32) -> Result<Variable> {
224        self.unary_op(OperationType::LeakyReLU(alpha))
225    }
226
227    /// Softmax activation
228    pub fn softmax(&self) -> Result<Variable> {
229        self.unary_op(OperationType::Softmax)
230    }
231
232    // Tensor operations
233
234    /// Reshape the tensor
235    pub fn reshape(&self, shape: Vec<usize>) -> Result<Variable> {
236        self.unary_op(OperationType::Reshape(shape))
237    }
238
239    /// Transpose the tensor
240    pub fn transpose(&self, permutation: Vec<usize>) -> Result<Variable> {
241        self.unary_op(OperationType::Transpose(permutation))
242    }
243
244    /// Sum along specified axes
245    pub fn sum(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
246        self.unary_op(OperationType::Sum(axes))
247    }
248
249    /// Mean along specified axes
250    pub fn mean(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
251        self.unary_op(OperationType::Mean(axes))
252    }
253
254    /// Max along specified axes
255    pub fn max(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
256        self.unary_op(OperationType::Max(axes))
257    }
258
259    /// Min along specified axes
260    pub fn min(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
261        self.unary_op(OperationType::Min(axes))
262    }
263
264    // Scalar operations
265
266    /// Add a scalar
267    pub fn add_scalar(&self, scalar: f32) -> Result<Variable> {
268        let scalar_tensor = Tensor::scalar(scalar)?;
269        let scalar_var = Variable::new(scalar_tensor, false);
270        self.add(&scalar_var)
271    }
272
273    /// Subtract a scalar
274    pub fn sub_scalar(&self, scalar: f32) -> Result<Variable> {
275        let scalar_tensor = Tensor::scalar(scalar)?;
276        let scalar_var = Variable::new(scalar_tensor, false);
277        self.sub(&scalar_var)
278    }
279
280    /// Multiply by a scalar
281    pub fn mul_scalar(&self, scalar: f32) -> Result<Variable> {
282        let scalar_tensor = Tensor::scalar(scalar)?;
283        let scalar_var = Variable::new(scalar_tensor, false);
284        self.mul(&scalar_var)
285    }
286
287    /// Divide by a scalar
288    pub fn div_scalar(&self, scalar: f32) -> Result<Variable> {
289        let scalar_tensor = Tensor::scalar(scalar)?;
290        let scalar_var = Variable::new(scalar_tensor, false);
291        self.div(&scalar_var)
292    }
293
294    // Helper methods for operations
295
296    /// Binary operation helper
297    fn binary_op(&self, other: &Variable, op: OperationType) -> Result<Variable> {
298        // Check if both variables are from the same graph
299        if !Arc::ptr_eq(&self.graph, &other.graph) {
300            return Err(TrustformersError::tensor_op_error(
301                "Variables must be from the same computation graph",
302                "Variable::binary_op",
303            ));
304        }
305
306        // Compute the operation on the tensor data
307        let result_tensor = self.compute_binary_tensor_op(&other.data()?, &op)?;
308
309        // Add operation node to the graph
310        let requires_grad = self.requires_grad || other.requires_grad;
311        let node_id = {
312            let mut graph = self.graph.lock().expect("lock should not be poisoned");
313            graph.add_operation_node(
314                result_tensor,
315                op,
316                vec![self.node_id, other.node_id],
317                requires_grad,
318                None,
319            )?
320        };
321
322        Ok(Variable::from_graph(
323            self.graph.clone(),
324            node_id,
325            requires_grad,
326        ))
327    }
328
329    /// Unary operation helper
330    fn unary_op(&self, op: OperationType) -> Result<Variable> {
331        // Compute the operation on the tensor data
332        let result_tensor = self.compute_unary_tensor_op(&op)?;
333
334        // Add operation node to the graph
335        let node_id = {
336            let mut graph = self.graph.lock().expect("lock should not be poisoned");
337            graph.add_operation_node(
338                result_tensor,
339                op,
340                vec![self.node_id],
341                self.requires_grad,
342                None,
343            )?
344        };
345
346        Ok(Variable::from_graph(
347            self.graph.clone(),
348            node_id,
349            self.requires_grad,
350        ))
351    }
352
353    /// Compute binary tensor operation
354    fn compute_binary_tensor_op(&self, other: &Tensor, op: &OperationType) -> Result<Tensor> {
355        let self_tensor = self.data()?;
356
357        match op {
358            OperationType::Add => Tensor::add(&self_tensor, other),
359            OperationType::Subtract => Tensor::sub(&self_tensor, other),
360            OperationType::Multiply => self_tensor.mul(other),
361            OperationType::Divide => Tensor::div(&self_tensor, other),
362            OperationType::MatrixMultiply => self_tensor.matmul(other),
363            _ => Err(TrustformersError::tensor_op_error(
364                &format!("Unsupported binary operation: {:?}", op),
365                "Variable::compute_binary_tensor_op",
366            )),
367        }
368    }
369
370    /// Compute unary tensor operation
371    fn compute_unary_tensor_op(&self, op: &OperationType) -> Result<Tensor> {
372        let self_tensor = self.data()?;
373
374        match op {
375            OperationType::Negate => self_tensor.neg(),
376            OperationType::Square => self_tensor.clone().mul(&self_tensor),
377            OperationType::Sqrt => self_tensor.sqrt(),
378            OperationType::Log => self_tensor.log(),
379            OperationType::Exp => self_tensor.exp(),
380            OperationType::Sigmoid => self_tensor.sigmoid(),
381            OperationType::Tanh => self_tensor.tanh(),
382            OperationType::ReLU => self_tensor.relu(),
383            OperationType::LeakyReLU(alpha) => self_tensor.leaky_relu(*alpha),
384            OperationType::Softmax => self_tensor.softmax(-1),
385            OperationType::Reshape(shape) => self_tensor.reshape(shape),
386            OperationType::Transpose(permutation) => {
387                // For now, handle simple 2D transpose case
388                if permutation.len() >= 2 {
389                    self_tensor.transpose(permutation[0], permutation[1])
390                } else {
391                    // Default transpose for 2D case
392                    self_tensor.transpose(0, 1)
393                }
394            },
395            OperationType::Sum(axes) => {
396                match axes {
397                    Some(axes_vec) => self_tensor.sum_axes(axes_vec),
398                    None => {
399                        // Sum all elements (global sum)
400                        let shape = self_tensor.shape();
401                        let all_axes: Vec<usize> = (0..shape.len()).collect();
402                        self_tensor.sum_axes(&all_axes)
403                    },
404                }
405            },
406            OperationType::Mean(_axes) => {
407                // For now, just compute global mean
408                self_tensor.mean()
409            },
410            _ => Err(TrustformersError::tensor_op_error(
411                &format!("Unsupported unary operation: {:?}", op),
412                "Variable::compute_unary_tensor_op",
413            )),
414        }
415    }
416
417    /// Set whether this variable requires gradients
418    pub fn set_requires_grad(&mut self, requires_grad: bool) {
419        self.requires_grad = requires_grad;
420        // Also update the node in the graph
421        if let Ok(mut graph) = self.graph.lock() {
422            if let Some(node) = graph.get_node_mut(self.node_id) {
423                node.requires_grad = requires_grad;
424            }
425        }
426    }
427
428    /// Create a variable from a tensor (with requires_grad = false by default)
429    pub fn from_tensor(tensor: Tensor) -> Self {
430        Variable::new(tensor, false)
431    }
432}
433
434/// Convenience functions for creating variables
435impl Variable {
436    /// Create a variable from a scalar
437    pub fn scalar(value: f32, requires_grad: bool) -> Result<Self> {
438        let tensor = Tensor::scalar(value)?;
439        Ok(Variable::new(tensor, requires_grad))
440    }
441
442    /// Create a variable with zeros
443    pub fn zeros(shape: &[usize], requires_grad: bool) -> Result<Self> {
444        let tensor = Tensor::zeros(shape)?;
445        Ok(Variable::new(tensor, requires_grad))
446    }
447
448    /// Create a variable with ones
449    pub fn ones(shape: &[usize], requires_grad: bool) -> Result<Self> {
450        let tensor = Tensor::ones(shape)?;
451        Ok(Variable::new(tensor, requires_grad))
452    }
453
454    /// Create a variable with random normal distribution
455    pub fn randn(shape: &[usize], requires_grad: bool) -> Result<Self> {
456        let tensor = Tensor::randn(shape)?;
457        Ok(Variable::new(tensor, requires_grad))
458    }
459
460    /// Create a variable with random uniform distribution
461    pub fn rand(shape: &[usize], requires_grad: bool) -> Result<Self> {
462        let tensor = Tensor::randn(shape)?;
463        Ok(Variable::new(tensor, requires_grad))
464    }
465}
466
467/// Operator overloading for Variables
468use std::ops::{Add, Div, Mul, Neg, Sub};
469
470impl Add for &Variable {
471    type Output = Result<Variable>;
472
473    fn add(self, rhs: Self) -> Self::Output {
474        self.add(rhs)
475    }
476}
477
478impl Sub for &Variable {
479    type Output = Result<Variable>;
480
481    fn sub(self, rhs: Self) -> Self::Output {
482        self.sub(rhs)
483    }
484}
485
486impl Mul for &Variable {
487    type Output = Result<Variable>;
488
489    fn mul(self, rhs: Self) -> Self::Output {
490        self.mul(rhs)
491    }
492}
493
494impl Div for &Variable {
495    type Output = Result<Variable>;
496
497    fn div(self, rhs: Self) -> Self::Output {
498        self.div(rhs)
499    }
500}
501
502impl Neg for &Variable {
503    type Output = Result<Variable>;
504
505    fn neg(self) -> Self::Output {
506        self.neg()
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513    use crate::tensor::Tensor;
514
515    #[test]
516    fn test_variable_creation() {
517        let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
518        let var = Variable::new(tensor, true);
519
520        assert!(var.requires_grad());
521        assert_eq!(var.shape().expect("operation failed in test"), vec![2, 3]);
522    }
523
524    #[test]
525    fn test_variable_operations() {
526        use super::super::AutodiffEngine;
527        use std::sync::Arc;
528
529        let engine = Arc::new(AutodiffEngine::default());
530        let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
531        let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
532
533        let c = a.add(&b).expect("Addition failed");
534        assert_eq!(c.item().expect("operation failed in test"), 5.0);
535
536        let d = a.mul(&b).expect("Multiplication failed");
537        assert_eq!(d.item().expect("operation failed in test"), 6.0);
538    }
539
540    #[test]
541    fn test_gradient_computation() {
542        use super::super::AutodiffEngine;
543        use std::sync::Arc;
544
545        let engine = Arc::new(AutodiffEngine::default());
546        let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
547        let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
548
549        let c = a.mul(&b).expect("Multiplication failed");
550        engine.backward(&c, None).expect("operation failed in test");
551
552        let grad_a = engine
553            .get_grad(&a)
554            .expect("operation failed in test")
555            .expect("operation failed in test");
556        let grad_b = engine
557            .get_grad(&b)
558            .expect("operation failed in test")
559            .expect("operation failed in test");
560
561        assert_eq!(grad_a.to_scalar().expect("operation failed in test"), 3.0);
562        assert_eq!(grad_b.to_scalar().expect("operation failed in test"), 2.0);
563    }
564
565    #[test]
566    fn test_activation_functions() {
567        let x = Variable::scalar(0.0, true).expect("operation failed in test");
568
569        let sigmoid_x = x.sigmoid().expect("Sigmoid failed");
570        assert_eq!(sigmoid_x.item().expect("operation failed in test"), 0.5);
571
572        let tanh_x = x.tanh().expect("Tanh failed");
573        assert_eq!(tanh_x.item().expect("operation failed in test"), 0.0);
574    }
575
576    #[test]
577    fn test_tensor_operations() {
578        let x = Variable::ones(&[2, 3], true).expect("operation failed in test");
579
580        let sum_x = x.sum(None).expect("operation failed in test");
581        assert_eq!(sum_x.item().expect("operation failed in test"), 6.0);
582
583        let mean_x = x.mean(None).expect("Mean calculation failed");
584        assert_eq!(mean_x.item().expect("operation failed in test"), 1.0);
585    }
586
587    #[test]
588    fn test_reshape_operation() {
589        let x = Variable::ones(&[2, 3], true).expect("operation failed in test");
590        let reshaped = x.reshape(vec![3, 2]).expect("Reshape failed");
591
592        assert_eq!(
593            reshaped.shape().expect("operation failed in test"),
594            vec![3, 2]
595        );
596    }
597
598    #[test]
599    fn test_detach_operation() {
600        let x = Variable::scalar(2.0, true).expect("operation failed in test");
601        let y = x.detach().expect("operation failed in test");
602
603        assert!(x.requires_grad());
604        assert!(!y.requires_grad());
605    }
606}