Skip to main content

torsh_tensor/
custom_ops.rs

1//! Custom operation registration system for torsh-tensor
2//!
3//! This module provides a flexible system for registering custom operations that can be used
4//! with tensors, including automatic differentiation support. Users can define their own
5//! operations and integrate them seamlessly with the existing tensor API.
6
7use crate::{core_ops::Tensor, TensorElement};
8use scirs2_core::numeric::FromPrimitive;
9use std::any::{Any, TypeId};
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use torsh_core::error::{Result, TorshError};
13
14/// Trait for custom operation implementations
15///
16/// Custom operations must implement this trait to be registered and used with tensors.
17/// The trait provides both forward and backward operations for automatic differentiation.
18pub trait CustomOperation<T: TensorElement>: Send + Sync {
19    /// Get the name of this operation
20    fn name(&self) -> &str;
21
22    /// Get a description of what this operation does
23    fn description(&self) -> &str;
24
25    /// Execute the forward pass of the operation
26    ///
27    /// # Arguments
28    /// * `inputs` - Input tensors for the operation
29    /// * `params` - Optional parameters for the operation
30    ///
31    /// # Returns
32    /// The result tensor(s) from applying this operation
33    fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>>;
34
35    /// Execute the backward pass of the operation (optional for non-differentiable ops)
36    ///
37    /// # Arguments
38    /// * `grad_outputs` - Gradients with respect to the outputs
39    /// * `inputs` - Original input tensors
40    /// * `outputs` - Original output tensors
41    /// * `params` - Operation parameters
42    ///
43    /// # Returns
44    /// Gradients with respect to the inputs
45    fn backward(
46        &self,
47        grad_outputs: &[Tensor<T>],
48        inputs: &[Tensor<T>],
49        _outputs: &[Tensor<T>],
50        _params: &OperationParams,
51    ) -> Result<Vec<Option<Tensor<T>>>> {
52        // Default implementation for non-differentiable operations
53        // Validate that we have gradient outputs matching expected count
54        let _ = grad_outputs.is_empty(); // Check if empty but continue
55
56        // Return None gradients for all inputs (non-differentiable by default)
57        Ok(vec![None; inputs.len()])
58    }
59
60    /// Validate that the inputs are compatible with this operation
61    ///
62    /// # Arguments
63    /// * `inputs` - Input tensors to validate
64    /// * `params` - Operation parameters
65    ///
66    /// # Returns
67    /// True if inputs are valid, false otherwise
68    fn validate_inputs(&self, inputs: &[Tensor<T>], _params: &OperationParams) -> Result<()> {
69        // Default implementation - basic validation
70        if inputs.is_empty() {
71            return Err(torsh_core::error::TorshError::InvalidShape(
72                "Operation requires at least one input tensor".to_string(),
73            ));
74        }
75
76        // Validate that all input tensors have data
77        for (idx, input) in inputs.iter().enumerate() {
78            let _ = (idx, input.shape.is_empty()); // Check shape validity
79        }
80
81        Ok(())
82    }
83
84    /// Get the expected output shapes given input shapes
85    ///
86    /// # Arguments
87    /// * `input_shapes` - Shapes of input tensors
88    /// * `params` - Operation parameters
89    ///
90    /// # Returns
91    /// Expected shapes of output tensors
92    fn output_shapes(
93        &self,
94        input_shapes: &[Vec<usize>],
95        params: &OperationParams,
96    ) -> Result<Vec<Vec<usize>>>;
97
98    /// Check if this operation supports automatic differentiation
99    fn supports_autograd(&self) -> bool {
100        true // Most operations should support autograd
101    }
102
103    /// Get the number of expected inputs
104    fn num_inputs(&self) -> usize;
105
106    /// Get the number of expected outputs
107    fn num_outputs(&self) -> usize;
108}
109
110/// Parameters that can be passed to custom operations
111#[derive(Debug, Clone)]
112pub struct OperationParams {
113    /// String parameters
114    pub strings: HashMap<String, String>,
115    /// Integer parameters
116    pub integers: HashMap<String, i64>,
117    /// Float parameters
118    pub floats: HashMap<String, f64>,
119    /// Boolean parameters
120    pub booleans: HashMap<String, bool>,
121    /// Vector parameters
122    pub vectors: HashMap<String, Vec<f64>>,
123    /// Shape parameters
124    pub shapes: HashMap<String, Vec<usize>>,
125}
126
127impl OperationParams {
128    /// Create a new empty parameter set
129    pub fn new() -> Self {
130        Self {
131            strings: HashMap::new(),
132            integers: HashMap::new(),
133            floats: HashMap::new(),
134            booleans: HashMap::new(),
135            vectors: HashMap::new(),
136            shapes: HashMap::new(),
137        }
138    }
139
140    /// Add a string parameter
141    pub fn with_string(mut self, key: &str, value: &str) -> Self {
142        self.strings.insert(key.to_string(), value.to_string());
143        self
144    }
145
146    /// Add an integer parameter
147    pub fn with_int(mut self, key: &str, value: i64) -> Self {
148        self.integers.insert(key.to_string(), value);
149        self
150    }
151
152    /// Add a float parameter
153    pub fn with_float(mut self, key: &str, value: f64) -> Self {
154        self.floats.insert(key.to_string(), value);
155        self
156    }
157
158    /// Add a boolean parameter
159    pub fn with_bool(mut self, key: &str, value: bool) -> Self {
160        self.booleans.insert(key.to_string(), value);
161        self
162    }
163
164    /// Add a vector parameter
165    pub fn with_vector(mut self, key: &str, value: Vec<f64>) -> Self {
166        self.vectors.insert(key.to_string(), value);
167        self
168    }
169
170    /// Add a shape parameter
171    pub fn with_shape(mut self, key: &str, value: Vec<usize>) -> Self {
172        self.shapes.insert(key.to_string(), value);
173        self
174    }
175
176    /// Get a string parameter
177    pub fn get_string(&self, key: &str) -> Option<&String> {
178        self.strings.get(key)
179    }
180
181    /// Get an integer parameter
182    pub fn get_int(&self, key: &str) -> Option<i64> {
183        self.integers.get(key).copied()
184    }
185
186    /// Get a float parameter
187    pub fn get_float(&self, key: &str) -> Option<f64> {
188        self.floats.get(key).copied()
189    }
190
191    /// Get a boolean parameter
192    pub fn get_bool(&self, key: &str) -> Option<bool> {
193        self.booleans.get(key).copied()
194    }
195
196    /// Get a vector parameter
197    pub fn get_vector(&self, key: &str) -> Option<&Vec<f64>> {
198        self.vectors.get(key)
199    }
200
201    /// Get a shape parameter
202    pub fn get_shape(&self, key: &str) -> Option<&Vec<usize>> {
203        self.shapes.get(key)
204    }
205}
206
207impl Default for OperationParams {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213/// Metadata about a registered operation
214#[derive(Debug, Clone)]
215pub struct OperationMetadata {
216    /// Operation name
217    pub name: String,
218    /// Operation description
219    pub description: String,
220    /// Number of inputs
221    pub num_inputs: usize,
222    /// Number of outputs
223    pub num_outputs: usize,
224    /// Whether the operation supports autograd
225    pub supports_autograd: bool,
226    /// Data type this operation is registered for
227    pub data_type: TypeId,
228    /// Version of the operation
229    pub version: String,
230    /// Author/creator information
231    pub author: Option<String>,
232    /// Additional tags for categorization
233    pub tags: Vec<String>,
234}
235
236/// Registry for custom operations
237///
238/// This registry maintains a collection of custom operations that can be applied to tensors.
239/// Operations are stored per data type to ensure type safety.
240pub struct CustomOperationRegistry {
241    /// Operations stored by (TypeId, operation_name)
242    operations: RwLock<HashMap<(TypeId, String), Arc<dyn Any + Send + Sync>>>,
243    /// Metadata for registered operations
244    metadata: RwLock<HashMap<(TypeId, String), OperationMetadata>>,
245}
246
247impl CustomOperationRegistry {
248    /// Create a new operation registry
249    pub fn new() -> Self {
250        Self {
251            operations: RwLock::new(HashMap::new()),
252            metadata: RwLock::new(HashMap::new()),
253        }
254    }
255
256    /// Register a custom operation
257    ///
258    /// # Arguments
259    /// * `operation` - The operation implementation
260    /// * `version` - Version string for this operation
261    /// * `author` - Optional author information
262    /// * `tags` - Optional tags for categorization
263    ///
264    /// # Returns
265    /// Success or error if registration fails
266    pub fn register<T: TensorElement + 'static>(
267        &self,
268        operation: Box<dyn CustomOperation<T>>,
269        version: &str,
270        author: Option<String>,
271        tags: Vec<String>,
272    ) -> Result<()> {
273        let type_id = TypeId::of::<T>();
274        let name = operation.name().to_string();
275        let key = (type_id, name.clone());
276
277        // Create metadata
278        let metadata = OperationMetadata {
279            name: name.clone(),
280            description: operation.description().to_string(),
281            num_inputs: operation.num_inputs(),
282            num_outputs: operation.num_outputs(),
283            supports_autograd: operation.supports_autograd(),
284            data_type: type_id,
285            version: version.to_string(),
286            author,
287            tags,
288        };
289
290        // Store the operation and metadata
291        {
292            let mut ops = self
293                .operations
294                .write()
295                .expect("lock should not be poisoned");
296            let mut meta = self.metadata.write().expect("lock should not be poisoned");
297
298            if ops.contains_key(&key) {
299                return Err(TorshError::InvalidArgument(format!(
300                    "Operation '{}' for type {:?} is already registered",
301                    name, type_id
302                )));
303            }
304
305            // Store the operation as Arc<dyn CustomOperation<T>> wrapped in Arc<dyn Any>
306            let arc_op: Arc<dyn CustomOperation<T>> = Arc::from(operation);
307            let boxed_any: Arc<dyn Any + Send + Sync> = Arc::new(arc_op);
308            ops.insert(key.clone(), boxed_any);
309            meta.insert(key, metadata);
310        }
311
312        Ok(())
313    }
314
315    /// Get a registered operation
316    ///
317    /// # Arguments
318    /// * `name` - Name of the operation to retrieve
319    ///
320    /// # Returns
321    /// Reference to the operation if found
322    pub fn get<T: TensorElement + 'static>(
323        &self,
324        name: &str,
325    ) -> Option<Arc<dyn CustomOperation<T>>> {
326        let type_id = TypeId::of::<T>();
327        let key = (type_id, name.to_string());
328
329        let ops = self.operations.read().expect("lock should not be poisoned");
330        ops.get(&key).and_then(|arc_any| {
331            // Downcast Arc<dyn Any> to Arc<dyn CustomOperation<T>>
332            arc_any
333                .downcast_ref::<Arc<dyn CustomOperation<T>>>()
334                .map(|arc_op| Arc::clone(arc_op))
335        })
336    }
337
338    /// Get metadata for a registered operation
339    pub fn get_metadata<T: TensorElement + 'static>(
340        &self,
341        name: &str,
342    ) -> Option<OperationMetadata> {
343        let type_id = TypeId::of::<T>();
344        let key = (type_id, name.to_string());
345
346        let meta = self.metadata.read().expect("lock should not be poisoned");
347        meta.get(&key).cloned()
348    }
349
350    /// List all registered operations for a given type
351    pub fn list_operations<T: TensorElement + 'static>(&self) -> Vec<String> {
352        let type_id = TypeId::of::<T>();
353        let meta = self.metadata.read().expect("lock should not be poisoned");
354
355        meta.keys()
356            .filter(|(tid, _)| *tid == type_id)
357            .map(|(_, name)| name.clone())
358            .collect()
359    }
360
361    /// Remove a registered operation
362    pub fn unregister<T: TensorElement + 'static>(&self, name: &str) -> Result<()> {
363        let type_id = TypeId::of::<T>();
364        let key = (type_id, name.to_string());
365
366        let mut ops = self
367            .operations
368            .write()
369            .expect("lock should not be poisoned");
370        let mut meta = self.metadata.write().expect("lock should not be poisoned");
371
372        if ops.remove(&key).is_none() {
373            return Err(TorshError::InvalidArgument(format!(
374                "Operation '{}' for type {:?} is not registered",
375                name, type_id
376            )));
377        }
378
379        meta.remove(&key);
380        Ok(())
381    }
382
383    /// Check if an operation is registered
384    pub fn is_registered<T: TensorElement + 'static>(&self, name: &str) -> bool {
385        let type_id = TypeId::of::<T>();
386        let key = (type_id, name.to_string());
387
388        let ops = self.operations.read().expect("lock should not be poisoned");
389        ops.contains_key(&key)
390    }
391
392    /// Get total number of registered operations
393    pub fn count(&self) -> usize {
394        let ops = self.operations.read().expect("lock should not be poisoned");
395        ops.len()
396    }
397
398    /// Clear all registered operations
399    pub fn clear(&self) {
400        let mut ops = self
401            .operations
402            .write()
403            .expect("lock should not be poisoned");
404        let mut meta = self.metadata.write().expect("lock should not be poisoned");
405        ops.clear();
406        meta.clear();
407    }
408}
409
410impl Default for CustomOperationRegistry {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416/// Global custom operation registry
417static GLOBAL_REGISTRY: std::sync::LazyLock<CustomOperationRegistry> =
418    std::sync::LazyLock::new(CustomOperationRegistry::new);
419
420/// Get the global custom operation registry
421pub fn global_registry() -> &'static CustomOperationRegistry {
422    &GLOBAL_REGISTRY
423}
424
425/// Extension trait to add custom operation support to tensors
426pub trait TensorCustomOps<T: TensorElement> {
427    /// Apply a custom operation to this tensor
428    ///
429    /// # Arguments
430    /// * `op_name` - Name of the registered operation
431    /// * `other_inputs` - Additional input tensors (if any)
432    /// * `params` - Operation parameters
433    ///
434    /// # Returns
435    /// Result tensor(s) from the operation
436    fn apply_custom_op(
437        &self,
438        op_name: &str,
439        other_inputs: &[&Tensor<T>],
440        params: &OperationParams,
441    ) -> Result<Vec<Tensor<T>>>;
442
443    /// Apply a custom operation using a specific registry
444    fn apply_custom_op_with_registry(
445        &self,
446        registry: &CustomOperationRegistry,
447        op_name: &str,
448        other_inputs: &[&Tensor<T>],
449        params: &OperationParams,
450    ) -> Result<Vec<Tensor<T>>>;
451}
452
453impl<T: TensorElement + 'static> TensorCustomOps<T> for Tensor<T> {
454    fn apply_custom_op(
455        &self,
456        op_name: &str,
457        other_inputs: &[&Tensor<T>],
458        params: &OperationParams,
459    ) -> Result<Vec<Tensor<T>>> {
460        self.apply_custom_op_with_registry(global_registry(), op_name, other_inputs, params)
461    }
462
463    fn apply_custom_op_with_registry(
464        &self,
465        registry: &CustomOperationRegistry,
466        op_name: &str,
467        other_inputs: &[&Tensor<T>],
468        params: &OperationParams,
469    ) -> Result<Vec<Tensor<T>>> {
470        // Get the operation from the registry
471        let operation = registry.get::<T>(op_name).ok_or_else(|| {
472            TorshError::InvalidArgument(format!(
473                "Custom operation '{}' not found for type",
474                op_name
475            ))
476        })?;
477
478        // Prepare input tensors
479        let mut inputs = vec![self.clone()];
480        inputs.extend(other_inputs.iter().map(|&t| t.clone()));
481
482        // Validate inputs
483        operation.validate_inputs(&inputs, params)?;
484
485        // Check input count
486        if inputs.len() != operation.num_inputs() {
487            return Err(TorshError::InvalidArgument(format!(
488                "Operation '{}' expects {} inputs, got {}",
489                op_name,
490                operation.num_inputs(),
491                inputs.len()
492            )));
493        }
494
495        // Execute the forward pass
496        let outputs = operation.forward(&inputs, params)?;
497
498        // Check output count
499        if outputs.len() != operation.num_outputs() {
500            return Err(TorshError::InvalidArgument(format!(
501                "Operation '{}' produced {} outputs, expected {}",
502                op_name,
503                outputs.len(),
504                operation.num_outputs()
505            )));
506        }
507
508        Ok(outputs)
509    }
510}
511
512// Example custom operations
513
514/// A simple element-wise scaling operation
515pub struct ScaleOperation;
516
517impl<T: TensorElement + Copy + std::ops::Mul<Output = T> + num_traits::FromPrimitive>
518    CustomOperation<T> for ScaleOperation
519{
520    fn name(&self) -> &str {
521        "scale"
522    }
523
524    fn description(&self) -> &str {
525        "Scales tensor elements by a constant factor"
526    }
527
528    fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
529        if inputs.len() != 1 {
530            return Err(TorshError::InvalidArgument(
531                "Scale operation requires exactly 1 input".to_string(),
532            ));
533        }
534
535        let scale = params.get_float("scale").unwrap_or(1.0);
536        let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
537            TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
538        })?;
539
540        let result = inputs[0].mul_scalar(scale_val)?;
541        Ok(vec![result])
542    }
543
544    fn backward(
545        &self,
546        grad_outputs: &[Tensor<T>],
547        _inputs: &[Tensor<T>],
548        _outputs: &[Tensor<T>],
549        params: &OperationParams,
550    ) -> Result<Vec<Option<Tensor<T>>>> {
551        let scale = params.get_float("scale").unwrap_or(1.0);
552        let scale_val = <T as FromPrimitive>::from_f64(scale).ok_or_else(|| {
553            TorshError::InvalidArgument("Cannot convert scale factor to tensor type".to_string())
554        })?;
555
556        let grad_input = grad_outputs[0].mul_scalar(scale_val)?;
557        Ok(vec![Some(grad_input)])
558    }
559
560    fn output_shapes(
561        &self,
562        input_shapes: &[Vec<usize>],
563        _params: &OperationParams,
564    ) -> Result<Vec<Vec<usize>>> {
565        if input_shapes.len() != 1 {
566            return Err(TorshError::InvalidArgument(
567                "Scale operation requires exactly 1 input".to_string(),
568            ));
569        }
570        Ok(vec![input_shapes[0].clone()])
571    }
572
573    fn num_inputs(&self) -> usize {
574        1
575    }
576
577    fn num_outputs(&self) -> usize {
578        1
579    }
580}
581
582/// A tensor concatenation operation along a specified axis
583pub struct ConcatOperation;
584
585impl<T: TensorElement + Copy> CustomOperation<T> for ConcatOperation {
586    fn name(&self) -> &str {
587        "concat"
588    }
589
590    fn description(&self) -> &str {
591        "Concatenates tensors along a specified axis"
592    }
593
594    fn forward(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<Vec<Tensor<T>>> {
595        if inputs.len() < 2 {
596            return Err(TorshError::InvalidArgument(
597                "Concat operation requires at least 2 inputs".to_string(),
598            ));
599        }
600
601        let axis = params.get_int("axis").unwrap_or(0) as usize;
602
603        // Use the existing cat operation from the tensor API
604        let input_refs: Vec<&Tensor<T>> = inputs.iter().collect();
605        let result = Tensor::cat(&input_refs, axis as i32)?;
606        Ok(vec![result])
607    }
608
609    fn backward(
610        &self,
611        grad_outputs: &[Tensor<T>],
612        inputs: &[Tensor<T>],
613        _outputs: &[Tensor<T>],
614        params: &OperationParams,
615    ) -> Result<Vec<Option<Tensor<T>>>> {
616        let axis = params.get_int("axis").unwrap_or(0) as usize;
617        let grad_output = &grad_outputs[0];
618
619        // Split the gradient back to match input sizes
620        let mut split_sizes = Vec::new();
621        for input in inputs {
622            split_sizes.push(input.shape().dims()[axis]);
623        }
624
625        // Use multiple slice operations instead of split_with_sizes
626        let mut grad_inputs = Vec::new();
627        let mut start = 0;
628        for &size in &split_sizes {
629            let end = start + size;
630            let slice = grad_output.slice_tensor(axis, start, end)?;
631            grad_inputs.push(Some(slice));
632            start = end;
633        }
634        Ok(grad_inputs)
635    }
636
637    fn output_shapes(
638        &self,
639        input_shapes: &[Vec<usize>],
640        params: &OperationParams,
641    ) -> Result<Vec<Vec<usize>>> {
642        if input_shapes.len() < 2 {
643            return Err(TorshError::InvalidArgument(
644                "Concat operation requires at least 2 inputs".to_string(),
645            ));
646        }
647
648        let axis = params.get_int("axis").unwrap_or(0) as usize;
649        let mut output_shape = input_shapes[0].clone();
650
651        if axis >= output_shape.len() {
652            return Err(TorshError::InvalidArgument(format!(
653                "Concat axis {} out of bounds for {} dimensions",
654                axis,
655                output_shape.len()
656            )));
657        }
658
659        // Sum the sizes along the concatenation axis
660        let mut total_size = output_shape[axis];
661        for shape in &input_shapes[1..] {
662            if shape.len() != output_shape.len() {
663                return Err(TorshError::InvalidArgument(
664                    "All tensors must have the same number of dimensions".to_string(),
665                ));
666            }
667
668            // Check that all dimensions except the concat axis match
669            for (i, (&dim1, &dim2)) in output_shape.iter().zip(shape.iter()).enumerate() {
670                if i != axis && dim1 != dim2 {
671                    return Err(TorshError::InvalidArgument(format!(
672                        "Dimension {} mismatch: {} vs {}",
673                        i, dim1, dim2
674                    )));
675                }
676            }
677
678            total_size += shape[axis];
679        }
680
681        output_shape[axis] = total_size;
682        Ok(vec![output_shape])
683    }
684
685    fn num_inputs(&self) -> usize {
686        // Variable number of inputs, but we'll validate at runtime
687        2 // Minimum required
688    }
689
690    fn num_outputs(&self) -> usize {
691        1
692    }
693
694    fn validate_inputs(&self, inputs: &[Tensor<T>], params: &OperationParams) -> Result<()> {
695        if inputs.len() < 2 {
696            return Err(TorshError::InvalidArgument(
697                "Concat operation requires at least 2 inputs".to_string(),
698            ));
699        }
700
701        let axis = params.get_int("axis").unwrap_or(0) as usize;
702        let first_tensor_shape = inputs[0].shape();
703        let first_shape = first_tensor_shape.dims();
704
705        if axis >= first_shape.len() {
706            return Err(TorshError::InvalidArgument(format!(
707                "Concat axis {} out of bounds for {} dimensions",
708                axis,
709                first_shape.len()
710            )));
711        }
712
713        // Validate that all tensors have compatible shapes
714        for (i, tensor) in inputs.iter().enumerate().skip(1) {
715            let tensor_shape = tensor.shape();
716            let shape = tensor_shape.dims();
717            if shape.len() != first_shape.len() {
718                return Err(TorshError::InvalidArgument(format!(
719                    "Tensor {} has {} dimensions, expected {}",
720                    i,
721                    shape.len(),
722                    first_shape.len()
723                )));
724            }
725
726            for (dim_idx, (&dim1, &dim2)) in first_shape.iter().zip(shape.iter()).enumerate() {
727                if dim_idx != axis && dim1 != dim2 {
728                    return Err(TorshError::InvalidArgument(format!(
729                        "Tensor {} dimension {} mismatch: {} vs {}",
730                        i, dim_idx, dim1, dim2
731                    )));
732                }
733            }
734        }
735
736        Ok(())
737    }
738}
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743    use torsh_core::device::DeviceType;
744
745    #[test]
746    fn test_operation_params() {
747        let params = OperationParams::new()
748            .with_string("mode", "linear")
749            .with_int("axis", 1)
750            .with_float("scale", 2.5)
751            .with_bool("inplace", false)
752            .with_vector("weights", vec![1.0, 2.0, 3.0])
753            .with_shape("target_shape", vec![10, 20]);
754
755        assert_eq!(params.get_string("mode"), Some(&"linear".to_string()));
756        assert_eq!(params.get_int("axis"), Some(1));
757        assert_eq!(params.get_float("scale"), Some(2.5));
758        assert_eq!(params.get_bool("inplace"), Some(false));
759        assert_eq!(params.get_vector("weights"), Some(&vec![1.0, 2.0, 3.0]));
760        assert_eq!(params.get_shape("target_shape"), Some(&vec![10, 20]));
761
762        assert_eq!(params.get_string("nonexistent"), None);
763    }
764
765    #[test]
766    fn test_registry_operations() {
767        let registry = CustomOperationRegistry::new();
768
769        // Register a scale operation
770        let scale_op = Box::new(ScaleOperation);
771        registry
772            .register::<f32>(
773                scale_op,
774                "1.0.0",
775                Some("Test".to_string()),
776                vec!["math".to_string()],
777            )
778            .expect("registration should succeed");
779
780        // Check registration
781        assert!(registry.is_registered::<f32>("scale"));
782        assert!(!registry.is_registered::<f32>("nonexistent"));
783
784        // Get metadata
785        let metadata = registry
786            .get_metadata::<f32>("scale")
787            .expect("metadata retrieval should succeed");
788        assert_eq!(metadata.name, "scale");
789        assert_eq!(
790            metadata.description,
791            "Scales tensor elements by a constant factor"
792        );
793        assert_eq!(metadata.num_inputs, 1);
794        assert_eq!(metadata.num_outputs, 1);
795        assert_eq!(metadata.version, "1.0.0");
796        assert_eq!(metadata.author, Some("Test".to_string()));
797        assert_eq!(metadata.tags, vec!["math".to_string()]);
798
799        // List operations
800        let ops = registry.list_operations::<f32>();
801        assert_eq!(ops, vec!["scale".to_string()]);
802
803        // Unregister
804        registry
805            .unregister::<f32>("scale")
806            .expect("unregister should succeed");
807        assert!(!registry.is_registered::<f32>("scale"));
808    }
809
810    #[test]
811    fn test_scale_operation() {
812        let registry = CustomOperationRegistry::new();
813        let scale_op = Box::new(ScaleOperation);
814        registry
815            .register::<f32>(scale_op, "1.0.0", None, vec![])
816            .expect("unregister should succeed");
817
818        // Create test tensor
819        let data = vec![1.0f32, 2.0, 3.0, 4.0];
820        let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
821            .expect("tensor creation should succeed");
822
823        // Apply scale operation
824        let params = OperationParams::new().with_float("scale", 2.0);
825        let results = tensor
826            .apply_custom_op_with_registry(&registry, "scale", &[], &params)
827            .expect("tensor creation should succeed");
828
829        assert_eq!(results.len(), 1);
830        let result = &results[0];
831        let expected_data = vec![2.0f32, 4.0, 6.0, 8.0];
832        assert_eq!(
833            result.data().expect("data retrieval should succeed"),
834            expected_data
835        );
836    }
837
838    #[test]
839    fn test_concat_operation() {
840        let registry = CustomOperationRegistry::new();
841        let concat_op = Box::new(ConcatOperation);
842        registry
843            .register::<f32>(concat_op, "1.0.0", None, vec![])
844            .expect("registration should succeed");
845
846        // Create test tensors (1D to work with current cat implementation)
847        let data1 = vec![1.0f32, 2.0];
848        let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu)
849            .expect("tensor creation should succeed");
850
851        let data2 = vec![3.0f32, 4.0];
852        let tensor2 = Tensor::from_data(data2, vec![2], DeviceType::Cpu)
853            .expect("tensor creation should succeed");
854
855        // Apply concat operation along axis 0
856        let params = OperationParams::new().with_int("axis", 0);
857        let results = tensor1
858            .apply_custom_op_with_registry(&registry, "concat", &[&tensor2], &params)
859            .expect("tensor creation should succeed");
860
861        assert_eq!(results.len(), 1);
862        let result = &results[0];
863        assert_eq!(result.shape().dims(), &[4]); // 2 + 2 = 4 elements
864        let expected_data = vec![1.0f32, 2.0, 3.0, 4.0];
865        assert_eq!(
866            result.data().expect("data retrieval should succeed"),
867            expected_data
868        );
869    }
870
871    #[test]
872    fn test_operation_validation() {
873        let registry = CustomOperationRegistry::new();
874        let concat_op = Box::new(ConcatOperation);
875        registry
876            .register::<f32>(concat_op, "1.0.0", None, vec![])
877            .expect("registration should succeed");
878
879        // Create tensors with incompatible dimensions (2D vs 1D should fail)
880        let data1 = vec![1.0f32, 2.0];
881        let tensor1 = Tensor::from_data(data1, vec![2], DeviceType::Cpu)
882            .expect("tensor creation should succeed"); // 1D tensor
883
884        let data2 = vec![3.0f32, 4.0, 5.0, 6.0];
885        let tensor2 = Tensor::from_data(data2, vec![2, 2], DeviceType::Cpu)
886            .expect("tensor creation should succeed"); // 2D tensor
887
888        // This should fail validation due to different number of dimensions
889        let params = OperationParams::new().with_int("axis", 0);
890        let result =
891            tensor1.apply_custom_op_with_registry(&registry, "concat", &[&tensor2], &params);
892        assert!(result.is_err());
893    }
894
895    #[test]
896    fn test_output_shape_inference() {
897        let concat_op = ConcatOperation;
898
899        // Test shape inference for concat operation (1D tensors)
900        let input_shapes = vec![vec![3], vec![4]];
901        let params = OperationParams::new().with_int("axis", 0);
902
903        let output_shapes = <ConcatOperation as CustomOperation<f32>>::output_shapes(
904            &concat_op,
905            &input_shapes,
906            &params,
907        )
908        .expect("custom dtype operation should succeed");
909        assert_eq!(output_shapes, vec![vec![7]]); // 3 + 4 = 7 along axis 0
910    }
911
912    #[test]
913    fn test_error_cases() {
914        let registry = CustomOperationRegistry::new();
915
916        // Try to register duplicate operation
917        let scale_op1 = Box::new(ScaleOperation);
918        let scale_op2 = Box::new(ScaleOperation);
919
920        registry
921            .register::<f32>(scale_op1, "1.0.0", None, vec![])
922            .expect("registration should succeed");
923        let result = registry.register::<f32>(scale_op2, "1.0.0", None, vec![]);
924        assert!(result.is_err());
925
926        // Try to unregister non-existent operation
927        let result = registry.unregister::<f32>("nonexistent");
928        assert!(result.is_err());
929
930        // Try to apply non-existent operation
931        let data = vec![1.0f32, 2.0];
932        let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu)
933            .expect("tensor creation should succeed");
934        let params = OperationParams::new();
935        let result = tensor.apply_custom_op_with_registry(&registry, "nonexistent", &[], &params);
936        assert!(result.is_err());
937    }
938
939    #[test]
940    fn test_global_registry() {
941        let registry = global_registry();
942
943        // Register an operation in the global registry
944        let scale_op = Box::new(ScaleOperation);
945        registry
946            .register::<f32>(scale_op, "1.0.0", None, vec![])
947            .expect("registration should succeed");
948
949        // Use the operation via the tensor extension trait
950        let data = vec![1.0f32, 2.0, 3.0];
951        let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu)
952            .expect("tensor creation should succeed");
953        let params = OperationParams::new().with_float("scale", 3.0);
954
955        let results = tensor
956            .apply_custom_op("scale", &[], &params)
957            .expect("custom_op should succeed");
958        assert_eq!(results.len(), 1);
959        let expected_data = vec![3.0f32, 6.0, 9.0];
960        assert_eq!(
961            results[0].data().expect("data retrieval should succeed"),
962            expected_data
963        );
964
965        // Clean up
966        registry
967            .unregister::<f32>("scale")
968            .expect("unregister should succeed");
969    }
970}