Skip to main content

torsh_functional/
autograd.rs

1//! Custom autograd function creation utilities
2//!
3//! This module provides utilities for creating custom differentiable operations
4//! that can be used with the automatic differentiation system.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex, OnceLock};
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11/// Trait for custom autograd functions
12pub trait CustomAutogradFunction {
13    /// Forward pass computation
14    fn forward(&self, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>>;
15
16    /// Backward pass computation
17    fn backward(
18        &self,
19        grad_outputs: &[Tensor],
20        inputs: &[Tensor],
21    ) -> TorshResult<Vec<Option<Tensor>>>;
22
23    /// Number of inputs expected by this function
24    fn num_inputs(&self) -> usize;
25
26    /// Number of outputs produced by this function
27    fn num_outputs(&self) -> usize;
28
29    /// Get the name of this function (for debugging)
30    fn name(&self) -> &str;
31}
32
33/// Context for storing intermediate values during forward pass
34#[derive(Debug, Clone)]
35pub struct AutogradContext {
36    /// Saved tensors for backward pass
37    pub saved_tensors: Vec<Tensor>,
38    /// Saved values for backward pass
39    pub saved_values: HashMap<String, f32>,
40    /// Saved shapes for backward pass
41    pub saved_shapes: HashMap<String, Vec<usize>>,
42    /// Whether to save tensors for backward pass
43    pub needs_input_grad: Vec<bool>,
44}
45
46impl AutogradContext {
47    /// Create a new autograd context
48    pub fn new(num_inputs: usize) -> Self {
49        Self {
50            saved_tensors: Vec::new(),
51            saved_values: HashMap::new(),
52            saved_shapes: HashMap::new(),
53            needs_input_grad: vec![true; num_inputs],
54        }
55    }
56
57    /// Save a tensor for backward pass
58    pub fn save_tensor(&mut self, tensor: Tensor) {
59        self.saved_tensors.push(tensor);
60    }
61
62    /// Save a scalar value for backward pass
63    pub fn save_value(&mut self, key: &str, value: f32) {
64        self.saved_values.insert(key.to_string(), value);
65    }
66
67    /// Save a shape for backward pass
68    pub fn save_shape(&mut self, key: &str, shape: Vec<usize>) {
69        self.saved_shapes.insert(key.to_string(), shape);
70    }
71
72    /// Get saved tensor by index
73    pub fn get_saved_tensor(&self, index: usize) -> Option<&Tensor> {
74        self.saved_tensors.get(index)
75    }
76
77    /// Get saved value by key
78    pub fn get_saved_value(&self, key: &str) -> Option<f32> {
79        self.saved_values.get(key).copied()
80    }
81
82    /// Get saved shape by key
83    pub fn get_saved_shape(&self, key: &str) -> Option<&Vec<usize>> {
84        self.saved_shapes.get(key)
85    }
86
87    /// Set whether input gradients are needed
88    pub fn set_needs_input_grad(&mut self, needs_grad: Vec<bool>) {
89        self.needs_input_grad = needs_grad;
90    }
91
92    /// Check if input gradient is needed
93    pub fn needs_input_grad(&self, index: usize) -> bool {
94        self.needs_input_grad.get(index).copied().unwrap_or(false)
95    }
96}
97
98/// Base trait for creating custom autograd functions with context
99pub trait CustomAutogradFunctionWithContext {
100    /// Forward pass with context
101    fn forward(&self, ctx: &mut AutogradContext, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>>;
102
103    /// Backward pass with context
104    fn backward(
105        &self,
106        ctx: &AutogradContext,
107        grad_outputs: &[Tensor],
108    ) -> TorshResult<Vec<Option<Tensor>>>;
109
110    /// Number of inputs expected by this function
111    fn num_inputs(&self) -> usize;
112
113    /// Number of outputs produced by this function
114    fn num_outputs(&self) -> usize;
115
116    /// Get the name of this function (for debugging)
117    fn name(&self) -> &str;
118}
119
120/// Registry for custom autograd functions
121pub struct AutogradRegistry {
122    functions: HashMap<String, Arc<dyn CustomAutogradFunction + Send + Sync>>,
123}
124
125impl AutogradRegistry {
126    /// Create a new registry
127    pub fn new() -> Self {
128        Self {
129            functions: HashMap::new(),
130        }
131    }
132
133    /// Register a custom function
134    pub fn register<F>(&mut self, name: String, function: F)
135    where
136        F: CustomAutogradFunction + Send + Sync + 'static,
137    {
138        self.functions.insert(name, Arc::new(function));
139    }
140
141    /// Get a registered function by name
142    pub fn get(&self, name: &str) -> Option<Arc<dyn CustomAutogradFunction + Send + Sync>> {
143        self.functions.get(name).cloned()
144    }
145
146    /// List all registered functions
147    pub fn list_functions(&self) -> Vec<&String> {
148        self.functions.keys().collect()
149    }
150}
151
152impl Default for AutogradRegistry {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Apply a custom autograd function
159pub fn apply_custom_function<F>(function: F, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>>
160where
161    F: CustomAutogradFunction,
162{
163    // Validate inputs
164    if inputs.len() != function.num_inputs() {
165        return Err(TorshError::invalid_argument_with_context(
166            &format!(
167                "Expected {} inputs, got {}",
168                function.num_inputs(),
169                inputs.len()
170            ),
171            "apply_custom_function",
172        ));
173    }
174
175    // Apply forward pass
176    let outputs = function.forward(inputs)?;
177
178    // Validate outputs
179    if outputs.len() != function.num_outputs() {
180        return Err(TorshError::invalid_argument_with_context(
181            &format!(
182                "Expected {} outputs, got {}",
183                function.num_outputs(),
184                outputs.len()
185            ),
186            "apply_custom_function",
187        ));
188    }
189
190    Ok(outputs)
191}
192
193/// Apply a custom autograd function with context
194pub fn apply_custom_function_with_context<F>(
195    function: F,
196    inputs: &[Tensor],
197) -> TorshResult<Vec<Tensor>>
198where
199    F: CustomAutogradFunctionWithContext,
200{
201    // Validate inputs
202    if inputs.len() != function.num_inputs() {
203        return Err(TorshError::invalid_argument_with_context(
204            &format!(
205                "Expected {} inputs, got {}",
206                function.num_inputs(),
207                inputs.len()
208            ),
209            "apply_custom_function_with_context",
210        ));
211    }
212
213    // Create context
214    let mut ctx = AutogradContext::new(inputs.len());
215
216    // Apply forward pass
217    let outputs = function.forward(&mut ctx, inputs)?;
218
219    // Validate outputs
220    if outputs.len() != function.num_outputs() {
221        return Err(TorshError::invalid_argument_with_context(
222            &format!(
223                "Expected {} outputs, got {}",
224                function.num_outputs(),
225                outputs.len()
226            ),
227            "apply_custom_function_with_context",
228        ));
229    }
230
231    Ok(outputs)
232}
233
234/// Example custom function: Element-wise square
235pub struct SquareFunction;
236
237impl CustomAutogradFunction for SquareFunction {
238    fn forward(&self, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
239        let input = &inputs[0];
240        let output = input.mul_op(input)?;
241        Ok(vec![output])
242    }
243
244    fn backward(
245        &self,
246        grad_outputs: &[Tensor],
247        inputs: &[Tensor],
248    ) -> TorshResult<Vec<Option<Tensor>>> {
249        let grad_output = &grad_outputs[0];
250        let input = &inputs[0];
251
252        // d/dx(x^2) = 2x
253        let two = Tensor::from_data(vec![2.0f32], vec![1], input.device())?;
254        let grad_input = grad_output.mul_op(&input.mul_op(&two)?)?;
255
256        Ok(vec![Some(grad_input)])
257    }
258
259    fn num_inputs(&self) -> usize {
260        1
261    }
262    fn num_outputs(&self) -> usize {
263        1
264    }
265    fn name(&self) -> &str {
266        "square"
267    }
268}
269
270/// Example custom function: Element-wise exponential
271pub struct ExpFunction;
272
273impl CustomAutogradFunction for ExpFunction {
274    fn forward(&self, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
275        let input = &inputs[0];
276        let output = input.exp()?;
277        Ok(vec![output])
278    }
279
280    fn backward(
281        &self,
282        grad_outputs: &[Tensor],
283        inputs: &[Tensor],
284    ) -> TorshResult<Vec<Option<Tensor>>> {
285        let grad_output = &grad_outputs[0];
286        let input = &inputs[0];
287
288        // d/dx(exp(x)) = exp(x)
289        let exp_input = input.exp()?;
290        let grad_input = grad_output.mul_op(&exp_input)?;
291
292        Ok(vec![Some(grad_input)])
293    }
294
295    fn num_inputs(&self) -> usize {
296        1
297    }
298    fn num_outputs(&self) -> usize {
299        1
300    }
301    fn name(&self) -> &str {
302        "exp"
303    }
304}
305
306/// Example custom function with context: Scaled addition
307pub struct ScaledAddFunction {
308    scale: f32,
309}
310
311impl ScaledAddFunction {
312    pub fn new(scale: f32) -> Self {
313        Self { scale }
314    }
315}
316
317impl CustomAutogradFunctionWithContext for ScaledAddFunction {
318    fn forward(&self, ctx: &mut AutogradContext, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
319        let a = &inputs[0];
320        let b = &inputs[1];
321
322        // Save scale for backward pass
323        ctx.save_value("scale", self.scale);
324
325        // Compute scale * a + b
326        let scaled_a = a.mul_scalar(self.scale)?;
327        let output = scaled_a.add_op(b)?;
328
329        Ok(vec![output])
330    }
331
332    fn backward(
333        &self,
334        ctx: &AutogradContext,
335        grad_outputs: &[Tensor],
336    ) -> TorshResult<Vec<Option<Tensor>>> {
337        let grad_output = &grad_outputs[0];
338        let scale = ctx.get_saved_value("scale").unwrap_or(1.0);
339
340        // Gradients: d/da = scale, d/db = 1
341        let grad_a = if ctx.needs_input_grad(0) {
342            Some(grad_output.mul_scalar(scale)?)
343        } else {
344            None
345        };
346
347        let grad_b = if ctx.needs_input_grad(1) {
348            Some(grad_output.clone())
349        } else {
350            None
351        };
352
353        Ok(vec![grad_a, grad_b])
354    }
355
356    fn num_inputs(&self) -> usize {
357        2
358    }
359    fn num_outputs(&self) -> usize {
360        1
361    }
362    fn name(&self) -> &str {
363        "scaled_add"
364    }
365}
366
367/// Macro for creating simple custom autograd functions
368#[macro_export]
369macro_rules! create_custom_autograd_function {
370    (
371        name: $name:ident,
372        inputs: $num_inputs:expr,
373        outputs: $num_outputs:expr,
374        forward: |$inputs:ident| $forward_body:expr,
375        backward: |$grad_outputs:ident, $backward_inputs:ident| $backward_body:expr
376    ) => {
377        pub struct $name;
378
379        impl CustomAutogradFunction for $name {
380            fn forward(&self, $inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
381                $forward_body
382            }
383
384            fn backward(
385                &self,
386                $grad_outputs: &[Tensor],
387                $backward_inputs: &[Tensor],
388            ) -> TorshResult<Vec<Option<Tensor>>> {
389                $backward_body
390            }
391
392            fn num_inputs(&self) -> usize {
393                $num_inputs
394            }
395            fn num_outputs(&self) -> usize {
396                $num_outputs
397            }
398            fn name(&self) -> &str {
399                stringify!($name)
400            }
401        }
402    };
403}
404
405/// Create a global registry for custom functions
406static GLOBAL_REGISTRY: OnceLock<Mutex<AutogradRegistry>> = OnceLock::new();
407
408/// Get the global autograd registry
409pub fn get_global_registry() -> &'static Mutex<AutogradRegistry> {
410    GLOBAL_REGISTRY.get_or_init(|| Mutex::new(AutogradRegistry::new()))
411}
412
413/// Register a custom function globally
414pub fn register_custom_function<F>(name: String, function: F)
415where
416    F: CustomAutogradFunction + Send + Sync + 'static,
417{
418    get_global_registry()
419        .lock()
420        .expect("autograd registry lock should not be poisoned")
421        .register(name, function);
422}
423
424/// Apply a globally registered function
425pub fn apply_registered_function(name: &str, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
426    let registry = get_global_registry()
427        .lock()
428        .expect("lock should not be poisoned");
429    let function = registry.get(name).ok_or_else(|| {
430        TorshError::invalid_argument_with_context(
431            &format!("Function '{}' not found in registry", name),
432            "apply_registered_function",
433        )
434    })?;
435
436    // Validate inputs
437    if inputs.len() != function.num_inputs() {
438        return Err(TorshError::invalid_argument_with_context(
439            &format!(
440                "Expected {} inputs, got {}",
441                function.num_inputs(),
442                inputs.len()
443            ),
444            "apply_registered_function",
445        ));
446    }
447
448    // Apply forward pass
449    let outputs = function.forward(inputs)?;
450
451    // Validate outputs
452    if outputs.len() != function.num_outputs() {
453        return Err(TorshError::invalid_argument_with_context(
454            &format!(
455                "Expected {} outputs, got {}",
456                function.num_outputs(),
457                outputs.len()
458            ),
459            "apply_registered_function",
460        ));
461    }
462
463    Ok(outputs)
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_square_function() -> TorshResult<()> {
472        let input = Tensor::from_data(vec![2.0, 3.0, 4.0], vec![3], torsh_core::DeviceType::Cpu)?;
473        let square_fn = SquareFunction;
474
475        let outputs = apply_custom_function(square_fn, &[input.clone()])?;
476        let output_data = outputs[0].to_vec()?;
477
478        assert!((output_data[0] - 4.0).abs() < 1e-6);
479        assert!((output_data[1] - 9.0).abs() < 1e-6);
480        assert!((output_data[2] - 16.0).abs() < 1e-6);
481
482        Ok(())
483    }
484
485    #[test]
486    fn test_exp_function() -> TorshResult<()> {
487        let input = Tensor::from_data(vec![0.0, 1.0], vec![2], torsh_core::DeviceType::Cpu)?;
488        let exp_fn = ExpFunction;
489
490        let outputs = apply_custom_function(exp_fn, &[input.clone()])?;
491        let output_data = outputs[0].to_vec()?;
492
493        assert!((output_data[0] - 1.0).abs() < 1e-6);
494        assert!((output_data[1] - std::f32::consts::E).abs() < 1e-6);
495
496        Ok(())
497    }
498
499    #[test]
500    fn test_scaled_add_function() -> TorshResult<()> {
501        let a = Tensor::from_data(vec![1.0, 2.0], vec![2], torsh_core::DeviceType::Cpu)?;
502        let b = Tensor::from_data(vec![3.0, 4.0], vec![2], torsh_core::DeviceType::Cpu)?;
503        let scaled_add_fn = ScaledAddFunction::new(2.0);
504
505        let outputs = apply_custom_function_with_context(scaled_add_fn, &[a, b])?;
506        let output_data = outputs[0].to_vec()?;
507
508        // 2 * 1 + 3 = 5, 2 * 2 + 4 = 8
509        assert!((output_data[0] - 5.0).abs() < 1e-6);
510        assert!((output_data[1] - 8.0).abs() < 1e-6);
511
512        Ok(())
513    }
514
515    #[test]
516    fn test_registry() -> TorshResult<()> {
517        let mut registry = AutogradRegistry::new();
518        registry.register("square".to_string(), SquareFunction);
519
520        let function = registry.get("square").unwrap();
521        assert_eq!(function.name(), "square");
522        assert_eq!(function.num_inputs(), 1);
523        assert_eq!(function.num_outputs(), 1);
524
525        Ok(())
526    }
527
528    #[test]
529    fn test_global_registry() -> TorshResult<()> {
530        register_custom_function("test_square".to_string(), SquareFunction);
531
532        let input = Tensor::from_data(vec![3.0], vec![1], torsh_core::DeviceType::Cpu)?;
533        let outputs = apply_registered_function("test_square", &[input])?;
534        let output_data = outputs[0].to_vec()?;
535
536        assert!((output_data[0] - 9.0).abs() < 1e-6);
537
538        Ok(())
539    }
540}