1use crate::{core_ops::Tensor, TensorElement};
8use scirs2_core::numeric::FromPrimitive;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use torsh_core::error::{Result, TorshError};
13
14pub trait CustomOperation<T: TensorElement>: Send + Sync {
19 fn name(&self) -> &str;
21
22 fn description(&self) -> &str;
24
25 fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>>;
34
35 fn backward(
46 &self,
47 grad_outputs: &[Tensor<T>],
48 inputs: &[Tensor<T>],
49 _outputs: &[Tensor<T>],
50 _params: &OperationParams,
51 ) -> Result<Vec<Option<Tensor<T>>>> {
52 let _ = grad_outputs.is_empty(); Ok(vec![None; inputs.len()])
58 }
59
60 fn validate_inputs(&self, inputs: &[Tensor<T>], _params: &OperationParams) -> Result<()> {
69 if inputs.is_empty() {
71 return Err(torsh_core::error::TorshError::InvalidShape(
72 "Operation requires at least one input tensor".to_string(),
73 ));
74 }
75
76 for (idx, input) in inputs.iter().enumerate() {
78 let _ = (idx, input.shape.is_empty()); }
80
81 Ok(())
82 }
83
84 fn output_shapes(
93 &self,
94 input_shapes: &[Vec<usize>],
95 params: &OperationParams,
96 ) -> Result<Vec<Vec<usize>>>;
97
98 fn supports_autograd(&self) -> bool {
100 true }
102
103 fn num_inputs(&self) -> usize;
105
106 fn num_outputs(&self) -> usize;
108}
109
110#[derive(Debug, Clone)]
112pub struct OperationParams {
113 pub strings: HashMap<String, String>,
115 pub integers: HashMap<String, i64>,
117 pub floats: HashMap<String, f64>,
119 pub booleans: HashMap<String, bool>,
121 pub vectors: HashMap<String, Vec<f64>>,
123 pub shapes: HashMap<String, Vec<usize>>,
125}
126
127impl OperationParams {
128 pub fn new() -> Self {
130 Self {
131 strings: HashMap::new(),
132 integers: HashMap::new(),
133 floats: HashMap::new(),
134 booleans: HashMap::new(),
135 vectors: HashMap::new(),
136 shapes: HashMap::new(),
137 }
138 }
139
140 pub fn with_string(mut self, key: &str, value: &str) -> Self {
142 self.strings.insert(key.to_string(), value.to_string());
143 self
144 }
145
146 pub fn with_int(mut self, key: &str, value: i64) -> Self {
148 self.integers.insert(key.to_string(), value);
149 self
150 }
151
152 pub fn with_float(mut self, key: &str, value: f64) -> Self {
154 self.floats.insert(key.to_string(), value);
155 self
156 }
157
158 pub fn with_bool(mut self, key: &str, value: bool) -> Self {
160 self.booleans.insert(key.to_string(), value);
161 self
162 }
163
164 pub fn with_vector(mut self, key: &str, value: Vec<f64>) -> Self {
166 self.vectors.insert(key.to_string(), value);
167 self
168 }
169
170 pub fn with_shape(mut self, key: &str, value: Vec<usize>) -> Self {
172 self.shapes.insert(key.to_string(), value);
173 self
174 }
175
176 pub fn get_string(&self, key: &str) -> Option<&String> {
178 self.strings.get(key)
179 }
180
181 pub fn get_int(&self, key: &str) -> Option<i64> {
183 self.integers.get(key).copied()
184 }
185
186 pub fn get_float(&self, key: &str) -> Option<f64> {
188 self.floats.get(key).copied()
189 }
190
191 pub fn get_bool(&self, key: &str) -> Option<bool> {
193 self.booleans.get(key).copied()
194 }
195
196 pub fn get_vector(&self, key: &str) -> Option<&Vec<f64>> {
198 self.vectors.get(key)
199 }
200
201 pub fn get_shape(&self, key: &str) -> Option<&Vec<usize>> {
203 self.shapes.get(key)
204 }
205}
206
207impl Default for OperationParams {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct OperationMetadata {
216 pub name: String,
218 pub description: String,
220 pub num_inputs: usize,
222 pub num_outputs: usize,
224 pub supports_autograd: bool,
226 pub data_type: TypeId,
228 pub version: String,
230 pub author: Option<String>,
232 pub tags: Vec<String>,
234}
235
236pub struct CustomOperationRegistry {
241 operations: RwLock<HashMap<(TypeId, String), Arc<dyn Any + Send + Sync>>>,
243 metadata: RwLock<HashMap<(TypeId, String), OperationMetadata>>,
245}
246
247impl CustomOperationRegistry {
248 pub fn new() -> Self {
250 Self {
251 operations: RwLock::new(HashMap::new()),
252 metadata: RwLock::new(HashMap::new()),
253 }
254 }
255
256 pub fn register<T: TensorElement + 'static>(
267 &self,
268 operation: Box<dyn CustomOperation<T>>,
269 version: &str,
270 author: Option<String>,
271 tags: Vec<String>,
272 ) -> Result<()> {
273 let type_id = TypeId::of::<T>();
274 let name = operation.name().to_string();
275 let key = (type_id, name.clone());
276
277 let metadata = OperationMetadata {
279 name: name.clone(),
280 description: operation.description().to_string(),
281 num_inputs: operation.num_inputs(),
282 num_outputs: operation.num_outputs(),
283 supports_autograd: operation.supports_autograd(),
284 data_type: type_id,
285 version: version.to_string(),
286 author,
287 tags,
288 };
289
290 {
292 let mut ops = self
293 .operations
294 .write()
295 .expect("lock should not be poisoned");
296 let mut meta = self.metadata.write().expect("lock should not be poisoned");
297
298 if ops.contains_key(&key) {
299 return Err(TorshError::InvalidArgument(format!(
300 "Operation '{}' for type {:?} is already registered",
301 name, type_id
302 )));
303 }
304
305 let arc_op: Arc<dyn CustomOperation<T>> = Arc::from(operation);
307 let boxed_any: Arc<dyn Any + Send + Sync> = Arc::new(arc_op);
308 ops.insert(key.clone(), boxed_any);
309 meta.insert(key, metadata);
310 }
311
312 Ok(())
313 }
314
315 pub fn get<T: TensorElement + 'static>(
323 &self,
324 name: &str,
325 ) -> Option<Arc<dyn CustomOperation<T>>> {
326 let type_id = TypeId::of::<T>();
327 let key = (type_id, name.to_string());
328
329 let ops = self.operations.read().expect("lock should not be poisoned");
330 ops.get(&key).and_then(|arc_any| {
331 arc_any
333 .downcast_ref::<Arc<dyn CustomOperation<T>>>()
334 .map(|arc_op| Arc::clone(arc_op))
335 })
336 }
337
338 pub fn get_metadata<T: TensorElement + 'static>(
340 &self,
341 name: &str,
342 ) -> Option<OperationMetadata> {
343 let type_id = TypeId::of::<T>();
344 let key = (type_id, name.to_string());
345
346 let meta = self.metadata.read().expect("lock should not be poisoned");
347 meta.get(&key).cloned()
348 }
349
350 pub fn list_operations<T: TensorElement + 'static>(&self) -> Vec<String> {
352 let type_id = TypeId::of::<T>();
353 let meta = self.metadata.read().expect("lock should not be poisoned");
354
355 meta.keys()
356 .filter(|(tid, _)| *tid == type_id)
357 .map(|(_, name)| name.clone())
358 .collect()
359 }
360
361 pub fn unregister<T: TensorElement + 'static>(&self, name: &str) -> Result<()> {
363 let type_id = TypeId::of::<T>();
364 let key = (type_id, name.to_string());
365
366 let mut ops = self
367 .operations
368 .write()
369 .expect("lock should not be poisoned");
370 let mut meta = self.metadata.write().expect("lock should not be poisoned");
371
372 if ops.remove(&key).is_none() {
373 return Err(TorshError::InvalidArgument(format!(
374 "Operation '{}' for type {:?} is not registered",
375 name, type_id
376 )));
377 }
378
379 meta.remove(&key);
380 Ok(())
381 }
382
383 pub fn is_registered<T: TensorElement + 'static>(&self, name: &str) -> bool {
385 let type_id = TypeId::of::<T>();
386 let key = (type_id, name.to_string());
387
388 let ops = self.operations.read().expect("lock should not be poisoned");
389 ops.contains_key(&key)
390 }
391
392 pub fn count(&self) -> usize {
394 let ops = self.operations.read().expect("lock should not be poisoned");
395 ops.len()
396 }
397
398 pub fn clear(&self) {
400 let mut ops = self
401 .operations
402 .write()
403 .expect("lock should not be poisoned");
404 let mut meta = self.metadata.write().expect("lock should not be poisoned");
405 ops.clear();
406 meta.clear();
407 }
408}
409
410impl Default for CustomOperationRegistry {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416static GLOBAL_REGISTRY: std::sync::LazyLock<CustomOperationRegistry> =
418 std::sync::LazyLock::new(CustomOperationRegistry::new);
419
420pub fn global_registry() -> &'static CustomOperationRegistry {
422 &GLOBAL_REGISTRY
423}
424
425pub trait TensorCustomOps<T: TensorElement> {
427 fn apply_custom_op(
437 &self,
438 op_name: &str,
439 other_inputs: &[&Tensor<T>],
440 params: &OperationParams,
441 ) -> Result<Vec<Tensor<T>>>;
442
443 fn apply_custom_op_with_registry(
445 &self,
446 registry: &CustomOperationRegistry,
447 op_name: &str,
448 other_inputs: &[&Tensor<T>],
449 params: &OperationParams,
450 ) -> Result<Vec<Tensor<T>>>;
451}
452
453impl<T: TensorElement + 'static> TensorCustomOps<T> for Tensor<T> {
454 fn apply_custom_op(
455 &self,
456 op_name: &str,
457 other_inputs: &[&Tensor<T>],
458 params: &OperationParams,
459 ) -> Result<Vec<Tensor<T>>> {
460 self.apply_custom_op_with_registry(global_registry(), op_name, other_inputs, params)
461 }
462
463 fn apply_custom_op_with_registry(
464 &self,
465 registry: &CustomOperationRegistry,
466 op_name: &str,
467 other_inputs: &[&Tensor<T>],
468 params: &OperationParams,
469 ) -> Result<Vec<Tensor<T>>> {
470 let operation = registry.get::<T>(op_name).ok_or_else(|| {
472 TorshError::InvalidArgument(format!(
473 "Custom operation '{}' not found for type",
474 op_name
475 ))
476 })?;
477
478 let mut inputs = vec![self.clone()];
480 inputs.extend(other_inputs.iter().map(|&t| t.clone()));
481
482 operation.validate_inputs(&inputs, params)?;
484
485 if inputs.len() != operation.num_inputs() {
487 return Err(TorshError::InvalidArgument(format!(
488 "Operation '{}' expects {} inputs, got {}",
489 op_name,
490 operation.num_inputs(),
491 inputs.len()
492 )));
493 }
494
495 let outputs = operation.forward(&inputs, params)?;
497
498 if outputs.len() != operation.num_outputs() {
500 return Err(TorshError::InvalidArgument(format!(
501 "Operation '{}' produced {} outputs, expected {}",
502 op_name,
503 outputs.len(),
504 operation.num_outputs()
505 )));
506 }
507
508 Ok(outputs)
509 }
510}
511
512pub struct ScaleOperation;
516
517impl<T: TensorElement + Copy + std::ops::Mul<Output = T> + num_traits::FromPrimitive>
518 CustomOperation<T> for ScaleOperation
519{
520 fn name(&self) -> &str {
521 "scale"
522 }
523
524 fn description(&self) -> &str {
525 "Scales tensor elements by a constant factor"
526 }
527
528 fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
529 if inputs.len() != 1 {
530 return Err(TorshError::InvalidArgument(
531 "Scale operation requires exactly 1 input".to_string(),
532 ));
533 }
534
535 let scale = params.get_float("scale").unwrap_or(1.0);
536 let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
537 TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
538 })?;
539
540 let result = inputs[0].mul_scalar(scale_val)?;
541 Ok(vec![result])
542 }
543
544 fn backward(
545 &self,
546 grad_outputs: &[Tensor<T>],
547 _inputs: &[Tensor<T>],
548 _outputs: &[Tensor<T>],
549 params: &OperationParams,
550 ) -> Result<Vec<Option<Tensor<T>>>> {
551 let scale = params.get_float("scale").unwrap_or(1.0);
552 let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
553 TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
554 })?;
555
556 let grad_input = grad_outputs[0].mul_scalar(scale_val)?;
557 Ok(vec![Some(grad_input)])
558 }
559
560 fn output_shapes(
561 &self,
562 input_shapes: &[Vec<usize>],
563 _params: &OperationParams,
564 ) -> Result<Vec<Vec<usize>>> {
565 if input_shapes.len() != 1 {
566 return Err(TorshError::InvalidArgument(
567 "Scale operation requires exactly 1 input".to_string(),
568 ));
569 }
570 Ok(vec![input_shapes[0].clone()])
571 }
572
573 fn num_inputs(&self) -> usize {
574 1
575 }
576
577 fn num_outputs(&self) -> usize {
578 1
579 }
580}
581
582pub struct ConcatOperation;
584
585impl<T: TensorElement + Copy> CustomOperation<T> for ConcatOperation {
586 fn name(&self) -> &str {
587 "concat"
588 }
589
590 fn description(&self) -> &str {
591 "Concatenates tensors along a specified axis"
592 }
593
594 fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
595 if inputs.len() < 2 {
596 return Err(TorshError::InvalidArgument(
597 "Concat operation requires at least 2 inputs".to_string(),
598 ));
599 }
600
601 let axis = params.get_int("axis").unwrap_or(0) as usize;
602
603 let input_refs: Vec<&Tensor<T>> = inputs.iter().collect();
605 let result = Tensor::cat(&input_refs, axis as i32)?;
606 Ok(vec![result])
607 }
608
609 fn backward(
610 &self,
611 grad_outputs: &[Tensor<T>],
612 inputs: &[Tensor<T>],
613 _outputs: &[Tensor<T>],
614 params: &OperationParams,
615 ) -> Result<Vec<Option<Tensor<T>>>> {
616 let axis = params.get_int("axis").unwrap_or(0) as usize;
617 let grad_output = &grad_outputs[0];
618
619 let mut split_sizes = Vec::new();
621 for input in inputs {
622 split_sizes.push(input.shape().dims()[axis]);
623 }
624
625 let mut grad_inputs = Vec::new();
627 let mut start = 0;
628 for &size in &split_sizes {
629 let end = start + size;
630 let slice = grad_output.slice_tensor(axis, start, end)?;
631 grad_inputs.push(Some(slice));
632 start = end;
633 }
634 Ok(grad_inputs)
635 }
636
637 fn output_shapes(
638 &self,
639 input_shapes: &[Vec<usize>],
640 params: &OperationParams,
641 ) -> Result<Vec<Vec<usize>>> {
642 if input_shapes.len() < 2 {
643 return Err(TorshError::InvalidArgument(
644 "Concat operation requires at least 2 inputs".to_string(),
645 ));
646 }
647
648 let axis = params.get_int("axis").unwrap_or(0) as usize;
649 let mut output_shape = input_shapes[0].clone();
650
651 if axis >= output_shape.len() {
652 return Err(TorshError::InvalidArgument(format!(
653 "Concat axis {} out of bounds for {} dimensions",
654 axis,
655 output_shape.len()
656 )));
657 }
658
659 let mut total_size = output_shape[axis];
661 for shape in &input_shapes[1..] {
662 if shape.len() != output_shape.len() {
663 return Err(TorshError::InvalidArgument(
664 "All tensors must have the same number of dimensions".to_string(),
665 ));
666 }
667
668 for (i, (&dim1, &dim2)) in output_shape.iter().zip(shape.iter()).enumerate() {
670 if i != axis && dim1 != dim2 {
671 return Err(TorshError::InvalidArgument(format!(
672 "Dimension {} mismatch: {} vs {}",
673 i, dim1, dim2
674 )));
675 }
676 }
677
678 total_size += shape[axis];
679 }
680
681 output_shape[axis] = total_size;
682 Ok(vec![output_shape])
683 }
684
685 fn num_inputs(&self) -> usize {
686 2 }
689
690 fn num_outputs(&self) -> usize {
691 1
692 }
693
694 fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
695 if inputs.len() < 2 {
696 return Err(TorshError::InvalidArgument(
697 "Concat operation requires at least 2 inputs".to_string(),
698 ));
699 }
700
701 let axis = params.get_int("axis").unwrap_or(0) as usize;
702 let first_tensor_shape = inputs[0].shape();
703 let first_shape = first_tensor_shape.dims();
704
705 if axis >= first_shape.len() {
706 return Err(TorshError::InvalidArgument(format!(
707 "Concat axis {} out of bounds for {} dimensions",
708 axis,
709 first_shape.len()
710 )));
711 }
712
713 for (i, tensor) in inputs.iter().enumerate().skip(1) {
715 let tensor_shape = tensor.shape();
716 let shape = tensor_shape.dims();
717 if shape.len() != first_shape.len() {
718 return Err(TorshError::InvalidArgument(format!(
719 "Tensor {} has {} dimensions, expected {}",
720 i,
721 shape.len(),
722 first_shape.len()
723 )));
724 }
725
726 for (dim_idx, (&dim1, &dim2)) in first_shape.iter().zip(shape.iter()).enumerate() {
727 if dim_idx != axis && dim1 != dim2 {
728 return Err(TorshError::InvalidArgument(format!(
729 "Tensor {} dimension {} mismatch: {} vs {}",
730 i, dim_idx, dim1, dim2
731 )));
732 }
733 }
734 }
735
736 Ok(())
737 }
738}
739
740#[cfg(test)]
741mod tests {
742 use super::*;
743 use torsh_core::device::DeviceType;
744
745 #[test]
746 fn test_operation_params() {
747 let params = OperationParams::new()
748 .with_string("mode", "linear")
749 .with_int("axis", 1)
750 .with_float("scale", 2.5)
751 .with_bool("inplace", false)
752 .with_vector("weights", vec![1.0, 2.0, 3.0])
753 .with_shape("target_shape", vec![10, 20]);
754
755 assert_eq!(params.get_string("mode"), Some(&"linear".to_string()));
756 assert_eq!(params.get_int("axis"), Some(1));
757 assert_eq!(params.get_float("scale"), Some(2.5));
758 assert_eq!(params.get_bool("inplace"), Some(false));
759 assert_eq!(params.get_vector("weights"), Some(&vec![1.0, 2.0, 3.0]));
760 assert_eq!(params.get_shape("target_shape"), Some(&vec![10, 20]));
761
762 assert_eq!(params.get_string("nonexistent"), None);
763 }
764
765 #[test]
766 fn test_registry_operations() {
767 let registry = CustomOperationRegistry::new();
768
769 let scale_op = Box::new(ScaleOperation);
771 registry
772 .register::<f32>(
773 scale_op,
774 "1.0.0",
775 Some("Test".to_string()),
776 vec!["math".to_string()],
777 )
778 .expect("registration should succeed");
779
780 assert!(registry.is_registered::<f32>("scale"));
782 assert!(!registry.is_registered::<f32>("nonexistent"));
783
784 let metadata = registry
786 .get_metadata::<f32>("scale")
787 .expect("metadata retrieval should succeed");
788 assert_eq!(metadata.name, "scale");
789 assert_eq!(
790 metadata.description,
791 "Scales tensor elements by a constant factor"
792 );
793 assert_eq!(metadata.num_inputs, 1);
794 assert_eq!(metadata.num_outputs, 1);
795 assert_eq!(metadata.version, "1.0.0");
796 assert_eq!(metadata.author, Some("Test".to_string()));
797 assert_eq!(metadata.tags, vec!["math".to_string()]);
798
799 let ops = registry.list_operations::<f32>();
801 assert_eq!(ops, vec!["scale".to_string()]);
802
803 registry
805 .unregister::<f32>("scale")
806 .expect("unregister should succeed");
807 assert!(!registry.is_registered::<f32>("scale"));
808 }
809
810 #[test]
811 fn test_scale_operation() {
812 let registry = CustomOperationRegistry::new();
813 let scale_op = Box::new(ScaleOperation);
814 registry
815 .register::<f32>(scale_op, "1.0.0", None, vec![])
816 .expect("unregister should succeed");
817
818 let data = vec![1.0f32, 2.0, 3.0, 4.0];
820 let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
821 .expect("tensor creation should succeed");
822
823 let params = OperationParams::new().with_float("scale", 2.0);
825 let results = tensor
826 .apply_custom_op_with_registry(®istry, "scale", &[], ¶ms)
827 .expect("tensor creation should succeed");
828
829 assert_eq!(results.len(), 1);
830 let result = &results[0];
831 let expected_data = vec![2.0f32, 4.0, 6.0, 8.0];
832 assert_eq!(
833 result.data().expect("data retrieval should succeed"),
834 expected_data
835 );
836 }
837
838 #[test]
839 fn test_concat_operation() {
840 let registry = CustomOperationRegistry::new();
841 let concat_op = Box::new(ConcatOperation);
842 registry
843 .register::<f32>(concat_op, "1.0.0", None, vec![])
844 .expect("registration should succeed");
845
846 let data1 = vec![1.0f32, 2.0];
848 let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu)
849 .expect("tensor creation should succeed");
850
851 let data2 = vec![3.0f32, 4.0];
852 let tensor2 = Tensor::from_data(data2, vec![2], DeviceType::Cpu)
853 .expect("tensor creation should succeed");
854
855 let params = OperationParams::new().with_int("axis", 0);
857 let results = tensor1
858 .apply_custom_op_with_registry(®istry, "concat", &[&tensor2], ¶ms)
859 .expect("tensor creation should succeed");
860
861 assert_eq!(results.len(), 1);
862 let result = &results[0];
863 assert_eq!(result.shape().dims(), &[4]); let expected_data = vec![1.0f32, 2.0, 3.0, 4.0];
865 assert_eq!(
866 result.data().expect("data retrieval should succeed"),
867 expected_data
868 );
869 }
870
871 #[test]
872 fn test_operation_validation() {
873 let registry = CustomOperationRegistry::new();
874 let concat_op = Box::new(ConcatOperation);
875 registry
876 .register::<f32>(concat_op, "1.0.0", None, vec![])
877 .expect("registration should succeed");
878
879 let data1 = vec![1.0f32, 2.0];
881 let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu)
882 .expect("tensor creation should succeed"); let data2 = vec![3.0f32, 4.0, 5.0, 6.0];
885 let tensor2 = Tensor::from_data(data2, vec![2, 2], DeviceType::Cpu)
886 .expect("tensor creation should succeed"); let params = OperationParams::new().with_int("axis", 0);
890 let result =
891 tensor1.apply_custom_op_with_registry(®istry, "concat", &[&tensor2], ¶ms);
892 assert!(result.is_err());
893 }
894
895 #[test]
896 fn test_output_shape_inference() {
897 let concat_op = ConcatOperation;
898
899 let input_shapes = vec![vec![3], vec![4]];
901 let params = OperationParams::new().with_int("axis", 0);
902
903 let output_shapes = <ConcatOperation as CustomOperation<f32>>::output_shapes(
904 &concat_op,
905 &input_shapes,
906 ¶ms,
907 )
908 .expect("custom dtype operation should succeed");
909 assert_eq!(output_shapes, vec![vec![7]]); }
911
912 #[test]
913 fn test_error_cases() {
914 let registry = CustomOperationRegistry::new();
915
916 let scale_op1 = Box::new(ScaleOperation);
918 let scale_op2 = Box::new(ScaleOperation);
919
920 registry
921 .register::<f32>(scale_op1, "1.0.0", None, vec![])
922 .expect("registration should succeed");
923 let result = registry.register::<f32>(scale_op2, "1.0.0", None, vec![]);
924 assert!(result.is_err());
925
926 let result = registry.unregister::<f32>("nonexistent");
928 assert!(result.is_err());
929
930 let data = vec![1.0f32, 2.0];
932 let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu)
933 .expect("tensor creation should succeed");
934 let params = OperationParams::new();
935 let result = tensor.apply_custom_op_with_registry(®istry, "nonexistent", &[], ¶ms);
936 assert!(result.is_err());
937 }
938
939 #[test]
940 fn test_global_registry() {
941 let registry = global_registry();
942
943 let scale_op = Box::new(ScaleOperation);
945 registry
946 .register::<f32>(scale_op, "1.0.0", None, vec![])
947 .expect("registration should succeed");
948
949 let data = vec![1.0f32, 2.0, 3.0];
951 let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu)
952 .expect("tensor creation should succeed");
953 let params = OperationParams::new().with_float("scale", 3.0);
954
955 let results = tensor
956 .apply_custom_op("scale", &[], ¶ms)
957 .expect("custom_op should succeed");
958 assert_eq!(results.len(), 1);
959 let expected_data = vec![3.0f32, 6.0, 9.0];
960 assert_eq!(
961 results[0].data().expect("data retrieval should succeed"),
962 expected_data
963 );
964
965 registry
967 .unregister::<f32>("scale")
968 .expect("unregister should succeed");
969 }
970}