1use crate::{core_ops::Tensor, TensorElement};
8use num_traits::FromPrimitive;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, RwLock};
12use torsh_core::{
13 device::DeviceType,
14 error::{Result, TorshError},
15};
16
17pub trait CustomOperation<T: TensorElement>: Send + Sync {
22 fn name(&self) -> &str;
24
25 fn description(&self) -> &str;
27
28 fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>>;
37
38 fn backward(
49 &self,
50 grad_outputs: &[Tensor<T>],
51 inputs: &[Tensor<T>],
52 outputs: &[Tensor<T>],
53 params: &OperationParams,
54 ) -> Result<Vec<Option<Tensor<T>>>> {
55 Ok(vec![None; inputs.len()])
57 }
58
59 fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
68 Ok(())
70 }
71
72 fn output_shapes(
81 &self,
82 input_shapes: &[Vec<usize>],
83 params: &OperationParams,
84 ) -> Result<Vec<Vec<usize>>>;
85
86 fn supports_autograd(&self) -> bool {
88 true }
90
91 fn num_inputs(&self) -> usize;
93
94 fn num_outputs(&self) -> usize;
96}
97
98#[derive(Debug, Clone)]
100pub struct OperationParams {
101 pub strings: HashMap<String, String>,
103 pub integers: HashMap<String, i64>,
105 pub floats: HashMap<String, f64>,
107 pub booleans: HashMap<String, bool>,
109 pub vectors: HashMap<String, Vec<f64>>,
111 pub shapes: HashMap<String, Vec<usize>>,
113}
114
115impl OperationParams {
116 pub fn new() -> Self {
118 Self {
119 strings: HashMap::new(),
120 integers: HashMap::new(),
121 floats: HashMap::new(),
122 booleans: HashMap::new(),
123 vectors: HashMap::new(),
124 shapes: HashMap::new(),
125 }
126 }
127
128 pub fn with_string(mut self, key: &str, value: &str) -> Self {
130 self.strings.insert(key.to_string(), value.to_string());
131 self
132 }
133
134 pub fn with_int(mut self, key: &str, value: i64) -> Self {
136 self.integers.insert(key.to_string(), value);
137 self
138 }
139
140 pub fn with_float(mut self, key: &str, value: f64) -> Self {
142 self.floats.insert(key.to_string(), value);
143 self
144 }
145
146 pub fn with_bool(mut self, key: &str, value: bool) -> Self {
148 self.booleans.insert(key.to_string(), value);
149 self
150 }
151
152 pub fn with_vector(mut self, key: &str, value: Vec<f64>) -> Self {
154 self.vectors.insert(key.to_string(), value);
155 self
156 }
157
158 pub fn with_shape(mut self, key: &str, value: Vec<usize>) -> Self {
160 self.shapes.insert(key.to_string(), value);
161 self
162 }
163
164 pub fn get_string(&self, key: &str) -> Option<&String> {
166 self.strings.get(key)
167 }
168
169 pub fn get_int(&self, key: &str) -> Option<i64> {
171 self.integers.get(key).copied()
172 }
173
174 pub fn get_float(&self, key: &str) -> Option<f64> {
176 self.floats.get(key).copied()
177 }
178
179 pub fn get_bool(&self, key: &str) -> Option<bool> {
181 self.booleans.get(key).copied()
182 }
183
184 pub fn get_vector(&self, key: &str) -> Option<&Vec<f64>> {
186 self.vectors.get(key)
187 }
188
189 pub fn get_shape(&self, key: &str) -> Option<&Vec<usize>> {
191 self.shapes.get(key)
192 }
193}
194
195impl Default for OperationParams {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct OperationMetadata {
204 pub name: String,
206 pub description: String,
208 pub num_inputs: usize,
210 pub num_outputs: usize,
212 pub supports_autograd: bool,
214 pub data_type: TypeId,
216 pub version: String,
218 pub author: Option<String>,
220 pub tags: Vec<String>,
222}
223
224pub struct CustomOperationRegistry {
229 operations: RwLock<HashMap<(TypeId, String), Arc<dyn Any + Send + Sync>>>,
231 metadata: RwLock<HashMap<(TypeId, String), OperationMetadata>>,
233}
234
235impl CustomOperationRegistry {
236 pub fn new() -> Self {
238 Self {
239 operations: RwLock::new(HashMap::new()),
240 metadata: RwLock::new(HashMap::new()),
241 }
242 }
243
244 pub fn register<T: TensorElement + 'static>(
255 &self,
256 operation: Box<dyn CustomOperation<T>>,
257 version: &str,
258 author: Option<String>,
259 tags: Vec<String>,
260 ) -> Result<()> {
261 let type_id = TypeId::of::<T>();
262 let name = operation.name().to_string();
263 let key = (type_id, name.clone());
264
265 let metadata = OperationMetadata {
267 name: name.clone(),
268 description: operation.description().to_string(),
269 num_inputs: operation.num_inputs(),
270 num_outputs: operation.num_outputs(),
271 supports_autograd: operation.supports_autograd(),
272 data_type: type_id,
273 version: version.to_string(),
274 author,
275 tags,
276 };
277
278 {
280 let mut ops = self.operations.write().unwrap();
281 let mut meta = self.metadata.write().unwrap();
282
283 if ops.contains_key(&key) {
284 return Err(TorshError::InvalidArgument(format!(
285 "Operation '{}' for type {:?} is already registered",
286 name, type_id
287 )));
288 }
289
290 let arc_op: Arc<dyn CustomOperation<T>> = Arc::from(operation);
292 let boxed_any: Arc<dyn Any + Send + Sync> = Arc::new(arc_op);
293 ops.insert(key.clone(), boxed_any);
294 meta.insert(key, metadata);
295 }
296
297 Ok(())
298 }
299
300 pub fn get<T: TensorElement + 'static>(
308 &self,
309 name: &str,
310 ) -> Option<Arc<dyn CustomOperation<T>>> {
311 let type_id = TypeId::of::<T>();
312 let key = (type_id, name.to_string());
313
314 let ops = self.operations.read().unwrap();
315 ops.get(&key).and_then(|arc_any| {
316 arc_any
318 .downcast_ref::<Arc<dyn CustomOperation<T>>>()
319 .map(|arc_op| Arc::clone(arc_op))
320 })
321 }
322
323 pub fn get_metadata<T: TensorElement + 'static>(
325 &self,
326 name: &str,
327 ) -> Option<OperationMetadata> {
328 let type_id = TypeId::of::<T>();
329 let key = (type_id, name.to_string());
330
331 let meta = self.metadata.read().unwrap();
332 meta.get(&key).cloned()
333 }
334
335 pub fn list_operations<T: TensorElement + 'static>(&self) -> Vec<String> {
337 let type_id = TypeId::of::<T>();
338 let meta = self.metadata.read().unwrap();
339
340 meta.keys()
341 .filter(|(tid, _)| *tid == type_id)
342 .map(|(_, name)| name.clone())
343 .collect()
344 }
345
346 pub fn unregister<T: TensorElement + 'static>(&self, name: &str) -> Result<()> {
348 let type_id = TypeId::of::<T>();
349 let key = (type_id, name.to_string());
350
351 let mut ops = self.operations.write().unwrap();
352 let mut meta = self.metadata.write().unwrap();
353
354 if ops.remove(&key).is_none() {
355 return Err(TorshError::InvalidArgument(format!(
356 "Operation '{}' for type {:?} is not registered",
357 name, type_id
358 )));
359 }
360
361 meta.remove(&key);
362 Ok(())
363 }
364
365 pub fn is_registered<T: TensorElement + 'static>(&self, name: &str) -> bool {
367 let type_id = TypeId::of::<T>();
368 let key = (type_id, name.to_string());
369
370 let ops = self.operations.read().unwrap();
371 ops.contains_key(&key)
372 }
373
374 pub fn count(&self) -> usize {
376 let ops = self.operations.read().unwrap();
377 ops.len()
378 }
379
380 pub fn clear(&self) {
382 let mut ops = self.operations.write().unwrap();
383 let mut meta = self.metadata.write().unwrap();
384 ops.clear();
385 meta.clear();
386 }
387}
388
389impl Default for CustomOperationRegistry {
390 fn default() -> Self {
391 Self::new()
392 }
393}
394
395static GLOBAL_REGISTRY: std::sync::LazyLock<CustomOperationRegistry> =
397 std::sync::LazyLock::new(CustomOperationRegistry::new);
398
399pub fn global_registry() -> &'static CustomOperationRegistry {
401 &GLOBAL_REGISTRY
402}
403
404pub trait TensorCustomOps<T: TensorElement> {
406 fn apply_custom_op(
416 &self,
417 op_name: &str,
418 other_inputs: &[&Tensor<T>],
419 params: &OperationParams,
420 ) -> Result<Vec<Tensor<T>>>;
421
422 fn apply_custom_op_with_registry(
424 &self,
425 registry: &CustomOperationRegistry,
426 op_name: &str,
427 other_inputs: &[&Tensor<T>],
428 params: &OperationParams,
429 ) -> Result<Vec<Tensor<T>>>;
430}
431
432impl<T: TensorElement + 'static> TensorCustomOps<T> for Tensor<T> {
433 fn apply_custom_op(
434 &self,
435 op_name: &str,
436 other_inputs: &[&Tensor<T>],
437 params: &OperationParams,
438 ) -> Result<Vec<Tensor<T>>> {
439 self.apply_custom_op_with_registry(global_registry(), op_name, other_inputs, params)
440 }
441
442 fn apply_custom_op_with_registry(
443 &self,
444 registry: &CustomOperationRegistry,
445 op_name: &str,
446 other_inputs: &[&Tensor<T>],
447 params: &OperationParams,
448 ) -> Result<Vec<Tensor<T>>> {
449 let operation = registry.get::<T>(op_name).ok_or_else(|| {
451 TorshError::InvalidArgument(format!(
452 "Custom operation '{}' not found for type",
453 op_name
454 ))
455 })?;
456
457 let mut inputs = vec![self.clone()];
459 inputs.extend(other_inputs.iter().map(|&t| t.clone()));
460
461 operation.validate_inputs(&inputs, params)?;
463
464 if inputs.len() != operation.num_inputs() {
466 return Err(TorshError::InvalidArgument(format!(
467 "Operation '{}' expects {} inputs, got {}",
468 op_name,
469 operation.num_inputs(),
470 inputs.len()
471 )));
472 }
473
474 let outputs = operation.forward(&inputs, params)?;
476
477 if outputs.len() != operation.num_outputs() {
479 return Err(TorshError::InvalidArgument(format!(
480 "Operation '{}' produced {} outputs, expected {}",
481 op_name,
482 outputs.len(),
483 operation.num_outputs()
484 )));
485 }
486
487 Ok(outputs)
488 }
489}
490
491pub struct ScaleOperation;
495
496impl<T: TensorElement + Copy + std::ops::Mul<Output = T> + num_traits::FromPrimitive>
497 CustomOperation<T> for ScaleOperation
498{
499 fn name(&self) -> &str {
500 "scale"
501 }
502
503 fn description(&self) -> &str {
504 "Scales tensor elements by a constant factor"
505 }
506
507 fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
508 if inputs.len() != 1 {
509 return Err(TorshError::InvalidArgument(
510 "Scale operation requires exactly 1 input".to_string(),
511 ));
512 }
513
514 let scale = params.get_float("scale").unwrap_or(1.0);
515 let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
516 TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
517 })?;
518
519 let result = inputs[0].mul_scalar(scale_val)?;
520 Ok(vec![result])
521 }
522
523 fn backward(
524 &self,
525 grad_outputs: &[Tensor<T>],
526 _inputs: &[Tensor<T>],
527 _outputs: &[Tensor<T>],
528 params: &OperationParams,
529 ) -> Result<Vec<Option<Tensor<T>>>> {
530 let scale = params.get_float("scale").unwrap_or(1.0);
531 let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
532 TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
533 })?;
534
535 let grad_input = grad_outputs[0].mul_scalar(scale_val)?;
536 Ok(vec![Some(grad_input)])
537 }
538
539 fn output_shapes(
540 &self,
541 input_shapes: &[Vec<usize>],
542 _params: &OperationParams,
543 ) -> Result<Vec<Vec<usize>>> {
544 if input_shapes.len() != 1 {
545 return Err(TorshError::InvalidArgument(
546 "Scale operation requires exactly 1 input".to_string(),
547 ));
548 }
549 Ok(vec![input_shapes[0].clone()])
550 }
551
552 fn num_inputs(&self) -> usize {
553 1
554 }
555
556 fn num_outputs(&self) -> usize {
557 1
558 }
559}
560
561pub struct ConcatOperation;
563
564impl<T: TensorElement + Copy> CustomOperation<T> for ConcatOperation {
565 fn name(&self) -> &str {
566 "concat"
567 }
568
569 fn description(&self) -> &str {
570 "Concatenates tensors along a specified axis"
571 }
572
573 fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
574 if inputs.len() < 2 {
575 return Err(TorshError::InvalidArgument(
576 "Concat operation requires at least 2 inputs".to_string(),
577 ));
578 }
579
580 let axis = params.get_int("axis").unwrap_or(0) as usize;
581
582 let input_refs: Vec<&Tensor<T>> = inputs.iter().collect();
584 let result = Tensor::cat(&input_refs, axis as i32)?;
585 Ok(vec![result])
586 }
587
588 fn backward(
589 &self,
590 grad_outputs: &[Tensor<T>],
591 inputs: &[Tensor<T>],
592 _outputs: &[Tensor<T>],
593 params: &OperationParams,
594 ) -> Result<Vec<Option<Tensor<T>>>> {
595 let axis = params.get_int("axis").unwrap_or(0) as usize;
596 let grad_output = &grad_outputs[0];
597
598 let mut split_sizes = Vec::new();
600 for input in inputs {
601 split_sizes.push(input.shape().dims()[axis]);
602 }
603
604 let mut grad_inputs = Vec::new();
606 let mut start = 0;
607 for &size in &split_sizes {
608 let end = start + size;
609 let slice = grad_output.slice_tensor(axis, start, end)?;
610 grad_inputs.push(Some(slice));
611 start = end;
612 }
613 Ok(grad_inputs)
614 }
615
616 fn output_shapes(
617 &self,
618 input_shapes: &[Vec<usize>],
619 params: &OperationParams,
620 ) -> Result<Vec<Vec<usize>>> {
621 if input_shapes.len() < 2 {
622 return Err(TorshError::InvalidArgument(
623 "Concat operation requires at least 2 inputs".to_string(),
624 ));
625 }
626
627 let axis = params.get_int("axis").unwrap_or(0) as usize;
628 let mut output_shape = input_shapes[0].clone();
629
630 if axis >= output_shape.len() {
631 return Err(TorshError::InvalidArgument(format!(
632 "Concat axis {} out of bounds for {} dimensions",
633 axis,
634 output_shape.len()
635 )));
636 }
637
638 let mut total_size = output_shape[axis];
640 for shape in &input_shapes[1..] {
641 if shape.len() != output_shape.len() {
642 return Err(TorshError::InvalidArgument(
643 "All tensors must have the same number of dimensions".to_string(),
644 ));
645 }
646
647 for (i, (&dim1, &dim2)) in output_shape.iter().zip(shape.iter()).enumerate() {
649 if i != axis && dim1 != dim2 {
650 return Err(TorshError::InvalidArgument(format!(
651 "Dimension {} mismatch: {} vs {}",
652 i, dim1, dim2
653 )));
654 }
655 }
656
657 total_size += shape[axis];
658 }
659
660 output_shape[axis] = total_size;
661 Ok(vec![output_shape])
662 }
663
664 fn num_inputs(&self) -> usize {
665 2 }
668
669 fn num_outputs(&self) -> usize {
670 1
671 }
672
673 fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
674 if inputs.len() < 2 {
675 return Err(TorshError::InvalidArgument(
676 "Concat operation requires at least 2 inputs".to_string(),
677 ));
678 }
679
680 let axis = params.get_int("axis").unwrap_or(0) as usize;
681 let first_tensor_shape = inputs[0].shape();
682 let first_shape = first_tensor_shape.dims();
683
684 if axis >= first_shape.len() {
685 return Err(TorshError::InvalidArgument(format!(
686 "Concat axis {} out of bounds for {} dimensions",
687 axis,
688 first_shape.len()
689 )));
690 }
691
692 for (i, tensor) in inputs.iter().enumerate().skip(1) {
694 let tensor_shape = tensor.shape();
695 let shape = tensor_shape.dims();
696 if shape.len() != first_shape.len() {
697 return Err(TorshError::InvalidArgument(format!(
698 "Tensor {} has {} dimensions, expected {}",
699 i,
700 shape.len(),
701 first_shape.len()
702 )));
703 }
704
705 for (dim_idx, (&dim1, &dim2)) in first_shape.iter().zip(shape.iter()).enumerate() {
706 if dim_idx != axis && dim1 != dim2 {
707 return Err(TorshError::InvalidArgument(format!(
708 "Tensor {} dimension {} mismatch: {} vs {}",
709 i, dim_idx, dim1, dim2
710 )));
711 }
712 }
713 }
714
715 Ok(())
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722 use torsh_core::device::DeviceType;
723
724 #[test]
725 fn test_operation_params() {
726 let params = OperationParams::new()
727 .with_string("mode", "linear")
728 .with_int("axis", 1)
729 .with_float("scale", 2.5)
730 .with_bool("inplace", false)
731 .with_vector("weights", vec![1.0, 2.0, 3.0])
732 .with_shape("target_shape", vec![10, 20]);
733
734 assert_eq!(params.get_string("mode"), Some(&"linear".to_string()));
735 assert_eq!(params.get_int("axis"), Some(1));
736 assert_eq!(params.get_float("scale"), Some(2.5));
737 assert_eq!(params.get_bool("inplace"), Some(false));
738 assert_eq!(params.get_vector("weights"), Some(&vec![1.0, 2.0, 3.0]));
739 assert_eq!(params.get_shape("target_shape"), Some(&vec![10, 20]));
740
741 assert_eq!(params.get_string("nonexistent"), None);
742 }
743
744 #[test]
745 fn test_registry_operations() {
746 let registry = CustomOperationRegistry::new();
747
748 let scale_op = Box::new(ScaleOperation);
750 registry
751 .register::<f32>(
752 scale_op,
753 "1.0.0",
754 Some("Test".to_string()),
755 vec!["math".to_string()],
756 )
757 .unwrap();
758
759 assert!(registry.is_registered::<f32>("scale"));
761 assert!(!registry.is_registered::<f32>("nonexistent"));
762
763 let metadata = registry.get_metadata::<f32>("scale").unwrap();
765 assert_eq!(metadata.name, "scale");
766 assert_eq!(
767 metadata.description,
768 "Scales tensor elements by a constant factor"
769 );
770 assert_eq!(metadata.num_inputs, 1);
771 assert_eq!(metadata.num_outputs, 1);
772 assert_eq!(metadata.version, "1.0.0");
773 assert_eq!(metadata.author, Some("Test".to_string()));
774 assert_eq!(metadata.tags, vec!["math".to_string()]);
775
776 let ops = registry.list_operations::<f32>();
778 assert_eq!(ops, vec!["scale".to_string()]);
779
780 registry.unregister::<f32>("scale").unwrap();
782 assert!(!registry.is_registered::<f32>("scale"));
783 }
784
785 #[test]
786 fn test_scale_operation() {
787 let registry = CustomOperationRegistry::new();
788 let scale_op = Box::new(ScaleOperation);
789 registry
790 .register::<f32>(scale_op, "1.0.0", None, vec![])
791 .unwrap();
792
793 let data = vec![1.0f32, 2.0, 3.0, 4.0];
795 let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu).unwrap();
796
797 let params = OperationParams::new().with_float("scale", 2.0);
799 let results = tensor
800 .apply_custom_op_with_registry(®istry, "scale", &[], ¶ms)
801 .unwrap();
802
803 assert_eq!(results.len(), 1);
804 let result = &results[0];
805 let expected_data = vec![2.0f32, 4.0, 6.0, 8.0];
806 assert_eq!(result.data().unwrap(), expected_data);
807 }
808
809 #[test]
810 fn test_concat_operation() {
811 let registry = CustomOperationRegistry::new();
812 let concat_op = Box::new(ConcatOperation);
813 registry
814 .register::<f32>(concat_op, "1.0.0", None, vec![])
815 .unwrap();
816
817 let data1 = vec![1.0f32, 2.0];
819 let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu).unwrap();
820
821 let data2 = vec![3.0f32, 4.0];
822 let tensor2 = Tensor::from_data(data2, vec![2], DeviceType::Cpu).unwrap();
823
824 let params = OperationParams::new().with_int("axis", 0);
826 let results = tensor1
827 .apply_custom_op_with_registry(®istry, "concat", &[&tensor2], ¶ms)
828 .unwrap();
829
830 assert_eq!(results.len(), 1);
831 let result = &results[0];
832 assert_eq!(result.shape().dims(), &[4]); let expected_data = vec![1.0f32, 2.0, 3.0, 4.0];
834 assert_eq!(result.data().unwrap(), expected_data);
835 }
836
837 #[test]
838 fn test_operation_validation() {
839 let registry = CustomOperationRegistry::new();
840 let concat_op = Box::new(ConcatOperation);
841 registry
842 .register::<f32>(concat_op, "1.0.0", None, vec![])
843 .unwrap();
844
845 let data1 = vec![1.0f32, 2.0];
847 let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu).unwrap(); let data2 = vec![3.0f32, 4.0, 5.0, 6.0];
850 let tensor2 = Tensor::from_data(data2, vec![2, 2], DeviceType::Cpu).unwrap(); let params = OperationParams::new().with_int("axis", 0);
854 let result =
855 tensor1.apply_custom_op_with_registry(®istry, "concat", &[&tensor2], ¶ms);
856 assert!(result.is_err());
857 }
858
859 #[test]
860 fn test_output_shape_inference() {
861 let concat_op = ConcatOperation;
862
863 let input_shapes = vec![vec![3], vec![4]];
865 let params = OperationParams::new().with_int("axis", 0);
866
867 let output_shapes = <ConcatOperation as CustomOperation<f32>>::output_shapes(
868 &concat_op,
869 &input_shapes,
870 ¶ms,
871 )
872 .unwrap();
873 assert_eq!(output_shapes, vec![vec![7]]); }
875
876 #[test]
877 fn test_error_cases() {
878 let registry = CustomOperationRegistry::new();
879
880 let scale_op1 = Box::new(ScaleOperation);
882 let scale_op2 = Box::new(ScaleOperation);
883
884 registry
885 .register::<f32>(scale_op1, "1.0.0", None, vec![])
886 .unwrap();
887 let result = registry.register::<f32>(scale_op2, "1.0.0", None, vec![]);
888 assert!(result.is_err());
889
890 let result = registry.unregister::<f32>("nonexistent");
892 assert!(result.is_err());
893
894 let data = vec![1.0f32, 2.0];
896 let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu).unwrap();
897 let params = OperationParams::new();
898 let result = tensor.apply_custom_op_with_registry(®istry, "nonexistent", &[], ¶ms);
899 assert!(result.is_err());
900 }
901
902 #[test]
903 fn test_global_registry() {
904 let registry = global_registry();
905
906 let scale_op = Box::new(ScaleOperation);
908 registry
909 .register::<f32>(scale_op, "1.0.0", None, vec![])
910 .unwrap();
911
912 let data = vec![1.0f32, 2.0, 3.0];
914 let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu).unwrap();
915 let params = OperationParams::new().with_float("scale", 3.0);
916
917 let results = tensor.apply_custom_op("scale", &[], ¶ms).unwrap();
918 assert_eq!(results.len(), 1);
919 let expected_data = vec![3.0f32, 6.0, 9.0];
920 assert_eq!(results[0].data().unwrap(), expected_data);
921
922 registry.unregister::<f32>("scale").unwrap();
924 }
925}