1use super::graph::{ComputationGraph, NodeId, OperationType};
7use crate::errors::{Result, TrustformersError};
8use crate::tensor::Tensor;
9use std::sync::{Arc, Mutex};
10
11pub type GraphRef = Arc<Mutex<ComputationGraph>>;
13
14#[derive(Debug, Clone)]
16pub struct Variable {
17 graph: GraphRef,
19 node_id: NodeId,
21 requires_grad: bool,
23}
24
25pub type VariableRef = Arc<Variable>;
27
28impl Variable {
29 pub fn new(tensor: Tensor, requires_grad: bool) -> Self {
31 let graph = Arc::new(Mutex::new(ComputationGraph::new()));
32 let node_id = {
33 let mut graph_guard = graph.lock().expect("lock should not be poisoned");
34 graph_guard.add_node(tensor, requires_grad, None)
35 };
36
37 Self {
38 graph,
39 node_id,
40 requires_grad,
41 }
42 }
43
44 pub fn new_with_name(tensor: Tensor, requires_grad: bool, name: String) -> Self {
46 let graph = Arc::new(Mutex::new(ComputationGraph::new()));
47 let node_id = {
48 let mut graph_guard = graph.lock().expect("lock should not be poisoned");
49 graph_guard.add_node(tensor, requires_grad, Some(name))
50 };
51
52 Self {
53 graph,
54 node_id,
55 requires_grad,
56 }
57 }
58
59 pub fn from_graph(graph: GraphRef, node_id: NodeId, requires_grad: bool) -> Self {
61 Self {
62 graph,
63 node_id,
64 requires_grad,
65 }
66 }
67
68 pub fn data(&self) -> Result<Tensor> {
70 let graph = self.graph.lock().expect("lock should not be poisoned");
71 graph.get_value(self.node_id).cloned().ok_or_else(|| {
72 TrustformersError::tensor_op_error(
73 &format!("Node {} not found in graph", self.node_id),
74 "Variable::data",
75 )
76 })
77 }
78
79 pub fn grad(&self) -> Result<Option<Tensor>> {
81 let graph = self.graph.lock().expect("lock should not be poisoned");
82 Ok(graph.get_gradient(self.node_id).cloned())
83 }
84
85 pub fn node_id(&self) -> NodeId {
87 self.node_id
88 }
89
90 pub fn requires_grad(&self) -> bool {
92 self.requires_grad
93 }
94
95 pub fn graph(&self) -> GraphRef {
97 self.graph.clone()
98 }
99
100 pub fn shape(&self) -> Result<Vec<usize>> {
102 let graph = self.graph.lock().expect("lock should not be poisoned");
103 graph.get_value(self.node_id).map(|tensor| tensor.shape()).ok_or_else(|| {
104 TrustformersError::tensor_op_error(
105 &format!("Node {} not found in graph", self.node_id),
106 "Variable::shape",
107 )
108 })
109 }
110
111 pub fn item(&self) -> Result<f32> {
113 let tensor = self.data()?;
114 tensor.to_scalar()
115 }
116
117 pub fn backward(&self) -> Result<()> {
119 let mut graph = self.graph.lock().expect("lock should not be poisoned");
120 graph.backward(self.node_id, None)
121 }
122
123 pub fn backward_with_grad(&self, grad: Tensor) -> Result<()> {
125 let mut graph = self.graph.lock().expect("lock should not be poisoned");
126 graph.backward(self.node_id, Some(grad))
127 }
128
129 pub fn zero_grad(&self) {
131 let mut graph = self.graph.lock().expect("lock should not be poisoned");
132 graph.zero_grad();
133 }
134
135 pub fn detach(&self) -> Result<Variable> {
137 let tensor = self.data()?;
138 Ok(Variable::new(tensor, false))
139 }
140
141 pub fn requires_grad_(&self) -> Result<Variable> {
143 let tensor = self.data()?;
144 Ok(Variable::new(tensor, true))
145 }
146
147 pub fn set_data(&self, tensor: Tensor) -> Result<()> {
149 let mut graph = self.graph.lock().expect("lock should not be poisoned");
150 graph.update_value(self.node_id, tensor)
151 }
152
153 pub fn add(&self, other: &Variable) -> Result<Variable> {
157 self.binary_op(other, OperationType::Add)
158 }
159
160 pub fn sub(&self, other: &Variable) -> Result<Variable> {
162 self.binary_op(other, OperationType::Subtract)
163 }
164
165 pub fn mul(&self, other: &Variable) -> Result<Variable> {
167 self.binary_op(other, OperationType::Multiply)
168 }
169
170 pub fn div(&self, other: &Variable) -> Result<Variable> {
172 self.binary_op(other, OperationType::Divide)
173 }
174
175 pub fn matmul(&self, other: &Variable) -> Result<Variable> {
177 self.binary_op(other, OperationType::MatrixMultiply)
178 }
179
180 pub fn neg(&self) -> Result<Variable> {
182 self.unary_op(OperationType::Negate)
183 }
184
185 pub fn square(&self) -> Result<Variable> {
187 self.unary_op(OperationType::Square)
188 }
189
190 pub fn sqrt(&self) -> Result<Variable> {
192 self.unary_op(OperationType::Sqrt)
193 }
194
195 pub fn log(&self) -> Result<Variable> {
197 self.unary_op(OperationType::Log)
198 }
199
200 pub fn exp(&self) -> Result<Variable> {
202 self.unary_op(OperationType::Exp)
203 }
204
205 pub fn sigmoid(&self) -> Result<Variable> {
209 self.unary_op(OperationType::Sigmoid)
210 }
211
212 pub fn tanh(&self) -> Result<Variable> {
214 self.unary_op(OperationType::Tanh)
215 }
216
217 pub fn relu(&self) -> Result<Variable> {
219 self.unary_op(OperationType::ReLU)
220 }
221
222 pub fn leaky_relu(&self, alpha: f32) -> Result<Variable> {
224 self.unary_op(OperationType::LeakyReLU(alpha))
225 }
226
227 pub fn softmax(&self) -> Result<Variable> {
229 self.unary_op(OperationType::Softmax)
230 }
231
232 pub fn reshape(&self, shape: Vec<usize>) -> Result<Variable> {
236 self.unary_op(OperationType::Reshape(shape))
237 }
238
239 pub fn transpose(&self, permutation: Vec<usize>) -> Result<Variable> {
241 self.unary_op(OperationType::Transpose(permutation))
242 }
243
244 pub fn sum(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
246 self.unary_op(OperationType::Sum(axes))
247 }
248
249 pub fn mean(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
251 self.unary_op(OperationType::Mean(axes))
252 }
253
254 pub fn max(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
256 self.unary_op(OperationType::Max(axes))
257 }
258
259 pub fn min(&self, axes: Option<Vec<usize>>) -> Result<Variable> {
261 self.unary_op(OperationType::Min(axes))
262 }
263
264 pub fn add_scalar(&self, scalar: f32) -> Result<Variable> {
268 let scalar_tensor = Tensor::scalar(scalar)?;
269 let scalar_var = Variable::new(scalar_tensor, false);
270 self.add(&scalar_var)
271 }
272
273 pub fn sub_scalar(&self, scalar: f32) -> Result<Variable> {
275 let scalar_tensor = Tensor::scalar(scalar)?;
276 let scalar_var = Variable::new(scalar_tensor, false);
277 self.sub(&scalar_var)
278 }
279
280 pub fn mul_scalar(&self, scalar: f32) -> Result<Variable> {
282 let scalar_tensor = Tensor::scalar(scalar)?;
283 let scalar_var = Variable::new(scalar_tensor, false);
284 self.mul(&scalar_var)
285 }
286
287 pub fn div_scalar(&self, scalar: f32) -> Result<Variable> {
289 let scalar_tensor = Tensor::scalar(scalar)?;
290 let scalar_var = Variable::new(scalar_tensor, false);
291 self.div(&scalar_var)
292 }
293
294 fn binary_op(&self, other: &Variable, op: OperationType) -> Result<Variable> {
298 if !Arc::ptr_eq(&self.graph, &other.graph) {
300 return Err(TrustformersError::tensor_op_error(
301 "Variables must be from the same computation graph",
302 "Variable::binary_op",
303 ));
304 }
305
306 let result_tensor = self.compute_binary_tensor_op(&other.data()?, &op)?;
308
309 let requires_grad = self.requires_grad || other.requires_grad;
311 let node_id = {
312 let mut graph = self.graph.lock().expect("lock should not be poisoned");
313 graph.add_operation_node(
314 result_tensor,
315 op,
316 vec![self.node_id, other.node_id],
317 requires_grad,
318 None,
319 )?
320 };
321
322 Ok(Variable::from_graph(
323 self.graph.clone(),
324 node_id,
325 requires_grad,
326 ))
327 }
328
329 fn unary_op(&self, op: OperationType) -> Result<Variable> {
331 let result_tensor = self.compute_unary_tensor_op(&op)?;
333
334 let node_id = {
336 let mut graph = self.graph.lock().expect("lock should not be poisoned");
337 graph.add_operation_node(
338 result_tensor,
339 op,
340 vec![self.node_id],
341 self.requires_grad,
342 None,
343 )?
344 };
345
346 Ok(Variable::from_graph(
347 self.graph.clone(),
348 node_id,
349 self.requires_grad,
350 ))
351 }
352
353 fn compute_binary_tensor_op(&self, other: &Tensor, op: &OperationType) -> Result<Tensor> {
355 let self_tensor = self.data()?;
356
357 match op {
358 OperationType::Add => Tensor::add(&self_tensor, other),
359 OperationType::Subtract => Tensor::sub(&self_tensor, other),
360 OperationType::Multiply => self_tensor.mul(other),
361 OperationType::Divide => Tensor::div(&self_tensor, other),
362 OperationType::MatrixMultiply => self_tensor.matmul(other),
363 _ => Err(TrustformersError::tensor_op_error(
364 &format!("Unsupported binary operation: {:?}", op),
365 "Variable::compute_binary_tensor_op",
366 )),
367 }
368 }
369
370 fn compute_unary_tensor_op(&self, op: &OperationType) -> Result<Tensor> {
372 let self_tensor = self.data()?;
373
374 match op {
375 OperationType::Negate => self_tensor.neg(),
376 OperationType::Square => self_tensor.clone().mul(&self_tensor),
377 OperationType::Sqrt => self_tensor.sqrt(),
378 OperationType::Log => self_tensor.log(),
379 OperationType::Exp => self_tensor.exp(),
380 OperationType::Sigmoid => self_tensor.sigmoid(),
381 OperationType::Tanh => self_tensor.tanh(),
382 OperationType::ReLU => self_tensor.relu(),
383 OperationType::LeakyReLU(alpha) => self_tensor.leaky_relu(*alpha),
384 OperationType::Softmax => self_tensor.softmax(-1),
385 OperationType::Reshape(shape) => self_tensor.reshape(shape),
386 OperationType::Transpose(permutation) => {
387 if permutation.len() >= 2 {
389 self_tensor.transpose(permutation[0], permutation[1])
390 } else {
391 self_tensor.transpose(0, 1)
393 }
394 },
395 OperationType::Sum(axes) => {
396 match axes {
397 Some(axes_vec) => self_tensor.sum_axes(axes_vec),
398 None => {
399 let shape = self_tensor.shape();
401 let all_axes: Vec<usize> = (0..shape.len()).collect();
402 self_tensor.sum_axes(&all_axes)
403 },
404 }
405 },
406 OperationType::Mean(_axes) => {
407 self_tensor.mean()
409 },
410 _ => Err(TrustformersError::tensor_op_error(
411 &format!("Unsupported unary operation: {:?}", op),
412 "Variable::compute_unary_tensor_op",
413 )),
414 }
415 }
416
417 pub fn set_requires_grad(&mut self, requires_grad: bool) {
419 self.requires_grad = requires_grad;
420 if let Ok(mut graph) = self.graph.lock() {
422 if let Some(node) = graph.get_node_mut(self.node_id) {
423 node.requires_grad = requires_grad;
424 }
425 }
426 }
427
428 pub fn from_tensor(tensor: Tensor) -> Self {
430 Variable::new(tensor, false)
431 }
432}
433
434impl Variable {
436 pub fn scalar(value: f32, requires_grad: bool) -> Result<Self> {
438 let tensor = Tensor::scalar(value)?;
439 Ok(Variable::new(tensor, requires_grad))
440 }
441
442 pub fn zeros(shape: &[usize], requires_grad: bool) -> Result<Self> {
444 let tensor = Tensor::zeros(shape)?;
445 Ok(Variable::new(tensor, requires_grad))
446 }
447
448 pub fn ones(shape: &[usize], requires_grad: bool) -> Result<Self> {
450 let tensor = Tensor::ones(shape)?;
451 Ok(Variable::new(tensor, requires_grad))
452 }
453
454 pub fn randn(shape: &[usize], requires_grad: bool) -> Result<Self> {
456 let tensor = Tensor::randn(shape)?;
457 Ok(Variable::new(tensor, requires_grad))
458 }
459
460 pub fn rand(shape: &[usize], requires_grad: bool) -> Result<Self> {
462 let tensor = Tensor::randn(shape)?;
463 Ok(Variable::new(tensor, requires_grad))
464 }
465}
466
467use std::ops::{Add, Div, Mul, Neg, Sub};
469
470impl Add for &Variable {
471 type Output = Result<Variable>;
472
473 fn add(self, rhs: Self) -> Self::Output {
474 self.add(rhs)
475 }
476}
477
478impl Sub for &Variable {
479 type Output = Result<Variable>;
480
481 fn sub(self, rhs: Self) -> Self::Output {
482 self.sub(rhs)
483 }
484}
485
486impl Mul for &Variable {
487 type Output = Result<Variable>;
488
489 fn mul(self, rhs: Self) -> Self::Output {
490 self.mul(rhs)
491 }
492}
493
494impl Div for &Variable {
495 type Output = Result<Variable>;
496
497 fn div(self, rhs: Self) -> Self::Output {
498 self.div(rhs)
499 }
500}
501
502impl Neg for &Variable {
503 type Output = Result<Variable>;
504
505 fn neg(self) -> Self::Output {
506 self.neg()
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513 use crate::tensor::Tensor;
514
515 #[test]
516 fn test_variable_creation() {
517 let tensor = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
518 let var = Variable::new(tensor, true);
519
520 assert!(var.requires_grad());
521 assert_eq!(var.shape().expect("operation failed in test"), vec![2, 3]);
522 }
523
524 #[test]
525 fn test_variable_operations() {
526 use super::super::AutodiffEngine;
527 use std::sync::Arc;
528
529 let engine = Arc::new(AutodiffEngine::default());
530 let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
531 let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
532
533 let c = a.add(&b).expect("Addition failed");
534 assert_eq!(c.item().expect("operation failed in test"), 5.0);
535
536 let d = a.mul(&b).expect("Multiplication failed");
537 assert_eq!(d.item().expect("operation failed in test"), 6.0);
538 }
539
540 #[test]
541 fn test_gradient_computation() {
542 use super::super::AutodiffEngine;
543 use std::sync::Arc;
544
545 let engine = Arc::new(AutodiffEngine::default());
546 let a = engine.variable(Tensor::scalar(2.0).expect("tensor operation failed"), true);
547 let b = engine.variable(Tensor::scalar(3.0).expect("tensor operation failed"), true);
548
549 let c = a.mul(&b).expect("Multiplication failed");
550 engine.backward(&c, None).expect("operation failed in test");
551
552 let grad_a = engine
553 .get_grad(&a)
554 .expect("operation failed in test")
555 .expect("operation failed in test");
556 let grad_b = engine
557 .get_grad(&b)
558 .expect("operation failed in test")
559 .expect("operation failed in test");
560
561 assert_eq!(grad_a.to_scalar().expect("operation failed in test"), 3.0);
562 assert_eq!(grad_b.to_scalar().expect("operation failed in test"), 2.0);
563 }
564
565 #[test]
566 fn test_activation_functions() {
567 let x = Variable::scalar(0.0, true).expect("operation failed in test");
568
569 let sigmoid_x = x.sigmoid().expect("Sigmoid failed");
570 assert_eq!(sigmoid_x.item().expect("operation failed in test"), 0.5);
571
572 let tanh_x = x.tanh().expect("Tanh failed");
573 assert_eq!(tanh_x.item().expect("operation failed in test"), 0.0);
574 }
575
576 #[test]
577 fn test_tensor_operations() {
578 let x = Variable::ones(&[2, 3], true).expect("operation failed in test");
579
580 let sum_x = x.sum(None).expect("operation failed in test");
581 assert_eq!(sum_x.item().expect("operation failed in test"), 6.0);
582
583 let mean_x = x.mean(None).expect("Mean calculation failed");
584 assert_eq!(mean_x.item().expect("operation failed in test"), 1.0);
585 }
586
587 #[test]
588 fn test_reshape_operation() {
589 let x = Variable::ones(&[2, 3], true).expect("operation failed in test");
590 let reshaped = x.reshape(vec![3, 2]).expect("Reshape failed");
591
592 assert_eq!(
593 reshaped.shape().expect("operation failed in test"),
594 vec![3, 2]
595 );
596 }
597
598 #[test]
599 fn test_detach_operation() {
600 let x = Variable::scalar(2.0, true).expect("operation failed in test");
601 let y = x.detach().expect("operation failed in test");
602
603 assert!(x.requires_grad());
604 assert!(!y.requires_grad());
605 }
606}