1use std::collections::HashMap;
7use std::sync::{Arc, Mutex, OnceLock};
8use torsh_core::{Result as TorshResult, TorshError};
9use torsh_tensor::Tensor;
10
11pub trait CustomAutogradFunction {
13 fn forward(&self, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>>;
15
16 fn backward(
18 &self,
19 grad_outputs: &[Tensor],
20 inputs: &[Tensor],
21 ) -> TorshResult<Vec<Option<Tensor>>>;
22
23 fn num_inputs(&self) -> usize;
25
26 fn num_outputs(&self) -> usize;
28
29 fn name(&self) -> &str;
31}
32
33#[derive(Debug, Clone)]
35pub struct AutogradContext {
36 pub saved_tensors: Vec<Tensor>,
38 pub saved_values: HashMap<String, f32>,
40 pub saved_shapes: HashMap<String, Vec<usize>>,
42 pub needs_input_grad: Vec<bool>,
44}
45
46impl AutogradContext {
47 pub fn new(num_inputs: usize) -> Self {
49 Self {
50 saved_tensors: Vec::new(),
51 saved_values: HashMap::new(),
52 saved_shapes: HashMap::new(),
53 needs_input_grad: vec![true; num_inputs],
54 }
55 }
56
57 pub fn save_tensor(&mut self, tensor: Tensor) {
59 self.saved_tensors.push(tensor);
60 }
61
62 pub fn save_value(&mut self, key: &str, value: f32) {
64 self.saved_values.insert(key.to_string(), value);
65 }
66
67 pub fn save_shape(&mut self, key: &str, shape: Vec<usize>) {
69 self.saved_shapes.insert(key.to_string(), shape);
70 }
71
72 pub fn get_saved_tensor(&self, index: usize) -> Option<&Tensor> {
74 self.saved_tensors.get(index)
75 }
76
77 pub fn get_saved_value(&self, key: &str) -> Option<f32> {
79 self.saved_values.get(key).copied()
80 }
81
82 pub fn get_saved_shape(&self, key: &str) -> Option<&Vec<usize>> {
84 self.saved_shapes.get(key)
85 }
86
87 pub fn set_needs_input_grad(&mut self, needs_grad: Vec<bool>) {
89 self.needs_input_grad = needs_grad;
90 }
91
92 pub fn needs_input_grad(&self, index: usize) -> bool {
94 self.needs_input_grad.get(index).copied().unwrap_or(false)
95 }
96}
97
98pub trait CustomAutogradFunctionWithContext {
100 fn forward(&self, ctx: &mut AutogradContext, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>>;
102
103 fn backward(
105 &self,
106 ctx: &AutogradContext,
107 grad_outputs: &[Tensor],
108 ) -> TorshResult<Vec<Option<Tensor>>>;
109
110 fn num_inputs(&self) -> usize;
112
113 fn num_outputs(&self) -> usize;
115
116 fn name(&self) -> &str;
118}
119
120pub struct AutogradRegistry {
122 functions: HashMap<String, Arc<dyn CustomAutogradFunction + Send + Sync>>,
123}
124
125impl AutogradRegistry {
126 pub fn new() -> Self {
128 Self {
129 functions: HashMap::new(),
130 }
131 }
132
133 pub fn register<F>(&mut self, name: String, function: F)
135 where
136 F: CustomAutogradFunction + Send + Sync + 'static,
137 {
138 self.functions.insert(name, Arc::new(function));
139 }
140
141 pub fn get(&self, name: &str) -> Option<Arc<dyn CustomAutogradFunction + Send + Sync>> {
143 self.functions.get(name).cloned()
144 }
145
146 pub fn list_functions(&self) -> Vec<&String> {
148 self.functions.keys().collect()
149 }
150}
151
152impl Default for AutogradRegistry {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158pub fn apply_custom_function<F>(function: F, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>>
160where
161 F: CustomAutogradFunction,
162{
163 if inputs.len() != function.num_inputs() {
165 return Err(TorshError::invalid_argument_with_context(
166 &format!(
167 "Expected {} inputs, got {}",
168 function.num_inputs(),
169 inputs.len()
170 ),
171 "apply_custom_function",
172 ));
173 }
174
175 let outputs = function.forward(inputs)?;
177
178 if outputs.len() != function.num_outputs() {
180 return Err(TorshError::invalid_argument_with_context(
181 &format!(
182 "Expected {} outputs, got {}",
183 function.num_outputs(),
184 outputs.len()
185 ),
186 "apply_custom_function",
187 ));
188 }
189
190 Ok(outputs)
191}
192
193pub fn apply_custom_function_with_context<F>(
195 function: F,
196 inputs: &[Tensor],
197) -> TorshResult<Vec<Tensor>>
198where
199 F: CustomAutogradFunctionWithContext,
200{
201 if inputs.len() != function.num_inputs() {
203 return Err(TorshError::invalid_argument_with_context(
204 &format!(
205 "Expected {} inputs, got {}",
206 function.num_inputs(),
207 inputs.len()
208 ),
209 "apply_custom_function_with_context",
210 ));
211 }
212
213 let mut ctx = AutogradContext::new(inputs.len());
215
216 let outputs = function.forward(&mut ctx, inputs)?;
218
219 if outputs.len() != function.num_outputs() {
221 return Err(TorshError::invalid_argument_with_context(
222 &format!(
223 "Expected {} outputs, got {}",
224 function.num_outputs(),
225 outputs.len()
226 ),
227 "apply_custom_function_with_context",
228 ));
229 }
230
231 Ok(outputs)
232}
233
234pub struct SquareFunction;
236
237impl CustomAutogradFunction for SquareFunction {
238 fn forward(&self, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
239 let input = &inputs[0];
240 let output = input.mul_op(input)?;
241 Ok(vec![output])
242 }
243
244 fn backward(
245 &self,
246 grad_outputs: &[Tensor],
247 inputs: &[Tensor],
248 ) -> TorshResult<Vec<Option<Tensor>>> {
249 let grad_output = &grad_outputs[0];
250 let input = &inputs[0];
251
252 let two = Tensor::from_data(vec![2.0f32], vec![1], input.device())?;
254 let grad_input = grad_output.mul_op(&input.mul_op(&two)?)?;
255
256 Ok(vec![Some(grad_input)])
257 }
258
259 fn num_inputs(&self) -> usize {
260 1
261 }
262 fn num_outputs(&self) -> usize {
263 1
264 }
265 fn name(&self) -> &str {
266 "square"
267 }
268}
269
270pub struct ExpFunction;
272
273impl CustomAutogradFunction for ExpFunction {
274 fn forward(&self, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
275 let input = &inputs[0];
276 let output = input.exp()?;
277 Ok(vec![output])
278 }
279
280 fn backward(
281 &self,
282 grad_outputs: &[Tensor],
283 inputs: &[Tensor],
284 ) -> TorshResult<Vec<Option<Tensor>>> {
285 let grad_output = &grad_outputs[0];
286 let input = &inputs[0];
287
288 let exp_input = input.exp()?;
290 let grad_input = grad_output.mul_op(&exp_input)?;
291
292 Ok(vec![Some(grad_input)])
293 }
294
295 fn num_inputs(&self) -> usize {
296 1
297 }
298 fn num_outputs(&self) -> usize {
299 1
300 }
301 fn name(&self) -> &str {
302 "exp"
303 }
304}
305
306pub struct ScaledAddFunction {
308 scale: f32,
309}
310
311impl ScaledAddFunction {
312 pub fn new(scale: f32) -> Self {
313 Self { scale }
314 }
315}
316
317impl CustomAutogradFunctionWithContext for ScaledAddFunction {
318 fn forward(&self, ctx: &mut AutogradContext, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
319 let a = &inputs[0];
320 let b = &inputs[1];
321
322 ctx.save_value("scale", self.scale);
324
325 let scaled_a = a.mul_scalar(self.scale)?;
327 let output = scaled_a.add_op(b)?;
328
329 Ok(vec![output])
330 }
331
332 fn backward(
333 &self,
334 ctx: &AutogradContext,
335 grad_outputs: &[Tensor],
336 ) -> TorshResult<Vec<Option<Tensor>>> {
337 let grad_output = &grad_outputs[0];
338 let scale = ctx.get_saved_value("scale").unwrap_or(1.0);
339
340 let grad_a = if ctx.needs_input_grad(0) {
342 Some(grad_output.mul_scalar(scale)?)
343 } else {
344 None
345 };
346
347 let grad_b = if ctx.needs_input_grad(1) {
348 Some(grad_output.clone())
349 } else {
350 None
351 };
352
353 Ok(vec![grad_a, grad_b])
354 }
355
356 fn num_inputs(&self) -> usize {
357 2
358 }
359 fn num_outputs(&self) -> usize {
360 1
361 }
362 fn name(&self) -> &str {
363 "scaled_add"
364 }
365}
366
367#[macro_export]
369macro_rules! create_custom_autograd_function {
370 (
371 name: $name:ident,
372 inputs: $num_inputs:expr,
373 outputs: $num_outputs:expr,
374 forward: |$inputs:ident| $forward_body:expr,
375 backward: |$grad_outputs:ident, $backward_inputs:ident| $backward_body:expr
376 ) => {
377 pub struct $name;
378
379 impl CustomAutogradFunction for $name {
380 fn forward(&self, $inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
381 $forward_body
382 }
383
384 fn backward(
385 &self,
386 $grad_outputs: &[Tensor],
387 $backward_inputs: &[Tensor],
388 ) -> TorshResult<Vec<Option<Tensor>>> {
389 $backward_body
390 }
391
392 fn num_inputs(&self) -> usize {
393 $num_inputs
394 }
395 fn num_outputs(&self) -> usize {
396 $num_outputs
397 }
398 fn name(&self) -> &str {
399 stringify!($name)
400 }
401 }
402 };
403}
404
405static GLOBAL_REGISTRY: OnceLock<Mutex<AutogradRegistry>> = OnceLock::new();
407
408pub fn get_global_registry() -> &'static Mutex<AutogradRegistry> {
410 GLOBAL_REGISTRY.get_or_init(|| Mutex::new(AutogradRegistry::new()))
411}
412
413pub fn register_custom_function<F>(name: String, function: F)
415where
416 F: CustomAutogradFunction + Send + Sync + 'static,
417{
418 get_global_registry()
419 .lock()
420 .expect("autograd registry lock should not be poisoned")
421 .register(name, function);
422}
423
424pub fn apply_registered_function(name: &str, inputs: &[Tensor]) -> TorshResult<Vec<Tensor>> {
426 let registry = get_global_registry()
427 .lock()
428 .expect("lock should not be poisoned");
429 let function = registry.get(name).ok_or_else(|| {
430 TorshError::invalid_argument_with_context(
431 &format!("Function '{}' not found in registry", name),
432 "apply_registered_function",
433 )
434 })?;
435
436 if inputs.len() != function.num_inputs() {
438 return Err(TorshError::invalid_argument_with_context(
439 &format!(
440 "Expected {} inputs, got {}",
441 function.num_inputs(),
442 inputs.len()
443 ),
444 "apply_registered_function",
445 ));
446 }
447
448 let outputs = function.forward(inputs)?;
450
451 if outputs.len() != function.num_outputs() {
453 return Err(TorshError::invalid_argument_with_context(
454 &format!(
455 "Expected {} outputs, got {}",
456 function.num_outputs(),
457 outputs.len()
458 ),
459 "apply_registered_function",
460 ));
461 }
462
463 Ok(outputs)
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_square_function() -> TorshResult<()> {
472 let input = Tensor::from_data(vec![2.0, 3.0, 4.0], vec![3], torsh_core::DeviceType::Cpu)?;
473 let square_fn = SquareFunction;
474
475 let outputs = apply_custom_function(square_fn, &[input.clone()])?;
476 let output_data = outputs[0].to_vec()?;
477
478 assert!((output_data[0] - 4.0).abs() < 1e-6);
479 assert!((output_data[1] - 9.0).abs() < 1e-6);
480 assert!((output_data[2] - 16.0).abs() < 1e-6);
481
482 Ok(())
483 }
484
485 #[test]
486 fn test_exp_function() -> TorshResult<()> {
487 let input = Tensor::from_data(vec![0.0, 1.0], vec![2], torsh_core::DeviceType::Cpu)?;
488 let exp_fn = ExpFunction;
489
490 let outputs = apply_custom_function(exp_fn, &[input.clone()])?;
491 let output_data = outputs[0].to_vec()?;
492
493 assert!((output_data[0] - 1.0).abs() < 1e-6);
494 assert!((output_data[1] - std::f32::consts::E).abs() < 1e-6);
495
496 Ok(())
497 }
498
499 #[test]
500 fn test_scaled_add_function() -> TorshResult<()> {
501 let a = Tensor::from_data(vec![1.0, 2.0], vec![2], torsh_core::DeviceType::Cpu)?;
502 let b = Tensor::from_data(vec![3.0, 4.0], vec![2], torsh_core::DeviceType::Cpu)?;
503 let scaled_add_fn = ScaledAddFunction::new(2.0);
504
505 let outputs = apply_custom_function_with_context(scaled_add_fn, &[a, b])?;
506 let output_data = outputs[0].to_vec()?;
507
508 assert!((output_data[0] - 5.0).abs() < 1e-6);
510 assert!((output_data[1] - 8.0).abs() < 1e-6);
511
512 Ok(())
513 }
514
515 #[test]
516 fn test_registry() -> TorshResult<()> {
517 let mut registry = AutogradRegistry::new();
518 registry.register("square".to_string(), SquareFunction);
519
520 let function = registry.get("square").unwrap();
521 assert_eq!(function.name(), "square");
522 assert_eq!(function.num_inputs(), 1);
523 assert_eq!(function.num_outputs(), 1);
524
525 Ok(())
526 }
527
528 #[test]
529 fn test_global_registry() -> TorshResult<()> {
530 register_custom_function("test_square".to_string(), SquareFunction);
531
532 let input = Tensor::from_data(vec![3.0], vec![1], torsh_core::DeviceType::Cpu)?;
533 let outputs = apply_registered_function("test_square", &[input])?;
534 let output_data = outputs[0].to_vec()?;
535
536 assert!((output_data[0] - 9.0).abs() < 1e-6);
537
538 Ok(())
539 }
540}