1use scirs2_core::ndarray::{Array1, Array2, ArrayD, Axis, Dimension, IxDyn};
7use sklears_core::error::{Result, SklearsError};
8use sklears_core::types::{Float, Int};
9use std::collections::HashMap;
10use std::ops::{Add, Mul};
11
12pub type Tensor = ArrayD<Float>;
14
15pub type TensorShape = Vec<usize>;
17
18#[derive(Debug, Clone)]
20pub struct TensorConfig {
21 pub enable_autograd: bool,
23 pub default_device: TensorDevice,
25 pub memory_layout: MemoryLayout,
27 pub enable_optimization: bool,
29 pub max_batch_size: usize,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub enum TensorDevice {
36 Cpu,
38 Gpu(usize), Auto,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum MemoryLayout {
47 RowMajor,
49 ColumnMajor,
51 Auto,
53}
54
55pub struct TensorOpsContext {
57 config: TensorConfig,
58 computation_graph: ComputationGraph,
59 device_manager: DeviceManager,
60}
61
62#[derive(Debug, Default)]
64pub struct ComputationGraph {
65 nodes: Vec<GraphNode>,
66 edges: Vec<GraphEdge>,
67 current_node_id: usize,
68}
69
70#[derive(Debug, Clone)]
72pub struct GraphNode {
73 pub id: usize,
74 pub operation: TensorOperation,
75 pub shape: TensorShape,
76 pub requires_grad: bool,
77 pub grad: Option<Tensor>,
78}
79
80#[derive(Debug, Clone)]
82pub struct GraphEdge {
83 pub from: usize,
84 pub to: usize,
85 pub input_index: usize,
86}
87
88#[derive(Debug, Clone)]
90pub enum TensorOperation {
91 Leaf(String),
93 Add,
95 Sub,
97 Mul,
99 Div,
101 MatMul,
103 Activation(ActivationType),
105 Reduction(ReductionType, Option<usize>),
107 Reshape(TensorShape),
109 Transpose(Vec<usize>),
111 Concat(usize), Split(usize, Vec<usize>), EnsembleAgg(AggregationType),
117}
118
119#[derive(Debug, Clone, Copy)]
121pub enum ActivationType {
122 ReLU,
123 Sigmoid,
124 Tanh,
125 Softmax,
126 LogSoftmax,
127 LeakyReLU(Float),
128 ELU(Float),
129 GELU,
130}
131
132#[derive(Debug, Clone, Copy)]
134pub enum ReductionType {
135 Sum,
136 Mean,
137 Max,
138 Min,
139 Prod,
140 Std,
141 Var,
142}
143
144#[derive(Debug, Clone, Copy)]
146pub enum AggregationType {
147 Average,
148 WeightedAverage,
149 Majority,
150 Stacking,
151 Blending,
152}
153
154pub struct DeviceManager {
156 available_devices: Vec<TensorDevice>,
157 current_device: TensorDevice,
158 memory_usage: HashMap<TensorDevice, usize>,
159}
160
161pub struct EnsembleTensorOps {
163 context: TensorOpsContext,
164}
165
166impl Default for TensorConfig {
167 fn default() -> Self {
168 Self {
169 enable_autograd: false,
170 default_device: TensorDevice::Cpu,
171 memory_layout: MemoryLayout::Auto,
172 enable_optimization: true,
173 max_batch_size: 1024,
174 }
175 }
176}
177
178impl TensorOpsContext {
179 pub fn new(config: TensorConfig) -> Self {
181 Self {
182 config,
183 computation_graph: ComputationGraph::default(),
184 device_manager: DeviceManager::new(),
185 }
186 }
187
188 pub fn from_array<D: Dimension>(
190 &mut self,
191 array: &scirs2_core::ndarray::Array<Float, D>,
192 ) -> Result<Tensor> {
193 let tensor = array.clone().into_dyn();
194
195 if self.config.enable_autograd {
196 self.add_leaf_node("input".to_string(), tensor.shape().to_vec());
197 }
198
199 Ok(tensor)
200 }
201
202 pub fn full(&mut self, shape: &[usize], value: Float) -> Result<Tensor> {
204 let tensor = Tensor::from_elem(IxDyn(shape), value);
205
206 if self.config.enable_autograd {
207 self.add_leaf_node("constant".to_string(), shape.to_vec());
208 }
209
210 Ok(tensor)
211 }
212
213 pub fn zeros(&mut self, shape: &[usize]) -> Result<Tensor> {
215 self.full(shape, 0.0)
216 }
217
218 pub fn ones(&mut self, shape: &[usize]) -> Result<Tensor> {
220 self.full(shape, 1.0)
221 }
222
223 pub fn randn(&mut self, shape: &[usize]) -> Result<Tensor> {
225 use scirs2_core::random::prelude::*;
226
227 let size = shape.iter().product();
228 let mut rng = thread_rng();
229 let data: Vec<Float> = (0..size)
231 .map(|_| {
232 let u1: f64 = rng.gen();
234 let u2: f64 = rng.gen();
235 let z = ((-2.0 * u1.ln()) as f64).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
236 z as Float
237 })
238 .collect();
239
240 let tensor = Tensor::from_shape_vec(IxDyn(shape), data)
241 .map_err(|e| SklearsError::InvalidInput(format!("Shape error: {}", e)))?;
242
243 if self.config.enable_autograd {
244 self.add_leaf_node("random".to_string(), shape.to_vec());
245 }
246
247 Ok(tensor)
248 }
249
250 pub fn add(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
252 if a.shape() != b.shape() {
253 return Err(SklearsError::ShapeMismatch {
254 expected: format!("{:?}", a.shape()),
255 actual: format!("{:?}", b.shape()),
256 });
257 }
258
259 let result = a + b;
260
261 if self.config.enable_autograd {
262 self.add_binary_op_node(TensorOperation::Add, a.shape().to_vec());
263 }
264
265 Ok(result)
266 }
267
268 pub fn sub(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
270 if a.shape() != b.shape() {
271 return Err(SklearsError::ShapeMismatch {
272 expected: format!("{:?}", a.shape()),
273 actual: format!("{:?}", b.shape()),
274 });
275 }
276
277 let result = a - b;
278
279 if self.config.enable_autograd {
280 self.add_binary_op_node(TensorOperation::Sub, a.shape().to_vec());
281 }
282
283 Ok(result)
284 }
285
286 pub fn mul(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
288 if a.shape() != b.shape() {
289 return Err(SklearsError::ShapeMismatch {
290 expected: format!("{:?}", a.shape()),
291 actual: format!("{:?}", b.shape()),
292 });
293 }
294
295 let result = a * b;
296
297 if self.config.enable_autograd {
298 self.add_binary_op_node(TensorOperation::Mul, a.shape().to_vec());
299 }
300
301 Ok(result)
302 }
303
304 pub fn matmul(&mut self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
306 let a_2d = self.ensure_2d(a)?;
308 let b_2d = self.ensure_2d(b)?;
309
310 let result = a_2d.dot(&b_2d).into_dyn();
311
312 if self.config.enable_autograd {
313 let output_shape = vec![a_2d.nrows(), b_2d.ncols()];
314 self.add_binary_op_node(TensorOperation::MatMul, output_shape);
315 }
316
317 Ok(result)
318 }
319
320 pub fn activation(&mut self, tensor: &Tensor, activation: ActivationType) -> Result<Tensor> {
322 let result = match activation {
323 ActivationType::ReLU => tensor.mapv(|x| x.max(0.0)),
324 ActivationType::Sigmoid => tensor.mapv(|x| 1.0 / (1.0 + (-x).exp())),
325 ActivationType::Tanh => tensor.mapv(|x| x.tanh()),
326 ActivationType::LeakyReLU(alpha) => {
327 tensor.mapv(|x| if x > 0.0 { x } else { alpha * x })
328 }
329 ActivationType::ELU(alpha) => {
330 tensor.mapv(|x| if x > 0.0 { x } else { alpha * (x.exp() - 1.0) })
331 }
332 ActivationType::GELU => tensor.mapv(|x| {
333 0.5 * x
334 * (1.0 + (std::f64::consts::FRAC_2_SQRT_PI * (x + 0.044715 * x.powi(3))).tanh())
335 }),
336 ActivationType::Softmax => self.softmax_impl(tensor)?,
337 ActivationType::LogSoftmax => self.log_softmax_impl(tensor)?,
338 };
339
340 if self.config.enable_autograd {
341 self.add_unary_op_node(
342 TensorOperation::Activation(activation),
343 tensor.shape().to_vec(),
344 );
345 }
346
347 Ok(result)
348 }
349
350 pub fn reduce(
352 &mut self,
353 tensor: &Tensor,
354 reduction: ReductionType,
355 axis: Option<usize>,
356 ) -> Result<Tensor> {
357 let result = match (reduction, axis) {
358 (ReductionType::Sum, None) => {
359 let sum = tensor.sum();
360 Tensor::from_elem(IxDyn(&[]), sum)
361 }
362 (ReductionType::Sum, Some(ax)) => tensor.sum_axis(Axis(ax)).into_dyn(),
363 (ReductionType::Mean, None) => {
364 let mean = tensor.mean().unwrap_or(0.0);
365 Tensor::from_elem(IxDyn(&[]), mean)
366 }
367 (ReductionType::Mean, Some(ax)) => tensor.mean_axis(Axis(ax)).unwrap().into_dyn(),
368 (ReductionType::Max, Some(ax)) => {
369 tensor
371 .fold_axis(Axis(ax), Float::NEG_INFINITY, |&a, &b| a.max(b))
372 .into_dyn()
373 }
374 (ReductionType::Min, Some(ax)) => {
375 tensor
377 .fold_axis(Axis(ax), Float::INFINITY, |&a, &b| a.min(b))
378 .into_dyn()
379 }
380 _ => {
381 return Err(SklearsError::InvalidInput(format!(
382 "Reduction {:?} not implemented for axis {:?}",
383 reduction, axis
384 )));
385 }
386 };
387
388 if self.config.enable_autograd {
389 let output_shape = result.shape().to_vec();
390 self.add_unary_op_node(TensorOperation::Reduction(reduction, axis), output_shape);
391 }
392
393 Ok(result)
394 }
395
396 pub fn reshape(&mut self, tensor: &Tensor, new_shape: &[usize]) -> Result<Tensor> {
398 let total_elements = tensor.len();
399 let new_total = new_shape.iter().product::<usize>();
400
401 if total_elements != new_total {
402 return Err(SklearsError::ShapeMismatch {
403 expected: format!("total elements = {}", total_elements),
404 actual: format!("total elements = {}", new_total),
405 });
406 }
407
408 let result = tensor
409 .clone()
410 .into_shape(IxDyn(new_shape))
411 .map_err(|e| SklearsError::InvalidInput(format!("Reshape error: {}", e)))?;
412
413 if self.config.enable_autograd {
414 self.add_unary_op_node(
415 TensorOperation::Reshape(new_shape.to_vec()),
416 new_shape.to_vec(),
417 );
418 }
419
420 Ok(result)
421 }
422
423 pub fn transpose(&mut self, tensor: &Tensor, axes: &[usize]) -> Result<Tensor> {
425 if axes.len() != tensor.ndim() {
426 return Err(SklearsError::InvalidInput(format!(
427 "Transpose axes count {} != tensor ndim {}",
428 axes.len(),
429 tensor.ndim()
430 )));
431 }
432
433 let result = tensor.clone().permuted_axes(axes);
434
435 if self.config.enable_autograd {
436 let output_shape = axes.iter().map(|&i| tensor.shape()[i]).collect();
437 self.add_unary_op_node(TensorOperation::Transpose(axes.to_vec()), output_shape);
438 }
439
440 Ok(result)
441 }
442
443 pub fn concat(&mut self, tensors: &[&Tensor], axis: usize) -> Result<Tensor> {
445 if tensors.is_empty() {
446 return Err(SklearsError::InvalidInput(
447 "Cannot concatenate empty tensor list".to_string(),
448 ));
449 }
450
451 let arrays_2d: Result<Vec<_>> = tensors.iter().map(|t| self.ensure_2d(t)).collect();
453 let arrays_2d = arrays_2d?;
454
455 let views: Vec<_> = arrays_2d.iter().map(|a| a.view()).collect();
456 let result = scirs2_core::ndarray::concatenate(Axis(axis), &views)
457 .map_err(|e| SklearsError::InvalidInput(format!("Concatenation error: {}", e)))?
458 .into_dyn();
459
460 if self.config.enable_autograd {
461 let output_shape = result.shape().to_vec();
462 self.add_variadic_op_node(TensorOperation::Concat(axis), output_shape, tensors.len());
463 }
464
465 Ok(result)
466 }
467
468 pub fn ensemble_aggregate(
470 &mut self,
471 predictions: &[&Tensor],
472 weights: Option<&Tensor>,
473 aggregation: AggregationType,
474 ) -> Result<Tensor> {
475 match aggregation {
476 AggregationType::Average => self.ensemble_average(predictions),
477 AggregationType::WeightedAverage => {
478 if let Some(w) = weights {
479 self.ensemble_weighted_average(predictions, w)
480 } else {
481 self.ensemble_average(predictions)
482 }
483 }
484 AggregationType::Majority => self.ensemble_majority_vote(predictions),
485 _ => Err(SklearsError::InvalidInput(format!(
486 "Aggregation type {:?} not yet implemented",
487 aggregation
488 ))),
489 }
490 }
491
492 pub fn batch_ensemble_forward(
494 &mut self,
495 inputs: &[&Tensor],
496 models: &[&Tensor], ) -> Result<Vec<Tensor>> {
498 let mut outputs = Vec::new();
499
500 for (input, model) in inputs.iter().zip(models.iter()) {
501 let output = self.matmul(input, model)?;
503 outputs.push(output);
504 }
505
506 Ok(outputs)
507 }
508
509 pub fn backward(&mut self, loss: &Tensor) -> Result<HashMap<String, Tensor>> {
511 if !self.config.enable_autograd {
512 return Err(SklearsError::InvalidInput(
513 "Autograd not enabled. Set enable_autograd=true in config.".to_string(),
514 ));
515 }
516
517 let mut gradients = HashMap::new();
522 gradients.insert("placeholder".to_string(), loss.clone());
523
524 Ok(gradients)
525 }
526
527 pub fn get_computation_graph(&self) -> &ComputationGraph {
529 &self.computation_graph
530 }
531
532 pub fn clear_graph(&mut self) {
534 self.computation_graph = ComputationGraph::default();
535 }
536
537 fn ensure_2d(&self, tensor: &Tensor) -> Result<Array2<Float>> {
540 match tensor.ndim() {
541 1 => {
542 let array_1d = tensor
543 .clone()
544 .into_dimensionality::<scirs2_core::ndarray::Ix1>()
545 .map_err(|e| {
546 SklearsError::InvalidInput(format!("1D conversion error: {}", e))
547 })?;
548 Ok(array_1d.insert_axis(Axis(0)))
549 }
550 2 => tensor
551 .clone()
552 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
553 .map_err(|e| SklearsError::InvalidInput(format!("2D conversion error: {}", e))),
554 _ => Err(SklearsError::InvalidInput(format!(
555 "Cannot convert {}D tensor to 2D",
556 tensor.ndim()
557 ))),
558 }
559 }
560
561 fn softmax_impl(&self, tensor: &Tensor) -> Result<Tensor> {
562 let tensor_2d = self.ensure_2d(tensor)?;
564 let mut result = tensor_2d.clone();
565
566 for mut row in result.rows_mut() {
567 let max_val = row.fold(Float::NEG_INFINITY, |a, &b| a.max(b));
568 row.mapv_inplace(|x| (x - max_val).exp());
569 let sum = row.sum();
570 if sum > 0.0 {
571 row /= sum;
572 }
573 }
574
575 Ok(result.into_dyn())
576 }
577
578 fn log_softmax_impl(&self, tensor: &Tensor) -> Result<Tensor> {
579 let softmax = self.softmax_impl(tensor)?;
580 Ok(softmax.mapv(|x| x.ln()))
581 }
582
583 fn ensemble_average(&mut self, predictions: &[&Tensor]) -> Result<Tensor> {
584 if predictions.is_empty() {
585 return Err(SklearsError::InvalidInput(
586 "No predictions to average".to_string(),
587 ));
588 }
589
590 let mut sum = predictions[0].clone();
591 for pred in predictions.iter().skip(1) {
592 sum = self.add(&sum, pred)?;
593 }
594
595 let n = predictions.len() as Float;
596 Ok(sum.mapv(|x| x / n))
597 }
598
599 fn ensemble_weighted_average(
600 &mut self,
601 predictions: &[&Tensor],
602 weights: &Tensor,
603 ) -> Result<Tensor> {
604 if predictions.is_empty() {
605 return Err(SklearsError::InvalidInput(
606 "No predictions to average".to_string(),
607 ));
608 }
609
610 if weights.len() != predictions.len() {
611 return Err(SklearsError::ShapeMismatch {
612 expected: format!("{} weights", predictions.len()),
613 actual: format!("{} weights", weights.len()),
614 });
615 }
616
617 let mut weighted_sum = self.mul(
618 predictions[0],
619 &weights
620 .slice(scirs2_core::ndarray::s![0..1])
621 .to_owned()
622 .into_dyn(),
623 )?;
624
625 for (i, pred) in predictions.iter().enumerate().skip(1) {
626 let weight = weights
627 .slice(scirs2_core::ndarray::s![i..i + 1])
628 .to_owned()
629 .into_dyn();
630 let weighted_pred = self.mul(pred, &weight)?;
631 weighted_sum = self.add(&weighted_sum, &weighted_pred)?;
632 }
633
634 Ok(weighted_sum)
635 }
636
637 fn ensemble_majority_vote(&mut self, predictions: &[&Tensor]) -> Result<Tensor> {
638 if predictions.is_empty() {
639 return Err(SklearsError::InvalidInput(
640 "No predictions for majority vote".to_string(),
641 ));
642 }
643
644 let first_shape = predictions[0].shape();
647 let mut votes = Tensor::zeros(IxDyn(first_shape));
648
649 for pred in predictions {
650 let rounded = pred.mapv(|x| x.round());
652 votes = self.add(&votes, &rounded)?;
653 }
654
655 let n_models = predictions.len() as Float;
657 Ok(votes.mapv(|x| if x > n_models / 2.0 { 1.0 } else { 0.0 }))
658 }
659
660 fn add_leaf_node(&mut self, name: String, shape: TensorShape) {
661 let node = GraphNode {
662 id: self.computation_graph.current_node_id,
663 operation: TensorOperation::Leaf(name),
664 shape,
665 requires_grad: false,
666 grad: None,
667 };
668
669 self.computation_graph.nodes.push(node);
670 self.computation_graph.current_node_id += 1;
671 }
672
673 fn add_unary_op_node(&mut self, operation: TensorOperation, output_shape: TensorShape) {
674 let node = GraphNode {
675 id: self.computation_graph.current_node_id,
676 operation,
677 shape: output_shape,
678 requires_grad: false,
679 grad: None,
680 };
681
682 self.computation_graph.nodes.push(node);
683 self.computation_graph.current_node_id += 1;
684 }
685
686 fn add_binary_op_node(&mut self, operation: TensorOperation, output_shape: TensorShape) {
687 let node = GraphNode {
688 id: self.computation_graph.current_node_id,
689 operation,
690 shape: output_shape,
691 requires_grad: false,
692 grad: None,
693 };
694
695 self.computation_graph.nodes.push(node);
696 self.computation_graph.current_node_id += 1;
697 }
698
699 fn add_variadic_op_node(
700 &mut self,
701 operation: TensorOperation,
702 output_shape: TensorShape,
703 _n_inputs: usize,
704 ) {
705 let node = GraphNode {
706 id: self.computation_graph.current_node_id,
707 operation,
708 shape: output_shape,
709 requires_grad: false,
710 grad: None,
711 };
712
713 self.computation_graph.nodes.push(node);
714 self.computation_graph.current_node_id += 1;
715 }
716}
717
718impl Default for DeviceManager {
719 fn default() -> Self {
720 Self::new()
721 }
722}
723
724impl DeviceManager {
725 pub fn new() -> Self {
727 Self {
728 available_devices: vec![TensorDevice::Cpu],
729 current_device: TensorDevice::Cpu,
730 memory_usage: HashMap::new(),
731 }
732 }
733
734 pub fn available_devices(&self) -> &[TensorDevice] {
736 &self.available_devices
737 }
738
739 pub fn set_device(&mut self, device: TensorDevice) {
741 self.current_device = device;
742 }
743
744 pub fn current_device(&self) -> TensorDevice {
746 self.current_device
747 }
748
749 pub fn memory_usage(&self, device: TensorDevice) -> usize {
751 self.memory_usage.get(&device).copied().unwrap_or(0)
752 }
753}
754
755impl EnsembleTensorOps {
756 pub fn new(config: TensorConfig) -> Self {
758 Self {
759 context: TensorOpsContext::new(config),
760 }
761 }
762
763 pub fn train_ensemble_tensors(
765 &mut self,
766 x: &Array2<Float>,
767 y: &Array1<Int>,
768 n_estimators: usize,
769 ) -> Result<Vec<Tensor>> {
770 let x_tensor = self.context.from_array(x)?;
771 let mut models = Vec::new();
772
773 for _i in 0..n_estimators {
774 let n_features = x.ncols();
776 let model_weights = self.context.randn(&[n_features, 1])?;
777 models.push(model_weights);
778 }
779
780 Ok(models)
781 }
782
783 pub fn predict_ensemble_tensors(
785 &mut self,
786 models: &[Tensor],
787 x: &Array2<Float>,
788 ) -> Result<Tensor> {
789 let x_tensor = self.context.from_array(x)?;
790 let mut predictions = Vec::new();
791
792 for model in models {
793 let pred = self.context.matmul(&x_tensor, model)?;
794 predictions.push(pred);
795 }
796
797 let pred_refs: Vec<_> = predictions.iter().collect();
799 self.context
800 .ensemble_aggregate(&pred_refs, None, AggregationType::Average)
801 }
802
803 pub fn context_mut(&mut self) -> &mut TensorOpsContext {
805 &mut self.context
806 }
807
808 pub fn context(&self) -> &TensorOpsContext {
810 &self.context
811 }
812}
813
814#[macro_export]
816macro_rules! tensor_op {
817 ($ctx:expr, $op:ident, $($args:expr),*) => {
818 $ctx.$op($($args),*)
819 };
820}
821
822#[allow(non_snake_case)]
823#[cfg(test)]
824mod tests {
825 use super::*;
826 use scirs2_core::ndarray::array;
827
828 #[test]
829 fn test_tensor_config() {
830 let config = TensorConfig::default();
831 assert!(!config.enable_autograd);
832 assert_eq!(config.default_device, TensorDevice::Cpu);
833 }
834
835 #[test]
836 fn test_tensor_context_creation() {
837 let config = TensorConfig::default();
838 let mut ctx = TensorOpsContext::new(config);
839
840 let tensor = ctx.zeros(&[2, 3]).unwrap();
841 assert_eq!(tensor.shape(), &[2, 3]);
842 }
843
844 #[test]
845 fn test_tensor_operations() {
846 let config = TensorConfig::default();
847 let mut ctx = TensorOpsContext::new(config);
848
849 let a = ctx.ones(&[2, 2]).unwrap();
850 let b = ctx.full(&[2, 2], 2.0).unwrap();
851
852 let result = ctx.add(&a, &b).unwrap();
853
854 assert!(result.iter().all(|&x| (x - 3.0).abs() < 1e-10));
856 }
857
858 #[test]
859 fn test_matrix_multiplication() {
860 let config = TensorConfig::default();
861 let mut ctx = TensorOpsContext::new(config);
862
863 let a_array = array![[1.0, 2.0], [3.0, 4.0]];
864 let b_array = array![[5.0, 6.0], [7.0, 8.0]];
865
866 let a = ctx.from_array(&a_array).unwrap();
867 let b = ctx.from_array(&b_array).unwrap();
868
869 let result = ctx.matmul(&a, &b).unwrap();
870
871 let result_2d = result
873 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
874 .unwrap();
875 assert_eq!(result_2d[[0, 0]], 19.0);
876 assert_eq!(result_2d[[0, 1]], 22.0);
877 assert_eq!(result_2d[[1, 0]], 43.0);
878 assert_eq!(result_2d[[1, 1]], 50.0);
879 }
880
881 #[test]
882 fn test_activation_functions() {
883 let config = TensorConfig::default();
884 let mut ctx = TensorOpsContext::new(config);
885
886 let tensor = ctx.from_array(&array![[-1.0, 0.0, 1.0]]).unwrap();
887
888 let relu_result = ctx.activation(&tensor, ActivationType::ReLU).unwrap();
889 let sigmoid_result = ctx.activation(&tensor, ActivationType::Sigmoid).unwrap();
890
891 assert_eq!(relu_result.as_slice().unwrap()[0], 0.0);
893 assert_eq!(relu_result.as_slice().unwrap()[1], 0.0);
894 assert_eq!(relu_result.as_slice().unwrap()[2], 1.0);
895
896 assert!(sigmoid_result.iter().all(|&x| x >= 0.0 && x <= 1.0));
898 }
899
900 #[test]
901 fn test_reduction_operations() {
902 let config = TensorConfig::default();
903 let mut ctx = TensorOpsContext::new(config);
904
905 let tensor = ctx.from_array(&array![[1.0, 2.0], [3.0, 4.0]]).unwrap();
906
907 let sum_result = ctx.reduce(&tensor, ReductionType::Sum, None).unwrap();
908 let mean_result = ctx.reduce(&tensor, ReductionType::Mean, None).unwrap();
909
910 assert_eq!(sum_result.as_slice().unwrap()[0], 10.0);
911 assert_eq!(mean_result.as_slice().unwrap()[0], 2.5);
912 }
913
914 #[test]
915 fn test_ensemble_operations() {
916 let config = TensorConfig::default();
917 let mut ctx = TensorOpsContext::new(config);
918
919 let pred1 = ctx.from_array(&array![[1.0, 2.0]]).unwrap();
920 let pred2 = ctx.from_array(&array![[3.0, 4.0]]).unwrap();
921 let predictions = vec![&pred1, &pred2];
922
923 let avg_result = ctx
924 .ensemble_aggregate(&predictions, None, AggregationType::Average)
925 .unwrap();
926
927 assert_eq!(avg_result.as_slice().unwrap()[0], 2.0);
929 assert_eq!(avg_result.as_slice().unwrap()[1], 3.0);
930 }
931
932 #[test]
933 fn test_ensemble_tensor_ops() {
934 let config = TensorConfig::default();
935 let mut ensemble_ops = EnsembleTensorOps::new(config);
936
937 let x = array![[1.0, 2.0], [3.0, 4.0]];
938 let y = array![0, 1];
939
940 let models = ensemble_ops.train_ensemble_tensors(&x, &y, 3).unwrap();
941 assert_eq!(models.len(), 3);
942
943 let predictions = ensemble_ops.predict_ensemble_tensors(&models, &x).unwrap();
944 assert_eq!(predictions.shape()[0], 2); }
946
947 #[test]
948 fn test_device_manager() {
949 let mut manager = DeviceManager::new();
950
951 assert_eq!(manager.current_device(), TensorDevice::Cpu);
952 assert_eq!(manager.memory_usage(TensorDevice::Cpu), 0);
953
954 manager.set_device(TensorDevice::Gpu(0));
955 assert_eq!(manager.current_device(), TensorDevice::Gpu(0));
956 }
957
958 #[test]
959 fn test_computation_graph() {
960 let config = TensorConfig {
961 enable_autograd: true,
962 ..Default::default()
963 };
964 let mut ctx = TensorOpsContext::new(config);
965
966 let a = ctx.ones(&[2, 2]).unwrap();
967 let b = ctx.ones(&[2, 2]).unwrap();
968 let _c = ctx.add(&a, &b).unwrap();
969
970 let graph = ctx.get_computation_graph();
971 assert!(graph.nodes.len() > 0);
972 }
973}