1use crate::{Scirs2Tensor, TlBackendError, TlBackendResult};
47use std::collections::HashMap;
48use std::sync::Arc;
49
50#[derive(Debug, Clone, Default)]
55pub struct CustomOpContext {
56 pub intermediates: HashMap<String, Scirs2Tensor>,
58
59 pub metadata: HashMap<String, String>,
61
62 pub requires_grad: bool,
64}
65
66impl CustomOpContext {
67 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn with_grad() -> Self {
74 Self {
75 requires_grad: true,
76 ..Default::default()
77 }
78 }
79
80 pub fn save_for_backward(&mut self, name: impl Into<String>, tensor: Scirs2Tensor) {
82 self.intermediates.insert(name.into(), tensor);
83 }
84
85 pub fn get_saved(&self, name: &str) -> Option<&Scirs2Tensor> {
87 self.intermediates.get(name)
88 }
89
90 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
92 self.metadata.insert(key.into(), value.into());
93 }
94
95 pub fn get_metadata(&self, key: &str) -> Option<&String> {
97 self.metadata.get(key)
98 }
99}
100
101pub trait CustomOp: Send + Sync {
106 fn name(&self) -> &str;
108
109 fn num_inputs(&self) -> usize {
111 1 }
113
114 fn forward(
116 &self,
117 inputs: &[&Scirs2Tensor],
118 ctx: &mut CustomOpContext,
119 ) -> Result<Scirs2Tensor, String>;
120
121 fn backward(
125 &self,
126 grad: &Scirs2Tensor,
127 inputs: &[&Scirs2Tensor],
128 ctx: &CustomOpContext,
129 ) -> Result<Vec<Scirs2Tensor>, String>;
130
131 fn validate_inputs(&self, inputs: &[&Scirs2Tensor]) -> Result<(), String> {
133 if inputs.len() != self.num_inputs() {
134 return Err(format!(
135 "Expected {} inputs, got {}",
136 self.num_inputs(),
137 inputs.len()
138 ));
139 }
140 Ok(())
141 }
142
143 fn infer_output_shape(&self, input_shapes: &[&[usize]]) -> Result<Vec<usize>, String> {
145 if input_shapes.is_empty() {
147 return Err("No input shapes provided".to_string());
148 }
149 Ok(input_shapes[0].to_vec())
150 }
151}
152
153#[derive(Default)]
157pub struct OpRegistry {
158 ops: HashMap<String, Arc<dyn CustomOp>>,
160}
161
162impl OpRegistry {
163 pub fn new() -> Self {
165 Self {
166 ops: HashMap::new(),
167 }
168 }
169
170 pub fn with_standard_ops() -> Self {
172 let mut registry = Self::new();
173
174 registry.register(Box::new(SoftplusOp));
176 registry.register(Box::new(LeakyReluOp::default()));
177 registry.register(Box::new(EluOp::default()));
178 registry.register(Box::new(SwishOp));
179 registry.register(Box::new(MishOp));
180 registry.register(Box::new(GeluOp));
181 registry.register(Box::new(HardSigmoidOp));
182 registry.register(Box::new(HardSwishOp));
183
184 registry
185 }
186
187 pub fn register(&mut self, op: Box<dyn CustomOp>) {
189 self.ops.insert(op.name().to_string(), Arc::from(op));
190 }
191
192 pub fn get(&self, name: &str) -> Option<Arc<dyn CustomOp>> {
194 self.ops.get(name).cloned()
195 }
196
197 pub fn contains(&self, name: &str) -> bool {
199 self.ops.contains_key(name)
200 }
201
202 pub fn list_ops(&self) -> Vec<&str> {
204 self.ops.keys().map(|s| s.as_str()).collect()
205 }
206
207 pub fn len(&self) -> usize {
209 self.ops.len()
210 }
211
212 pub fn is_empty(&self) -> bool {
214 self.ops.is_empty()
215 }
216
217 pub fn execute(
219 &self,
220 name: &str,
221 inputs: &[&Scirs2Tensor],
222 ctx: &mut CustomOpContext,
223 ) -> TlBackendResult<Scirs2Tensor> {
224 let op = self
225 .get(name)
226 .ok_or_else(|| TlBackendError::unsupported(format!("Unknown operation: {}", name)))?;
227
228 op.validate_inputs(inputs)
229 .map_err(TlBackendError::execution)?;
230
231 op.forward(inputs, ctx).map_err(TlBackendError::execution)
232 }
233
234 pub fn backward(
236 &self,
237 name: &str,
238 grad: &Scirs2Tensor,
239 inputs: &[&Scirs2Tensor],
240 ctx: &CustomOpContext,
241 ) -> TlBackendResult<Vec<Scirs2Tensor>> {
242 let op = self
243 .get(name)
244 .ok_or_else(|| TlBackendError::unsupported(format!("Unknown operation: {}", name)))?;
245
246 op.backward(grad, inputs, ctx)
247 .map_err(TlBackendError::gradient)
248 }
249}
250
251pub struct SoftplusOp;
255
256impl CustomOp for SoftplusOp {
257 fn name(&self) -> &str {
258 "softplus"
259 }
260
261 fn forward(
262 &self,
263 inputs: &[&Scirs2Tensor],
264 _ctx: &mut CustomOpContext,
265 ) -> Result<Scirs2Tensor, String> {
266 let x = inputs[0];
267 Ok(x.mapv(|v| {
269 if v > 20.0 {
270 v } else if v < -20.0 {
272 v.exp() } else {
274 (1.0 + v.exp()).ln()
275 }
276 }))
277 }
278
279 fn backward(
280 &self,
281 grad: &Scirs2Tensor,
282 inputs: &[&Scirs2Tensor],
283 _ctx: &CustomOpContext,
284 ) -> Result<Vec<Scirs2Tensor>, String> {
285 let x = inputs[0];
286 let sigmoid = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
288 Ok(vec![grad * &sigmoid])
289 }
290}
291
292pub struct LeakyReluOp {
294 pub alpha: f64,
296}
297
298impl Default for LeakyReluOp {
299 fn default() -> Self {
300 Self { alpha: 0.01 }
301 }
302}
303
304impl CustomOp for LeakyReluOp {
305 fn name(&self) -> &str {
306 "leaky_relu"
307 }
308
309 fn forward(
310 &self,
311 inputs: &[&Scirs2Tensor],
312 _ctx: &mut CustomOpContext,
313 ) -> Result<Scirs2Tensor, String> {
314 let x = inputs[0];
315 let alpha = self.alpha;
316 Ok(x.mapv(|v| if v > 0.0 { v } else { alpha * v }))
317 }
318
319 fn backward(
320 &self,
321 grad: &Scirs2Tensor,
322 inputs: &[&Scirs2Tensor],
323 _ctx: &CustomOpContext,
324 ) -> Result<Vec<Scirs2Tensor>, String> {
325 let x = inputs[0];
326 let alpha = self.alpha;
327 let grad_input = scirs2_core::ndarray::Zip::from(grad)
328 .and(x)
329 .map_collect(|&g, &v| if v > 0.0 { g } else { alpha * g });
330 Ok(vec![grad_input])
331 }
332}
333
334pub struct EluOp {
336 pub alpha: f64,
338}
339
340impl Default for EluOp {
341 fn default() -> Self {
342 Self { alpha: 1.0 }
343 }
344}
345
346impl CustomOp for EluOp {
347 fn name(&self) -> &str {
348 "elu"
349 }
350
351 fn forward(
352 &self,
353 inputs: &[&Scirs2Tensor],
354 ctx: &mut CustomOpContext,
355 ) -> Result<Scirs2Tensor, String> {
356 let x = inputs[0];
357 let alpha = self.alpha;
358 let result = x.mapv(|v| if v > 0.0 { v } else { alpha * (v.exp() - 1.0) });
359
360 if ctx.requires_grad {
362 ctx.save_for_backward("output", result.clone());
363 }
364
365 Ok(result)
366 }
367
368 fn backward(
369 &self,
370 grad: &Scirs2Tensor,
371 inputs: &[&Scirs2Tensor],
372 ctx: &CustomOpContext,
373 ) -> Result<Vec<Scirs2Tensor>, String> {
374 let x = inputs[0];
375 let alpha = self.alpha;
376
377 let grad_input = if let Some(output) = ctx.get_saved("output") {
378 scirs2_core::ndarray::Zip::from(grad)
380 .and(x)
381 .and(output)
382 .map_collect(|&g, &v, &o| if v > 0.0 { g } else { g * (o + alpha) })
383 } else {
384 scirs2_core::ndarray::Zip::from(grad)
386 .and(x)
387 .map_collect(|&g, &v| if v > 0.0 { g } else { g * alpha * v.exp() })
388 };
389
390 Ok(vec![grad_input])
391 }
392}
393
394pub struct SwishOp;
396
397impl CustomOp for SwishOp {
398 fn name(&self) -> &str {
399 "swish"
400 }
401
402 fn forward(
403 &self,
404 inputs: &[&Scirs2Tensor],
405 ctx: &mut CustomOpContext,
406 ) -> Result<Scirs2Tensor, String> {
407 let x = inputs[0];
408 let sigmoid = x.mapv(|v| 1.0 / (1.0 + (-v).exp()));
409 let result = x * &sigmoid;
410
411 if ctx.requires_grad {
412 ctx.save_for_backward("sigmoid", sigmoid);
413 }
414
415 Ok(result)
416 }
417
418 fn backward(
419 &self,
420 grad: &Scirs2Tensor,
421 inputs: &[&Scirs2Tensor],
422 ctx: &CustomOpContext,
423 ) -> Result<Vec<Scirs2Tensor>, String> {
424 let x = inputs[0];
425
426 let sigmoid = if let Some(s) = ctx.get_saved("sigmoid") {
427 s.clone()
428 } else {
429 x.mapv(|v| 1.0 / (1.0 + (-v).exp()))
430 };
431
432 let grad_input = scirs2_core::ndarray::Zip::from(grad)
435 .and(x)
436 .and(&sigmoid)
437 .map_collect(|&g, &v, &s| g * (s + v * s * (1.0 - s)));
438
439 Ok(vec![grad_input])
440 }
441}
442
443pub struct MishOp;
445
446impl CustomOp for MishOp {
447 fn name(&self) -> &str {
448 "mish"
449 }
450
451 fn forward(
452 &self,
453 inputs: &[&Scirs2Tensor],
454 _ctx: &mut CustomOpContext,
455 ) -> Result<Scirs2Tensor, String> {
456 let x = inputs[0];
457 Ok(x.mapv(|v| {
458 let softplus = if v > 20.0 {
459 v
460 } else if v < -20.0 {
461 v.exp()
462 } else {
463 (1.0 + v.exp()).ln()
464 };
465 v * softplus.tanh()
466 }))
467 }
468
469 fn backward(
470 &self,
471 grad: &Scirs2Tensor,
472 inputs: &[&Scirs2Tensor],
473 _ctx: &CustomOpContext,
474 ) -> Result<Vec<Scirs2Tensor>, String> {
475 let x = inputs[0];
476 let grad_input = scirs2_core::ndarray::Zip::from(grad)
478 .and(x)
479 .map_collect(|&g, &v| {
480 let e = v.exp();
481 let omega = 4.0 * (v + 1.0) + 4.0 * e * e + e * e * e + e * (4.0 * v + 6.0);
482 let delta = 2.0 * e + e * e + 2.0;
483 g * e * omega / (delta * delta)
484 });
485
486 Ok(vec![grad_input])
487 }
488}
489
490pub struct GeluOp;
492
493impl CustomOp for GeluOp {
494 fn name(&self) -> &str {
495 "gelu"
496 }
497
498 fn forward(
499 &self,
500 inputs: &[&Scirs2Tensor],
501 _ctx: &mut CustomOpContext,
502 ) -> Result<Scirs2Tensor, String> {
503 let x = inputs[0];
504 let sqrt_2_over_pi = (2.0 / std::f64::consts::PI).sqrt();
506 Ok(x.mapv(|v| {
507 let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
508 0.5 * v * (1.0 + inner.tanh())
509 }))
510 }
511
512 fn backward(
513 &self,
514 grad: &Scirs2Tensor,
515 inputs: &[&Scirs2Tensor],
516 _ctx: &CustomOpContext,
517 ) -> Result<Vec<Scirs2Tensor>, String> {
518 let x = inputs[0];
519 let sqrt_2_over_pi = (2.0 / std::f64::consts::PI).sqrt();
520
521 let grad_input = scirs2_core::ndarray::Zip::from(grad)
522 .and(x)
523 .map_collect(|&g, &v| {
524 let inner = sqrt_2_over_pi * (v + 0.044715 * v * v * v);
525 let tanh_inner = inner.tanh();
526 let sech2 = 1.0 - tanh_inner * tanh_inner;
527 let d_inner = sqrt_2_over_pi * (1.0 + 3.0 * 0.044715 * v * v);
528
529 g * (0.5 * (1.0 + tanh_inner) + 0.5 * v * sech2 * d_inner)
530 });
531
532 Ok(vec![grad_input])
533 }
534}
535
536pub struct HardSigmoidOp;
538
539impl CustomOp for HardSigmoidOp {
540 fn name(&self) -> &str {
541 "hard_sigmoid"
542 }
543
544 fn forward(
545 &self,
546 inputs: &[&Scirs2Tensor],
547 _ctx: &mut CustomOpContext,
548 ) -> Result<Scirs2Tensor, String> {
549 let x = inputs[0];
550 Ok(x.mapv(|v| ((v + 3.0) / 6.0).clamp(0.0, 1.0)))
551 }
552
553 fn backward(
554 &self,
555 grad: &Scirs2Tensor,
556 inputs: &[&Scirs2Tensor],
557 _ctx: &CustomOpContext,
558 ) -> Result<Vec<Scirs2Tensor>, String> {
559 let x = inputs[0];
560 let grad_input = scirs2_core::ndarray::Zip::from(grad)
561 .and(x)
562 .map_collect(|&g, &v| if v > -3.0 && v < 3.0 { g / 6.0 } else { 0.0 });
563
564 Ok(vec![grad_input])
565 }
566}
567
568pub struct HardSwishOp;
570
571impl CustomOp for HardSwishOp {
572 fn name(&self) -> &str {
573 "hard_swish"
574 }
575
576 fn forward(
577 &self,
578 inputs: &[&Scirs2Tensor],
579 _ctx: &mut CustomOpContext,
580 ) -> Result<Scirs2Tensor, String> {
581 let x = inputs[0];
582 Ok(x.mapv(|v| {
583 let hard_sigmoid = ((v + 3.0) / 6.0).clamp(0.0, 1.0);
584 v * hard_sigmoid
585 }))
586 }
587
588 fn backward(
589 &self,
590 grad: &Scirs2Tensor,
591 inputs: &[&Scirs2Tensor],
592 _ctx: &CustomOpContext,
593 ) -> Result<Vec<Scirs2Tensor>, String> {
594 let x = inputs[0];
595 let grad_input = scirs2_core::ndarray::Zip::from(grad)
596 .and(x)
597 .map_collect(|&g, &v| {
598 if v <= -3.0 {
599 0.0
600 } else if v >= 3.0 {
601 g
602 } else {
603 g * (v / 3.0 + 0.5)
604 }
605 });
606
607 Ok(vec![grad_input])
608 }
609}
610
611pub struct BinaryCustomOp<F, G>
613where
614 F: Fn(f64, f64) -> f64 + Send + Sync,
615 G: Fn(f64, f64, f64) -> (f64, f64) + Send + Sync,
616{
617 name: String,
618 forward_fn: F,
619 backward_fn: G,
620}
621
622impl<F, G> BinaryCustomOp<F, G>
623where
624 F: Fn(f64, f64) -> f64 + Send + Sync,
625 G: Fn(f64, f64, f64) -> (f64, f64) + Send + Sync,
626{
627 pub fn new(name: impl Into<String>, forward_fn: F, backward_fn: G) -> Self {
629 Self {
630 name: name.into(),
631 forward_fn,
632 backward_fn,
633 }
634 }
635}
636
637impl<F, G> CustomOp for BinaryCustomOp<F, G>
638where
639 F: Fn(f64, f64) -> f64 + Send + Sync,
640 G: Fn(f64, f64, f64) -> (f64, f64) + Send + Sync,
641{
642 fn name(&self) -> &str {
643 &self.name
644 }
645
646 fn num_inputs(&self) -> usize {
647 2
648 }
649
650 fn forward(
651 &self,
652 inputs: &[&Scirs2Tensor],
653 _ctx: &mut CustomOpContext,
654 ) -> Result<Scirs2Tensor, String> {
655 let x = inputs[0];
656 let y = inputs[1];
657
658 if x.shape() != y.shape() {
659 return Err(format!(
660 "Shape mismatch: {:?} vs {:?}",
661 x.shape(),
662 y.shape()
663 ));
664 }
665
666 let result = scirs2_core::ndarray::Zip::from(x)
667 .and(y)
668 .map_collect(|&a, &b| (self.forward_fn)(a, b));
669
670 Ok(result)
671 }
672
673 fn backward(
674 &self,
675 grad: &Scirs2Tensor,
676 inputs: &[&Scirs2Tensor],
677 _ctx: &CustomOpContext,
678 ) -> Result<Vec<Scirs2Tensor>, String> {
679 let x = inputs[0];
680 let y = inputs[1];
681
682 let mut grad_x = Scirs2Tensor::zeros(x.raw_dim());
683 let mut grad_y = Scirs2Tensor::zeros(y.raw_dim());
684
685 scirs2_core::ndarray::Zip::from(&mut grad_x)
686 .and(&mut grad_y)
687 .and(grad)
688 .and(x)
689 .and(y)
690 .for_each(|gx, gy, &g, &a, &b| {
691 let (dx, dy) = (self.backward_fn)(a, b, g);
692 *gx = dx;
693 *gy = dy;
694 });
695
696 Ok(vec![grad_x, grad_y])
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use scirs2_core::ndarray::ArrayD;
704
705 fn create_tensor(data: Vec<f64>, shape: Vec<usize>) -> Scirs2Tensor {
706 ArrayD::from_shape_vec(shape, data).unwrap()
707 }
708
709 #[test]
710 fn test_op_registry_basic() {
711 let mut registry = OpRegistry::new();
712 assert!(registry.is_empty());
713
714 registry.register(Box::new(SoftplusOp));
715 assert_eq!(registry.len(), 1);
716 assert!(registry.contains("softplus"));
717 assert!(!registry.contains("unknown"));
718 }
719
720 #[test]
721 fn test_op_registry_with_standard_ops() {
722 let registry = OpRegistry::with_standard_ops();
723 assert!(registry.contains("softplus"));
724 assert!(registry.contains("leaky_relu"));
725 assert!(registry.contains("elu"));
726 assert!(registry.contains("swish"));
727 assert!(registry.contains("mish"));
728 assert!(registry.contains("gelu"));
729 assert!(registry.contains("hard_sigmoid"));
730 assert!(registry.contains("hard_swish"));
731 }
732
733 #[test]
734 fn test_softplus_forward() {
735 let registry = OpRegistry::with_standard_ops();
736 let tensor = create_tensor(vec![-1.0, 0.0, 1.0], vec![3]);
737 let mut ctx = CustomOpContext::new();
738
739 let result = registry.execute("softplus", &[&tensor], &mut ctx).unwrap();
740
741 assert!(result[[0]] > 0.3 && result[[0]] < 0.35);
743 assert!((result[[1]] - std::f64::consts::LN_2).abs() < 0.01);
744 assert!(result[[2]] > 1.3 && result[[2]] < 1.35);
745 }
746
747 #[test]
748 fn test_softplus_backward() {
749 let registry = OpRegistry::with_standard_ops();
750 let tensor = create_tensor(vec![0.0], vec![1]);
751 let grad = create_tensor(vec![1.0], vec![1]);
752 let ctx = CustomOpContext::new();
753
754 let grads = registry
755 .backward("softplus", &grad, &[&tensor], &ctx)
756 .unwrap();
757
758 assert!((grads[0][[0]] - 0.5).abs() < 0.001);
760 }
761
762 #[test]
763 fn test_leaky_relu_forward() {
764 let registry = OpRegistry::with_standard_ops();
765 let tensor = create_tensor(vec![-2.0, 0.0, 2.0], vec![3]);
766 let mut ctx = CustomOpContext::new();
767
768 let result = registry
769 .execute("leaky_relu", &[&tensor], &mut ctx)
770 .unwrap();
771
772 assert!((result[[0]] - (-0.02)).abs() < 0.001); assert_eq!(result[[1]], 0.0);
774 assert_eq!(result[[2]], 2.0);
775 }
776
777 #[test]
778 fn test_elu_forward() {
779 let registry = OpRegistry::with_standard_ops();
780 let tensor = create_tensor(vec![-1.0, 0.0, 1.0], vec![3]);
781 let mut ctx = CustomOpContext::with_grad();
782
783 let result = registry.execute("elu", &[&tensor], &mut ctx).unwrap();
784
785 assert!((result[[0]] - (-0.632)).abs() < 0.01);
787 assert_eq!(result[[1]], 0.0);
788 assert_eq!(result[[2]], 1.0);
789 }
790
791 #[test]
792 fn test_swish_forward() {
793 let registry = OpRegistry::with_standard_ops();
794 let tensor = create_tensor(vec![0.0], vec![1]);
795 let mut ctx = CustomOpContext::new();
796
797 let result = registry.execute("swish", &[&tensor], &mut ctx).unwrap();
798
799 assert_eq!(result[[0]], 0.0);
801 }
802
803 #[test]
804 fn test_gelu_forward() {
805 let registry = OpRegistry::with_standard_ops();
806 let tensor = create_tensor(vec![-1.0, 0.0, 1.0], vec![3]);
807 let mut ctx = CustomOpContext::new();
808
809 let result = registry.execute("gelu", &[&tensor], &mut ctx).unwrap();
810
811 assert!((result[[1]]).abs() < 0.01);
813 assert!(result[[0]] < 0.0); assert!(result[[2]] > 0.5); }
817
818 #[test]
819 fn test_hard_sigmoid_forward() {
820 let registry = OpRegistry::with_standard_ops();
821 let tensor = create_tensor(vec![-4.0, 0.0, 4.0], vec![3]);
822 let mut ctx = CustomOpContext::new();
823
824 let result = registry
825 .execute("hard_sigmoid", &[&tensor], &mut ctx)
826 .unwrap();
827
828 assert_eq!(result[[0]], 0.0); assert_eq!(result[[1]], 0.5); assert_eq!(result[[2]], 1.0); }
832
833 #[test]
834 fn test_hard_swish_forward() {
835 let registry = OpRegistry::with_standard_ops();
836 let tensor = create_tensor(vec![-4.0, 0.0, 4.0], vec![3]);
837 let mut ctx = CustomOpContext::new();
838
839 let result = registry
840 .execute("hard_swish", &[&tensor], &mut ctx)
841 .unwrap();
842
843 assert_eq!(result[[0]], 0.0); assert_eq!(result[[1]], 0.0); assert_eq!(result[[2]], 4.0); }
847
848 #[test]
849 fn test_custom_op_context() {
850 let mut ctx = CustomOpContext::with_grad();
851 assert!(ctx.requires_grad);
852
853 let tensor = create_tensor(vec![1.0, 2.0], vec![2]);
854 ctx.save_for_backward("test", tensor.clone());
855
856 let saved = ctx.get_saved("test").unwrap();
857 assert_eq!(saved[[0]], 1.0);
858 assert_eq!(saved[[1]], 2.0);
859
860 ctx.set_metadata("key", "value");
861 assert_eq!(ctx.get_metadata("key"), Some(&"value".to_string()));
862 }
863
864 #[test]
865 fn test_binary_custom_op() {
866 let pow_op = BinaryCustomOp::new(
868 "pow",
869 |a, b| a.powf(b),
870 |a, b, g| {
871 let da = g * b * a.powf(b - 1.0);
872 let db = g * a.powf(b) * a.ln();
873 (da, db)
874 },
875 );
876
877 let mut registry = OpRegistry::new();
878 registry.register(Box::new(pow_op));
879
880 let x = create_tensor(vec![2.0, 3.0], vec![2]);
881 let y = create_tensor(vec![3.0, 2.0], vec![2]);
882 let mut ctx = CustomOpContext::new();
883
884 let result = registry.execute("pow", &[&x, &y], &mut ctx).unwrap();
885
886 assert_eq!(result[[0]], 8.0); assert_eq!(result[[1]], 9.0); }
889
890 #[test]
891 fn test_validate_inputs() {
892 let registry = OpRegistry::with_standard_ops();
893 let tensor = create_tensor(vec![1.0], vec![1]);
894 let mut ctx = CustomOpContext::new();
895
896 let result = registry.execute("softplus", &[&tensor], &mut ctx);
898 assert!(result.is_ok());
899
900 let result = registry.execute("softplus", &[&tensor, &tensor], &mut ctx);
902 assert!(result.is_err());
903 }
904
905 #[test]
906 fn test_list_ops() {
907 let registry = OpRegistry::with_standard_ops();
908 let ops = registry.list_ops();
909
910 assert!(ops.contains(&"softplus"));
911 assert!(ops.contains(&"gelu"));
912 }
913
914 #[test]
915 fn test_unknown_operation() {
916 let registry = OpRegistry::new();
917 let tensor = create_tensor(vec![1.0], vec![1]);
918 let mut ctx = CustomOpContext::new();
919
920 let result = registry.execute("unknown", &[&tensor], &mut ctx);
921 assert!(result.is_err());
922 }
923
924 #[test]
925 fn test_mish_forward() {
926 let registry = OpRegistry::with_standard_ops();
927 let tensor = create_tensor(vec![0.0], vec![1]);
928 let mut ctx = CustomOpContext::new();
929
930 let result = registry.execute("mish", &[&tensor], &mut ctx).unwrap();
931
932 assert!(result[[0]].abs() < 0.01);
934 }
935}