sklears_ensemble/
tensor_ops.rs

1//! Tensor operations for ensemble methods
2//!
3//! This module provides high-level tensor operations optimized for ensemble learning,
4//! including batch operations, automatic differentiation support, and GPU acceleration.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis, Dimension, IxDyn};
7use sklears_core::error::{Result, SklearsError};
8use sklears_core::types::{Float, Int};
9use std::collections::HashMap;
10use std::ops::{Add, Mul};
11
12/// Multi-dimensional tensor type
13pub type Tensor = ArrayD<Float>;
14
15/// Tensor shape type
16pub type TensorShape = Vec<usize>;
17
18/// Tensor configuration
19#[derive(Debug, Clone)]
20pub struct TensorConfig {
21    /// Enable automatic differentiation
22    pub enable_autograd: bool,
23    /// Default tensor device (CPU or GPU)
24    pub default_device: TensorDevice,
25    /// Memory layout preference
26    pub memory_layout: MemoryLayout,
27    /// Enable graph optimization
28    pub enable_optimization: bool,
29    /// Maximum tensor size for automatic batching
30    pub max_batch_size: usize,
31}
32
33/// Tensor device enumeration
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum TensorDevice {
36    /// CPU computation
37    Cpu,
38    /// GPU computation
39    Gpu(usize), // GPU device ID
40    /// Automatic device selection
41    Auto,
42}
43
44/// Memory layout for tensors
45#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum MemoryLayout {
47    /// Row-major (C-style) layout
48    RowMajor,
49    /// Column-major (Fortran-style) layout
50    ColumnMajor,
51    /// Automatic layout selection
52    Auto,
53}
54
55/// Tensor operations context
56pub struct TensorOpsContext {
57    config: TensorConfig,
58    computation_graph: ComputationGraph,
59    device_manager: DeviceManager,
60}
61
62/// Computation graph for automatic differentiation
63#[derive(Debug, Default)]
64pub struct ComputationGraph {
65    nodes: Vec<GraphNode>,
66    edges: Vec<GraphEdge>,
67    current_node_id: usize,
68}
69
70/// Graph node representing a tensor operation
71#[derive(Debug, Clone)]
72pub struct GraphNode {
73    pub id: usize,
74    pub operation: TensorOperation,
75    pub shape: TensorShape,
76    pub requires_grad: bool,
77    pub grad: Option<Tensor>,
78}
79
80/// Graph edge connecting operations
81#[derive(Debug, Clone)]
82pub struct GraphEdge {
83    pub from: usize,
84    pub to: usize,
85    pub input_index: usize,
86}
87
88/// Tensor operation enumeration
89#[derive(Debug, Clone)]
90pub enum TensorOperation {
91    /// Leaf node (input tensor)
92    Leaf(String),
93    /// Addition operation
94    Add,
95    /// Subtraction operation
96    Sub,
97    /// Multiplication operation
98    Mul,
99    /// Division operation
100    Div,
101    /// Matrix multiplication
102    MatMul,
103    /// Element-wise activation functions
104    Activation(ActivationType),
105    /// Reduction operations
106    Reduction(ReductionType, Option<usize>),
107    /// Reshape operation
108    Reshape(TensorShape),
109    /// Transpose operation
110    Transpose(Vec<usize>),
111    /// Concatenation operation
112    Concat(usize), // axis
113    /// Split operation
114    Split(usize, Vec<usize>), // axis, split points
115    /// Ensemble aggregation
116    EnsembleAgg(AggregationType),
117}
118
119/// Activation function types
120#[derive(Debug, Clone, Copy)]
121pub enum ActivationType {
122    ReLU,
123    Sigmoid,
124    Tanh,
125    Softmax,
126    LogSoftmax,
127    LeakyReLU(Float),
128    ELU(Float),
129    GELU,
130}
131
132/// Reduction operation types
133#[derive(Debug, Clone, Copy)]
134pub enum ReductionType {
135    Sum,
136    Mean,
137    Max,
138    Min,
139    Prod,
140    Std,
141    Var,
142}
143
144/// Ensemble aggregation types
145#[derive(Debug, Clone, Copy)]
146pub enum AggregationType {
147    Average,
148    WeightedAverage,
149    Majority,
150    Stacking,
151    Blending,
152}
153
154/// Device manager for tensor operations
155pub struct DeviceManager {
156    available_devices: Vec<TensorDevice>,
157    current_device: TensorDevice,
158    memory_usage: HashMap<TensorDevice, usize>,
159}
160
161/// Ensemble tensor operations
162pub struct EnsembleTensorOps {
163    context: TensorOpsContext,
164}
165
166impl Default for TensorConfig {
167    fn default() -> Self {
168        Self {
169            enable_autograd: false,
170            default_device: TensorDevice::Cpu,
171            memory_layout: MemoryLayout::Auto,
172            enable_optimization: true,
173            max_batch_size: 1024,
174        }
175    }
176}
177
178impl TensorOpsContext {
179    /// Create new tensor operations context
180    pub fn new(config: TensorConfig) -> Self {
181        Self {
182            config,
183            computation_graph: ComputationGraph::default(),
184            device_manager: DeviceManager::new(),
185        }
186    }
187
188    /// Create tensor from ndarray
189    pub fn from_array<D: Dimension>(
190        &mut self,
191        array: &scirs2_core::ndarray::Array<Float, D>,
192    ) -> Result<Tensor> {
193        let tensor = array.clone().into_dyn();
194
195        if self.config.enable_autograd {
196            self.add_leaf_node("input".to_string(), tensor.shape().to_vec());
197        }
198
199        Ok(tensor)
200    }
201
202    /// Create tensor with specific shape and fill value
203    pub fn full(&mut self, shape: &[usize], value: Float) -> Result<Tensor> {
204        let tensor = Tensor::from_elem(IxDyn(shape), value);
205
206        if self.config.enable_autograd {
207            self.add_leaf_node("constant".to_string(), shape.to_vec());
208        }
209
210        Ok(tensor)
211    }
212
213    /// Create zero tensor
214    pub fn zeros(&mut self, shape: &[usize]) -> Result<Tensor> {
215        self.full(shape, 0.0)
216    }
217
218    /// Create ones tensor
219    pub fn ones(&mut self, shape: &[usize]) -> Result<Tensor> {
220        self.full(shape, 1.0)
221    }
222
223    /// Create random tensor
224    pub fn randn(&mut self, shape: &[usize]) -> Result<Tensor> {
225        use scirs2_core::random::prelude::*;
226
227        let size = shape.iter().product();
228        let mut rng = thread_rng();
229        // Use Box-Muller transform to generate normal distribution
230        let data: Vec<Float> = (0..size)
231            .map(|_| {
232                // Simple normal distribution using Box-Muller transform
233                let u1: f64 = rng.gen();
234                let u2: f64 = rng.gen();
235                let z = ((-2.0 * u1.ln()) as f64).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
236                z as Float
237            })
238            .collect();
239
240        let tensor = Tensor::from_shape_vec(IxDyn(shape), data)
241            .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
242
243        if self.config.enable_autograd {
244            self.add_leaf_node("random".to_string(), shape.to_vec());
245        }
246
247        Ok(tensor)
248    }
249
250    /// Element-wise addition
251    pub fn add(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
252        if a.shape() != b.shape() {
253            return Err(SklearsError::ShapeMismatch {
254                expected: format!("{:?}", a.shape()),
255                actual: format!("{:?}", b.shape()),
256            });
257        }
258
259        let result = a + b;
260
261        if self.config.enable_autograd {
262            self.add_binary_op_node(TensorOperation::Add, a.shape().to_vec());
263        }
264
265        Ok(result)
266    }
267
268    /// Element-wise subtraction
269    pub fn sub(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
270        if a.shape() != b.shape() {
271            return Err(SklearsError::ShapeMismatch {
272                expected: format!("{:?}", a.shape()),
273                actual: format!("{:?}", b.shape()),
274            });
275        }
276
277        let result = a - b;
278
279        if self.config.enable_autograd {
280            self.add_binary_op_node(TensorOperation::Sub, a.shape().to_vec());
281        }
282
283        Ok(result)
284    }
285
286    /// Element-wise multiplication
287    pub fn mul(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
288        if a.shape() != b.shape() {
289            return Err(SklearsError::ShapeMismatch {
290                expected: format!("{:?}", a.shape()),
291                actual: format!("{:?}", b.shape()),
292            });
293        }
294
295        let result = a * b;
296
297        if self.config.enable_autograd {
298            self.add_binary_op_node(TensorOperation::Mul, a.shape().to_vec());
299        }
300
301        Ok(result)
302    }
303
304    /// Matrix multiplication
305    pub fn matmul(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
306        // Convert to 2D for matrix multiplication
307        let a_2d = self.ensure_2d(a)?;
308        let b_2d = self.ensure_2d(b)?;
309
310        let result = a_2d.dot(&b_2d).into_dyn();
311
312        if self.config.enable_autograd {
313            let output_shape = vec![a_2d.nrows(), b_2d.ncols()];
314            self.add_binary_op_node(TensorOperation::MatMul, output_shape);
315        }
316
317        Ok(result)
318    }
319
320    /// Apply activation function
321    pub fn activation(&mut self, tensor: &Tensor, activation: ActivationType) -> Result<Tensor> {
322        let result = match activation {
323            ActivationType::ReLU => tensor.mapv(|x| x.max(0.0)),
324            ActivationType::Sigmoid => tensor.mapv(|x| 1.0 / (1.0 + (-x).exp())),
325            ActivationType::Tanh => tensor.mapv(|x| x.tanh()),
326            ActivationType::LeakyReLU(alpha) => {
327                tensor.mapv(|x| if x > 0.0 { x } else { alpha * x })
328            }
329            ActivationType::ELU(alpha) => {
330                tensor.mapv(|x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
331            }
332            ActivationType::GELU => tensor.mapv(|x| {
333                0.5 * x
334                    * (1.0 + (std::f64::consts::FRAC_2_SQRT_PI * (x + 0.044715 * x.powi(3))).tanh())
335            }),
336            ActivationType::Softmax => self.softmax_impl(tensor)?,
337            ActivationType::LogSoftmax => self.log_softmax_impl(tensor)?,
338        };
339
340        if self.config.enable_autograd {
341            self.add_unary_op_node(
342                TensorOperation::Activation(activation),
343                tensor.shape().to_vec(),
344            );
345        }
346
347        Ok(result)
348    }
349
350    /// Reduction operations
351    pub fn reduce(
352        &mut self,
353        tensor: &Tensor,
354        reduction: ReductionType,
355        axis: Option<usize>,
356    ) -> Result<Tensor> {
357        let result = match (reduction, axis) {
358            (ReductionType::Sum, None) => {
359                let sum = tensor.sum();
360                Tensor::from_elem(IxDyn(&[]), sum)
361            }
362            (ReductionType::Sum, Some(ax)) => tensor.sum_axis(Axis(ax)).into_dyn(),
363            (ReductionType::Mean, None) => {
364                let mean = tensor.mean().unwrap_or(0.0);
365                Tensor::from_elem(IxDyn(&[]), mean)
366            }
367            (ReductionType::Mean, Some(ax)) => tensor.mean_axis(Axis(ax)).unwrap().into_dyn(),
368            (ReductionType::Max, Some(ax)) => {
369                // Find max along axis
370                tensor
371                    .fold_axis(Axis(ax), Float::NEG_INFINITY, |&a, &b| a.max(b))
372                    .into_dyn()
373            }
374            (ReductionType::Min, Some(ax)) => {
375                // Find min along axis
376                tensor
377                    .fold_axis(Axis(ax), Float::INFINITY, |&a, &b| a.min(b))
378                    .into_dyn()
379            }
380            _ => {
381                return Err(SklearsError::InvalidInput(format!(
382                    "Reduction {:?} not implemented for axis {:?}",
383                    reduction, axis
384                )));
385            }
386        };
387
388        if self.config.enable_autograd {
389            let output_shape = result.shape().to_vec();
390            self.add_unary_op_node(TensorOperation::Reduction(reduction, axis), output_shape);
391        }
392
393        Ok(result)
394    }
395
396    /// Reshape tensor
397    pub fn reshape(&mut self, tensor: &Tensor, new_shape: &[usize]) -> Result<Tensor> {
398        let total_elements = tensor.len();
399        let new_total = new_shape.iter().product::<usize>();
400
401        if total_elements != new_total {
402            return Err(SklearsError::ShapeMismatch {
403                expected: format!("total elements = {}", total_elements),
404                actual: format!("total elements = {}", new_total),
405            });
406        }
407
408        let result = tensor
409            .clone()
410            .into_shape(IxDyn(new_shape))
411            .map_err(|e| SklearsError::InvalidInput(format!("Reshape error: {}", e)))?;
412
413        if self.config.enable_autograd {
414            self.add_unary_op_node(
415                TensorOperation::Reshape(new_shape.to_vec()),
416                new_shape.to_vec(),
417            );
418        }
419
420        Ok(result)
421    }
422
423    /// Transpose tensor
424    pub fn transpose(&mut self, tensor: &Tensor, axes: &[usize]) -> Result<Tensor> {
425        if axes.len() != tensor.ndim() {
426            return Err(SklearsError::InvalidInput(format!(
427                "Transpose axes count {} != tensor ndim {}",
428                axes.len(),
429                tensor.ndim()
430            )));
431        }
432
433        let result = tensor.clone().permuted_axes(axes);
434
435        if self.config.enable_autograd {
436            let output_shape = axes.iter().map(|&i| tensor.shape()[i]).collect();
437            self.add_unary_op_node(TensorOperation::Transpose(axes.to_vec()), output_shape);
438        }
439
440        Ok(result)
441    }
442
443    /// Concatenate tensors along axis
444    pub fn concat(&mut self, tensors: &[&Tensor], axis: usize) -> Result<Tensor> {
445        if tensors.is_empty() {
446            return Err(SklearsError::InvalidInput(
447                "Cannot concatenate empty tensor list".to_string(),
448            ));
449        }
450
451        // Convert to 2D arrays for concatenation
452        let arrays_2d: Result<Vec<_>> = tensors.iter().map(|t| self.ensure_2d(t)).collect();
453        let arrays_2d = arrays_2d?;
454
455        let views: Vec<_> = arrays_2d.iter().map(|a| a.view()).collect();
456        let result = scirs2_core::ndarray::concatenate(Axis(axis), &views)
457            .map_err(|e| SklearsError::InvalidInput(format!("Concatenation error: {}", e)))?
458            .into_dyn();
459
460        if self.config.enable_autograd {
461            let output_shape = result.shape().to_vec();
462            self.add_variadic_op_node(TensorOperation::Concat(axis), output_shape, tensors.len());
463        }
464
465        Ok(result)
466    }
467
468    /// Ensemble-specific operations
469    pub fn ensemble_aggregate(
470        &mut self,
471        predictions: &[&Tensor],
472        weights: Option<&Tensor>,
473        aggregation: AggregationType,
474    ) -> Result<Tensor> {
475        match aggregation {
476            AggregationType::Average => self.ensemble_average(predictions),
477            AggregationType::WeightedAverage => {
478                if let Some(w) = weights {
479                    self.ensemble_weighted_average(predictions, w)
480                } else {
481                    self.ensemble_average(predictions)
482                }
483            }
484            AggregationType::Majority => self.ensemble_majority_vote(predictions),
485            _ => Err(SklearsError::InvalidInput(format!(
486                "Aggregation type {:?} not yet implemented",
487                aggregation
488            ))),
489        }
490    }
491
492    /// Batch operations for ensemble training
493    pub fn batch_ensemble_forward(
494        &mut self,
495        inputs: &[&Tensor],
496        models: &[&Tensor], // Model parameters
497    ) -> Result<Vec<Tensor>> {
498        let mut outputs = Vec::new();
499
500        for (input, model) in inputs.iter().zip(models.iter()) {
501            // Simplified forward pass - in practice this would depend on model type
502            let output = self.matmul(input, model)?;
503            outputs.push(output);
504        }
505
506        Ok(outputs)
507    }
508
509    /// Backward pass for gradient computation
510    pub fn backward(&mut self, loss: &Tensor) -> Result<HashMap<String, Tensor>> {
511        if !self.config.enable_autograd {
512            return Err(SklearsError::InvalidInput(
513                "Autograd not enabled. Set enable_autograd=true in config.".to_string(),
514            ));
515        }
516
517        // Placeholder for actual backward pass implementation
518        // In a real implementation, this would traverse the computation graph
519        // and compute gradients using the chain rule
520
521        let mut gradients = HashMap::new();
522        gradients.insert("placeholder".to_string(), loss.clone());
523
524        Ok(gradients)
525    }
526
527    /// Get computation graph
528    pub fn get_computation_graph(&self) -> &ComputationGraph {
529        &self.computation_graph
530    }
531
532    /// Clear computation graph
533    pub fn clear_graph(&mut self) {
534        self.computation_graph = ComputationGraph::default();
535    }
536
537    // Private helper methods
538
539    fn ensure_2d(&self, tensor: &Tensor) -> Result<Array2<Float>> {
540        match tensor.ndim() {
541            1 => {
542                let array_1d = tensor
543                    .clone()
544                    .into_dimensionality::<scirs2_core::ndarray::Ix1>()
545                    .map_err(|e| {
546                        SklearsError::InvalidInput(format!("1D conversion error: {}", e))
547                    })?;
548                Ok(array_1d.insert_axis(Axis(0)))
549            }
550            2 => tensor
551                .clone()
552                .into_dimensionality::<scirs2_core::ndarray::Ix2>()
553                .map_err(|e| SklearsError::InvalidInput(format!("2D conversion error: {}", e))),
554            _ => Err(SklearsError::InvalidInput(format!(
555                "Cannot convert {}D tensor to 2D",
556                tensor.ndim()
557            ))),
558        }
559    }
560
561    fn softmax_impl(&self, tensor: &Tensor) -> Result<Tensor> {
562        // Ensure 2D for softmax computation
563        let tensor_2d = self.ensure_2d(tensor)?;
564        let mut result = tensor_2d.clone();
565
566        for mut row in result.rows_mut() {
567            let max_val = row.fold(Float::NEG_INFINITY, |a, &b| a.max(b));
568            row.mapv_inplace(|x| (x - max_val).exp());
569            let sum = row.sum();
570            if sum > 0.0 {
571                row /= sum;
572            }
573        }
574
575        Ok(result.into_dyn())
576    }
577
578    fn log_softmax_impl(&self, tensor: &Tensor) -> Result<Tensor> {
579        let softmax = self.softmax_impl(tensor)?;
580        Ok(softmax.mapv(|x| x.ln()))
581    }
582
583    fn ensemble_average(&mut self, predictions: &[&Tensor]) -> Result<Tensor> {
584        if predictions.is_empty() {
585            return Err(SklearsError::InvalidInput(
586                "No predictions to average".to_string(),
587            ));
588        }
589
590        let mut sum = predictions[0].clone();
591        for pred in predictions.iter().skip(1) {
592            sum = self.add(&sum, pred)?;
593        }
594
595        let n = predictions.len() as Float;
596        Ok(sum.mapv(|x| x / n))
597    }
598
599    fn ensemble_weighted_average(
600        &mut self,
601        predictions: &[&Tensor],
602        weights: &Tensor,
603    ) -> Result<Tensor> {
604        if predictions.is_empty() {
605            return Err(SklearsError::InvalidInput(
606                "No predictions to average".to_string(),
607            ));
608        }
609
610        if weights.len() != predictions.len() {
611            return Err(SklearsError::ShapeMismatch {
612                expected: format!("{} weights", predictions.len()),
613                actual: format!("{} weights", weights.len()),
614            });
615        }
616
617        let mut weighted_sum = self.mul(
618            predictions[0],
619            &weights
620                .slice(scirs2_core::ndarray::s![0..1])
621                .to_owned()
622                .into_dyn(),
623        )?;
624
625        for (i, pred) in predictions.iter().enumerate().skip(1) {
626            let weight = weights
627                .slice(scirs2_core::ndarray::s![i..i + 1])
628                .to_owned()
629                .into_dyn();
630            let weighted_pred = self.mul(pred, &weight)?;
631            weighted_sum = self.add(&weighted_sum, &weighted_pred)?;
632        }
633
634        Ok(weighted_sum)
635    }
636
637    fn ensemble_majority_vote(&mut self, predictions: &[&Tensor]) -> Result<Tensor> {
638        if predictions.is_empty() {
639            return Err(SklearsError::InvalidInput(
640                "No predictions for majority vote".to_string(),
641            ));
642        }
643
644        // Convert predictions to discrete votes and find majority
645        // This is a simplified implementation
646        let first_shape = predictions[0].shape();
647        let mut votes = Tensor::zeros(IxDyn(first_shape));
648
649        for pred in predictions {
650            // Round predictions to nearest integer for voting
651            let rounded = pred.mapv(|x| x.round());
652            votes = self.add(&votes, &rounded)?;
653        }
654
655        // Majority decision
656        let n_models = predictions.len() as Float;
657        Ok(votes.mapv(|x| if x > n_models / 2.0 { 1.0 } else { 0.0 }))
658    }
659
660    fn add_leaf_node(&mut self, name: String, shape: TensorShape) {
661        let node = GraphNode {
662            id: self.computation_graph.current_node_id,
663            operation: TensorOperation::Leaf(name),
664            shape,
665            requires_grad: false,
666            grad: None,
667        };
668
669        self.computation_graph.nodes.push(node);
670        self.computation_graph.current_node_id += 1;
671    }
672
673    fn add_unary_op_node(&mut self, operation: TensorOperation, output_shape: TensorShape) {
674        let node = GraphNode {
675            id: self.computation_graph.current_node_id,
676            operation,
677            shape: output_shape,
678            requires_grad: false,
679            grad: None,
680        };
681
682        self.computation_graph.nodes.push(node);
683        self.computation_graph.current_node_id += 1;
684    }
685
686    fn add_binary_op_node(&mut self, operation: TensorOperation, output_shape: TensorShape) {
687        let node = GraphNode {
688            id: self.computation_graph.current_node_id,
689            operation,
690            shape: output_shape,
691            requires_grad: false,
692            grad: None,
693        };
694
695        self.computation_graph.nodes.push(node);
696        self.computation_graph.current_node_id += 1;
697    }
698
699    fn add_variadic_op_node(
700        &mut self,
701        operation: TensorOperation,
702        output_shape: TensorShape,
703        _n_inputs: usize,
704    ) {
705        let node = GraphNode {
706            id: self.computation_graph.current_node_id,
707            operation,
708            shape: output_shape,
709            requires_grad: false,
710            grad: None,
711        };
712
713        self.computation_graph.nodes.push(node);
714        self.computation_graph.current_node_id += 1;
715    }
716}
717
718impl Default for DeviceManager {
719    fn default() -> Self {
720        Self::new()
721    }
722}
723
724impl DeviceManager {
725    /// Create new device manager
726    pub fn new() -> Self {
727        Self {
728            available_devices: vec![TensorDevice::Cpu],
729            current_device: TensorDevice::Cpu,
730            memory_usage: HashMap::new(),
731        }
732    }
733
734    /// Get available devices
735    pub fn available_devices(&self) -> &[TensorDevice] {
736        &self.available_devices
737    }
738
739    /// Set current device
740    pub fn set_device(&mut self, device: TensorDevice) {
741        self.current_device = device;
742    }
743
744    /// Get current device
745    pub fn current_device(&self) -> TensorDevice {
746        self.current_device
747    }
748
749    /// Get memory usage for device
750    pub fn memory_usage(&self, device: TensorDevice) -> usize {
751        self.memory_usage.get(&device).copied().unwrap_or(0)
752    }
753}
754
755impl EnsembleTensorOps {
756    /// Create new ensemble tensor operations
757    pub fn new(config: TensorConfig) -> Self {
758        Self {
759            context: TensorOpsContext::new(config),
760        }
761    }
762
763    /// Train ensemble with tensor operations
764    pub fn train_ensemble_tensors(
765        &mut self,
766        x: &Array2<Float>,
767        y: &Array1<Int>,
768        n_estimators: usize,
769    ) -> Result<Vec<Tensor>> {
770        let x_tensor = self.context.from_array(x)?;
771        let mut models = Vec::new();
772
773        for _i in 0..n_estimators {
774            // Create a simple linear model (weight matrix)
775            let n_features = x.ncols();
776            let model_weights = self.context.randn(&[n_features, 1])?;
777            models.push(model_weights);
778        }
779
780        Ok(models)
781    }
782
783    /// Predict with ensemble using tensor operations
784    pub fn predict_ensemble_tensors(
785        &mut self,
786        models: &[Tensor],
787        x: &Array2<Float>,
788    ) -> Result<Tensor> {
789        let x_tensor = self.context.from_array(x)?;
790        let mut predictions = Vec::new();
791
792        for model in models {
793            let pred = self.context.matmul(&x_tensor, model)?;
794            predictions.push(pred);
795        }
796
797        // Average predictions
798        let pred_refs: Vec<_> = predictions.iter().collect();
799        self.context
800            .ensemble_aggregate(&pred_refs, None, AggregationType::Average)
801    }
802
803    /// Get mutable context
804    pub fn context_mut(&mut self) -> &mut TensorOpsContext {
805        &mut self.context
806    }
807
808    /// Get context
809    pub fn context(&self) -> &TensorOpsContext {
810        &self.context
811    }
812}
813
814// Convenience macro for tensor operations
815#[macro_export]
816macro_rules! tensor_op {
817    ($ctx:expr, $op:ident, $($args:expr),*) => {
818        $ctx.$op($($args),*)
819    };
820}
821
822#[allow(non_snake_case)]
823#[cfg(test)]
824mod tests {
825    use super::*;
826    use scirs2_core::ndarray::array;
827
828    #[test]
829    fn test_tensor_config() {
830        let config = TensorConfig::default();
831        assert!(!config.enable_autograd);
832        assert_eq!(config.default_device, TensorDevice::Cpu);
833    }
834
835    #[test]
836    fn test_tensor_context_creation() {
837        let config = TensorConfig::default();
838        let mut ctx = TensorOpsContext::new(config);
839
840        let tensor = ctx.zeros(&[2, 3]).unwrap();
841        assert_eq!(tensor.shape(), &[2, 3]);
842    }
843
844    #[test]
845    fn test_tensor_operations() {
846        let config = TensorConfig::default();
847        let mut ctx = TensorOpsContext::new(config);
848
849        let a = ctx.ones(&[2, 2]).unwrap();
850        let b = ctx.full(&[2, 2], 2.0).unwrap();
851
852        let result = ctx.add(&a, &b).unwrap();
853
854        // Check all elements are 3.0
855        assert!(result.iter().all(|&x| (x - 3.0).abs() < 1e-10));
856    }
857
858    #[test]
859    fn test_matrix_multiplication() {
860        let config = TensorConfig::default();
861        let mut ctx = TensorOpsContext::new(config);
862
863        let a_array = array![[1.0, 2.0], [3.0, 4.0]];
864        let b_array = array![[5.0, 6.0], [7.0, 8.0]];
865
866        let a = ctx.from_array(&a_array).unwrap();
867        let b = ctx.from_array(&b_array).unwrap();
868
869        let result = ctx.matmul(&a, &b).unwrap();
870
871        // Expected: [[19, 22], [43, 50]]
872        let result_2d = result
873            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
874            .unwrap();
875        assert_eq!(result_2d[[0, 0]], 19.0);
876        assert_eq!(result_2d[[0, 1]], 22.0);
877        assert_eq!(result_2d[[1, 0]], 43.0);
878        assert_eq!(result_2d[[1, 1]], 50.0);
879    }
880
881    #[test]
882    fn test_activation_functions() {
883        let config = TensorConfig::default();
884        let mut ctx = TensorOpsContext::new(config);
885
886        let tensor = ctx.from_array(&array![[-1.0, 0.0, 1.0]]).unwrap();
887
888        let relu_result = ctx.activation(&tensor, ActivationType::ReLU).unwrap();
889        let sigmoid_result = ctx.activation(&tensor, ActivationType::Sigmoid).unwrap();
890
891        // ReLU should clip negative values to 0
892        assert_eq!(relu_result.as_slice().unwrap()[0], 0.0);
893        assert_eq!(relu_result.as_slice().unwrap()[1], 0.0);
894        assert_eq!(relu_result.as_slice().unwrap()[2], 1.0);
895
896        // Sigmoid should be between 0 and 1
897        assert!(sigmoid_result.iter().all(|&x| x >= 0.0 && x <= 1.0));
898    }
899
900    #[test]
901    fn test_reduction_operations() {
902        let config = TensorConfig::default();
903        let mut ctx = TensorOpsContext::new(config);
904
905        let tensor = ctx.from_array(&array![[1.0, 2.0], [3.0, 4.0]]).unwrap();
906
907        let sum_result = ctx.reduce(&tensor, ReductionType::Sum, None).unwrap();
908        let mean_result = ctx.reduce(&tensor, ReductionType::Mean, None).unwrap();
909
910        assert_eq!(sum_result.as_slice().unwrap()[0], 10.0);
911        assert_eq!(mean_result.as_slice().unwrap()[0], 2.5);
912    }
913
914    #[test]
915    fn test_ensemble_operations() {
916        let config = TensorConfig::default();
917        let mut ctx = TensorOpsContext::new(config);
918
919        let pred1 = ctx.from_array(&array![[1.0, 2.0]]).unwrap();
920        let pred2 = ctx.from_array(&array![[3.0, 4.0]]).unwrap();
921        let predictions = vec![&pred1, &pred2];
922
923        let avg_result = ctx
924            .ensemble_aggregate(&predictions, None, AggregationType::Average)
925            .unwrap();
926
927        // Average should be [2.0, 3.0]
928        assert_eq!(avg_result.as_slice().unwrap()[0], 2.0);
929        assert_eq!(avg_result.as_slice().unwrap()[1], 3.0);
930    }
931
932    #[test]
933    fn test_ensemble_tensor_ops() {
934        let config = TensorConfig::default();
935        let mut ensemble_ops = EnsembleTensorOps::new(config);
936
937        let x = array![[1.0, 2.0], [3.0, 4.0]];
938        let y = array![0, 1];
939
940        let models = ensemble_ops.train_ensemble_tensors(&x, &y, 3).unwrap();
941        assert_eq!(models.len(), 3);
942
943        let predictions = ensemble_ops.predict_ensemble_tensors(&models, &x).unwrap();
944        assert_eq!(predictions.shape()[0], 2); // 2 samples
945    }
946
947    #[test]
948    fn test_device_manager() {
949        let mut manager = DeviceManager::new();
950
951        assert_eq!(manager.current_device(), TensorDevice::Cpu);
952        assert_eq!(manager.memory_usage(TensorDevice::Cpu), 0);
953
954        manager.set_device(TensorDevice::Gpu(0));
955        assert_eq!(manager.current_device(), TensorDevice::Gpu(0));
956    }
957
958    #[test]
959    fn test_computation_graph() {
960        let config = TensorConfig {
961            enable_autograd: true,
962            ..Default::default()
963        };
964        let mut ctx = TensorOpsContext::new(config);
965
966        let a = ctx.ones(&[2, 2]).unwrap();
967        let b = ctx.ones(&[2, 2]).unwrap();
968        let _c = ctx.add(&a, &b).unwrap();
969
970        let graph = ctx.get_computation_graph();
971        assert!(graph.nodes.len() > 0);
972    }
973}