1use std::sync::Arc;
20use torsh_core::{
21 dtype::TensorElement,
22 error::{Result, TorshError},
23};
24
25#[cfg(feature = "simd")]
28mod simd_imports {
29 pub use scirs2_core::ndarray::Array1;
37}
38
39#[cfg(feature = "simd")]
40use simd_imports::*;
41
42#[cfg(feature = "parallel")]
43#[cfg(feature = "parallel")]
49use scirs2_core::chunking::{
50 CacheAwareness, ChunkConfig, ChunkStrategy, ComputeIntensity, GpuChunkSettings, MemoryPattern,
51 NumaStrategy,
52};
53
54use crate::core_ops::{Operation, Tensor};
66
67#[cfg(feature = "simd")]
69pub(crate) mod adaptive_simd {
70 use super::*;
71 use scirs2_core::ndarray::ArrayView1;
72
73 pub fn adaptive_simd_relu_f32(input: &ArrayView1<f32>) -> Array1<f32> {
76 scirs2_core::simd::activation::simd_relu_f32(input)
77 }
78
79 pub fn adaptive_simd_sigmoid_f32(input: &ArrayView1<f32>) -> Array1<f32> {
82 scirs2_core::simd::transcendental::simd_sigmoid_f32(input)
83 }
84
85 pub fn adaptive_simd_gelu_f32(input: &ArrayView1<f32>) -> Array1<f32> {
88 scirs2_core::simd::transcendental::simd_gelu_f32(input)
89 }
90}
91
92#[cfg(feature = "simd")]
93#[cfg(feature = "parallel")]
95mod intelligent_chunking {
96 use super::*;
97
98 #[derive(Debug, Clone, Copy)]
100 pub enum TensorOpType {
101 ElementWise,
103 Activation,
105 }
106
107 pub fn create_optimal_chunk_config(
111 tensor_size: usize,
112 op_type: TensorOpType,
113 _device: torsh_core::device::DeviceType,
114 is_gpu_available: bool,
115 ) -> ChunkConfig {
116 match op_type {
117 TensorOpType::ElementWise => ChunkConfig {
118 strategy: if tensor_size > 100_000 {
119 ChunkStrategy::MemoryOptimized
120 } else {
121 ChunkStrategy::CacheOptimized
122 },
123 min_chunk_size: 64,
124 max_chunk_size: 8192,
125 prefer_work_stealing: true,
126 memory_pattern: MemoryPattern::Sequential,
127 compute_intensity: ComputeIntensity::MemoryBound,
128 enable_monitoring: false,
129 load_balance_factor: 0.1,
130 cache_awareness: CacheAwareness::L2,
131 numa_strategy: NumaStrategy::LocalPreferred,
132 gpu_settings: if is_gpu_available {
133 Some(GpuChunkSettings::default())
134 } else {
135 None
136 },
137 },
138
139 TensorOpType::Activation => ChunkConfig {
140 strategy: ChunkStrategy::CacheOptimized,
141 min_chunk_size: 64,
142 max_chunk_size: 4096,
143 prefer_work_stealing: true,
144 memory_pattern: MemoryPattern::Sequential,
145 compute_intensity: ComputeIntensity::ComputeIntensive,
146 enable_monitoring: false,
147 load_balance_factor: 0.1,
148 cache_awareness: CacheAwareness::L1,
149 numa_strategy: NumaStrategy::LocalPreferred,
150 gpu_settings: if is_gpu_available {
151 Some(GpuChunkSettings {
152 gpu_memory_ratio: 0.7,
153 gpu_min_chunk: 2048,
154 overlap_compute: true,
155 gpu_bandwidth: None, transfer_bandwidth: None, })
158 } else {
159 None
160 },
161 },
162 }
163 }
164
165 pub fn intelligent_parallel_process<T, F, R>(
167 data: Vec<T>,
168 op_type: TensorOpType,
169 device: torsh_core::device::DeviceType,
170 operation: F,
171 ) -> Vec<R>
172 where
173 T: Send + Sync,
174 R: Send + Sync,
175 F: Fn(T) -> R + Send + Sync,
176 {
177 let is_gpu_available = matches!(
178 device,
179 torsh_core::device::DeviceType::Cuda(_)
180 | torsh_core::device::DeviceType::Metal(_)
181 | torsh_core::device::DeviceType::Wgpu(_)
182 );
183
184 let _chunk_config =
185 create_optimal_chunk_config(data.len(), op_type, device, is_gpu_available);
186
187 #[cfg(feature = "parallel")]
189 {
190 use scirs2_core::parallel_ops::*;
191 data.into_par_iter().map(operation).collect()
192 }
193 #[cfg(not(feature = "parallel"))]
194 {
195 data.into_iter().map(operation).collect()
196 }
197 }
198}
199
200#[cfg(feature = "parallel")]
201use intelligent_chunking::*;
202
203fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
205 let max_dims = shape1.len().max(shape2.len());
206
207 for i in 0..max_dims {
208 let dim1 = if i < shape1.len() {
209 shape1[shape1.len() - 1 - i]
210 } else {
211 1
212 };
213 let dim2 = if i < shape2.len() {
214 shape2[shape2.len() - 1 - i]
215 } else {
216 1
217 };
218
219 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
220 return false;
221 }
222 }
223 true
224}
225
226fn compute_broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Result<Vec<usize>> {
228 let max_dims = shape1.len().max(shape2.len());
229 let mut result = Vec::with_capacity(max_dims);
230
231 for i in 0..max_dims {
232 let dim1 = if i < shape1.len() {
233 shape1[shape1.len() - 1 - i]
234 } else {
235 1
236 };
237 let dim2 = if i < shape2.len() {
238 shape2[shape2.len() - 1 - i]
239 } else {
240 1
241 };
242
243 if dim1 == dim2 {
244 result.push(dim1);
245 } else if dim1 == 1 {
246 result.push(dim2);
247 } else if dim2 == 1 {
248 result.push(dim1);
249 } else {
250 return Err(TorshError::ShapeMismatch {
251 expected: shape1.to_vec(),
252 got: shape2.to_vec(),
253 });
254 }
255 }
256
257 result.reverse();
258 Ok(result)
259}
260
261fn compute_broadcast_index(
263 flat_idx: usize,
264 broadcast_shape: &[usize],
265 original_shape: &[usize],
266) -> usize {
267 let mut result = 0;
268 let mut remaining = flat_idx;
269
270 let dims_diff = broadcast_shape.len() - original_shape.len();
271
272 for (i, &broadcast_dim) in broadcast_shape.iter().enumerate() {
273 let stride = broadcast_shape[i + 1..].iter().product::<usize>().max(1);
275 let coord = remaining / stride;
276 remaining %= stride;
277
278 debug_assert!(
280 coord < broadcast_dim,
281 "Coordinate {} out of bounds for dimension {} of size {}",
282 coord,
283 i,
284 broadcast_dim
285 );
286
287 if i >= dims_diff {
288 let original_dim = original_shape[i - dims_diff];
289 let adjusted_coord = if original_dim == 1 { 0 } else { coord };
290 result = result * original_dim + adjusted_coord;
291 }
292 }
293
294 result
295}
296
297impl<T: TensorElement + Copy> Tensor<T> {
298 pub fn add_scalar_(&mut self, scalar: T) -> Result<()>
300 where
301 T: Copy + std::ops::Add<Output = T>,
302 {
303 self.make_unique()?;
305 self.apply_(|x| x + scalar)
306 }
307
308 pub fn add_scalar(&self, scalar: T) -> Result<Self>
310 where
311 T: Copy + std::ops::Add<Output = T>,
312 {
313 self.map(|x| x + scalar)
314 }
315
316 pub fn sub_scalar_(&mut self, scalar: T) -> Result<()>
318 where
319 T: Copy + std::ops::Sub<Output = T>,
320 {
321 self.make_unique()?;
322 self.apply_(|x| x - scalar)
323 }
324
325 pub fn sub_scalar(&self, scalar: T) -> Result<Self>
327 where
328 T: Copy + std::ops::Sub<Output = T>,
329 {
330 self.map(|x| x - scalar)
331 }
332
333 pub fn mul_scalar_(&mut self, scalar: T) -> Result<()>
335 where
336 T: Copy + std::ops::Mul<Output = T>,
337 {
338 self.make_unique()?;
340 self.apply_(|x| x * scalar)
341 }
342
343 pub fn mul_scalar(&self, scalar: T) -> Result<Self>
345 where
346 T: Copy + std::ops::Mul<Output = T>,
347 {
348 self.map(|x| x * scalar)
349 }
350
351 pub fn div_scalar_(&mut self, scalar: T) -> Result<()>
353 where
354 T: Copy + std::ops::Div<Output = T>,
355 {
356 self.make_unique()?;
357 self.apply_(|x| x / scalar)
358 }
359
360 pub fn div_scalar(&self, scalar: T) -> Result<Self>
362 where
363 T: Copy + std::ops::Div<Output = T>,
364 {
365 self.map(|x| x / scalar)
366 }
367
368 pub fn add(&self, other: &Self) -> Result<Self>
370 where
371 T: std::ops::Add<Output = T>,
372 {
373 if self.shape() != other.shape() {
375 return self.broadcast_add(other);
376 }
377
378 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
380 let self_data = self.data()?;
381 if self_data.len() >= 1024 {
382 let other_data = other.data()?;
383 let a_f32: &[f32] = unsafe {
385 std::slice::from_raw_parts(self_data.as_ptr() as *const f32, self_data.len())
386 };
387 let b_f32: &[f32] = unsafe {
388 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
389 };
390 let mut out = vec![0.0f32; self_data.len()];
391 crate::simd_ops_f32::add_into_f32(a_f32, b_f32, &mut out);
392 let result_data: Vec<T> = unsafe {
394 let mut v = std::mem::ManuallyDrop::new(out);
395 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
396 };
397 let mut result =
398 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)?;
399 if self.requires_grad || other.requires_grad {
400 result.requires_grad = true;
401 result.operation = Operation::Add {
402 lhs: Arc::new(self.clone()),
403 rhs: Arc::new(other.clone()),
404 };
405 }
406 return Ok(result);
407 }
408 }
409
410 let mut result = self.elementwise_operation(other, |a, b| a + b)?;
412
413 if self.requires_grad || other.requires_grad {
415 result.requires_grad = true;
416 result.operation = Operation::Add {
417 lhs: Arc::new(self.clone()),
418 rhs: Arc::new(other.clone()),
419 };
420 }
421
422 Ok(result)
423 }
424
425 fn broadcast_add(&self, other: &Self) -> Result<Self>
427 where
428 T: std::ops::Add<Output = T>,
429 {
430 let self_shape_binding = self.shape();
432 let other_shape_binding = other.shape();
433 let self_shape = self_shape_binding.dims();
434 let other_shape = other_shape_binding.dims();
435
436 if !can_broadcast(self_shape, other_shape) {
438 return Err(TorshError::ShapeMismatch {
439 expected: self_shape.to_vec(),
440 got: other_shape.to_vec(),
441 });
442 }
443
444 let broadcast_shape = compute_broadcast_shape(self_shape, other_shape)?;
446
447 let self_data = self.data()?;
449 let other_data = other.data()?;
450
451 let mut result_data = Vec::with_capacity(broadcast_shape.iter().product());
453
454 for i in 0..broadcast_shape.iter().product::<usize>() {
455 let self_idx = compute_broadcast_index(i, &broadcast_shape, self_shape);
456 let other_idx = compute_broadcast_index(i, &broadcast_shape, other_shape);
457
458 let self_val = *self_data
459 .get(self_idx)
460 .ok_or_else(|| TorshError::IndexError {
461 index: self_idx,
462 size: self_data.len(),
463 })?;
464 let other_val = *other_data
465 .get(other_idx)
466 .ok_or_else(|| TorshError::IndexError {
467 index: other_idx,
468 size: other_data.len(),
469 })?;
470 result_data.push(self_val + other_val);
471 }
472
473 let mut result = Self::from_data(result_data, broadcast_shape, self.device)?;
474
475 if self.requires_grad || other.requires_grad {
477 result.requires_grad = true;
478 result.operation = Operation::Add {
479 lhs: Arc::new(self.clone()),
480 rhs: Arc::new(other.clone()),
481 };
482 }
483
484 Ok(result)
485 }
486
487 pub fn sub(&self, other: &Self) -> Result<Self>
489 where
490 T: std::ops::Sub<Output = T>,
491 {
492 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
494 && self.shape() == other.shape()
495 {
496 let self_data = self.data()?;
497 if self_data.len() >= 1024 {
498 let other_data = other.data()?;
499 let a_f32: &[f32] = unsafe {
500 std::slice::from_raw_parts(self_data.as_ptr() as *const f32, self_data.len())
501 };
502 let b_f32: &[f32] = unsafe {
503 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
504 };
505 let mut out = vec![0.0f32; self_data.len()];
506 crate::simd_ops_f32::sub_into_f32(a_f32, b_f32, &mut out);
507 let result_data: Vec<T> = unsafe {
508 let mut v = std::mem::ManuallyDrop::new(out);
509 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
510 };
511 return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
512 }
513 }
514 self.elementwise_operation(other, |a, b| a - b)
515 }
516
517 pub fn mul(&self, other: &Self) -> Result<Self>
519 where
520 T: std::ops::Mul<Output = T>,
521 {
522 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
524 && self.shape() == other.shape()
525 {
526 let self_data = self.data()?;
527 if self_data.len() >= 1024 {
528 let other_data = other.data()?;
529 let a_f32: &[f32] = unsafe {
530 std::slice::from_raw_parts(self_data.as_ptr() as *const f32, self_data.len())
531 };
532 let b_f32: &[f32] = unsafe {
533 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
534 };
535 let mut out = vec![0.0f32; self_data.len()];
536 crate::simd_ops_f32::mul_into_f32(a_f32, b_f32, &mut out);
537 let result_data: Vec<T> = unsafe {
538 let mut v = std::mem::ManuallyDrop::new(out);
539 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
540 };
541 return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
542 }
543 }
544 self.elementwise_operation(other, |a, b| a * b)
545 }
546
547 pub fn div(&self, other: &Self) -> Result<Self>
549 where
550 T: std::ops::Div<Output = T>,
551 {
552 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
554 && self.shape() == other.shape()
555 {
556 let self_data = self.data()?;
557 if self_data.len() >= 1024 {
558 let other_data = other.data()?;
559 let a_f32: &[f32] = unsafe {
560 std::slice::from_raw_parts(self_data.as_ptr() as *const f32, self_data.len())
561 };
562 let b_f32: &[f32] = unsafe {
563 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
564 };
565 let mut out = vec![0.0f32; self_data.len()];
566 crate::simd_ops_f32::div_into_f32(a_f32, b_f32, &mut out);
567 let result_data: Vec<T> = unsafe {
568 let mut v = std::mem::ManuallyDrop::new(out);
569 Vec::from_raw_parts(v.as_mut_ptr() as *mut T, v.len(), v.capacity())
570 };
571 return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
572 }
573 }
574 self.elementwise_operation(other, |a, b| a / b)
575 }
576
577 fn broadcast_binary_op<F>(&self, other: &Self, op: F) -> Result<Self>
579 where
580 F: Fn(T, T) -> T + Send + Sync,
581 {
582 use crate::broadcast::BroadcastOps;
583
584 let self_shape_binding = self.shape();
585 let self_shape = self_shape_binding.dims();
586 let other_shape_binding = other.shape();
587 let other_shape = other_shape_binding.dims();
588
589 let broadcast_shape = BroadcastOps::compute_broadcast_shape(self_shape, other_shape)?;
591
592 let self_data = self.data()?;
593 let other_data = other.data()?;
594
595 let total_elements = broadcast_shape.iter().product::<usize>();
596 let mut result_data = Vec::with_capacity(total_elements);
597
598 let mut indices = vec![0; broadcast_shape.len()];
600 for _ in 0..total_elements {
601 let self_idx = self.compute_broadcast_index(&indices, self_shape, &broadcast_shape)?;
603 let other_idx =
604 other.compute_broadcast_index(&indices, other_shape, &broadcast_shape)?;
605
606 let result = op(self_data[self_idx], other_data[other_idx]);
607 result_data.push(result);
608
609 Self::increment_indices(&mut indices, &broadcast_shape);
611 }
612
613 Self::from_data(result_data, broadcast_shape, self.device)
614 }
615
616 fn increment_indices(indices: &mut [usize], shape: &[usize]) {
618 for i in (0..indices.len()).rev() {
619 indices[i] += 1;
620 if indices[i] < shape[i] {
621 break;
622 }
623 indices[i] = 0;
624 }
625 }
626
627 fn compute_broadcast_index(
629 &self,
630 broadcast_indices: &[usize],
631 original_shape: &[usize],
632 broadcast_shape: &[usize],
633 ) -> Result<usize> {
634 let ndim_diff = broadcast_shape.len() - original_shape.len();
635 let mut flat_index = 0;
636 let mut stride = 1;
637
638 for i in (0..original_shape.len()).rev() {
639 let broadcast_idx = broadcast_indices[ndim_diff + i];
640 let original_size = original_shape[i];
641
642 let actual_idx = if original_size == 1 { 0 } else { broadcast_idx };
644
645 flat_index += actual_idx * stride;
646 stride *= original_size;
647 }
648
649 Ok(flat_index)
650 }
651
652 fn elementwise_operation<F>(&self, other: &Self, op: F) -> Result<Self>
654 where
655 F: Fn(T, T) -> T + Send + Sync,
656 {
657 if self.shape() != other.shape() {
659 return self.broadcast_binary_op(other, op);
660 }
661
662 let self_data = self.data()?;
663 let other_data = other.data()?;
664
665 #[cfg(feature = "simd")]
667 {
668 if self_data.len() > 1000 {
669 let result_data = self.simd_elementwise_operation(&self_data, &other_data, op)?;
671 return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
672 }
673 }
674
675 #[cfg(feature = "parallel")]
677 {
678 if self_data.len() > 100 {
679 let paired_data: Vec<(T, T)> = self_data
681 .iter()
682 .zip(other_data.iter())
683 .map(|(&a, &b)| (a, b))
684 .collect();
685 let result_data = intelligent_parallel_process(
686 paired_data,
687 TensorOpType::ElementWise, self.device.clone(),
689 |(a, b)| op(a, b),
690 );
691 return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
692 }
693 }
694
695 let result_data: Vec<T> = self_data
697 .iter()
698 .zip(other_data.iter())
699 .map(|(&a, &b)| op(a, b))
700 .collect();
701
702 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
703 }
704
705 #[cfg(feature = "simd")]
707 #[allow(dead_code)]
708 fn simd_elementwise_operation<F>(&self, data_a: &[T], data_b: &[T], op: F) -> Result<Vec<T>>
709 where
710 F: Fn(T, T) -> T + Send + Sync,
711 T: TensorElement,
712 {
713 #[cfg(feature = "simd")]
715 {
716 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
718 let _a_f32 = unsafe { std::mem::transmute::<&[T], &[f32]>(data_a) };
719 let _b_f32 = unsafe { std::mem::transmute::<&[T], &[f32]>(data_b) };
720
721 #[cfg(feature = "parallel")]
725 {
726 use scirs2_core::parallel_ops::*;
727 return Ok(data_a
728 .par_iter()
729 .zip(data_b.par_iter())
730 .map(|(&a, &b)| op(a, b))
731 .collect());
732 }
733 #[cfg(not(feature = "parallel"))]
734 {
735 return Ok(data_a
736 .iter()
737 .zip(data_b.iter())
738 .map(|(&a, &b)| op(a, b))
739 .collect());
740 }
741 }
742 }
743
744 Ok(data_a
746 .iter()
747 .zip(data_b.iter())
748 .map(|(&a, &b)| op(a, b))
749 .collect())
750 }
751}
752
753impl<T: TensorElement + Copy> Tensor<T> {
758 pub fn add_(&mut self, other: &Self) -> Result<&mut Self>
768 where
769 T: std::ops::Add<Output = T>,
770 {
771 if self.requires_grad {
772 return Err(TorshError::InvalidArgument(
773 "In-place operation on tensor that requires grad is not allowed".to_string(),
774 ));
775 }
776 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
777 && self.shape() == other.shape()
778 && self.numel() >= 1024
779 {
780 self.make_unique()?;
782 let other_data = other.data()?;
783 self.storage.with_slice_mut(|out_t: &mut [T]| {
784 let out_f32: &mut [f32] = unsafe {
786 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
787 };
788 let rhs_f32: &[f32] = unsafe {
789 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
790 };
791 crate::simd_ops_f32::add_assign_f32(out_f32, rhs_f32);
792 Ok(())
793 })?;
794 return Ok(self);
795 }
796 let len = self.storage.len();
798 let other_data = other.data()?;
799 for i in 0..len {
800 let a = self.storage.get(i)?;
801 let b = *other_data.get(i).ok_or_else(|| TorshError::IndexError {
802 index: i,
803 size: other_data.len(),
804 })?;
805 self.storage.set(i, a + b)?;
806 }
807 Ok(self)
808 }
809
810 pub fn sub_(&mut self, other: &Self) -> Result<&mut Self>
815 where
816 T: std::ops::Sub<Output = T>,
817 {
818 if self.requires_grad {
819 return Err(TorshError::InvalidArgument(
820 "In-place operation on tensor that requires grad is not allowed".to_string(),
821 ));
822 }
823 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
824 && self.shape() == other.shape()
825 && self.numel() >= 1024
826 {
827 self.make_unique()?;
828 let other_data = other.data()?;
829 self.storage.with_slice_mut(|out_t: &mut [T]| {
830 let out_f32: &mut [f32] = unsafe {
831 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
832 };
833 let rhs_f32: &[f32] = unsafe {
834 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
835 };
836 crate::simd_ops_f32::sub_assign_f32(out_f32, rhs_f32);
837 Ok(())
838 })?;
839 return Ok(self);
840 }
841 let len = self.storage.len();
842 let other_data = other.data()?;
843 for i in 0..len {
844 let a = self.storage.get(i)?;
845 let b = *other_data.get(i).ok_or_else(|| TorshError::IndexError {
846 index: i,
847 size: other_data.len(),
848 })?;
849 self.storage.set(i, a - b)?;
850 }
851 Ok(self)
852 }
853
854 pub fn mul_(&mut self, other: &Self) -> Result<&mut Self>
859 where
860 T: std::ops::Mul<Output = T>,
861 {
862 if self.requires_grad {
863 return Err(TorshError::InvalidArgument(
864 "In-place operation on tensor that requires grad is not allowed".to_string(),
865 ));
866 }
867 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
868 && self.shape() == other.shape()
869 && self.numel() >= 1024
870 {
871 self.make_unique()?;
872 let other_data = other.data()?;
873 self.storage.with_slice_mut(|out_t: &mut [T]| {
874 let out_f32: &mut [f32] = unsafe {
875 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
876 };
877 let rhs_f32: &[f32] = unsafe {
878 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
879 };
880 crate::simd_ops_f32::mul_assign_f32(out_f32, rhs_f32);
881 Ok(())
882 })?;
883 return Ok(self);
884 }
885 let len = self.storage.len();
886 let other_data = other.data()?;
887 for i in 0..len {
888 let a = self.storage.get(i)?;
889 let b = *other_data.get(i).ok_or_else(|| TorshError::IndexError {
890 index: i,
891 size: other_data.len(),
892 })?;
893 self.storage.set(i, a * b)?;
894 }
895 Ok(self)
896 }
897
898 pub fn div_(&mut self, other: &Self) -> Result<&mut Self>
903 where
904 T: std::ops::Div<Output = T>,
905 {
906 if self.requires_grad {
907 return Err(TorshError::InvalidArgument(
908 "In-place operation on tensor that requires grad is not allowed".to_string(),
909 ));
910 }
911 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>()
912 && self.shape() == other.shape()
913 && self.numel() >= 1024
914 {
915 self.make_unique()?;
916 let other_data = other.data()?;
917 self.storage.with_slice_mut(|out_t: &mut [T]| {
918 let out_f32: &mut [f32] = unsafe {
919 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
920 };
921 let rhs_f32: &[f32] = unsafe {
922 std::slice::from_raw_parts(other_data.as_ptr() as *const f32, other_data.len())
923 };
924 crate::simd_ops_f32::div_assign_f32(out_f32, rhs_f32);
925 Ok(())
926 })?;
927 return Ok(self);
928 }
929 let len = self.storage.len();
930 let other_data = other.data()?;
931 for i in 0..len {
932 let a = self.storage.get(i)?;
933 let b = *other_data.get(i).ok_or_else(|| TorshError::IndexError {
934 index: i,
935 size: other_data.len(),
936 })?;
937 self.storage.set(i, a / b)?;
938 }
939 Ok(self)
940 }
941}
942
943impl<T: TensorElement + Copy> Tensor<T>
945where
946 T: scirs2_core::numeric::Float + torsh_core::dtype::FloatElement,
947{
948 pub fn sqrt(&self) -> Result<Self> {
950 self.map(|x| x.sqrt())
951 }
952
953 pub fn square(&self) -> Result<Self> {
955 self.map(|x| x * x)
956 }
957
958 pub fn rsqrt(&self) -> Result<Self> {
960 self.map(|x| T::from(1.0).expect("numeric conversion should succeed") / x.sqrt())
961 }
962
963 pub fn reciprocal(&self) -> Result<Self> {
965 self.map(|x| T::from(1.0).expect("numeric conversion should succeed") / x)
966 }
967
968 pub fn exp(&self) -> Result<Self> {
970 self.map(|x| x.exp())
971 }
972
973 pub fn ln(&self) -> Result<Self> {
975 self.map(|x| x.ln())
976 }
977
978 pub fn log10(&self) -> Result<Self> {
980 self.map(|x| x.log10())
981 }
982
983 pub fn log2(&self) -> Result<Self> {
985 self.map(|x| x.log2())
986 }
987
988 pub fn log(&self) -> Result<Self> {
990 self.map(|x| x.ln())
991 }
992
993 pub fn sin(&self) -> Result<Self> {
995 self.map(|x| x.sin())
996 }
997
998 pub fn cos(&self) -> Result<Self> {
1000 self.map(|x| x.cos())
1001 }
1002
1003 pub fn tan(&self) -> Result<Self> {
1005 self.map(|x| x.tan())
1006 }
1007
1008 pub fn gelu(&self) -> Result<Self> {
1010 #[cfg(feature = "gpu")]
1012 {
1013 if self.numel() > 50000 {
1014 if let Ok(result) = self.gpu_gelu() {
1015 return Ok(result);
1016 }
1017 }
1018 }
1019
1020 #[cfg(feature = "simd")]
1022 {
1023 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && self.numel() > 1000 {
1024 return self.simd_gelu_f32();
1025 }
1026 }
1027
1028 #[cfg(feature = "parallel")]
1030 {
1031 if self.numel() > 100 {
1032 return self.parallel_map(|x| self.compute_gelu_scalar(x));
1033 }
1034 }
1035
1036 self.map(|x| self.compute_gelu_scalar(x))
1039 }
1040
1041 #[cfg(feature = "gpu")]
1044 #[allow(dead_code)]
1045 fn gpu_gelu(&self) -> Result<Self>
1046 where
1047 T: torsh_core::dtype::FloatElement,
1048 {
1049 #[cfg(feature = "profiling")]
1050 Err(TorshError::InvalidArgument(
1067 "GPU GELU temporarily unavailable".to_string(),
1068 ))
1069 }
1070
1071 fn compute_gelu_scalar(&self, x: T) -> T {
1073 let pi = T::from(std::f64::consts::PI).expect("numeric conversion should succeed");
1074 let two = T::from(2.0).expect("numeric conversion should succeed");
1075 let sqrt_2_over_pi = (two / pi).sqrt();
1076 let point_044715 = T::from(0.044715).expect("numeric conversion should succeed");
1077 let one = <T as scirs2_core::numeric::One>::one();
1078 let half = T::from(0.5).expect("numeric conversion should succeed");
1079
1080 let x_cubed = x * x * x;
1081 let tanh_input = sqrt_2_over_pi * (x + point_044715 * x_cubed);
1082 half * x * (one + tanh_input.tanh())
1083 }
1084
1085 #[cfg(feature = "simd")]
1087 fn simd_gelu_f32(&self) -> Result<Self> {
1088 use scirs2_core::ndarray::ArrayView1;
1089
1090 let data = self.data()?;
1091
1092 let data_f32: &[f32] =
1094 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) };
1095
1096 let data_view = ArrayView1::from(data_f32);
1098
1099 let result_array = adaptive_simd::adaptive_simd_gelu_f32(&data_view);
1101
1102 let result_vec: Vec<T> = result_array
1104 .to_vec()
1105 .into_iter()
1106 .map(|f| unsafe { std::mem::transmute_copy::<f32, T>(&f) })
1107 .collect();
1108
1109 Self::from_data(
1110 result_vec,
1111 self.shape().dims().to_vec(),
1112 self.device.clone(),
1113 )
1114 }
1115
1116 pub fn leaky_relu(&self, negative_slope: T) -> Result<Self> {
1118 self.map(|x| {
1119 if x > scirs2_core::numeric::Zero::zero() {
1120 x
1121 } else {
1122 negative_slope * x
1123 }
1124 })
1125 }
1126
1127 pub fn asin(&self) -> Result<Self> {
1129 self.map(|x| x.asin())
1130 }
1131
1132 pub fn acos(&self) -> Result<Self> {
1134 self.map(|x| x.acos())
1135 }
1136
1137 pub fn atan(&self) -> Result<Self> {
1139 self.map(|x| x.atan())
1140 }
1141
1142 pub fn sinh(&self) -> Result<Self> {
1144 self.map(|x| x.sinh())
1145 }
1146
1147 pub fn cosh(&self) -> Result<Self> {
1149 self.map(|x| x.cosh())
1150 }
1151
1152 pub fn tanh(&self) -> Result<Self> {
1154 self.map(|x| x.tanh())
1155 }
1156
1157 pub fn pow(&self, exponent: T) -> Result<Self>
1159 where
1160 T: TensorElement + Into<f32>,
1161 {
1162 let exponent_f32: f32 = exponent.into();
1164
1165 let mut result = self.map(|x| x.powf(exponent))?;
1166
1167 if self.requires_grad {
1169 result.requires_grad = true;
1170 result.operation = Operation::Power {
1171 input: Arc::new(self.clone()),
1172 exponent: exponent_f32,
1173 };
1174 }
1175
1176 Ok(result)
1177 }
1178
1179 pub fn pow_scalar(&self, exponent: T) -> Result<Self>
1181 where
1182 T: TensorElement + Into<f32>,
1183 {
1184 self.pow(exponent)
1185 }
1186
1187 pub fn pow_tensor(&self, exponent: &Self) -> Result<Self> {
1189 self.elementwise_operation(exponent, |base, exp| base.powf(exp))
1190 }
1191
1192 pub fn floor(&self) -> Result<Self> {
1194 self.map(|x| x.floor())
1195 }
1196
1197 pub fn ceil(&self) -> Result<Self> {
1199 self.map(|x| x.ceil())
1200 }
1201
1202 pub fn round(&self) -> Result<Self> {
1204 self.map(|x| x.round())
1205 }
1206
1207 pub fn trunc(&self) -> Result<Self> {
1209 self.map(|x| x.trunc())
1210 }
1211
1212 pub fn fract(&self) -> Result<Self> {
1214 self.map(|x| x.fract())
1215 }
1216
1217 pub fn neg(&self) -> Result<Self>
1219 where
1220 T: std::ops::Neg<Output = T>,
1221 {
1222 self.map(|x| -x)
1223 }
1224
1225 pub fn sign(&self) -> Result<Self> {
1227 self.map(|x| {
1228 if x > <T as scirs2_core::numeric::Zero>::zero() {
1229 <T as scirs2_core::numeric::One>::one()
1230 } else if x < <T as scirs2_core::numeric::Zero>::zero() {
1231 -<T as scirs2_core::numeric::One>::one()
1232 } else {
1233 <T as scirs2_core::numeric::Zero>::zero()
1234 }
1235 })
1236 }
1237}
1238
1239impl<T: TensorElement + Copy> Tensor<T> {
1241 pub fn add_op(&self, other: &Self) -> Result<Self>
1243 where
1244 T: std::ops::Add<Output = T>,
1245 {
1246 self.add(other)
1247 }
1248
1249 pub fn mul_op(&self, other: &Self) -> Result<Self>
1251 where
1252 T: std::ops::Mul<Output = T>,
1253 {
1254 self.mul(other)
1255 }
1256
1257 pub fn sigmoid(&self) -> Result<Self>
1259 where
1260 T: torsh_core::dtype::FloatElement,
1261 {
1262 #[cfg(feature = "simd")]
1264 {
1265 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && self.numel() > 1000 {
1266 return self.simd_sigmoid_f32();
1267 }
1268 }
1269
1270 #[cfg(feature = "parallel")]
1272 {
1273 if self.numel() > 100 {
1274 let one = <T as scirs2_core::numeric::One>::one();
1275 return self.parallel_map(|x| {
1276 one / (one + (-x).exp())
1278 });
1279 }
1280 }
1281
1282 let one = <T as scirs2_core::numeric::One>::one();
1284 let neg_self = self.neg()?;
1285 let exp_neg = neg_self.exp()?;
1286 let one_plus_exp = exp_neg.add_scalar(one)?;
1287 let ones = Self::ones(self.shape().dims(), self.device)?;
1288 ones.div(&one_plus_exp)
1289 }
1290
1291 #[cfg(feature = "simd")]
1293 fn simd_sigmoid_f32(&self) -> Result<Self> {
1294 use scirs2_core::ndarray::ArrayView1;
1295
1296 let data = self.data()?;
1297
1298 let data_f32: &[f32] =
1300 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) };
1301
1302 let data_view = ArrayView1::from(data_f32);
1304
1305 let result_array = adaptive_simd::adaptive_simd_sigmoid_f32(&data_view);
1307
1308 let result_vec: Vec<T> = result_array
1310 .to_vec()
1311 .into_iter()
1312 .map(|f| unsafe { std::mem::transmute_copy::<f32, T>(&f) })
1313 .collect();
1314
1315 Self::from_data(
1316 result_vec,
1317 self.shape().dims().to_vec(),
1318 self.device.clone(),
1319 )
1320 }
1321
1322 pub fn relu(&self) -> Result<Self>
1324 where
1325 T: std::cmp::PartialOrd + scirs2_core::numeric::Zero,
1326 {
1327 let zero = <T as scirs2_core::numeric::Zero>::zero();
1328
1329 #[cfg(feature = "simd")]
1331 {
1332 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && self.numel() > 1000 {
1333 return self.simd_relu_f32();
1334 }
1335 }
1336
1337 #[cfg(feature = "parallel")]
1339 {
1340 if self.numel() > 100 {
1341 return self.parallel_map(|x| if x > zero { x } else { zero });
1342 }
1343 }
1344
1345 self.map(|x| if x > zero { x } else { zero })
1347 }
1348
1349 #[cfg(feature = "simd")]
1351 fn simd_relu_f32(&self) -> Result<Self> {
1352 use scirs2_core::ndarray::ArrayView1;
1353
1354 let data = self.data()?;
1355
1356 let data_f32: &[f32] =
1358 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, data.len()) };
1359
1360 let data_view = ArrayView1::from(data_f32);
1362
1363 let result_array = adaptive_simd::adaptive_simd_relu_f32(&data_view);
1365
1366 let result_vec: Vec<T> = result_array
1368 .to_vec()
1369 .into_iter()
1370 .map(|f| unsafe { std::mem::transmute_copy::<f32, T>(&f) })
1371 .collect();
1372
1373 Self::from_data(
1374 result_vec,
1375 self.shape().dims().to_vec(),
1376 self.device.clone(),
1377 )
1378 }
1379
1380 #[cfg(feature = "parallel")]
1382 fn parallel_map<F>(&self, op: F) -> Result<Self>
1383 where
1384 F: Fn(T) -> T + Send + Sync,
1385 {
1386 let data = self.data()?;
1387
1388 let result_data = intelligent_parallel_process(
1390 data.iter().copied().collect::<Vec<_>>(),
1391 TensorOpType::Activation, self.device.clone(),
1393 op,
1394 );
1395
1396 Self::from_data(result_data, self.shape().dims().to_vec(), self.device)
1397 }
1398
1399 pub fn minimum(&self, other: &Self) -> Result<Self>
1401 where
1402 T: std::cmp::PartialOrd,
1403 {
1404 self.elementwise_operation(other, |a, b| if a < b { a } else { b })
1405 }
1406
1407 pub fn maximum(&self, other: &Self) -> Result<Self>
1409 where
1410 T: std::cmp::PartialOrd,
1411 {
1412 self.elementwise_operation(other, |a, b| if a > b { a } else { b })
1413 }
1414
1415 pub fn clamp(&self, min: T, max: T) -> Result<Self>
1417 where
1418 T: std::cmp::PartialOrd + Copy,
1419 {
1420 let data = self.to_vec()?;
1421 let clamped_data: Vec<T> = data
1422 .iter()
1423 .map(|&x| {
1424 if x < min {
1425 min
1426 } else if x > max {
1427 max
1428 } else {
1429 x
1430 }
1431 })
1432 .collect();
1433
1434 Self::from_data(
1435 clamped_data,
1436 self.shape().dims().to_vec(),
1437 self.device.clone(),
1438 )
1439 }
1440
1441 pub fn dot(&self, other: &Self) -> Result<Self>
1445 where
1446 T: std::ops::Mul<Output = T> + std::ops::Add<Output = T> + scirs2_core::numeric::Zero,
1447 {
1448 let elementwise = self.mul(other)?;
1450 elementwise.sum()
1451 }
1452}
1453
1454impl<T: TensorElement + Copy + scirs2_core::numeric::FromPrimitive> Tensor<T> {
1456 pub fn add_scirs2(&self, other: &Self) -> Result<Self>
1458 where
1459 T: std::ops::Add<Output = T> + scirs2_core::numeric::Float,
1460 {
1461 self.add(other)
1464 }
1465
1466 pub fn mul_scirs2(&self, other: &Self) -> Result<Self>
1468 where
1469 T: std::ops::Mul<Output = T> + scirs2_core::numeric::Float,
1470 {
1471 self.mul(other)
1474 }
1475
1476 pub fn sub_scirs2(&self, other: &Self) -> Result<Self>
1478 where
1479 T: std::ops::Sub<Output = T> + scirs2_core::numeric::Float,
1480 {
1481 self.sub(other)
1484 }
1485
1486 pub fn div_scirs2(&self, other: &Self) -> Result<Self>
1488 where
1489 T: std::ops::Div<Output = T> + scirs2_core::numeric::Float,
1490 {
1491 self.div(other)
1494 }
1495}
1496
1497impl<T: TensorElement + Copy> std::ops::Add for &Tensor<T>
1499where
1500 T: std::ops::Add<Output = T>,
1501{
1502 type Output = Tensor<T>;
1503
1504 fn add(self, rhs: Self) -> Self::Output {
1505 self.add(rhs).expect("tensor addition should succeed")
1506 }
1507}
1508
1509impl<T: TensorElement + Copy> std::ops::Sub for &Tensor<T>
1510where
1511 T: std::ops::Sub<Output = T>,
1512{
1513 type Output = Tensor<T>;
1514
1515 fn sub(self, rhs: Self) -> Self::Output {
1516 self.sub(rhs).expect("tensor subtraction should succeed")
1517 }
1518}
1519
1520impl<T: TensorElement + Copy> std::ops::Mul for &Tensor<T>
1521where
1522 T: std::ops::Mul<Output = T>,
1523{
1524 type Output = Tensor<T>;
1525
1526 fn mul(self, rhs: Self) -> Self::Output {
1527 self.mul(rhs).expect("tensor multiplication should succeed")
1528 }
1529}
1530
1531impl<T: TensorElement + Copy> std::ops::Div for &Tensor<T>
1532where
1533 T: std::ops::Div<Output = T>,
1534{
1535 type Output = Tensor<T>;
1536
1537 fn div(self, rhs: Self) -> Self::Output {
1538 self.div(rhs).expect("tensor division should succeed")
1539 }
1540}
1541
1542impl<T: TensorElement + Copy> std::ops::Neg for &Tensor<T>
1544where
1545 T: std::ops::Neg<Output = T>,
1546{
1547 type Output = Tensor<T>;
1548
1549 fn neg(self) -> Self::Output {
1550 self.map(|x| -x).expect("negation map should succeed")
1551 }
1552}
1553
1554impl<T: TensorElement + Copy + scirs2_core::numeric::Float> Tensor<T> {
1558 #[cfg(feature = "simd")]
1562 pub fn add_simd(&self, other: &Self) -> Result<Self>
1563 where
1564 T: scirs2_core::simd_ops::SimdUnifiedOps,
1565 {
1566 use scirs2_core::ndarray::Array1;
1567
1568 if self.shape().dims() != other.shape().dims() {
1570 return Err(TorshError::ShapeMismatch {
1571 expected: self.shape().dims().to_vec(),
1572 got: other.shape().dims().to_vec(),
1573 });
1574 }
1575
1576 let data_a = self.to_vec()?;
1578 let data_b = other.to_vec()?;
1579
1580 let arr_a = Array1::from_vec(data_a);
1582 let arr_b = Array1::from_vec(data_b);
1583
1584 let result_arr = T::simd_add(&arr_a.view(), &arr_b.view());
1586
1587 Tensor::from_vec(result_arr.to_vec(), self.shape().dims())
1589 }
1590
1591 #[cfg(feature = "simd")]
1595 pub fn mul_simd(&self, other: &Self) -> Result<Self>
1596 where
1597 T: scirs2_core::simd_ops::SimdUnifiedOps,
1598 {
1599 use scirs2_core::ndarray::Array1;
1600
1601 if self.shape().dims() != other.shape().dims() {
1603 return Err(TorshError::ShapeMismatch {
1604 expected: self.shape().dims().to_vec(),
1605 got: other.shape().dims().to_vec(),
1606 });
1607 }
1608
1609 let data_a = self.to_vec()?;
1611 let data_b = other.to_vec()?;
1612
1613 let arr_a = Array1::from_vec(data_a);
1615 let arr_b = Array1::from_vec(data_b);
1616
1617 let result_arr = T::simd_mul(&arr_a.view(), &arr_b.view());
1619
1620 Tensor::from_vec(result_arr.to_vec(), self.shape().dims())
1622 }
1623
1624 #[cfg(feature = "simd")]
1630 pub fn dot_simd(&self, other: &Self) -> Result<T>
1631 where
1632 T: scirs2_core::simd_ops::SimdUnifiedOps,
1633 {
1634 use scirs2_core::ndarray::Array1;
1635
1636 if self.shape().dims() != other.shape().dims() {
1638 return Err(TorshError::ShapeMismatch {
1639 expected: self.shape().dims().to_vec(),
1640 got: other.shape().dims().to_vec(),
1641 });
1642 }
1643
1644 let data_a = self.to_vec()?;
1646 let data_b = other.to_vec()?;
1647
1648 let arr_a = Array1::from_vec(data_a);
1650 let arr_b = Array1::from_vec(data_b);
1651
1652 Ok(T::simd_dot(&arr_a.view(), &arr_b.view()))
1654 }
1655
1656 pub fn reduce_memory_efficient<F>(&self, func: F) -> Result<T>
1658 where
1659 F: Fn(T, T) -> T + Send + Sync,
1660 {
1661 #[cfg(feature = "profiling")]
1662 {
1663 }
1665
1666 let data = self.to_vec()?;
1669 Ok(data
1670 .into_iter()
1671 .reduce(func)
1672 .unwrap_or_else(|| <T as scirs2_core::numeric::Zero>::zero()))
1673 }
1674}
1675
1676impl<T: TensorElement + Copy + std::ops::Mul<Output = T>> Tensor<T> {
1678 pub fn relu_(&mut self) -> Result<&mut Self>
1686 where
1687 T: std::cmp::PartialOrd + scirs2_core::numeric::Zero,
1688 {
1689 if self.requires_grad {
1690 return Err(TorshError::InvalidArgument(
1691 "In-place operation on tensor that requires grad is not allowed".to_string(),
1692 ));
1693 }
1694
1695 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && self.numel() >= 1024 {
1697 self.make_unique()?;
1698 self.storage.with_slice_mut(|out_t: &mut [T]| {
1699 let out_f32: &mut [f32] = unsafe {
1701 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
1702 };
1703 crate::simd_ops_f32::relu_assign_f32(out_f32);
1704 Ok(())
1705 })?;
1706 return Ok(self);
1707 }
1708
1709 let zero = <T as scirs2_core::numeric::Zero>::zero();
1710 let len = self.storage.len();
1711
1712 for i in 0..len {
1713 let current = self.storage.get(i)?;
1714 if current < zero {
1715 self.storage.set(i, zero)?;
1716 }
1717 }
1718
1719 Ok(self)
1720 }
1721
1722 pub fn sigmoid_(&mut self) -> Result<&mut Self>
1727 where
1728 T: torsh_core::dtype::FloatElement,
1729 {
1730 if self.requires_grad {
1731 return Err(TorshError::InvalidArgument(
1732 "In-place operation on tensor that requires grad is not allowed".to_string(),
1733 ));
1734 }
1735
1736 let one = <T as scirs2_core::numeric::One>::one();
1737 let len = self.storage.len();
1738
1739 for i in 0..len {
1740 let x = self.storage.get(i)?;
1741 let sigmoid_val = one / (one + (-x).exp());
1742 self.storage.set(i, sigmoid_val)?;
1743 }
1744
1745 Ok(self)
1746 }
1747
1748 pub fn tanh_(&mut self) -> Result<&mut Self>
1753 where
1754 T: torsh_core::dtype::FloatElement,
1755 {
1756 if self.requires_grad {
1757 return Err(TorshError::InvalidArgument(
1758 "In-place operation on tensor that requires grad is not allowed".to_string(),
1759 ));
1760 }
1761
1762 let len = self.storage.len();
1763
1764 for i in 0..len {
1765 let x = self.storage.get(i)?;
1766 self.storage.set(i, x.tanh())?;
1767 }
1768
1769 Ok(self)
1770 }
1771
1772 pub fn gelu_(&mut self) -> Result<&mut Self>
1777 where
1778 T: torsh_core::dtype::FloatElement,
1779 {
1780 if self.requires_grad {
1781 return Err(TorshError::InvalidArgument(
1782 "In-place operation on tensor that requires grad is not allowed".to_string(),
1783 ));
1784 }
1785
1786 let len = self.storage.len();
1787 let pi = T::from(std::f64::consts::PI).expect("numeric conversion should succeed");
1788 let two = T::from(2.0).expect("numeric conversion should succeed");
1789 let sqrt_2_over_pi = (two / pi).sqrt();
1790 let point_044715 = T::from(0.044715).expect("numeric conversion should succeed");
1791 let one = <T as scirs2_core::numeric::One>::one();
1792 let half = T::from(0.5).expect("numeric conversion should succeed");
1793
1794 for i in 0..len {
1795 let x = self.storage.get(i)?;
1796 let x_cubed = x * x * x;
1797 let tanh_input = sqrt_2_over_pi * (x + point_044715 * x_cubed);
1798 let gelu_val = half * x * (one + tanh_input.tanh());
1799 self.storage.set(i, gelu_val)?;
1800 }
1801
1802 Ok(self)
1803 }
1804
1805 pub fn leaky_relu_(&mut self, negative_slope: T) -> Result<&mut Self>
1810 where
1811 T: std::cmp::PartialOrd + scirs2_core::numeric::Zero,
1812 {
1813 if self.requires_grad {
1814 return Err(TorshError::InvalidArgument(
1815 "In-place operation on tensor that requires grad is not allowed".to_string(),
1816 ));
1817 }
1818
1819 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && self.numel() >= 1024 {
1821 let slope_f32: f32 = unsafe { std::mem::transmute_copy::<T, f32>(&negative_slope) };
1823 self.make_unique()?;
1824 self.storage.with_slice_mut(|out_t: &mut [T]| {
1825 let out_f32: &mut [f32] = unsafe {
1826 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
1827 };
1828 crate::simd_ops_f32::leaky_relu_assign_f32(out_f32, slope_f32);
1829 Ok(())
1830 })?;
1831 return Ok(self);
1832 }
1833
1834 let zero = <T as scirs2_core::numeric::Zero>::zero();
1835 let len = self.storage.len();
1836
1837 for i in 0..len {
1838 let x = self.storage.get(i)?;
1839 if x < zero {
1840 self.storage.set(i, negative_slope * x)?;
1841 }
1842 }
1843
1844 Ok(self)
1845 }
1846
1847 pub fn clamp_(&mut self, min: T, max: T) -> Result<&mut Self>
1852 where
1853 T: std::cmp::PartialOrd,
1854 {
1855 if self.requires_grad {
1856 return Err(TorshError::InvalidArgument(
1857 "In-place operation on tensor that requires grad is not allowed".to_string(),
1858 ));
1859 }
1860
1861 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() && self.numel() >= 1024 {
1863 let min_f32: f32 = unsafe { std::mem::transmute_copy::<T, f32>(&min) };
1865 let max_f32: f32 = unsafe { std::mem::transmute_copy::<T, f32>(&max) };
1866 self.make_unique()?;
1867 self.storage.with_slice_mut(|out_t: &mut [T]| {
1868 let out_f32: &mut [f32] = unsafe {
1869 std::slice::from_raw_parts_mut(out_t.as_mut_ptr() as *mut f32, out_t.len())
1870 };
1871 crate::simd_ops_f32::clamp_assign_f32(out_f32, min_f32, max_f32);
1872 Ok(())
1873 })?;
1874 return Ok(self);
1875 }
1876
1877 let len = self.storage.len();
1878
1879 for i in 0..len {
1880 let x = self.storage.get(i)?;
1881 let clamped = if x < min {
1882 min
1883 } else if x > max {
1884 max
1885 } else {
1886 x
1887 };
1888 self.storage.set(i, clamped)?;
1889 }
1890
1891 Ok(self)
1892 }
1893}
1894
1895#[cfg(test)]
1896#[path = "math_ops_tests.rs"]
1897mod tests;