1use super::{DType, Tensor};
6use crate::errors::{Result, TrustformersError};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::collections::HashMap;
9use std::sync::atomic::AtomicU64;
10use std::sync::{Arc, RwLock};
11
12#[allow(dead_code)] static TENSOR_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16lazy_static::lazy_static! {
17 static ref GRADIENT_REGISTRY: Arc<RwLock<HashMap<u64, Tensor>>> = Arc::new(RwLock::new(HashMap::new()));
19}
20
21thread_local! {
22 static GRADIENT_MODE: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
24}
25
26pub fn enable_grad() {
28 GRADIENT_MODE.with(|mode| mode.set(true));
29}
30
31pub fn disable_grad() {
33 GRADIENT_MODE.with(|mode| mode.set(false));
34}
35
36pub fn is_grad_enabled() -> bool {
38 GRADIENT_MODE.with(|mode| mode.get())
39}
40
41pub fn clear_gradients() {
43 if let Ok(mut registry) = GRADIENT_REGISTRY.write() {
44 registry.clear();
45 }
46}
47
48impl Tensor {
49 fn tensor_id(&self) -> u64 {
51 use std::collections::hash_map::DefaultHasher;
53 use std::hash::{Hash, Hasher};
54
55 let mut hasher = DefaultHasher::new();
56 self.shape().hash(&mut hasher);
57
58 match self {
59 Tensor::F32(arr) => {
60 arr.as_ptr().hash(&mut hasher);
61 arr.len().hash(&mut hasher);
62 },
63 Tensor::F64(arr) => {
64 arr.as_ptr().hash(&mut hasher);
65 arr.len().hash(&mut hasher);
66 },
67 Tensor::I64(arr) => {
68 arr.as_ptr().hash(&mut hasher);
69 arr.len().hash(&mut hasher);
70 },
71 #[cfg(all(target_os = "macos", feature = "metal"))]
72 Tensor::Metal(data) => {
73 data.buffer_id.hash(&mut hasher);
75 self.len().hash(&mut hasher);
76 },
77 #[cfg(feature = "cuda")]
78 Tensor::CUDA(data) => {
79 data.buffer_id.hash(&mut hasher);
81 self.len().hash(&mut hasher);
82 },
83 _ => {
84 self.len().hash(&mut hasher);
86 },
87 }
88
89 hasher.finish()
90 }
91 pub fn shape(&self) -> Vec<usize> {
97 match self {
98 Tensor::F32(a) => a.shape().to_vec(),
99 Tensor::F64(a) => a.shape().to_vec(),
100 Tensor::F16(a) => a.shape().to_vec(),
101 Tensor::BF16(a) => a.shape().to_vec(),
102 Tensor::I64(a) => a.shape().to_vec(),
103 Tensor::C32(a) => a.shape().to_vec(),
104 Tensor::C64(a) => a.shape().to_vec(),
105 Tensor::CF16(a) => a.shape().to_vec(),
106 Tensor::CBF16(a) => a.shape().to_vec(),
107 Tensor::Sparse(s) => s.shape().to_vec(),
108 #[cfg(feature = "torch")]
109 Tensor::Torch(t) => t.size().iter().map(|&d| d as usize).collect(),
110 #[cfg(feature = "candle")]
111 Tensor::Candle(t) => t.shape().dims().to_vec(),
112 #[cfg(all(target_os = "macos", feature = "metal"))]
113 Tensor::Metal(data) => data.shape.clone(),
114 #[cfg(feature = "cuda")]
115 Tensor::CUDA(data) => data.shape.clone(),
116 }
117 }
118
119 pub fn len(&self) -> usize {
125 match self {
126 Tensor::F32(a) => a.len(),
127 Tensor::F64(a) => a.len(),
128 Tensor::F16(a) => a.len(),
129 Tensor::BF16(a) => a.len(),
130 Tensor::I64(a) => a.len(),
131 Tensor::C32(a) => a.len(),
132 Tensor::C64(a) => a.len(),
133 Tensor::CF16(a) => a.len(),
134 Tensor::CBF16(a) => a.len(),
135 Tensor::Sparse(s) => s.nnz(), #[cfg(feature = "torch")]
137 Tensor::Torch(t) => t.numel(),
138 #[cfg(feature = "candle")]
139 Tensor::Candle(t) => t.elem_count(),
140 #[cfg(all(target_os = "macos", feature = "metal"))]
141 Tensor::Metal(data) => data.shape.iter().product(),
142 #[cfg(feature = "cuda")]
143 Tensor::CUDA(data) => data.shape.iter().product(),
144 }
145 }
146
147 pub fn is_empty(&self) -> bool {
153 self.len() == 0
154 }
155
156 pub fn ndim(&self) -> usize {
162 self.shape().len()
163 }
164
165 pub fn size_bytes(&self) -> usize {
171 match self {
172 Tensor::F32(a) => a.len() * std::mem::size_of::<f32>(),
173 Tensor::F64(a) => a.len() * std::mem::size_of::<f64>(),
174 Tensor::F16(a) => a.len() * std::mem::size_of::<half::f16>(),
175 Tensor::BF16(a) => a.len() * std::mem::size_of::<half::bf16>(),
176 Tensor::I64(a) => a.len() * std::mem::size_of::<i64>(),
177 Tensor::C32(a) => a.len() * std::mem::size_of::<scirs2_core::Complex32>(),
178 Tensor::C64(a) => a.len() * std::mem::size_of::<scirs2_core::Complex64>(),
179 Tensor::CF16(a) => a.len() * std::mem::size_of::<scirs2_core::Complex<half::f16>>(),
180 Tensor::CBF16(a) => a.len() * std::mem::size_of::<scirs2_core::Complex<half::bf16>>(),
181 Tensor::Sparse(s) => s.nnz() * std::mem::size_of::<f32>(), #[cfg(feature = "torch")]
183 Tensor::Torch(t) => t.numel() * std::mem::size_of::<f32>(), #[cfg(feature = "candle")]
185 Tensor::Candle(t) => t.elem_count() * std::mem::size_of::<f32>(), #[cfg(all(target_os = "macos", feature = "metal"))]
187 Tensor::Metal(data) => {
188 let num_elements: usize = data.shape.iter().product();
189 num_elements * data.dtype.size_in_bytes()
190 },
191 #[cfg(feature = "cuda")]
192 Tensor::CUDA(data) => {
193 let num_elements: usize = data.shape.iter().product();
194 num_elements * data.dtype.size_in_bytes()
195 },
196 }
197 }
198
199 pub fn to_device(&self, device: &str) -> Result<Tensor> {
209 let device_lower = device.to_lowercase();
211
212 let (device_type, device_index) = if device_lower.contains(':') {
214 let parts: Vec<&str> = device_lower.split(':').collect();
215 if parts.len() != 2 {
216 return Err(TrustformersError::tensor_op_error(
217 &format!("Invalid device format '{}'. Expected format: 'device_type' or 'device_type:index'", device),
218 "to_device"
219 ));
220 }
221
222 let index = parts[1].parse::<usize>().map_err(|_| {
223 TrustformersError::tensor_op_error(
224 &format!(
225 "Invalid device index '{}'. Expected a non-negative integer",
226 parts[1]
227 ),
228 "to_device",
229 )
230 })?;
231
232 (parts[0], Some(index))
233 } else {
234 (device_lower.as_str(), None)
235 };
236
237 match device_type {
239 "cpu" => {
240 if let Some(index) = device_index {
242 if index > 0 {
243 return Err(TrustformersError::tensor_op_error(
244 &format!("CPU device index {} not supported. CPU only supports index 0 or no index", index),
245 "to_device"
246 ));
247 }
248 }
249 Ok(self.clone())
251 },
252 "cuda" => {
253 if let Some(index) = device_index {
255 Err(TrustformersError::tensor_op_error(
256 &format!("CUDA device cuda:{} not available. This build doesn't support CUDA. Consider using CPU instead with device='cpu'", index),
257 "to_device"
258 ))
259 } else {
260 Err(TrustformersError::tensor_op_error(
261 "CUDA devices not available. This build doesn't support CUDA. Consider using CPU instead with device='cpu'",
262 "to_device"
263 ))
264 }
265 },
266 "mps" => {
267 Err(TrustformersError::tensor_op_error(
269 "MPS device not available. This build doesn't support Metal Performance Shaders. Consider using CPU instead with device='cpu'",
270 "to_device"
271 ))
272 },
273 "tpu" => {
274 Err(TrustformersError::tensor_op_error(
276 "TPU devices not available. This build doesn't support TPU. Consider using CPU instead with device='cpu'",
277 "to_device"
278 ))
279 },
280 "xpu" | "intel" => {
281 Err(TrustformersError::tensor_op_error(
283 "Intel XPU devices not available. This build doesn't support Intel XPU. Consider using CPU instead with device='cpu'",
284 "to_device"
285 ))
286 },
287 "npu" => {
288 Err(TrustformersError::tensor_op_error(
290 "NPU devices not available. This build doesn't support NPU. Consider using CPU instead with device='cpu'",
291 "to_device"
292 ))
293 },
294 _ => {
295 Err(TrustformersError::tensor_op_error(
296 &format!("Unknown device type '{}'. Supported device types: cpu, cuda, mps, tpu, xpu, npu. For this build, only 'cpu' is supported", device_type),
297 "to_device"
298 ))
299 },
300 }
301 }
302
303 pub fn to_device_enum(&self, device: &crate::device::Device) -> Result<Tensor> {
334 match (self, device) {
335 #[cfg(all(target_os = "macos", feature = "metal"))]
337 (Tensor::F32(arr), crate::device::Device::Metal(_)) => {
338 use crate::gpu_ops::metal::get_metal_backend;
339 let backend = get_metal_backend()?;
340 let data_vec: Vec<f32> = arr.iter().copied().collect();
341
342 #[cfg(debug_assertions)]
343 {
344 eprintln!(
346 "🔍 to_device_enum(F32→Metal): data_vec.len()={}",
347 data_vec.len()
348 );
349 if !data_vec.is_empty() {
350 eprintln!(
351 "🔍 to_device_enum: first 10 values: {:?}",
352 &data_vec[..10.min(data_vec.len())]
353 );
354 eprintln!(
355 "🔍 to_device_enum: stats - min={:.4}, max={:.4}, mean={:.4}",
356 data_vec.iter().fold(f32::INFINITY, |a, &b| a.min(b)),
357 data_vec.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)),
358 data_vec.iter().sum::<f32>() / data_vec.len() as f32
359 );
360 }
361 }
362
363 let buffer_id = backend.create_persistent_buffer(&data_vec)?;
364
365 #[cfg(debug_assertions)]
366 {
367 eprintln!("🔍 to_device_enum: Created buffer_id={:?}", buffer_id);
368
369 let verify_data = backend.download_buffer_to_vec(&buffer_id)?;
371 eprintln!(
372 "🔍 to_device_enum: Verification download - len={}, first 10: {:?}",
373 verify_data.len(),
374 &verify_data[..10.min(verify_data.len())]
375 );
376 }
377
378 Ok(Tensor::Metal(super::MetalTensorData {
379 buffer_id,
380 shape: arr.shape().to_vec(),
381 dtype: DType::F32,
382 }))
383 },
384
385 #[cfg(all(target_os = "macos", feature = "metal"))]
387 (Tensor::F64(arr), crate::device::Device::Metal(_)) => {
388 use crate::gpu_ops::metal::get_metal_backend;
389 let backend = get_metal_backend()?;
390 let data_vec: Vec<f32> = arr.iter().map(|&x| x as f32).collect();
391 let buffer_id = backend.create_persistent_buffer(&data_vec)?;
392 Ok(Tensor::Metal(super::MetalTensorData {
393 buffer_id,
394 shape: arr.shape().to_vec(),
395 dtype: DType::F32,
396 }))
397 },
398
399 #[cfg(all(target_os = "macos", feature = "metal"))]
401 (Tensor::Metal(metal_data), crate::device::Device::CPU) => {
402 use crate::gpu_ops::metal::get_metal_backend;
403 let backend = get_metal_backend()?;
404 let buffer = backend.get_persistent_buffer(&metal_data.buffer_id)?;
405
406 let size: usize = metal_data.shape.iter().product();
408
409 match metal_data.dtype {
411 DType::F32 => {
412 let ptr = buffer.contents() as *const f32;
413 let data_vec = unsafe { std::slice::from_raw_parts(ptr, size) }.to_vec();
414
415 use scirs2_core::ndarray::ArrayD;
417 let arr = ArrayD::from_shape_vec(
418 scirs2_core::ndarray::IxDyn(&metal_data.shape),
419 data_vec,
420 )
421 .map_err(|e| {
422 TrustformersError::tensor_op_error(
423 &format!("Failed to create array from shape: {}", e),
424 "to_device_enum",
425 )
426 })?;
427 Ok(Tensor::F32(arr))
428 },
429 _ => Err(TrustformersError::tensor_op_error(
430 &format!("Unsupported Metal tensor dtype: {:?}", metal_data.dtype),
431 "to_device_enum",
432 )),
433 }
434 },
435
436 #[cfg(all(target_os = "macos", feature = "metal"))]
438 (Tensor::Metal(metal_data), crate::device::Device::Metal(_)) => {
439 Ok(Tensor::Metal(metal_data.clone()))
442 },
443
444 (Tensor::F32(_), crate::device::Device::CPU) => Ok(self.clone()),
446 (Tensor::F64(_), crate::device::Device::CPU) => Ok(self.clone()),
447 (Tensor::F16(_), crate::device::Device::CPU) => Ok(self.clone()),
448 (Tensor::BF16(_), crate::device::Device::CPU) => Ok(self.clone()),
449 (Tensor::I64(_), crate::device::Device::CPU) => Ok(self.clone()),
450 (Tensor::C32(_), crate::device::Device::CPU) => Ok(self.clone()),
451 (Tensor::C64(_), crate::device::Device::CPU) => Ok(self.clone()),
452 (Tensor::CF16(_), crate::device::Device::CPU) => Ok(self.clone()),
453 (Tensor::CBF16(_), crate::device::Device::CPU) => Ok(self.clone()),
454 (Tensor::Sparse(_), crate::device::Device::CPU) => Ok(self.clone()),
455
456 #[cfg(not(feature = "metal"))]
458 (_, crate::device::Device::Metal(_)) => Err(TrustformersError::hardware_error(
459 "Metal not available. Compile with --features metal",
460 "to_device_enum",
461 )),
462
463 #[cfg(feature = "cuda")]
465 #[allow(unused_variables)]
466 (Tensor::F32(arr), crate::device::Device::CUDA(device_id)) => {
467 #[cfg(any(target_os = "linux", target_os = "windows"))]
468 {
469 use crate::gpu_ops::cuda::get_cuda_backend;
470 let backend = get_cuda_backend(*device_id)?;
471 let data_vec: Vec<f32> = arr.iter().copied().collect();
472 let buffer_id = backend.create_persistent_buffer(&data_vec)?;
473 Ok(Tensor::CUDA(super::CudaTensorData {
474 buffer_id,
475 shape: arr.shape().to_vec(),
476 dtype: DType::F32,
477 }))
478 }
479 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
480 {
481 Err(TrustformersError::hardware_error(
482 "CUDA is only supported on Linux and Windows",
483 "to_device_enum",
484 ))
485 }
486 },
487
488 #[cfg(feature = "cuda")]
490 #[allow(unused_variables)]
491 (Tensor::F64(arr), crate::device::Device::CUDA(device_id)) => {
492 #[cfg(any(target_os = "linux", target_os = "windows"))]
493 {
494 use crate::gpu_ops::cuda::get_cuda_backend;
495 let backend = get_cuda_backend(*device_id)?;
496 let data_vec: Vec<f32> = arr.iter().map(|&x| x as f32).collect();
497 let buffer_id = backend.create_persistent_buffer(&data_vec)?;
498 Ok(Tensor::CUDA(super::CudaTensorData {
499 buffer_id,
500 shape: arr.shape().to_vec(),
501 dtype: DType::F32,
502 }))
503 }
504 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
505 {
506 Err(TrustformersError::hardware_error(
507 "CUDA is only supported on Linux and Windows",
508 "to_device_enum",
509 ))
510 }
511 },
512
513 #[cfg(feature = "cuda")]
515 #[allow(unused_variables)]
516 (Tensor::CUDA(cuda_data), crate::device::Device::CPU) => {
517 #[cfg(any(target_os = "linux", target_os = "windows"))]
518 {
519 use crate::gpu_ops::cuda::get_cuda_backend;
520 let backend = get_cuda_backend(0)?;
522
523 match cuda_data.dtype {
525 DType::F32 => {
526 let data_vec = backend.download_buffer(&cuda_data.buffer_id)?;
528
529 use scirs2_core::ndarray::ArrayD;
531 let arr = ArrayD::from_shape_vec(
532 scirs2_core::ndarray::IxDyn(&cuda_data.shape),
533 data_vec,
534 )
535 .map_err(|e| {
536 TrustformersError::tensor_op_error(
537 &format!("Failed to create array from shape: {}", e),
538 "to_device_enum",
539 )
540 })?;
541 Ok(Tensor::F32(arr))
542 },
543 _ => Err(TrustformersError::tensor_op_error(
544 &format!("Unsupported CUDA tensor dtype: {:?}", cuda_data.dtype),
545 "to_device_enum",
546 )),
547 }
548 }
549 #[cfg(not(any(target_os = "linux", target_os = "windows")))]
550 {
551 Err(TrustformersError::hardware_error(
552 "CUDA is only supported on Linux and Windows",
553 "to_device_enum",
554 ))
555 }
556 },
557
558 #[cfg(feature = "cuda")]
560 (Tensor::CUDA(cuda_data), crate::device::Device::CUDA(_)) => {
561 Ok(Tensor::CUDA(cuda_data.clone()))
564 },
565
566 #[cfg(not(feature = "cuda"))]
568 (_, crate::device::Device::CUDA(_)) => Err(TrustformersError::hardware_error(
569 "CUDA not available. Compile with --features cuda",
570 "to_device_enum",
571 )),
572
573 (_, crate::device::Device::ROCm(_)) => Err(TrustformersError::hardware_error(
575 "ROCm transfer not implemented yet",
576 "to_device_enum",
577 )),
578
579 (_, crate::device::Device::WebGPU) => Err(TrustformersError::hardware_error(
581 "WebGPU transfer not implemented yet",
582 "to_device_enum",
583 )),
584
585 #[allow(unreachable_patterns)]
587 _ => Err(TrustformersError::tensor_op_error(
588 &format!(
589 "Unsupported device transfer from {:?} to {:?}",
590 self.dtype(),
591 device
592 ),
593 "to_device_enum",
594 )),
595 }
596 }
597
598 pub fn grad(&self) -> Result<Tensor> {
622 if !is_grad_enabled() {
623 return Err(TrustformersError::tensor_op_error(
624 "Gradient tracking is not enabled. Use enable_grad() to enable gradient tracking.",
625 "grad",
626 ));
627 }
628
629 let tensor_id = self.tensor_id();
630
631 if let Ok(registry) = GRADIENT_REGISTRY.read() {
632 if let Some(grad_tensor) = registry.get(&tensor_id) {
633 Ok(grad_tensor.clone())
634 } else {
635 Err(TrustformersError::tensor_op_error(
636 "No gradient found for this tensor. Gradients are set during backward pass.",
637 "grad",
638 ))
639 }
640 } else {
641 Err(TrustformersError::tensor_op_error(
642 "Failed to access gradient registry.",
643 "grad",
644 ))
645 }
646 }
647
648 pub fn set_grad(&mut self, grad: Tensor) -> Result<()> {
673 if !is_grad_enabled() {
674 return Err(TrustformersError::tensor_op_error(
675 "Gradient tracking is not enabled. Use enable_grad() to enable gradient tracking.",
676 "set_grad",
677 ));
678 }
679
680 if self.shape() != grad.shape() {
682 return Err(TrustformersError::tensor_op_error(
683 &format!(
684 "Gradient shape {:?} doesn't match tensor shape {:?}",
685 grad.shape(),
686 self.shape()
687 ),
688 "set_grad",
689 ));
690 }
691
692 let tensor_id = self.tensor_id();
693
694 if let Ok(mut registry) = GRADIENT_REGISTRY.write() {
695 registry.insert(tensor_id, grad);
696 Ok(())
697 } else {
698 Err(TrustformersError::tensor_op_error(
699 "Failed to access gradient registry.",
700 "set_grad",
701 ))
702 }
703 }
704
705 pub fn data(&self) -> Result<Vec<f32>> {
711 match self {
712 Tensor::F32(a) => Ok(a.iter().cloned().collect()),
713 Tensor::F64(a) => Ok(a.iter().map(|&x| x as f32).collect()),
714 Tensor::I64(a) => Ok(a.iter().map(|&x| x as f32).collect()),
715 #[cfg(all(target_os = "macos", feature = "metal"))]
716 Tensor::Metal(_) => {
717 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
719 cpu_tensor.data()
720 },
721 #[cfg(feature = "cuda")]
722 Tensor::CUDA(_) => {
723 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
725 cpu_tensor.data()
726 },
727 _ => Err(TrustformersError::tensor_op_error(
728 "Unsupported tensor type for data conversion",
729 "data_conversion",
730 )),
731 }
732 }
733
734 pub fn data_f32(&self) -> Result<Vec<f32>> {
740 self.data()
741 }
742
743 pub fn set_data_f32(&mut self, data: &[f32]) -> Result<()> {
753 match self {
754 Tensor::F32(a) => {
755 let shape = a.shape().to_vec();
756 let expected_len: usize = shape.iter().product();
757 if data.len() != expected_len {
758 return Err(TrustformersError::tensor_op_error(
759 &format!(
760 "Data length {} does not match tensor size {}",
761 data.len(),
762 expected_len
763 ),
764 "set_data_f32",
765 ));
766 }
767 *a = ArrayD::from_shape_vec(IxDyn(&shape), data.to_vec()).map_err(|e| {
768 TrustformersError::tensor_op_error(&e.to_string(), "set_data_f32")
769 })?;
770 Ok(())
771 },
772 _ => Err(TrustformersError::tensor_op_error(
773 "set_data_f32 only supported for F32 tensors",
774 "set_data_f32",
775 )),
776 }
777 }
778
779 pub fn data_mut(&mut self) -> Result<&mut [f32]> {
785 match self {
786 Tensor::F32(a) => a.as_slice_mut().ok_or_else(|| {
787 TrustformersError::tensor_op_error(
788 "Tensor data must be contiguous for mutable access",
789 "data_mut",
790 )
791 }),
792 _ => Err(TrustformersError::tensor_op_error(
793 "Mutable data access only supported for F32 tensors",
794 "data_mut",
795 )),
796 }
797 }
798
799 pub fn modify_data<F>(&mut self, f: F) -> Result<()>
809 where
810 F: FnOnce(&mut [f32]),
811 {
812 match self {
813 Tensor::F32(a) => {
814 if let Some(slice) = a.as_slice_mut() {
815 f(slice);
816 Ok(())
817 } else {
818 Err(TrustformersError::tensor_op_error(
819 "Cannot get mutable slice",
820 "modify_data",
821 ))
822 }
823 },
824 _ => Err(TrustformersError::tensor_op_error(
825 "Modify data only supported for F32 tensors",
826 "modify_data",
827 )),
828 }
829 }
830
831 pub fn device(&self) -> String {
837 match self {
838 Tensor::F32(_)
839 | Tensor::F64(_)
840 | Tensor::F16(_)
841 | Tensor::BF16(_)
842 | Tensor::I64(_)
843 | Tensor::C32(_)
844 | Tensor::C64(_)
845 | Tensor::CF16(_)
846 | Tensor::CBF16(_) => "cpu".to_string(),
847 Tensor::Sparse(_) => "cpu".to_string(),
848 #[cfg(feature = "torch")]
849 Tensor::Torch(t) => format!("{:?}", t.device()),
850 #[cfg(feature = "candle")]
851 Tensor::Candle(t) => format!("{:?}", t.device()),
852 #[cfg(all(target_os = "macos", feature = "metal"))]
853 Tensor::Metal(_) => "metal".to_string(),
854 #[cfg(feature = "cuda")]
855 Tensor::CUDA(_) => "cuda".to_string(),
856 }
857 }
858
859 pub fn size(&self) -> usize {
865 self.shape().iter().product()
866 }
867
868 pub fn memory_usage(&self) -> usize {
874 match self {
875 Tensor::F32(a) => a.len() * std::mem::size_of::<f32>(),
876 Tensor::F64(a) => a.len() * std::mem::size_of::<f64>(),
877 Tensor::F16(a) => a.len() * std::mem::size_of::<half::f16>(),
878 Tensor::BF16(a) => a.len() * std::mem::size_of::<half::bf16>(),
879 Tensor::I64(a) => a.len() * std::mem::size_of::<i64>(),
880 Tensor::C32(a) => a.len() * std::mem::size_of::<scirs2_core::Complex32>(),
881 Tensor::C64(a) => a.len() * std::mem::size_of::<scirs2_core::Complex64>(),
882 Tensor::CF16(a) => a.len() * std::mem::size_of::<scirs2_core::Complex<half::f16>>(),
883 Tensor::CBF16(a) => a.len() * std::mem::size_of::<scirs2_core::Complex<half::bf16>>(),
884 Tensor::Sparse(s) => s.memory_usage(),
885 #[cfg(feature = "torch")]
886 Tensor::Torch(t) => t.numel() * 4, #[cfg(feature = "candle")]
888 Tensor::Candle(t) => t.elem_count() * 4, #[cfg(all(target_os = "macos", feature = "metal"))]
890 Tensor::Metal(m) => m.shape.iter().product::<usize>() * 4, #[cfg(feature = "cuda")]
892 Tensor::CUDA(c) => c.shape.iter().product::<usize>() * 4, }
894 }
895
896 pub fn dtype(&self) -> DType {
902 match self {
903 Tensor::F32(_) => DType::F32,
904 Tensor::F64(_) => DType::F64,
905 Tensor::F16(_) => DType::F16,
906 Tensor::BF16(_) => DType::BF16,
907 Tensor::I64(_) => DType::I64,
908 Tensor::C32(_) => DType::C32,
909 Tensor::C64(_) => DType::C64,
910 Tensor::CF16(_) => DType::CF16,
911 Tensor::CBF16(_) => DType::CBF16,
912 Tensor::Sparse(_) => DType::F32, #[cfg(feature = "torch")]
914 Tensor::Torch(_) => DType::F32, #[cfg(feature = "candle")]
916 Tensor::Candle(_) => DType::F32, #[cfg(all(target_os = "macos", feature = "metal"))]
918 Tensor::Metal(data) => data.dtype,
919 #[cfg(feature = "cuda")]
920 Tensor::CUDA(data) => data.dtype,
921 }
922 }
923
924 pub fn get_dtype(&self) -> DType {
926 self.dtype()
927 }
928
929 pub fn get_float(&self, index: usize) -> Result<f32> {
939 match self {
940 Tensor::F32(a) => {
941 if index >= a.len() {
942 return Err(TrustformersError::tensor_op_error(
943 &format!(
944 "Index {} out of bounds for tensor of size {}",
945 index,
946 a.len()
947 ),
948 "get_float",
949 ));
950 }
951 Ok(a.iter().nth(index).copied().unwrap_or(0.0))
952 },
953 Tensor::F64(a) => {
954 if index >= a.len() {
955 return Err(TrustformersError::tensor_op_error(
956 &format!(
957 "Index {} out of bounds for tensor of size {}",
958 index,
959 a.len()
960 ),
961 "get_float",
962 ));
963 }
964 Ok(a.iter().nth(index).copied().unwrap_or(0.0) as f32)
965 },
966 #[cfg(all(target_os = "macos", feature = "metal"))]
967 Tensor::Metal(_) => {
968 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
970 cpu_tensor.get_float(index)
971 },
972 #[cfg(feature = "cuda")]
973 Tensor::CUDA(_) => {
974 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
976 cpu_tensor.get_float(index)
977 },
978 _ => Err(TrustformersError::tensor_op_error(
979 "Get float not supported for this tensor type",
980 "get_float",
981 )),
982 }
983 }
984
985 pub fn item<T>(&self) -> Result<T>
995 where
996 T: num_traits::NumCast,
997 {
998 if self.len() != 1 {
999 return Err(TrustformersError::tensor_op_error(
1000 &format!(
1001 "item() requires a single-element tensor, but got {} elements",
1002 self.len()
1003 ),
1004 "item",
1005 ));
1006 }
1007
1008 match self {
1009 Tensor::F32(a) => {
1010 let val = a.iter().next().copied().unwrap_or(0.0);
1011 T::from(val).ok_or_else(|| {
1012 TrustformersError::tensor_op_error(
1013 "Failed to convert f32 to target type",
1014 "item",
1015 )
1016 })
1017 },
1018 Tensor::F64(a) => {
1019 let val = a.iter().next().copied().unwrap_or(0.0);
1020 T::from(val).ok_or_else(|| {
1021 TrustformersError::tensor_op_error(
1022 "Failed to convert f64 to target type",
1023 "item",
1024 )
1025 })
1026 },
1027 Tensor::I64(a) => {
1028 let val = a.iter().next().copied().unwrap_or(0);
1029 T::from(val).ok_or_else(|| {
1030 TrustformersError::tensor_op_error(
1031 "Failed to convert i64 to target type",
1032 "item",
1033 )
1034 })
1035 },
1036 #[cfg(all(target_os = "macos", feature = "metal"))]
1037 Tensor::Metal(_) => {
1038 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
1040 cpu_tensor.item::<T>()
1041 },
1042 #[cfg(feature = "cuda")]
1043 Tensor::CUDA(_) => {
1044 let cpu_tensor = self.to_device_enum(&crate::device::Device::CPU)?;
1046 cpu_tensor.item::<T>()
1047 },
1048 _ => Err(TrustformersError::tensor_op_error(
1049 "item() not supported for this tensor type",
1050 "item",
1051 )),
1052 }
1053 }
1054
1055 pub fn get_scalar_i64(&self) -> Result<i64> {
1061 self.item::<i64>()
1062 }
1063
1064 pub fn eq_scalar(&self, scalar: f64) -> Result<Tensor> {
1075 match self {
1076 Tensor::F32(a) => {
1077 let scalar_f32 = scalar as f32;
1078 let result =
1079 a.mapv(|x| if (x - scalar_f32).abs() < 1e-6 { 1.0f32 } else { 0.0f32 });
1080 Ok(Tensor::F32(result))
1081 },
1082 Tensor::F64(a) => {
1083 let result = a.mapv(|x| if (x - scalar).abs() < 1e-9 { 1.0f64 } else { 0.0f64 });
1084 Ok(Tensor::F64(result))
1085 },
1086 Tensor::I64(a) => {
1087 let scalar_i64 = scalar as i64;
1088 let result = a.mapv(|x| if x == scalar_i64 { 1i64 } else { 0i64 });
1089 Ok(Tensor::I64(result))
1090 },
1091 _ => Err(TrustformersError::tensor_op_error(
1092 "eq_scalar not supported for this tensor type",
1093 "eq_scalar",
1094 )),
1095 }
1096 }
1097
1098 pub fn batch_split(&self, batch_size: usize) -> Result<Vec<Tensor>> {
1121 if batch_size == 0 {
1122 return Err(TrustformersError::tensor_op_error(
1123 "Batch size must be greater than 0",
1124 "batch_split",
1125 ));
1126 }
1127
1128 let shape = self.shape();
1129 if shape.is_empty() {
1130 return Err(TrustformersError::tensor_op_error(
1131 "Cannot batch split a scalar tensor",
1132 "batch_split",
1133 ));
1134 }
1135
1136 let total_size = shape[0];
1137 let mut batches = Vec::new();
1138
1139 for start in (0..total_size).step_by(batch_size) {
1140 let end = std::cmp::min(start + batch_size, total_size);
1141 let batch = self.slice(0, start, end)?;
1142 batches.push(batch);
1143 }
1144
1145 Ok(batches)
1146 }
1147
1148 pub fn batch_stack(tensors: &[&Tensor]) -> Result<Tensor> {
1171 if tensors.is_empty() {
1172 return Err(TrustformersError::tensor_op_error(
1173 "Cannot stack empty tensor list",
1174 "batch_stack",
1175 ));
1176 }
1177
1178 let reference_shape = tensors[0].shape();
1180 for (i, tensor) in tensors.iter().enumerate() {
1181 if tensor.shape() != reference_shape {
1182 return Err(TrustformersError::tensor_op_error(
1183 &format!(
1184 "Tensor {} has shape {:?}, expected {:?}",
1185 i,
1186 tensor.shape(),
1187 reference_shape
1188 ),
1189 "batch_stack",
1190 ));
1191 }
1192 }
1193
1194 let mut new_shape = vec![tensors.len()];
1196 new_shape.extend_from_slice(&reference_shape);
1197
1198 match tensors[0] {
1199 Tensor::F32(_) => {
1200 let mut result_data = Vec::new();
1201 for tensor in tensors {
1202 if let Tensor::F32(arr) = tensor {
1203 result_data.extend(arr.iter().copied());
1204 }
1205 }
1206 Tensor::from_vec(result_data, &new_shape)
1207 },
1208 _ => Err(TrustformersError::tensor_op_error(
1209 "Batch stacking currently only implemented for F32 tensors",
1210 "batch_stack",
1211 )),
1212 }
1213 }
1214
1215 pub fn unbatch(&self) -> Result<Vec<Tensor>> {
1232 let shape = self.shape();
1233 if shape.is_empty() {
1234 return Err(TrustformersError::tensor_op_error(
1235 "Cannot unbatch a scalar tensor",
1236 "unbatch",
1237 ));
1238 }
1239
1240 let batch_size = shape[0];
1241 let mut items = Vec::with_capacity(batch_size);
1242
1243 for i in 0..batch_size {
1244 let item = self.slice(0, i, i + 1)?;
1245 let squeezed = item.squeeze(0)?;
1247 items.push(squeezed);
1248 }
1249
1250 Ok(items)
1251 }
1252}
1253
1254#[cfg(test)]
1255mod tests {
1256 use super::*;
1257
1258 #[test]
1259 fn test_gradient_tracking_basic() {
1260 enable_grad();
1262
1263 let mut x = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
1264 let grad = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3])
1265 .expect("Tensor from_vec failed");
1266
1267 assert!(x.set_grad(grad.clone()).is_ok());
1269
1270 let retrieved_grad = x.grad().expect("operation failed in test");
1272 assert_eq!(retrieved_grad.shape(), vec![2, 3]);
1273
1274 disable_grad();
1275 }
1276
1277 #[test]
1278 fn test_gradient_tracking_disabled() {
1279 disable_grad();
1281
1282 let mut x = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
1283 let grad = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
1284
1285 assert!(x.set_grad(grad).is_err());
1287 assert!(x.grad().is_err());
1288 }
1289
1290 #[test]
1291 fn test_gradient_shape_validation() {
1292 enable_grad();
1293
1294 let mut x = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
1295 let wrong_shape_grad = Tensor::ones(&[3, 2]).expect("Failed to create ones tensor");
1296
1297 assert!(x.set_grad(wrong_shape_grad).is_err());
1299
1300 disable_grad();
1301 }
1302
1303 #[test]
1304 fn test_clear_gradients() {
1305 enable_grad();
1306
1307 let mut x = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
1308 let grad = Tensor::ones(&[2, 3]).expect("Failed to create ones tensor");
1309
1310 x.set_grad(grad).expect("operation failed in test");
1312
1313 assert!(x.grad().is_ok());
1315
1316 clear_gradients();
1318
1319 assert!(x.grad().is_err());
1321
1322 disable_grad();
1323 }
1324
1325 #[test]
1326 fn test_gradient_mode_functions() {
1327 disable_grad();
1329 assert!(!is_grad_enabled());
1330
1331 enable_grad();
1332 assert!(is_grad_enabled());
1333
1334 disable_grad();
1335 assert!(!is_grad_enabled());
1336 }
1337}