Skip to main content

tensorlogic_scirs_backend/
custom_ops.rs

1//! Custom operations infrastructure with dynamic registration.
2//!
3//! This module provides a plugin system for user-defined tensor operations,
4//! allowing extensibility without modifying the core executor.
5//!
6//! ## Features
7//!
8//! - **Custom Operation Trait**: Define operations with forward and backward passes
9//! - **Operation Registry**: Dynamic registration and lookup
10//! - **Gradient Support**: Automatic gradient computation for custom ops
11//! - **Validation**: Shape and type checking for custom ops
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use tensorlogic_scirs_backend::custom_ops::{CustomOp, OpRegistry, CustomOpContext};
17//! use tensorlogic_scirs_backend::Scirs2Tensor;
18//!
19//! // Define a custom softplus operation
20//! struct SoftplusOp;
21//!
22//! impl CustomOp for SoftplusOp {
23//!     fn name(&self) -> &str {
24//!         "softplus"
25//!     }
26//!
27//!     fn forward(&self, inputs: &[&Scirs2Tensor], _ctx: &mut CustomOpContext) -> Result<Scirs2Tensor, String> {
28//!         let x = inputs[0];
29//!         Ok(x.mapv(|v| (1.0 + v.exp()).ln()))
30//!     }
31//!
32//!     fn backward(&self, grad: &Scirs2Tensor, inputs: &[&Scirs2Tensor], _ctx: &CustomOpContext) -> Result<Vec<Scirs2Tensor>, String> {
33//!         let x = inputs[0];
34//!         let sigmoid = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
35//!         Ok(vec![grad * &sigmoid])
36//!     }
37//! }
38//!
39//! // Register and use
40//! let mut registry = OpRegistry::new();
41//! registry.register(Box::new(SoftplusOp));
42//!
43//! let result = registry.execute("softplus", &[&tensor], &mut context)?;
44//! ```
45
46use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
47use std::collections::HashMap;
48use std::sync::Arc;
49
50/// Context for custom operation execution.
51///
52/// Provides storage for intermediate values needed during backward pass
53/// and metadata about the execution environment.
54#[derive(Debug, Clone, Default)]
55pub struct CustomOpContext {
56    /// Storage for intermediate values (forward pass -> backward pass)
57    pub intermediates: HashMap<String, Scirs2Tensor>,
58
59    /// Custom metadata
60    pub metadata: HashMap<String, String>,
61
62    /// Whether gradient computation is enabled
63    pub requires_grad: bool,
64}
65
66impl CustomOpContext {
67    /// Create a new context.
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Create a context with gradient computation enabled.
73    pub fn with_grad() -> Self {
74        Self {
75            requires_grad: true,
76            ..Default::default()
77        }
78    }
79
80    /// Store an intermediate tensor for backward pass.
81    pub fn save_for_backward(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
82        self.intermediates.insert(name.into(), tensor);
83    }
84
85    /// Retrieve a saved intermediate tensor.
86    pub fn get_saved(&self, name: &str) -> Option<&Scirs2Tensor> {
87        self.intermediates.get(name)
88    }
89
90    /// Set metadata.
91    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
92        self.metadata.insert(key.into(), value.into());
93    }
94
95    /// Get metadata.
96    pub fn get_metadata(&self, key: &str) -> Option<&String> {
97        self.metadata.get(key)
98    }
99}
100
101/// Trait for custom tensor operations.
102///
103/// Implement this trait to define custom operations that can be registered
104/// with the operation registry.
105pub trait CustomOp: Send + Sync {
106    /// Get the operation name.
107    fn name(&self) -> &str;
108
109    /// Number of inputs expected.
110    fn num_inputs(&self) -> usize {
111        1 // Default to unary operation
112    }
113
114    /// Execute the forward pass.
115    fn forward(
116        &self,
117        inputs: &[&Scirs2Tensor],
118        ctx: &mut CustomOpContext,
119    ) -> Result<Scirs2Tensor, String>;
120
121    /// Execute the backward pass (compute gradients).
122    ///
123    /// Returns gradients for each input tensor.
124    fn backward(
125        &self,
126        grad: &Scirs2Tensor,
127        inputs: &[&Scirs2Tensor],
128        ctx: &CustomOpContext,
129    ) -> Result<Vec<Scirs2Tensor>, String>;
130
131    /// Validate input shapes before execution.
132    fn validate_inputs(&self, inputs: &[&Scirs2Tensor]) -> Result<(), String> {
133        if inputs.len() != self.num_inputs() {
134            return Err(format!(
135                "Expected {} inputs, got {}",
136                self.num_inputs(),
137                inputs.len()
138            ));
139        }
140        Ok(())
141    }
142
143    /// Infer output shape from input shapes.
144    fn infer_output_shape(&self, input_shapes: &[&[usize]]) -> Result<Vec<usize>, String> {
145        // Default: same shape as first input
146        if input_shapes.is_empty() {
147            return Err("No input shapes provided".to_string());
148        }
149        Ok(input_shapes[0].to_vec())
150    }
151}
152
153/// Registry for custom operations.
154///
155/// Manages registration and lookup of custom operations by name.
156#[derive(Default)]
157pub struct OpRegistry {
158    /// Registered operations
159    ops: HashMap<String, Arc<dyn CustomOp>>,
160}
161
162impl OpRegistry {
163    /// Create a new empty registry.
164    pub fn new() -> Self {
165        Self {
166            ops: HashMap::new(),
167        }
168    }
169
170    /// Create a registry with common operations pre-registered.
171    pub fn with_standard_ops() -> Self {
172        let mut registry = Self::new();
173
174        // Register common operations
175        registry.register(Box::new(SoftplusOp));
176        registry.register(Box::new(LeakyReluOp::default()));
177        registry.register(Box::new(EluOp::default()));
178        registry.register(Box::new(SwishOp));
179        registry.register(Box::new(MishOp));
180        registry.register(Box::new(GeluOp));
181        registry.register(Box::new(HardSigmoidOp));
182        registry.register(Box::new(HardSwishOp));
183
184        registry
185    }
186
187    /// Register a custom operation.
188    pub fn register(&mut self, op: Box<dyn CustomOp>) {
189        self.ops.insert(op.name().to_string(), Arc::from(op));
190    }
191
192    /// Get a registered operation by name.
193    pub fn get(&self, name: &str) -> Option<Arc<dyn CustomOp>> {
194        self.ops.get(name).cloned()
195    }
196
197    /// Check if an operation is registered.
198    pub fn contains(&self, name: &str) -> bool {
199        self.ops.contains_key(name)
200    }
201
202    /// List all registered operations.
203    pub fn list_ops(&self) -> Vec<&str> {
204        self.ops.keys().map(|s| s.as_str()).collect()
205    }
206
207    /// Number of registered operations.
208    pub fn len(&self) -> usize {
209        self.ops.len()
210    }
211
212    /// Check if registry is empty.
213    pub fn is_empty(&self) -> bool {
214        self.ops.is_empty()
215    }
216
217    /// Execute a registered operation.
218    pub fn execute(
219        &self,
220        name: &str,
221        inputs: &[&Scirs2Tensor],
222        ctx: &mut CustomOpContext,
223    ) -> TlBackendResult<Scirs2Tensor> {
224        let op = self
225            .get(name)
226            .ok_or_else(|| TlBackendError::unsupported(format!("Unknown operation: {}", name)))?;
227
228        op.validate_inputs(inputs)
229            .map_err(TlBackendError::execution)?;
230
231        op.forward(inputs, ctx).map_err(TlBackendError::execution)
232    }
233
234    /// Execute backward pass for a registered operation.
235    pub fn backward(
236        &self,
237        name: &str,
238        grad: &Scirs2Tensor,
239        inputs: &[&Scirs2Tensor],
240        ctx: &CustomOpContext,
241    ) -> TlBackendResult<Vec<Scirs2Tensor>> {
242        let op = self
243            .get(name)
244            .ok_or_else(|| TlBackendError::unsupported(format!("Unknown operation: {}", name)))?;
245
246        op.backward(grad, inputs, ctx)
247            .map_err(TlBackendError::gradient)
248    }
249}
250
251// Standard custom operations
252
253/// Softplus activation: ln(1 + exp(x))
254pub struct SoftplusOp;
255
256impl CustomOp for SoftplusOp {
257    fn name(&self) -> &str {
258        "softplus"
259    }
260
261    fn forward(
262        &self,
263        inputs: &[&Scirs2Tensor],
264        _ctx: &mut CustomOpContext,
265    ) -> Result<Scirs2Tensor, String> {
266        let x = inputs[0];
267        // Numerically stable softplus
268        Ok(x.mapv(|v| {
269            if v > 20.0 {
270                v // For large values, softplus ≈ x
271            } else if v < -20.0 {
272                v.exp() // For small values, softplus ≈ exp(x)
273            } else {
274                (1.0 + v.exp()).ln()
275            }
276        }))
277    }
278
279    fn backward(
280        &self,
281        grad: &Scirs2Tensor,
282        inputs: &[&Scirs2Tensor],
283        _ctx: &CustomOpContext,
284    ) -> Result<Vec<Scirs2Tensor>, String> {
285        let x = inputs[0];
286        // d/dx softplus(x) = sigmoid(x)
287        let sigmoid = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
288        Ok(vec![grad * &sigmoid])
289    }
290}
291
292/// Leaky ReLU: max(alpha * x, x)
293pub struct LeakyReluOp {
294    /// Negative slope (default: 0.01)
295    pub alpha: f64,
296}
297
298impl Default for LeakyReluOp {
299    fn default() -> Self {
300        Self { alpha: 0.01 }
301    }
302}
303
304impl CustomOp for LeakyReluOp {
305    fn name(&self) -> &str {
306        "leaky_relu"
307    }
308
309    fn forward(
310        &self,
311        inputs: &[&Scirs2Tensor],
312        _ctx: &mut CustomOpContext,
313    ) -> Result<Scirs2Tensor, String> {
314        let x = inputs[0];
315        let alpha = self.alpha;
316        Ok(x.mapv(|v| if v > 0.0 { v } else { alpha * v }))
317    }
318
319    fn backward(
320        &self,
321        grad: &Scirs2Tensor,
322        inputs: &[&Scirs2Tensor],
323        _ctx: &CustomOpContext,
324    ) -> Result<Vec<Scirs2Tensor>, String> {
325        let x = inputs[0];
326        let alpha = self.alpha;
327        let grad_input = scirs2_core::ndarray::Zip::from(grad)
328            .and(x)
329            .map_collect(|&g, &v| if v > 0.0 { g } else { alpha * g });
330        Ok(vec![grad_input])
331    }
332}
333
334/// ELU: x if x > 0 else alpha * (exp(x) - 1)
335pub struct EluOp {
336    /// Scale for negative values (default: 1.0)
337    pub alpha: f64,
338}
339
340impl Default for EluOp {
341    fn default() -> Self {
342        Self { alpha: 1.0 }
343    }
344}
345
346impl CustomOp for EluOp {
347    fn name(&self) -> &str {
348        "elu"
349    }
350
351    fn forward(
352        &self,
353        inputs: &[&Scirs2Tensor],
354        ctx: &mut CustomOpContext,
355    ) -> Result<Scirs2Tensor, String> {
356        let x = inputs[0];
357        let alpha = self.alpha;
358        let result = x.mapv(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) });
359
360        // Save for backward
361        if ctx.requires_grad {
362            ctx.save_for_backward("output", result.clone());
363        }
364
365        Ok(result)
366    }
367
368    fn backward(
369        &self,
370        grad: &Scirs2Tensor,
371        inputs: &[&Scirs2Tensor],
372        ctx: &CustomOpContext,
373    ) -> Result<Vec<Scirs2Tensor>, String> {
374        let x = inputs[0];
375        let alpha = self.alpha;
376
377        let grad_input = if let Some(output) = ctx.get_saved("output") {
378            // Use saved output for efficiency
379            scirs2_core::ndarray::Zip::from(grad)
380                .and(x)
381                .and(output)
382                .map_collect(|&g, &v, &o| if v > 0.0 { g } else { g * (o + alpha) })
383        } else {
384            // Compute from inputs
385            scirs2_core::ndarray::Zip::from(grad)
386                .and(x)
387                .map_collect(|&g, &v| if v > 0.0 { g } else { g * alpha * v.exp() })
388        };
389
390        Ok(vec![grad_input])
391    }
392}
393
394/// Swish: x * sigmoid(x)
395pub struct SwishOp;
396
397impl CustomOp for SwishOp {
398    fn name(&self) -> &str {
399        "swish"
400    }
401
402    fn forward(
403        &self,
404        inputs: &[&Scirs2Tensor],
405        ctx: &mut CustomOpContext,
406    ) -> Result<Scirs2Tensor, String> {
407        let x = inputs[0];
408        let sigmoid = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
409        let result = x * &sigmoid;
410
411        if ctx.requires_grad {
412            ctx.save_for_backward("sigmoid", sigmoid);
413        }
414
415        Ok(result)
416    }
417
418    fn backward(
419        &self,
420        grad: &Scirs2Tensor,
421        inputs: &[&Scirs2Tensor],
422        ctx: &CustomOpContext,
423    ) -> Result<Vec<Scirs2Tensor>, String> {
424        let x = inputs[0];
425
426        let sigmoid = if let Some(s) = ctx.get_saved("sigmoid") {
427            s.clone()
428        } else {
429            x.mapv(|v| 1.0 / (1.0 + (-v).exp()))
430        };
431
432        // d/dx swish(x) = swish(x) + sigmoid(x) * (1 - swish(x))
433        // = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
434        let grad_input = scirs2_core::ndarray::Zip::from(grad)
435            .and(x)
436            .and(&sigmoid)
437            .map_collect(|&g, &v, &s| g * (s + v * s * (1.0 - s)));
438
439        Ok(vec![grad_input])
440    }
441}
442
443/// Mish: x * tanh(softplus(x))
444pub struct MishOp;
445
446impl CustomOp for MishOp {
447    fn name(&self) -> &str {
448        "mish"
449    }
450
451    fn forward(
452        &self,
453        inputs: &[&Scirs2Tensor],
454        _ctx: &mut CustomOpContext,
455    ) -> Result<Scirs2Tensor, String> {
456        let x = inputs[0];
457        Ok(x.mapv(|v| {
458            let softplus = if v > 20.0 {
459                v
460            } else if v < -20.0 {
461                v.exp()
462            } else {
463                (1.0 + v.exp()).ln()
464            };
465            v * softplus.tanh()
466        }))
467    }
468
469    fn backward(
470        &self,
471        grad: &Scirs2Tensor,
472        inputs: &[&Scirs2Tensor],
473        _ctx: &CustomOpContext,
474    ) -> Result<Vec<Scirs2Tensor>, String> {
475        let x = inputs[0];
476        // Numerical gradient computation for mish
477        let grad_input = scirs2_core::ndarray::Zip::from(grad)
478            .and(x)
479            .map_collect(|&g, &v| {
480                let e = v.exp();
481                let omega = 4.0 * (v + 1.0) + 4.0 * e * e + e * e * e + e * (4.0 * v + 6.0);
482                let delta = 2.0 * e + e * e + 2.0;
483                g * e * omega / (delta * delta)
484            });
485
486        Ok(vec![grad_input])
487    }
488}
489
490/// GELU: Gaussian Error Linear Unit
491pub struct GeluOp;
492
493impl CustomOp for GeluOp {
494    fn name(&self) -> &str {
495        "gelu"
496    }
497
498    fn forward(
499        &self,
500        inputs: &[&Scirs2Tensor],
501        _ctx: &mut CustomOpContext,
502    ) -> Result<Scirs2Tensor, String> {
503        let x = inputs[0];
504        // Approximate GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
505        let sqrt_2_over_pi = (2.0 / std::f64::consts::PI).sqrt();
506        Ok(x.mapv(|v| {
507            let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
508            0.5 * v * (1.0 + inner.tanh())
509        }))
510    }
511
512    fn backward(
513        &self,
514        grad: &Scirs2Tensor,
515        inputs: &[&Scirs2Tensor],
516        _ctx: &CustomOpContext,
517    ) -> Result<Vec<Scirs2Tensor>, String> {
518        let x = inputs[0];
519        let sqrt_2_over_pi = (2.0 / std::f64::consts::PI).sqrt();
520
521        let grad_input = scirs2_core::ndarray::Zip::from(grad)
522            .and(x)
523            .map_collect(|&g, &v| {
524                let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
525                let tanh_inner = inner.tanh();
526                let sech2 = 1.0 - tanh_inner * tanh_inner;
527                let d_inner = sqrt_2_over_pi * (1.0 + 3.0 * 0.044715 * v * v);
528
529                g * (0.5 * (1.0 + tanh_inner) + 0.5 * v * sech2 * d_inner)
530            });
531
532        Ok(vec![grad_input])
533    }
534}
535
536/// Hard Sigmoid: clip((x + 3) / 6, 0, 1)
537pub struct HardSigmoidOp;
538
539impl CustomOp for HardSigmoidOp {
540    fn name(&self) -> &str {
541        "hard_sigmoid"
542    }
543
544    fn forward(
545        &self,
546        inputs: &[&Scirs2Tensor],
547        _ctx: &mut CustomOpContext,
548    ) -> Result<Scirs2Tensor, String> {
549        let x = inputs[0];
550        Ok(x.mapv(|v| ((v + 3.0) / 6.0).clamp(0.0, 1.0)))
551    }
552
553    fn backward(
554        &self,
555        grad: &Scirs2Tensor,
556        inputs: &[&Scirs2Tensor],
557        _ctx: &CustomOpContext,
558    ) -> Result<Vec<Scirs2Tensor>, String> {
559        let x = inputs[0];
560        let grad_input = scirs2_core::ndarray::Zip::from(grad)
561            .and(x)
562            .map_collect(|&g, &v| if v > -3.0 && v < 3.0 { g / 6.0 } else { 0.0 });
563
564        Ok(vec![grad_input])
565    }
566}
567
568/// Hard Swish: x * hard_sigmoid(x)
569pub struct HardSwishOp;
570
571impl CustomOp for HardSwishOp {
572    fn name(&self) -> &str {
573        "hard_swish"
574    }
575
576    fn forward(
577        &self,
578        inputs: &[&Scirs2Tensor],
579        _ctx: &mut CustomOpContext,
580    ) -> Result<Scirs2Tensor, String> {
581        let x = inputs[0];
582        Ok(x.mapv(|v| {
583            let hard_sigmoid = ((v + 3.0) / 6.0).clamp(0.0, 1.0);
584            v * hard_sigmoid
585        }))
586    }
587
588    fn backward(
589        &self,
590        grad: &Scirs2Tensor,
591        inputs: &[&Scirs2Tensor],
592        _ctx: &CustomOpContext,
593    ) -> Result<Vec<Scirs2Tensor>, String> {
594        let x = inputs[0];
595        let grad_input = scirs2_core::ndarray::Zip::from(grad)
596            .and(x)
597            .map_collect(|&g, &v| {
598                if v <= -3.0 {
599                    0.0
600                } else if v >= 3.0 {
601                    g
602                } else {
603                    g * (v / 3.0 + 0.5)
604                }
605            });
606
607        Ok(vec![grad_input])
608    }
609}
610
611/// Binary custom operation for element-wise operations
612pub struct BinaryCustomOp<F, G>
613where
614    F: Fn(f64, f64) -> f64 + Send + Sync,
615    G: Fn(f64, f64, f64) -> (f64, f64) + Send + Sync,
616{
617    name: String,
618    forward_fn: F,
619    backward_fn: G,
620}
621
622impl<F, G> BinaryCustomOp<F, G>
623where
624    F: Fn(f64, f64) -> f64 + Send + Sync,
625    G: Fn(f64, f64, f64) -> (f64, f64) + Send + Sync,
626{
627    /// Create a new binary custom operation.
628    pub fn new(name: impl Into<String>, forward_fn: F, backward_fn: G) -> Self {
629        Self {
630            name: name.into(),
631            forward_fn,
632            backward_fn,
633        }
634    }
635}
636
637impl<F, G> CustomOp for BinaryCustomOp<F, G>
638where
639    F: Fn(f64, f64) -> f64 + Send + Sync,
640    G: Fn(f64, f64, f64) -> (f64, f64) + Send + Sync,
641{
642    fn name(&self) -> &str {
643        &self.name
644    }
645
646    fn num_inputs(&self) -> usize {
647        2
648    }
649
650    fn forward(
651        &self,
652        inputs: &[&Scirs2Tensor],
653        _ctx: &mut CustomOpContext,
654    ) -> Result<Scirs2Tensor, String> {
655        let x = inputs[0];
656        let y = inputs[1];
657
658        if x.shape() != y.shape() {
659            return Err(format!(
660                "Shape mismatch: {:?} vs {:?}",
661                x.shape(),
662                y.shape()
663            ));
664        }
665
666        let result = scirs2_core::ndarray::Zip::from(x)
667            .and(y)
668            .map_collect(|&a, &b| (self.forward_fn)(a, b));
669
670        Ok(result)
671    }
672
673    fn backward(
674        &self,
675        grad: &Scirs2Tensor,
676        inputs: &[&Scirs2Tensor],
677        _ctx: &CustomOpContext,
678    ) -> Result<Vec<Scirs2Tensor>, String> {
679        let x = inputs[0];
680        let y = inputs[1];
681
682        let mut grad_x = Scirs2Tensor::zeros(x.raw_dim());
683        let mut grad_y = Scirs2Tensor::zeros(y.raw_dim());
684
685        scirs2_core::ndarray::Zip::from(&mut grad_x)
686            .and(&mut grad_y)
687            .and(grad)
688            .and(x)
689            .and(y)
690            .for_each(|gx, gy, &g, &a, &b| {
691                let (dx, dy) = (self.backward_fn)(a, b, g);
692                *gx = dx;
693                *gy = dy;
694            });
695
696        Ok(vec![grad_x, grad_y])
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use scirs2_core::ndarray::ArrayD;
704
705    fn create_tensor(data: Vec<f64>, shape: Vec<usize>) -> Scirs2Tensor {
706        ArrayD::from_shape_vec(shape, data).unwrap()
707    }
708
709    #[test]
710    fn test_op_registry_basic() {
711        let mut registry = OpRegistry::new();
712        assert!(registry.is_empty());
713
714        registry.register(Box::new(SoftplusOp));
715        assert_eq!(registry.len(), 1);
716        assert!(registry.contains("softplus"));
717        assert!(!registry.contains("unknown"));
718    }
719
720    #[test]
721    fn test_op_registry_with_standard_ops() {
722        let registry = OpRegistry::with_standard_ops();
723        assert!(registry.contains("softplus"));
724        assert!(registry.contains("leaky_relu"));
725        assert!(registry.contains("elu"));
726        assert!(registry.contains("swish"));
727        assert!(registry.contains("mish"));
728        assert!(registry.contains("gelu"));
729        assert!(registry.contains("hard_sigmoid"));
730        assert!(registry.contains("hard_swish"));
731    }
732
733    #[test]
734    fn test_softplus_forward() {
735        let registry = OpRegistry::with_standard_ops();
736        let tensor = create_tensor(vec![-1.0, 0.0, 1.0], vec![3]);
737        let mut ctx = CustomOpContext::new();
738
739        let result = registry.execute("softplus", &[&tensor], &mut ctx).unwrap();
740
741        // softplus(-1) ≈ 0.3133, softplus(0) = ln(2) ≈ 0.6931, softplus(1) ≈ 1.3133
742        assert!(result[[0]] > 0.3 && result[[0]] < 0.35);
743        assert!((result[[1]] - std::f64::consts::LN_2).abs() < 0.01);
744        assert!(result[[2]] > 1.3 && result[[2]] < 1.35);
745    }
746
747    #[test]
748    fn test_softplus_backward() {
749        let registry = OpRegistry::with_standard_ops();
750        let tensor = create_tensor(vec![0.0], vec![1]);
751        let grad = create_tensor(vec![1.0], vec![1]);
752        let ctx = CustomOpContext::new();
753
754        let grads = registry
755            .backward("softplus", &grad, &[&tensor], &ctx)
756            .unwrap();
757
758        // d/dx softplus(0) = sigmoid(0) = 0.5
759        assert!((grads[0][[0]] - 0.5).abs() < 0.001);
760    }
761
762    #[test]
763    fn test_leaky_relu_forward() {
764        let registry = OpRegistry::with_standard_ops();
765        let tensor = create_tensor(vec![-2.0, 0.0, 2.0], vec![3]);
766        let mut ctx = CustomOpContext::new();
767
768        let result = registry
769            .execute("leaky_relu", &[&tensor], &mut ctx)
770            .unwrap();
771
772        assert!((result[[0]] - (-0.02)).abs() < 0.001); // -2 * 0.01
773        assert_eq!(result[[1]], 0.0);
774        assert_eq!(result[[2]], 2.0);
775    }
776
777    #[test]
778    fn test_elu_forward() {
779        let registry = OpRegistry::with_standard_ops();
780        let tensor = create_tensor(vec![-1.0, 0.0, 1.0], vec![3]);
781        let mut ctx = CustomOpContext::with_grad();
782
783        let result = registry.execute("elu", &[&tensor], &mut ctx).unwrap();
784
785        // elu(-1) = exp(-1) - 1 ≈ -0.632
786        assert!((result[[0]] - (-0.632)).abs() < 0.01);
787        assert_eq!(result[[1]], 0.0);
788        assert_eq!(result[[2]], 1.0);
789    }
790
791    #[test]
792    fn test_swish_forward() {
793        let registry = OpRegistry::with_standard_ops();
794        let tensor = create_tensor(vec![0.0], vec![1]);
795        let mut ctx = CustomOpContext::new();
796
797        let result = registry.execute("swish", &[&tensor], &mut ctx).unwrap();
798
799        // swish(0) = 0 * 0.5 = 0
800        assert_eq!(result[[0]], 0.0);
801    }
802
803    #[test]
804    fn test_gelu_forward() {
805        let registry = OpRegistry::with_standard_ops();
806        let tensor = create_tensor(vec![-1.0, 0.0, 1.0], vec![3]);
807        let mut ctx = CustomOpContext::new();
808
809        let result = registry.execute("gelu", &[&tensor], &mut ctx).unwrap();
810
811        // gelu(0) = 0
812        assert!((result[[1]]).abs() < 0.01);
813        // gelu(x) has specific values
814        assert!(result[[0]] < 0.0); // gelu(-1) is negative
815        assert!(result[[2]] > 0.5); // gelu(1) > 0.5
816    }
817
818    #[test]
819    fn test_hard_sigmoid_forward() {
820        let registry = OpRegistry::with_standard_ops();
821        let tensor = create_tensor(vec![-4.0, 0.0, 4.0], vec![3]);
822        let mut ctx = CustomOpContext::new();
823
824        let result = registry
825            .execute("hard_sigmoid", &[&tensor], &mut ctx)
826            .unwrap();
827
828        assert_eq!(result[[0]], 0.0); // Clipped to 0
829        assert_eq!(result[[1]], 0.5); // (0 + 3) / 6 = 0.5
830        assert_eq!(result[[2]], 1.0); // Clipped to 1
831    }
832
833    #[test]
834    fn test_hard_swish_forward() {
835        let registry = OpRegistry::with_standard_ops();
836        let tensor = create_tensor(vec![-4.0, 0.0, 4.0], vec![3]);
837        let mut ctx = CustomOpContext::new();
838
839        let result = registry
840            .execute("hard_swish", &[&tensor], &mut ctx)
841            .unwrap();
842
843        assert_eq!(result[[0]], 0.0); // -4 * 0 = 0
844        assert_eq!(result[[1]], 0.0); // 0 * 0.5 = 0
845        assert_eq!(result[[2]], 4.0); // 4 * 1 = 4
846    }
847
848    #[test]
849    fn test_custom_op_context() {
850        let mut ctx = CustomOpContext::with_grad();
851        assert!(ctx.requires_grad);
852
853        let tensor = create_tensor(vec![1.0, 2.0], vec![2]);
854        ctx.save_for_backward("test", tensor.clone());
855
856        let saved = ctx.get_saved("test").unwrap();
857        assert_eq!(saved[[0]], 1.0);
858        assert_eq!(saved[[1]], 2.0);
859
860        ctx.set_metadata("key", "value");
861        assert_eq!(ctx.get_metadata("key"), Some(&"value".to_string()));
862    }
863
864    #[test]
865    fn test_binary_custom_op() {
866        // Define a custom power operation
867        let pow_op = BinaryCustomOp::new(
868            "pow",
869            |a, b| a.powf(b),
870            |a, b, g| {
871                let da = g * b * a.powf(b - 1.0);
872                let db = g * a.powf(b) * a.ln();
873                (da, db)
874            },
875        );
876
877        let mut registry = OpRegistry::new();
878        registry.register(Box::new(pow_op));
879
880        let x = create_tensor(vec![2.0, 3.0], vec![2]);
881        let y = create_tensor(vec![3.0, 2.0], vec![2]);
882        let mut ctx = CustomOpContext::new();
883
884        let result = registry.execute("pow", &[&x, &y], &mut ctx).unwrap();
885
886        assert_eq!(result[[0]], 8.0); // 2^3
887        assert_eq!(result[[1]], 9.0); // 3^2
888    }
889
890    #[test]
891    fn test_validate_inputs() {
892        let registry = OpRegistry::with_standard_ops();
893        let tensor = create_tensor(vec![1.0], vec![1]);
894        let mut ctx = CustomOpContext::new();
895
896        // Correct number of inputs
897        let result = registry.execute("softplus", &[&tensor], &mut ctx);
898        assert!(result.is_ok());
899
900        // Wrong number of inputs
901        let result = registry.execute("softplus", &[&tensor, &tensor], &mut ctx);
902        assert!(result.is_err());
903    }
904
905    #[test]
906    fn test_list_ops() {
907        let registry = OpRegistry::with_standard_ops();
908        let ops = registry.list_ops();
909
910        assert!(ops.contains(&"softplus"));
911        assert!(ops.contains(&"gelu"));
912    }
913
914    #[test]
915    fn test_unknown_operation() {
916        let registry = OpRegistry::new();
917        let tensor = create_tensor(vec![1.0], vec![1]);
918        let mut ctx = CustomOpContext::new();
919
920        let result = registry.execute("unknown", &[&tensor], &mut ctx);
921        assert!(result.is_err());
922    }
923
924    #[test]
925    fn test_mish_forward() {
926        let registry = OpRegistry::with_standard_ops();
927        let tensor = create_tensor(vec![0.0], vec![1]);
928        let mut ctx = CustomOpContext::new();
929
930        let result = registry.execute("mish", &[&tensor], &mut ctx).unwrap();
931
932        // mish(0) = 0 * tanh(softplus(0)) = 0 * tanh(ln(2)) ≈ 0
933        assert!(result[[0]].abs() < 0.01);
934    }
935}