Skip to main content

trustformers_core/tensor/
mod.rs

1//! Core tensor abstraction for TrustformeRS.
2//!
3//! This module provides the fundamental `Tensor` type that serves as the backbone
4//! for all numerical computations in TrustformeRS. It offers a unified interface
5//! over different backend implementations (ndarray, PyTorch, Candle) while
6//! maintaining high performance through SIMD optimizations.
7//!
8//! # Overview
9//!
10//! The `Tensor` enum provides:
11//! - Multi-backend support (CPU via ndarray, GPU via PyTorch/Candle)
12//! - Common tensor operations (matmul, add, softmax, etc.)
13//! - Broadcasting and shape manipulation
14//! - Gradient-related operations for training
15//! - Serialization support for model persistence
16//!
17//! # Example
18//!
19//! ```no_run
20//! use trustformers_core::tensor::Tensor;
21//!
22//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
23//! // Create tensors
24//! let a = Tensor::randn(&[2, 3])?;
25//! let b = Tensor::randn(&[3, 4])?;
26//!
27//! // Perform operations
28//! let c = a.matmul(&b)?;  // Matrix multiplication
29//! let d = c.relu()?;       // ReLU activation
30//! let e = d.softmax(-1)?;  // Softmax along last dimension
31//! # Ok(())
32//! # }
33//! ```
34//!
35//! # Performance Notes
36//!
37//! - SIMD operations are used where available for better performance
38//! - Tensor operations are optimized for common transformer patterns
39//! - GPU operations are available when compiled with appropriate features
40
41mod activations;
42mod complex;
43pub mod constructors;
44mod conversions;
45mod expression;
46mod math_ops;
47mod sparse;
48pub mod transformations;
49mod utils;
50
51#[cfg(test)]
52mod property_tests;
53
54use crate::errors::Result;
55use scirs2_core::ndarray::{ArrayBase, ArrayD, Dim, IxDynImpl, OwnedRepr};
56use scirs2_core::Complex;
57use scirs2_core::{Complex32, Complex64};
58use serde::{Deserialize, Serialize};
59
60/// Data types supported by tensors
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
62pub enum DType {
63    /// 32-bit floating point
64    F32,
65    /// 16-bit floating point
66    F16,
67    /// Brain floating point 16
68    BF16,
69    /// 64-bit floating point
70    F64,
71    /// 32-bit complex number (two 32-bit floats)
72    C32,
73    /// 64-bit complex number (two 64-bit floats)
74    C64,
75    /// 16-bit complex number (two 16-bit floats)
76    CF16,
77    /// Brain floating point 16 complex number (two BF16 floats)
78    CBF16,
79    /// 8-bit unsigned integer
80    U8,
81    /// 16-bit unsigned integer
82    U16,
83    /// 32-bit unsigned integer
84    U32,
85    /// 64-bit unsigned integer
86    U64,
87    /// 8-bit signed integer
88    I8,
89    /// 16-bit signed integer
90    I16,
91    /// 32-bit signed integer
92    I32,
93    /// 64-bit signed integer
94    I64,
95    /// Boolean
96    Bool,
97}
98
99impl DType {
100    /// Returns the size in bytes of an element of this data type
101    pub fn size_in_bytes(&self) -> usize {
102        match self {
103            DType::F32 => 4,
104            DType::F16 => 2,
105            DType::BF16 => 2,
106            DType::F64 => 8,
107            DType::C32 => 8,   // Two 32-bit floats
108            DType::C64 => 16,  // Two 64-bit floats
109            DType::CF16 => 4,  // Two 16-bit floats
110            DType::CBF16 => 4, // Two BF16 floats
111            DType::U8 => 1,
112            DType::U16 => 2,
113            DType::U32 => 4,
114            DType::U64 => 8,
115            DType::I8 => 1,
116            DType::I16 => 2,
117            DType::I32 => 4,
118            DType::I64 => 8,
119            DType::Bool => 1,
120        }
121    }
122}
123
124/// Multi-backend tensor representation.
125///
126/// The `Tensor` enum provides a unified interface over different tensor backends,
127/// allowing seamless switching between CPU and GPU computations based on availability
128/// and requirements.
129///
130/// # Variants
131///
132/// - `F32`: 32-bit floating point tensors (most common for neural networks)
133/// - `F64`: 64-bit floating point tensors (for high precision requirements)
134/// - `I64`: 64-bit integer tensors (for indices and discrete values)
135/// - `Torch`: PyTorch backend (requires `torch` feature)
136/// - `Candle`: Candle backend (requires `candle` feature)
137///
138/// # Backend Selection
139///
140/// The default backend is ndarray (CPU), which provides good performance for
141/// small to medium models. For larger models or when GPU acceleration is needed,
142/// enable the `torch` or `candle` features.
143///
144/// # Example
145///
146/// ```no_run
147/// use trustformers_core::tensor::Tensor;
148///
149/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
150/// // Create a tensor with default backend
151/// let tensor = Tensor::zeros(&[2, 3, 4])?;
152/// assert_eq!(tensor.shape(), vec![2, 3, 4]);
153/// # Ok(())
154/// # }
155/// ```
156/// Metal GPU buffer wrapper for GPU-resident tensors
157#[cfg(all(target_os = "macos", feature = "metal"))]
158#[derive(Debug)]
159pub struct MetalTensorData {
160    pub buffer_id: crate::gpu_ops::metal::BufferId,
161    pub shape: Vec<usize>,
162    pub dtype: DType,
163}
164
165#[cfg(all(target_os = "macos", feature = "metal"))]
166impl Clone for MetalTensorData {
167    fn clone(&self) -> Self {
168        // Note: This creates a reference to the same GPU buffer
169        // Actual data is not copied - buffer is reference counted
170        Self {
171            buffer_id: self.buffer_id,
172            shape: self.shape.clone(),
173            dtype: self.dtype,
174        }
175    }
176}
177
178/// CUDA GPU buffer wrapper for GPU-resident tensors
179#[cfg(feature = "cuda")]
180#[derive(Debug)]
181pub struct CudaTensorData {
182    pub buffer_id: crate::gpu_ops::cuda::BufferId,
183    pub shape: Vec<usize>,
184    pub dtype: DType,
185}
186
187#[cfg(feature = "cuda")]
188impl Clone for CudaTensorData {
189    fn clone(&self) -> Self {
190        // Note: This creates a reference to the same GPU buffer
191        // Actual data is not copied - buffer is reference counted
192        Self {
193            buffer_id: self.buffer_id,
194            shape: self.shape.clone(),
195            dtype: self.dtype,
196        }
197    }
198}
199
200pub enum Tensor {
201    // Standard ndarray types
202    F32(ArrayD<f32>),
203    F64(ArrayD<f64>),
204    F16(ArrayD<half::f16>),
205    BF16(ArrayD<half::bf16>),
206    I64(ArrayD<i64>),
207    // Complex number types
208    C32(ArrayD<Complex32>),
209    C64(ArrayD<Complex64>),
210    CF16(ArrayD<Complex<half::f16>>),
211    CBF16(ArrayD<Complex<half::bf16>>),
212    // Sparse tensor variant
213    Sparse(crate::sparse_tensor::SparseTensor),
214    // GPU support available via hardware acceleration module (CUDA, ROCm, Intel OneAPI, Vulkan, Metal)
215    // and backend-specific implementations (Torch, Candle)
216    #[cfg(feature = "torch")]
217    Torch(tch::Tensor),
218    #[cfg(feature = "candle")]
219    Candle(candle_core::Tensor),
220    // Metal GPU-resident tensor (data lives on GPU)
221    #[cfg(all(target_os = "macos", feature = "metal"))]
222    Metal(MetalTensorData),
223    // CUDA GPU-resident tensor (data lives on GPU)
224    #[cfg(feature = "cuda")]
225    CUDA(CudaTensorData),
226}
227
228// Manual Clone implementation because tch::Tensor doesn't implement Clone
229impl Clone for Tensor {
230    fn clone(&self) -> Self {
231        match self {
232            Tensor::F32(arr) => Tensor::F32(arr.clone()),
233            Tensor::F64(arr) => Tensor::F64(arr.clone()),
234            Tensor::F16(arr) => Tensor::F16(arr.clone()),
235            Tensor::BF16(arr) => Tensor::BF16(arr.clone()),
236            Tensor::I64(arr) => Tensor::I64(arr.clone()),
237            Tensor::C32(arr) => Tensor::C32(arr.clone()),
238            Tensor::C64(arr) => Tensor::C64(arr.clone()),
239            Tensor::CF16(arr) => Tensor::CF16(arr.clone()),
240            Tensor::CBF16(arr) => Tensor::CBF16(arr.clone()),
241            Tensor::Sparse(s) => Tensor::Sparse(s.clone()),
242            #[cfg(feature = "torch")]
243            Tensor::Torch(t) => Tensor::Torch(t.shallow_clone()),
244            #[cfg(feature = "candle")]
245            Tensor::Candle(t) => Tensor::Candle(t.clone()),
246            #[cfg(all(target_os = "macos", feature = "metal"))]
247            Tensor::Metal(data) => Tensor::Metal(data.clone()),
248            #[cfg(feature = "cuda")]
249            Tensor::CUDA(data) => Tensor::CUDA(data.clone()),
250        }
251    }
252}
253
254// Manual Debug implementation for Tensor
255impl std::fmt::Debug for Tensor {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        match self {
258            Tensor::F32(_) => write!(f, "Tensor::F32(shape: {:?}, dtype: F32)", self.shape()),
259            Tensor::F64(_) => write!(f, "Tensor::F64(shape: {:?}, dtype: F64)", self.shape()),
260            Tensor::F16(_) => write!(f, "Tensor::F16(shape: {:?}, dtype: F16)", self.shape()),
261            Tensor::BF16(_) => write!(f, "Tensor::BF16(shape: {:?}, dtype: BF16)", self.shape()),
262            Tensor::I64(_) => write!(f, "Tensor::I64(shape: {:?}, dtype: I64)", self.shape()),
263            Tensor::C32(_) => write!(f, "Tensor::C32(shape: {:?}, dtype: C32)", self.shape()),
264            Tensor::C64(_) => write!(f, "Tensor::C64(shape: {:?}, dtype: C64)", self.shape()),
265            Tensor::CF16(_) => write!(f, "Tensor::CF16(shape: {:?}, dtype: CF16)", self.shape()),
266            Tensor::CBF16(_) => write!(f, "Tensor::CBF16(shape: {:?}, dtype: CBF16)", self.shape()),
267            Tensor::Sparse(s) => write!(f, "Tensor::Sparse({:?})", s),
268            #[cfg(feature = "torch")]
269            Tensor::Torch(_) => write!(f, "Tensor::Torch(shape: {:?})", self.shape()),
270            #[cfg(feature = "candle")]
271            Tensor::Candle(_) => write!(f, "Tensor::Candle(shape: {:?})", self.shape()),
272            #[cfg(all(target_os = "macos", feature = "metal"))]
273            Tensor::Metal(data) => write!(
274                f,
275                "Tensor::Metal(shape: {:?}, dtype: {:?}, buffer_id: {:?})",
276                data.shape, data.dtype, data.buffer_id
277            ),
278            #[cfg(feature = "cuda")]
279            Tensor::CUDA(data) => write!(
280                f,
281                "Tensor::CUDA(shape: {:?}, dtype: {:?}, buffer_id: {:?})",
282                data.shape, data.dtype, data.buffer_id
283            ),
284        }
285    }
286}
287
288// Safety: Both PyTorch and Candle backends are internally thread-safe:
289// - PyTorch: The tch::Tensor uses reference counting and the underlying data is managed
290//   by PyTorch's thread-safe memory allocator. The raw pointer is just an FFI wrapper.
291// - Candle: Tensors are designed to be thread-safe with reference-counted storage.
292// Multiple threads can safely hold references to the same tensor.
293#[cfg(any(feature = "torch", feature = "candle"))]
294unsafe impl Sync for Tensor {}
295
296// The implementations are in separate modules but the methods are part of the Tensor impl blocks
297
298impl From<ArrayBase<OwnedRepr<f32>, Dim<IxDynImpl>>> for Tensor {
299    fn from(arr: ArrayD<f32>) -> Self {
300        Tensor::F32(arr)
301    }
302}
303
304impl From<ArrayBase<OwnedRepr<f64>, Dim<IxDynImpl>>> for Tensor {
305    fn from(arr: ArrayD<f64>) -> Self {
306        Tensor::F64(arr)
307    }
308}
309
310// Additional math operations for trait compatibility
311impl std::ops::Add for Tensor {
312    type Output = Result<Tensor>;
313
314    fn add(self, other: Tensor) -> Self::Output {
315        Tensor::add(&self, &other)
316    }
317}
318
319impl std::ops::Add for &Tensor {
320    type Output = Result<Tensor>;
321
322    fn add(self, other: &Tensor) -> Self::Output {
323        Tensor::add(self, other)
324    }
325}
326
327impl std::ops::Add<&&Tensor> for &Tensor {
328    type Output = Result<Tensor>;
329
330    fn add(self, other: &&Tensor) -> Self::Output {
331        Tensor::add(self, other)
332    }
333}
334
335impl std::ops::Add<&Tensor> for &&Tensor {
336    type Output = Result<Tensor>;
337
338    fn add(self, other: &Tensor) -> Self::Output {
339        Tensor::add(self, other)
340    }
341}
342
343impl std::ops::Sub for Tensor {
344    type Output = Result<Tensor>;
345
346    fn sub(self, other: Tensor) -> Self::Output {
347        Tensor::sub(&self, &other)
348    }
349}
350
351// Scalar multiplication operators
352impl std::ops::Mul<f32> for Tensor {
353    type Output = Result<Tensor>;
354
355    fn mul(self, scalar: f32) -> Self::Output {
356        self.scalar_mul(scalar)
357    }
358}
359
360impl std::ops::Mul<f32> for &Tensor {
361    type Output = Result<Tensor>;
362
363    fn mul(self, scalar: f32) -> Self::Output {
364        self.scalar_mul(scalar)
365    }
366}
367
368impl std::ops::Mul<f64> for Tensor {
369    type Output = Result<Tensor>;
370
371    fn mul(self, scalar: f64) -> Self::Output {
372        self.scalar_mul(scalar as f32)
373    }
374}
375
376impl std::ops::Mul<f64> for &Tensor {
377    type Output = Result<Tensor>;
378
379    fn mul(self, scalar: f64) -> Self::Output {
380        self.scalar_mul(scalar as f32)
381    }
382}
383
384// Element-wise multiplication with another tensor
385impl std::ops::Mul<&Tensor> for &Tensor {
386    type Output = Result<Tensor>;
387
388    fn mul(self, other: &Tensor) -> Self::Output {
389        Tensor::mul(self, other)
390    }
391}
392
393impl std::ops::Mul<Tensor> for &Tensor {
394    type Output = Result<Tensor>;
395
396    fn mul(self, other: Tensor) -> Self::Output {
397        Tensor::mul(self, &other)
398    }
399}
400
401impl std::ops::Mul<&Tensor> for Tensor {
402    type Output = Result<Tensor>;
403
404    fn mul(self, other: &Tensor) -> Self::Output {
405        Tensor::mul(&self, other)
406    }
407}
408
409// Scalar division operators
410impl std::ops::Div<f32> for Tensor {
411    type Output = Result<Tensor>;
412
413    fn div(self, scalar: f32) -> Self::Output {
414        self.scalar_div(scalar)
415    }
416}
417
418impl std::ops::Div<f32> for &Tensor {
419    type Output = Result<Tensor>;
420
421    fn div(self, scalar: f32) -> Self::Output {
422        self.scalar_div(scalar)
423    }
424}
425
426impl std::ops::Div<f64> for Tensor {
427    type Output = Result<Tensor>;
428
429    fn div(self, scalar: f64) -> Self::Output {
430        self.scalar_div(scalar as f32)
431    }
432}
433
434impl std::ops::Div<f64> for &Tensor {
435    type Output = Result<Tensor>;
436
437    fn div(self, scalar: f64) -> Self::Output {
438        self.scalar_div(scalar as f32)
439    }
440}
441
442impl std::ops::Div<f64> for &&Tensor {
443    type Output = Result<Tensor>;
444
445    fn div(self, scalar: f64) -> Self::Output {
446        (*self).scalar_div(scalar as f32)
447    }
448}
449
450// Tensor subtraction operators
451impl std::ops::Sub for &Tensor {
452    type Output = Result<Tensor>;
453
454    fn sub(self, other: &Tensor) -> Self::Output {
455        Tensor::sub(self, other)
456    }
457}
458
459// Type alias for backward compatibility
460pub type TensorType = DType;
461
462// Re-export expression template types
463pub use expression::{EvalContext, ExprNode, OpType, OptimizationHints, TensorExpr};
464
465// Re-export gradient tracking utilities
466pub use utils::{clear_gradients, disable_grad, enable_grad, is_grad_enabled};