torsh_tensor/
custom_ops.rs

1//! Custom operation registration system for torsh-tensor
2//!
3//! This module provides a flexible system for registering custom operations that can be used
4//! with tensors, including automatic differentiation support. Users can define their own
5//! operations and integrate them seamlessly with the existing tensor API.
6
7use crate::{core_ops::Tensor, TensorElement};
8use num_traits::FromPrimitive;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, RwLock};
12use torsh_core::{
13    device::DeviceType,
14    error::{Result, TorshError},
15};
16
17/// Trait for custom operation implementations
18///
19/// Custom operations must implement this trait to be registered and used with tensors.
20/// The trait provides both forward and backward operations for automatic differentiation.
21pub trait CustomOperation<T: TensorElement>: Send + Sync {
22    /// Get the name of this operation
23    fn name(&self) -> &str;
24
25    /// Get a description of what this operation does
26    fn description(&self) -> &str;
27
28    /// Execute the forward pass of the operation
29    ///
30    /// # Arguments
31    /// * `inputs` - Input tensors for the operation
32    /// * `params` - Optional parameters for the operation
33    ///
34    /// # Returns
35    /// The result tensor(s) from applying this operation
36    fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>>;
37
38    /// Execute the backward pass of the operation (optional for non-differentiable ops)
39    ///
40    /// # Arguments
41    /// * `grad_outputs` - Gradients with respect to the outputs
42    /// * `inputs` - Original input tensors
43    /// * `outputs` - Original output tensors
44    /// * `params` - Operation parameters
45    ///
46    /// # Returns
47    /// Gradients with respect to the inputs
48    fn backward(
49        &self,
50        grad_outputs: &[Tensor<T>],
51        inputs: &[Tensor<T>],
52        outputs: &[Tensor<T>],
53        params: &OperationParams,
54    ) -> Result<Vec<Option<Tensor<T>>>> {
55        // Default implementation for non-differentiable operations
56        Ok(vec![None; inputs.len()])
57    }
58
59    /// Validate that the inputs are compatible with this operation
60    ///
61    /// # Arguments
62    /// * `inputs` - Input tensors to validate
63    /// * `params` - Operation parameters
64    ///
65    /// # Returns
66    /// True if inputs are valid, false otherwise
67    fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
68        // Default implementation - no validation
69        Ok(())
70    }
71
72    /// Get the expected output shapes given input shapes
73    ///
74    /// # Arguments
75    /// * `input_shapes` - Shapes of input tensors
76    /// * `params` - Operation parameters
77    ///
78    /// # Returns
79    /// Expected shapes of output tensors
80    fn output_shapes(
81        &self,
82        input_shapes: &[Vec<usize>],
83        params: &OperationParams,
84    ) -> Result<Vec<Vec<usize>>>;
85
86    /// Check if this operation supports automatic differentiation
87    fn supports_autograd(&self) -> bool {
88        true // Most operations should support autograd
89    }
90
91    /// Get the number of expected inputs
92    fn num_inputs(&self) -> usize;
93
94    /// Get the number of expected outputs
95    fn num_outputs(&self) -> usize;
96}
97
98/// Parameters that can be passed to custom operations
99#[derive(Debug, Clone)]
100pub struct OperationParams {
101    /// String parameters
102    pub strings: HashMap<String, String>,
103    /// Integer parameters
104    pub integers: HashMap<String, i64>,
105    /// Float parameters
106    pub floats: HashMap<String, f64>,
107    /// Boolean parameters
108    pub booleans: HashMap<String, bool>,
109    /// Vector parameters
110    pub vectors: HashMap<String, Vec<f64>>,
111    /// Shape parameters
112    pub shapes: HashMap<String, Vec<usize>>,
113}
114
115impl OperationParams {
116    /// Create a new empty parameter set
117    pub fn new() -> Self {
118        Self {
119            strings: HashMap::new(),
120            integers: HashMap::new(),
121            floats: HashMap::new(),
122            booleans: HashMap::new(),
123            vectors: HashMap::new(),
124            shapes: HashMap::new(),
125        }
126    }
127
128    /// Add a string parameter
129    pub fn with_string(mut self, key: &str, value: &str) -> Self {
130        self.strings.insert(key.to_string(), value.to_string());
131        self
132    }
133
134    /// Add an integer parameter
135    pub fn with_int(mut self, key: &str, value: i64) -> Self {
136        self.integers.insert(key.to_string(), value);
137        self
138    }
139
140    /// Add a float parameter
141    pub fn with_float(mut self, key: &str, value: f64) -> Self {
142        self.floats.insert(key.to_string(), value);
143        self
144    }
145
146    /// Add a boolean parameter
147    pub fn with_bool(mut self, key: &str, value: bool) -> Self {
148        self.booleans.insert(key.to_string(), value);
149        self
150    }
151
152    /// Add a vector parameter
153    pub fn with_vector(mut self, key: &str, value: Vec<f64>) -> Self {
154        self.vectors.insert(key.to_string(), value);
155        self
156    }
157
158    /// Add a shape parameter
159    pub fn with_shape(mut self, key: &str, value: Vec<usize>) -> Self {
160        self.shapes.insert(key.to_string(), value);
161        self
162    }
163
164    /// Get a string parameter
165    pub fn get_string(&self, key: &str) -> Option<&String> {
166        self.strings.get(key)
167    }
168
169    /// Get an integer parameter
170    pub fn get_int(&self, key: &str) -> Option<i64> {
171        self.integers.get(key).copied()
172    }
173
174    /// Get a float parameter
175    pub fn get_float(&self, key: &str) -> Option<f64> {
176        self.floats.get(key).copied()
177    }
178
179    /// Get a boolean parameter
180    pub fn get_bool(&self, key: &str) -> Option<bool> {
181        self.booleans.get(key).copied()
182    }
183
184    /// Get a vector parameter
185    pub fn get_vector(&self, key: &str) -> Option<&Vec<f64>> {
186        self.vectors.get(key)
187    }
188
189    /// Get a shape parameter
190    pub fn get_shape(&self, key: &str) -> Option<&Vec<usize>> {
191        self.shapes.get(key)
192    }
193}
194
195impl Default for OperationParams {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201/// Metadata about a registered operation
202#[derive(Debug, Clone)]
203pub struct OperationMetadata {
204    /// Operation name
205    pub name: String,
206    /// Operation description
207    pub description: String,
208    /// Number of inputs
209    pub num_inputs: usize,
210    /// Number of outputs
211    pub num_outputs: usize,
212    /// Whether the operation supports autograd
213    pub supports_autograd: bool,
214    /// Data type this operation is registered for
215    pub data_type: TypeId,
216    /// Version of the operation
217    pub version: String,
218    /// Author/creator information
219    pub author: Option<String>,
220    /// Additional tags for categorization
221    pub tags: Vec<String>,
222}
223
224/// Registry for custom operations
225///
226/// This registry maintains a collection of custom operations that can be applied to tensors.
227/// Operations are stored per data type to ensure type safety.
228pub struct CustomOperationRegistry {
229    /// Operations stored by (TypeId, operation_name)
230    operations: RwLock<HashMap<(TypeId, String), Arc<dyn Any + Send + Sync>>>,
231    /// Metadata for registered operations
232    metadata: RwLock<HashMap<(TypeId, String), OperationMetadata>>,
233}
234
235impl CustomOperationRegistry {
236    /// Create a new operation registry
237    pub fn new() -> Self {
238        Self {
239            operations: RwLock::new(HashMap::new()),
240            metadata: RwLock::new(HashMap::new()),
241        }
242    }
243
244    /// Register a custom operation
245    ///
246    /// # Arguments
247    /// * `operation` - The operation implementation
248    /// * `version` - Version string for this operation
249    /// * `author` - Optional author information
250    /// * `tags` - Optional tags for categorization
251    ///
252    /// # Returns
253    /// Success or error if registration fails
254    pub fn register<T: TensorElement + 'static>(
255        &self,
256        operation: Box<dyn CustomOperation<T>>,
257        version: &str,
258        author: Option<String>,
259        tags: Vec<String>,
260    ) -> Result<()> {
261        let type_id = TypeId::of::<T>();
262        let name = operation.name().to_string();
263        let key = (type_id, name.clone());
264
265        // Create metadata
266        let metadata = OperationMetadata {
267            name: name.clone(),
268            description: operation.description().to_string(),
269            num_inputs: operation.num_inputs(),
270            num_outputs: operation.num_outputs(),
271            supports_autograd: operation.supports_autograd(),
272            data_type: type_id,
273            version: version.to_string(),
274            author,
275            tags,
276        };
277
278        // Store the operation and metadata
279        {
280            let mut ops = self.operations.write().unwrap();
281            let mut meta = self.metadata.write().unwrap();
282
283            if ops.contains_key(&key) {
284                return Err(TorshError::InvalidArgument(format!(
285                    "Operation '{}' for type {:?} is already registered",
286                    name, type_id
287                )));
288            }
289
290            // Store the operation as Arc<dyn CustomOperation<T>> wrapped in Arc<dyn Any>
291            let arc_op: Arc<dyn CustomOperation<T>> = Arc::from(operation);
292            let boxed_any: Arc<dyn Any + Send + Sync> = Arc::new(arc_op);
293            ops.insert(key.clone(), boxed_any);
294            meta.insert(key, metadata);
295        }
296
297        Ok(())
298    }
299
300    /// Get a registered operation
301    ///
302    /// # Arguments
303    /// * `name` - Name of the operation to retrieve
304    ///
305    /// # Returns
306    /// Reference to the operation if found
307    pub fn get<T: TensorElement + 'static>(
308        &self,
309        name: &str,
310    ) -> Option<Arc<dyn CustomOperation<T>>> {
311        let type_id = TypeId::of::<T>();
312        let key = (type_id, name.to_string());
313
314        let ops = self.operations.read().unwrap();
315        ops.get(&key).and_then(|arc_any| {
316            // Downcast Arc<dyn Any> to Arc<dyn CustomOperation<T>>
317            arc_any
318                .downcast_ref::<Arc<dyn CustomOperation<T>>>()
319                .map(|arc_op| Arc::clone(arc_op))
320        })
321    }
322
323    /// Get metadata for a registered operation
324    pub fn get_metadata<T: TensorElement + 'static>(
325        &self,
326        name: &str,
327    ) -> Option<OperationMetadata> {
328        let type_id = TypeId::of::<T>();
329        let key = (type_id, name.to_string());
330
331        let meta = self.metadata.read().unwrap();
332        meta.get(&key).cloned()
333    }
334
335    /// List all registered operations for a given type
336    pub fn list_operations<T: TensorElement + 'static>(&self) -> Vec<String> {
337        let type_id = TypeId::of::<T>();
338        let meta = self.metadata.read().unwrap();
339
340        meta.keys()
341            .filter(|(tid, _)| *tid == type_id)
342            .map(|(_, name)| name.clone())
343            .collect()
344    }
345
346    /// Remove a registered operation
347    pub fn unregister<T: TensorElement + 'static>(&self, name: &str) -> Result<()> {
348        let type_id = TypeId::of::<T>();
349        let key = (type_id, name.to_string());
350
351        let mut ops = self.operations.write().unwrap();
352        let mut meta = self.metadata.write().unwrap();
353
354        if ops.remove(&key).is_none() {
355            return Err(TorshError::InvalidArgument(format!(
356                "Operation '{}' for type {:?} is not registered",
357                name, type_id
358            )));
359        }
360
361        meta.remove(&key);
362        Ok(())
363    }
364
365    /// Check if an operation is registered
366    pub fn is_registered<T: TensorElement + 'static>(&self, name: &str) -> bool {
367        let type_id = TypeId::of::<T>();
368        let key = (type_id, name.to_string());
369
370        let ops = self.operations.read().unwrap();
371        ops.contains_key(&key)
372    }
373
374    /// Get total number of registered operations
375    pub fn count(&self) -> usize {
376        let ops = self.operations.read().unwrap();
377        ops.len()
378    }
379
380    /// Clear all registered operations
381    pub fn clear(&self) {
382        let mut ops = self.operations.write().unwrap();
383        let mut meta = self.metadata.write().unwrap();
384        ops.clear();
385        meta.clear();
386    }
387}
388
389impl Default for CustomOperationRegistry {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395/// Global custom operation registry
396static GLOBAL_REGISTRY: std::sync::LazyLock<CustomOperationRegistry> =
397    std::sync::LazyLock::new(CustomOperationRegistry::new);
398
399/// Get the global custom operation registry
400pub fn global_registry() -> &'static CustomOperationRegistry {
401    &GLOBAL_REGISTRY
402}
403
404/// Extension trait to add custom operation support to tensors
405pub trait TensorCustomOps<T: TensorElement> {
406    /// Apply a custom operation to this tensor
407    ///
408    /// # Arguments
409    /// * `op_name` - Name of the registered operation
410    /// * `other_inputs` - Additional input tensors (if any)
411    /// * `params` - Operation parameters
412    ///
413    /// # Returns
414    /// Result tensor(s) from the operation
415    fn apply_custom_op(
416        &self,
417        op_name: &str,
418        other_inputs: &[&Tensor<T>],
419        params: &OperationParams,
420    ) -> Result<Vec<Tensor<T>>>;
421
422    /// Apply a custom operation using a specific registry
423    fn apply_custom_op_with_registry(
424        &self,
425        registry: &CustomOperationRegistry,
426        op_name: &str,
427        other_inputs: &[&Tensor<T>],
428        params: &OperationParams,
429    ) -> Result<Vec<Tensor<T>>>;
430}
431
432impl<T: TensorElement + 'static> TensorCustomOps<T> for Tensor<T> {
433    fn apply_custom_op(
434        &self,
435        op_name: &str,
436        other_inputs: &[&Tensor<T>],
437        params: &OperationParams,
438    ) -> Result<Vec<Tensor<T>>> {
439        self.apply_custom_op_with_registry(global_registry(), op_name, other_inputs, params)
440    }
441
442    fn apply_custom_op_with_registry(
443        &self,
444        registry: &CustomOperationRegistry,
445        op_name: &str,
446        other_inputs: &[&Tensor<T>],
447        params: &OperationParams,
448    ) -> Result<Vec<Tensor<T>>> {
449        // Get the operation from the registry
450        let operation = registry.get::<T>(op_name).ok_or_else(|| {
451            TorshError::InvalidArgument(format!(
452                "Custom operation '{}' not found for type",
453                op_name
454            ))
455        })?;
456
457        // Prepare input tensors
458        let mut inputs = vec![self.clone()];
459        inputs.extend(other_inputs.iter().map(|&t| t.clone()));
460
461        // Validate inputs
462        operation.validate_inputs(&inputs, params)?;
463
464        // Check input count
465        if inputs.len() != operation.num_inputs() {
466            return Err(TorshError::InvalidArgument(format!(
467                "Operation '{}' expects {} inputs, got {}",
468                op_name,
469                operation.num_inputs(),
470                inputs.len()
471            )));
472        }
473
474        // Execute the forward pass
475        let outputs = operation.forward(&inputs, params)?;
476
477        // Check output count
478        if outputs.len() != operation.num_outputs() {
479            return Err(TorshError::InvalidArgument(format!(
480                "Operation '{}' produced {} outputs, expected {}",
481                op_name,
482                outputs.len(),
483                operation.num_outputs()
484            )));
485        }
486
487        Ok(outputs)
488    }
489}
490
491// Example custom operations
492
493/// A simple element-wise scaling operation
494pub struct ScaleOperation;
495
496impl<T: TensorElement + Copy + std::ops::Mul<Output = T> + num_traits::FromPrimitive>
497    CustomOperation<T> for ScaleOperation
498{
499    fn name(&self) -> &str {
500        "scale"
501    }
502
503    fn description(&self) -> &str {
504        "Scales tensor elements by a constant factor"
505    }
506
507    fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
508        if inputs.len() != 1 {
509            return Err(TorshError::InvalidArgument(
510                "Scale operation requires exactly 1 input".to_string(),
511            ));
512        }
513
514        let scale = params.get_float("scale").unwrap_or(1.0);
515        let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
516            TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
517        })?;
518
519        let result = inputs[0].mul_scalar(scale_val)?;
520        Ok(vec![result])
521    }
522
523    fn backward(
524        &self,
525        grad_outputs: &[Tensor<T>],
526        _inputs: &[Tensor<T>],
527        _outputs: &[Tensor<T>],
528        params: &OperationParams,
529    ) -> Result<Vec<Option<Tensor<T>>>> {
530        let scale = params.get_float("scale").unwrap_or(1.0);
531        let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
532            TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
533        })?;
534
535        let grad_input = grad_outputs[0].mul_scalar(scale_val)?;
536        Ok(vec![Some(grad_input)])
537    }
538
539    fn output_shapes(
540        &self,
541        input_shapes: &[Vec<usize>],
542        _params: &OperationParams,
543    ) -> Result<Vec<Vec<usize>>> {
544        if input_shapes.len() != 1 {
545            return Err(TorshError::InvalidArgument(
546                "Scale operation requires exactly 1 input".to_string(),
547            ));
548        }
549        Ok(vec![input_shapes[0].clone()])
550    }
551
552    fn num_inputs(&self) -> usize {
553        1
554    }
555
556    fn num_outputs(&self) -> usize {
557        1
558    }
559}
560
561/// A tensor concatenation operation along a specified axis
562pub struct ConcatOperation;
563
564impl<T: TensorElement + Copy> CustomOperation<T> for ConcatOperation {
565    fn name(&self) -> &str {
566        "concat"
567    }
568
569    fn description(&self) -> &str {
570        "Concatenates tensors along a specified axis"
571    }
572
573    fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
574        if inputs.len() < 2 {
575            return Err(TorshError::InvalidArgument(
576                "Concat operation requires at least 2 inputs".to_string(),
577            ));
578        }
579
580        let axis = params.get_int("axis").unwrap_or(0) as usize;
581
582        // Use the existing cat operation from the tensor API
583        let input_refs: Vec<&Tensor<T>> = inputs.iter().collect();
584        let result = Tensor::cat(&input_refs, axis as i32)?;
585        Ok(vec![result])
586    }
587
588    fn backward(
589        &self,
590        grad_outputs: &[Tensor<T>],
591        inputs: &[Tensor<T>],
592        _outputs: &[Tensor<T>],
593        params: &OperationParams,
594    ) -> Result<Vec<Option<Tensor<T>>>> {
595        let axis = params.get_int("axis").unwrap_or(0) as usize;
596        let grad_output = &grad_outputs[0];
597
598        // Split the gradient back to match input sizes
599        let mut split_sizes = Vec::new();
600        for input in inputs {
601            split_sizes.push(input.shape().dims()[axis]);
602        }
603
604        // Use multiple slice operations instead of split_with_sizes
605        let mut grad_inputs = Vec::new();
606        let mut start = 0;
607        for &size in &split_sizes {
608            let end = start + size;
609            let slice = grad_output.slice_tensor(axis, start, end)?;
610            grad_inputs.push(Some(slice));
611            start = end;
612        }
613        Ok(grad_inputs)
614    }
615
616    fn output_shapes(
617        &self,
618        input_shapes: &[Vec<usize>],
619        params: &OperationParams,
620    ) -> Result<Vec<Vec<usize>>> {
621        if input_shapes.len() < 2 {
622            return Err(TorshError::InvalidArgument(
623                "Concat operation requires at least 2 inputs".to_string(),
624            ));
625        }
626
627        let axis = params.get_int("axis").unwrap_or(0) as usize;
628        let mut output_shape = input_shapes[0].clone();
629
630        if axis >= output_shape.len() {
631            return Err(TorshError::InvalidArgument(format!(
632                "Concat axis {} out of bounds for {} dimensions",
633                axis,
634                output_shape.len()
635            )));
636        }
637
638        // Sum the sizes along the concatenation axis
639        let mut total_size = output_shape[axis];
640        for shape in &input_shapes[1..] {
641            if shape.len() != output_shape.len() {
642                return Err(TorshError::InvalidArgument(
643                    "All tensors must have the same number of dimensions".to_string(),
644                ));
645            }
646
647            // Check that all dimensions except the concat axis match
648            for (i, (&dim1, &dim2)) in output_shape.iter().zip(shape.iter()).enumerate() {
649                if i != axis && dim1 != dim2 {
650                    return Err(TorshError::InvalidArgument(format!(
651                        "Dimension {} mismatch: {} vs {}",
652                        i, dim1, dim2
653                    )));
654                }
655            }
656
657            total_size += shape[axis];
658        }
659
660        output_shape[axis] = total_size;
661        Ok(vec![output_shape])
662    }
663
664    fn num_inputs(&self) -> usize {
665        // Variable number of inputs, but we'll validate at runtime
666        2 // Minimum required
667    }
668
669    fn num_outputs(&self) -> usize {
670        1
671    }
672
673    fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
674        if inputs.len() < 2 {
675            return Err(TorshError::InvalidArgument(
676                "Concat operation requires at least 2 inputs".to_string(),
677            ));
678        }
679
680        let axis = params.get_int("axis").unwrap_or(0) as usize;
681        let first_tensor_shape = inputs[0].shape();
682        let first_shape = first_tensor_shape.dims();
683
684        if axis >= first_shape.len() {
685            return Err(TorshError::InvalidArgument(format!(
686                "Concat axis {} out of bounds for {} dimensions",
687                axis,
688                first_shape.len()
689            )));
690        }
691
692        // Validate that all tensors have compatible shapes
693        for (i, tensor) in inputs.iter().enumerate().skip(1) {
694            let tensor_shape = tensor.shape();
695            let shape = tensor_shape.dims();
696            if shape.len() != first_shape.len() {
697                return Err(TorshError::InvalidArgument(format!(
698                    "Tensor {} has {} dimensions, expected {}",
699                    i,
700                    shape.len(),
701                    first_shape.len()
702                )));
703            }
704
705            for (dim_idx, (&dim1, &dim2)) in first_shape.iter().zip(shape.iter()).enumerate() {
706                if dim_idx != axis && dim1 != dim2 {
707                    return Err(TorshError::InvalidArgument(format!(
708                        "Tensor {} dimension {} mismatch: {} vs {}",
709                        i, dim_idx, dim1, dim2
710                    )));
711                }
712            }
713        }
714
715        Ok(())
716    }
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722    use torsh_core::device::DeviceType;
723
724    #[test]
725    fn test_operation_params() {
726        let params = OperationParams::new()
727            .with_string("mode", "linear")
728            .with_int("axis", 1)
729            .with_float("scale", 2.5)
730            .with_bool("inplace", false)
731            .with_vector("weights", vec![1.0, 2.0, 3.0])
732            .with_shape("target_shape", vec![10, 20]);
733
734        assert_eq!(params.get_string("mode"), Some(&"linear".to_string()));
735        assert_eq!(params.get_int("axis"), Some(1));
736        assert_eq!(params.get_float("scale"), Some(2.5));
737        assert_eq!(params.get_bool("inplace"), Some(false));
738        assert_eq!(params.get_vector("weights"), Some(&vec![1.0, 2.0, 3.0]));
739        assert_eq!(params.get_shape("target_shape"), Some(&vec![10, 20]));
740
741        assert_eq!(params.get_string("nonexistent"), None);
742    }
743
744    #[test]
745    fn test_registry_operations() {
746        let registry = CustomOperationRegistry::new();
747
748        // Register a scale operation
749        let scale_op = Box::new(ScaleOperation);
750        registry
751            .register::<f32>(
752                scale_op,
753                "1.0.0",
754                Some("Test".to_string()),
755                vec!["math".to_string()],
756            )
757            .unwrap();
758
759        // Check registration
760        assert!(registry.is_registered::<f32>("scale"));
761        assert!(!registry.is_registered::<f32>("nonexistent"));
762
763        // Get metadata
764        let metadata = registry.get_metadata::<f32>("scale").unwrap();
765        assert_eq!(metadata.name, "scale");
766        assert_eq!(
767            metadata.description,
768            "Scales tensor elements by a constant factor"
769        );
770        assert_eq!(metadata.num_inputs, 1);
771        assert_eq!(metadata.num_outputs, 1);
772        assert_eq!(metadata.version, "1.0.0");
773        assert_eq!(metadata.author, Some("Test".to_string()));
774        assert_eq!(metadata.tags, vec!["math".to_string()]);
775
776        // List operations
777        let ops = registry.list_operations::<f32>();
778        assert_eq!(ops, vec!["scale".to_string()]);
779
780        // Unregister
781        registry.unregister::<f32>("scale").unwrap();
782        assert!(!registry.is_registered::<f32>("scale"));
783    }
784
785    #[test]
786    fn test_scale_operation() {
787        let registry = CustomOperationRegistry::new();
788        let scale_op = Box::new(ScaleOperation);
789        registry
790            .register::<f32>(scale_op, "1.0.0", None, vec![])
791            .unwrap();
792
793        // Create test tensor
794        let data = vec![1.0f32, 2.0, 3.0, 4.0];
795        let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu).unwrap();
796
797        // Apply scale operation
798        let params = OperationParams::new().with_float("scale", 2.0);
799        let results = tensor
800            .apply_custom_op_with_registry(&registry, "scale", &[], &params)
801            .unwrap();
802
803        assert_eq!(results.len(), 1);
804        let result = &results[0];
805        let expected_data = vec![2.0f32, 4.0, 6.0, 8.0];
806        assert_eq!(result.data().unwrap(), expected_data);
807    }
808
809    #[test]
810    fn test_concat_operation() {
811        let registry = CustomOperationRegistry::new();
812        let concat_op = Box::new(ConcatOperation);
813        registry
814            .register::<f32>(concat_op, "1.0.0", None, vec![])
815            .unwrap();
816
817        // Create test tensors (1D to work with current cat implementation)
818        let data1 = vec![1.0f32, 2.0];
819        let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu).unwrap();
820
821        let data2 = vec![3.0f32, 4.0];
822        let tensor2 = Tensor::from_data(data2, vec![2], DeviceType::Cpu).unwrap();
823
824        // Apply concat operation along axis 0
825        let params = OperationParams::new().with_int("axis", 0);
826        let results = tensor1
827            .apply_custom_op_with_registry(&registry, "concat", &[&tensor2], &params)
828            .unwrap();
829
830        assert_eq!(results.len(), 1);
831        let result = &results[0];
832        assert_eq!(result.shape().dims(), &[4]); // 2 + 2 = 4 elements
833        let expected_data = vec![1.0f32, 2.0, 3.0, 4.0];
834        assert_eq!(result.data().unwrap(), expected_data);
835    }
836
837    #[test]
838    fn test_operation_validation() {
839        let registry = CustomOperationRegistry::new();
840        let concat_op = Box::new(ConcatOperation);
841        registry
842            .register::<f32>(concat_op, "1.0.0", None, vec![])
843            .unwrap();
844
845        // Create tensors with incompatible dimensions (2D vs 1D should fail)
846        let data1 = vec![1.0f32, 2.0];
847        let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu).unwrap(); // 1D tensor
848
849        let data2 = vec![3.0f32, 4.0, 5.0, 6.0];
850        let tensor2 = Tensor::from_data(data2, vec![2, 2], DeviceType::Cpu).unwrap(); // 2D tensor
851
852        // This should fail validation due to different number of dimensions
853        let params = OperationParams::new().with_int("axis", 0);
854        let result =
855            tensor1.apply_custom_op_with_registry(&registry, "concat", &[&tensor2], &params);
856        assert!(result.is_err());
857    }
858
859    #[test]
860    fn test_output_shape_inference() {
861        let concat_op = ConcatOperation;
862
863        // Test shape inference for concat operation (1D tensors)
864        let input_shapes = vec![vec![3], vec![4]];
865        let params = OperationParams::new().with_int("axis", 0);
866
867        let output_shapes = <ConcatOperation as CustomOperation<f32>>::output_shapes(
868            &concat_op,
869            &input_shapes,
870            &params,
871        )
872        .unwrap();
873        assert_eq!(output_shapes, vec![vec![7]]); // 3 + 4 = 7 along axis 0
874    }
875
876    #[test]
877    fn test_error_cases() {
878        let registry = CustomOperationRegistry::new();
879
880        // Try to register duplicate operation
881        let scale_op1 = Box::new(ScaleOperation);
882        let scale_op2 = Box::new(ScaleOperation);
883
884        registry
885            .register::<f32>(scale_op1, "1.0.0", None, vec![])
886            .unwrap();
887        let result = registry.register::<f32>(scale_op2, "1.0.0", None, vec![]);
888        assert!(result.is_err());
889
890        // Try to unregister non-existent operation
891        let result = registry.unregister::<f32>("nonexistent");
892        assert!(result.is_err());
893
894        // Try to apply non-existent operation
895        let data = vec![1.0f32, 2.0];
896        let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu).unwrap();
897        let params = OperationParams::new();
898        let result = tensor.apply_custom_op_with_registry(&registry, "nonexistent", &[], &params);
899        assert!(result.is_err());
900    }
901
902    #[test]
903    fn test_global_registry() {
904        let registry = global_registry();
905
906        // Register an operation in the global registry
907        let scale_op = Box::new(ScaleOperation);
908        registry
909            .register::<f32>(scale_op, "1.0.0", None, vec![])
910            .unwrap();
911
912        // Use the operation via the tensor extension trait
913        let data = vec![1.0f32, 2.0, 3.0];
914        let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu).unwrap();
915        let params = OperationParams::new().with_float("scale", 3.0);
916
917        let results = tensor.apply_custom_op("scale", &[], &params).unwrap();
918        assert_eq!(results.len(), 1);
919        let expected_data = vec![3.0f32, 6.0, 9.0];
920        assert_eq!(results[0].data().unwrap(), expected_data);
921
922        // Clean up
923        registry.unregister::<f32>("scale").unwrap();
924    }
925}