Skip to main content

torsh_backend/quantization/
tensor.rs

1//! Quantized tensor representation and operations
2//!
3//! This module provides the QuantizedTensor struct which represents tensors
4//! that have been quantized to lower-precision integer formats. It includes
5//! memory-efficient storage, shape management, and basic tensor operations
6//! optimized for quantized data.
7
8use super::params::QuantizationParams;
9use super::types::QuantizedDType;
10use crate::{BackendResult, Device};
11
12#[cfg(not(feature = "std"))]
13use alloc::vec::Vec;
14
15/// Quantized tensor representation
16///
17/// Represents a tensor that has been quantized to a lower-precision format.
18/// The data is stored as raw bytes with associated quantization parameters
19/// that define how to interpret and convert the data back to floating-point.
20#[derive(Debug, Clone)]
21pub struct QuantizedTensor {
22    /// Quantized data stored as raw bytes
23    ///
24    /// The data layout depends on the quantization type:
25    /// - For 8-bit and 16-bit types: one value per element
26    /// - For 4-bit types: two values packed per byte
27    /// - For binary: eight values packed per byte
28    pub data: Vec<u8>,
29
30    /// Original tensor shape
31    ///
32    /// Maintains the logical shape of the tensor for operations.
33    /// The total number of elements is the product of all dimensions.
34    pub shape: Vec<usize>,
35
36    /// Quantization parameters
37    ///
38    /// Contains all information needed to convert between quantized
39    /// and floating-point representations, including scale factors,
40    /// zero points, and metadata about the quantization scheme.
41    pub params: QuantizationParams,
42
43    /// Device where tensor is stored
44    ///
45    /// Indicates whether the tensor data resides in CPU memory,
46    /// GPU memory, or other accelerator memory.
47    pub device: Device,
48}
49
50impl QuantizedTensor {
51    /// Create a new quantized tensor with zero-initialized data
52    ///
53    /// Allocates memory for a quantized tensor with the specified shape
54    /// and quantization parameters. The data is initialized to zeros.
55    ///
56    /// # Arguments
57    ///
58    /// * `shape` - Dimensions of the tensor
59    /// * `params` - Quantization parameters defining the format
60    /// * `device` - Target device for tensor storage
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use torsh_backend::quantization::{QuantizedTensor, QuantizationParams};
66    /// use torsh_backend::Device;
67    ///
68    /// let shape = vec![2, 3, 4];
69    /// let params = QuantizationParams::int8_symmetric();
70    /// let device = Device::cpu().unwrap();
71    /// let tensor = QuantizedTensor::new(shape, params, device);
72    /// assert_eq!(tensor.num_elements(), 24);
73    /// ```
74    pub fn new(shape: Vec<usize>, params: QuantizationParams, device: Device) -> Self {
75        let total_elements: usize = shape.iter().product();
76        let data_size = Self::calculate_data_size(total_elements, &params.dtype);
77
78        Self {
79            data: vec![0; data_size],
80            shape,
81            params,
82            device,
83        }
84    }
85
86    /// Create a quantized tensor from existing data
87    ///
88    /// Creates a quantized tensor using pre-existing quantized data.
89    /// The data length must match the expected size for the given
90    /// shape and quantization type.
91    ///
92    /// # Arguments
93    ///
94    /// * `data` - Pre-quantized data bytes
95    /// * `shape` - Dimensions of the tensor
96    /// * `params` - Quantization parameters
97    /// * `device` - Target device for tensor storage
98    ///
99    /// # Returns
100    ///
101    /// Returns `Ok(QuantizedTensor)` if the data size matches expectations,
102    /// or an error if the sizes are incompatible.
103    pub fn from_data(
104        data: Vec<u8>,
105        shape: Vec<usize>,
106        params: QuantizationParams,
107        device: Device,
108    ) -> BackendResult<Self> {
109        let total_elements: usize = shape.iter().product();
110        let expected_size = Self::calculate_data_size(total_elements, &params.dtype);
111
112        if data.len() != expected_size {
113            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
114                "Data size mismatch: expected {} bytes for shape {:?} with dtype {:?}, got {} bytes",
115                expected_size, shape, params.dtype, data.len()
116            )));
117        }
118
119        Ok(Self {
120            data,
121            shape,
122            params,
123            device,
124        })
125    }
126
127    /// Calculate the required data size in bytes for a given element count and dtype
128    fn calculate_data_size(num_elements: usize, dtype: &QuantizedDType) -> usize {
129        match dtype {
130            QuantizedDType::Int4 | QuantizedDType::UInt4 => {
131                // 4-bit types: 2 elements per byte, round up for odd counts
132                (num_elements + 1) / 2
133            }
134            QuantizedDType::Binary => {
135                // Binary: 8 elements per byte, round up
136                (num_elements + 7) / 8
137            }
138            _ => {
139                // 8-bit and 16-bit types: standard byte alignment
140                num_elements * (dtype.bits() as usize / 8)
141            }
142        }
143    }
144
145    /// Get the number of elements in the tensor
146    ///
147    /// Returns the total number of logical elements in the tensor,
148    /// which is the product of all dimensions in the shape.
149    ///
150    /// # Examples
151    ///
152    /// ```
153    /// # use torsh_backend::quantization::{QuantizedTensor, QuantizationParams};
154    /// # use torsh_backend::Device;
155    /// let tensor = QuantizedTensor::new(vec![2, 3, 4], QuantizationParams::default(), Device::cpu().unwrap());
156    /// assert_eq!(tensor.num_elements(), 24);
157    /// ```
158    pub fn num_elements(&self) -> usize {
159        self.shape.iter().product()
160    }
161
162    /// Get the memory usage in bytes
163    ///
164    /// Returns the actual number of bytes used to store the quantized data.
165    /// This may be less than `num_elements()` for sub-byte quantization types.
166    ///
167    /// # Examples
168    ///
169    /// ```
170    /// # use torsh_backend::quantization::{QuantizedTensor, QuantizationParams};
171    /// # use torsh_backend::Device;
172    /// let params = QuantizationParams::int4_symmetric();
173    /// let tensor = QuantizedTensor::new(vec![8], params, Device::cpu().unwrap());
174    /// assert_eq!(tensor.memory_usage(), 4); // 8 elements, 2 per byte = 4 bytes
175    /// ```
176    pub fn memory_usage(&self) -> usize {
177        self.data.len()
178    }
179
180    /// Get the shape of the tensor
181    ///
182    /// Returns a reference to the shape vector. This is the logical
183    /// shape of the tensor, not the storage layout.
184    pub fn shape(&self) -> &[usize] {
185        &self.shape
186    }
187
188    /// Get the number of dimensions
189    ///
190    /// Returns the number of dimensions (rank) of the tensor.
191    pub fn ndim(&self) -> usize {
192        self.shape.len()
193    }
194
195    /// Check if the tensor is empty (has zero elements)
196    pub fn is_empty(&self) -> bool {
197        self.num_elements() == 0
198    }
199
200    /// Get the size of a specific dimension
201    ///
202    /// Returns the size of the dimension at the given index,
203    /// or an error if the index is out of bounds.
204    pub fn size(&self, dim: usize) -> BackendResult<usize> {
205        self.shape.get(dim).copied().ok_or_else(|| {
206            torsh_core::error::TorshError::InvalidArgument(format!(
207                "Dimension {} is out of bounds for tensor with {} dimensions",
208                dim,
209                self.ndim()
210            ))
211        })
212    }
213
214    /// Reshape the tensor to a new shape
215    ///
216    /// Returns a new tensor with the same data but a different shape.
217    /// The total number of elements must remain the same.
218    ///
219    /// # Arguments
220    ///
221    /// * `new_shape` - New shape for the tensor
222    ///
223    /// # Returns
224    ///
225    /// Returns `Ok(QuantizedTensor)` with the new shape, or an error
226    /// if the total number of elements doesn't match.
227    pub fn reshape(&self, new_shape: Vec<usize>) -> BackendResult<QuantizedTensor> {
228        let new_num_elements: usize = new_shape.iter().product();
229        if new_num_elements != self.num_elements() {
230            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
231                "Cannot reshape tensor with {} elements to shape with {} elements",
232                self.num_elements(),
233                new_num_elements
234            )));
235        }
236
237        Ok(QuantizedTensor {
238            data: self.data.clone(),
239            shape: new_shape,
240            params: self.params.clone(),
241            device: self.device.clone(),
242        })
243    }
244
245    /// Create a view with a new shape (zero-copy reshape)
246    ///
247    /// Similar to reshape, but returns a view that shares the same data.
248    /// This is more memory-efficient but creates aliasing.
249    pub fn view(&self, new_shape: Vec<usize>) -> BackendResult<QuantizedTensorView<'_>> {
250        let new_num_elements: usize = new_shape.iter().product();
251        if new_num_elements != self.num_elements() {
252            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
253                "Cannot view tensor with {} elements as shape with {} elements",
254                self.num_elements(),
255                new_num_elements
256            )));
257        }
258
259        Ok(QuantizedTensorView {
260            data: &self.data,
261            shape: new_shape,
262            params: &self.params,
263            device: &self.device,
264        })
265    }
266
267    /// Move tensor to a different device
268    ///
269    /// Creates a copy of the tensor on the specified device.
270    /// If the source and destination devices are the same, returns a copy without transfer.
271    /// For different devices, performs a data transfer and creates a new tensor.
272    pub fn to_device(&self, device: Device) -> BackendResult<QuantizedTensor> {
273        // If the devices are the same, no transfer is needed
274        if self.device.device_type() == device.device_type() && self.device.id() == device.id() {
275            return Ok(self.clone());
276        }
277
278        // For cross-device transfers, we need to handle the data movement
279        // Currently implementing basic data copy - can be enhanced with optimized
280        // backend-specific transfers in the future
281        let transferred_data = self.transfer_data_to_device(&device)?;
282
283        Ok(QuantizedTensor {
284            data: transferred_data,
285            shape: self.shape.clone(),
286            params: self.params.clone(),
287            device,
288        })
289    }
290
291    /// Transfer tensor data to a different device
292    ///
293    /// This is a helper method that handles the actual data transfer.
294    /// Currently implements basic data copying, but can be enhanced with
295    /// backend-specific optimizations for different device types.
296    fn transfer_data_to_device(&self, target_device: &Device) -> BackendResult<Vec<u8>> {
297        use torsh_core::device::DeviceType;
298
299        // For now, implement basic data copying across device types
300        // This provides functional cross-device support while maintaining
301        // simplicity and can be optimized in future iterations
302
303        match (self.device.device_type(), target_device.device_type()) {
304            // Same device type transfers (different device IDs)
305            (DeviceType::Cpu, DeviceType::Cpu) => {
306                // CPU to CPU: simple memory copy
307                Ok(self.data.clone())
308            }
309            (DeviceType::Cuda(_), DeviceType::Cuda(_)) => {
310                // CUDA to CUDA: device-to-device copy
311                // For now, copy through host memory
312                Ok(self.data.clone())
313            }
314            (DeviceType::Metal(_), DeviceType::Metal(_)) => {
315                // Metal to Metal: device-to-device copy
316                Ok(self.data.clone())
317            }
318
319            // Cross-device type transfers
320            (DeviceType::Cpu, DeviceType::Cuda(_)) => {
321                // CPU to CUDA: host to device transfer
322                Ok(self.data.clone())
323            }
324            (DeviceType::Cuda(_), DeviceType::Cpu) => {
325                // CUDA to CPU: device to host transfer
326                Ok(self.data.clone())
327            }
328            (DeviceType::Cpu, DeviceType::Metal(_)) => {
329                // CPU to Metal: host to device transfer
330                Ok(self.data.clone())
331            }
332            (DeviceType::Metal(_), DeviceType::Cpu) => {
333                // Metal to CPU: device to host transfer
334                Ok(self.data.clone())
335            }
336            (DeviceType::Cuda(_), DeviceType::Metal(_)) => {
337                // CUDA to Metal: cross-device transfer via host
338                Ok(self.data.clone())
339            }
340            (DeviceType::Metal(_), DeviceType::Cuda(_)) => {
341                // Metal to CUDA: cross-device transfer via host
342                Ok(self.data.clone())
343            }
344
345            // Future device types can be added here
346            _ => {
347                // Fallback: basic data copy for any unsupported device combinations
348                Ok(self.data.clone())
349            }
350        }
351    }
352
353    /// Get a slice of the raw data
354    ///
355    /// Returns a reference to a portion of the underlying byte data.
356    /// This is useful for low-level operations and custom kernels.
357    ///
358    /// # Arguments
359    ///
360    /// * `start` - Starting byte index
361    /// * `len` - Number of bytes to include
362    ///
363    /// # Safety
364    ///
365    /// The caller must ensure that the slice boundaries are valid
366    /// and aligned with the quantization format.
367    pub fn data_slice(&self, start: usize, len: usize) -> BackendResult<&[u8]> {
368        if start + len > self.data.len() {
369            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
370                "Slice [{}..{}] is out of bounds for data of length {}",
371                start,
372                start + len,
373                self.data.len()
374            )));
375        }
376
377        Ok(&self.data[start..start + len])
378    }
379
380    /// Get a mutable slice of the raw data
381    ///
382    /// Returns a mutable reference to a portion of the underlying byte data.
383    /// This allows in-place modifications of the quantized data.
384    ///
385    /// # Arguments
386    ///
387    /// * `start` - Starting byte index
388    /// * `len` - Number of bytes to include
389    ///
390    /// # Safety
391    ///
392    /// The caller must ensure that any modifications maintain the
393    /// integrity of the quantized representation.
394    pub fn data_slice_mut(&mut self, start: usize, len: usize) -> BackendResult<&mut [u8]> {
395        if start + len > self.data.len() {
396            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
397                "Slice [{}..{}] is out of bounds for data of length {}",
398                start,
399                start + len,
400                self.data.len()
401            )));
402        }
403
404        Ok(&mut self.data[start..start + len])
405    }
406
407    /// Calculate storage efficiency compared to FP32
408    ///
409    /// Returns the ratio of this tensor's memory usage to what
410    /// an equivalent FP32 tensor would require.
411    pub fn storage_efficiency(&self) -> f32 {
412        let fp32_size = self.num_elements() * 4; // 4 bytes per FP32
413        if fp32_size == 0 {
414            return 1.0;
415        }
416        self.memory_usage() as f32 / fp32_size as f32
417    }
418
419    /// Get compression ratio compared to FP32
420    ///
421    /// Returns how many times smaller this tensor is compared to FP32.
422    pub fn compression_ratio(&self) -> f32 {
423        1.0 / self.storage_efficiency()
424    }
425
426    /// Validate tensor consistency
427    ///
428    /// Checks that the tensor's data size, shape, and parameters
429    /// are all consistent with each other.
430    pub fn validate(&self) -> BackendResult<()> {
431        // Validate quantization parameters
432        self.params.validate()?;
433
434        // Check data size consistency
435        let expected_size = Self::calculate_data_size(self.num_elements(), &self.params.dtype);
436        if self.data.len() != expected_size {
437            return Err(torsh_core::error::TorshError::InvalidArgument(format!(
438                "Data size inconsistency: expected {} bytes, actual {} bytes",
439                expected_size,
440                self.data.len()
441            )));
442        }
443
444        // Check for empty shape
445        if self.shape.is_empty() {
446            return Err(torsh_core::error::TorshError::InvalidArgument(
447                "Tensor shape cannot be empty".to_string(),
448            ));
449        }
450
451        // Check for zero dimensions
452        for (i, &dim) in self.shape.iter().enumerate() {
453            if dim == 0 {
454                return Err(torsh_core::error::TorshError::InvalidArgument(format!(
455                    "Dimension {} cannot be zero",
456                    i
457                )));
458            }
459        }
460
461        Ok(())
462    }
463}
464
465/// Read-only view of a quantized tensor with different shape
466///
467/// Provides a view into an existing quantized tensor with a potentially
468/// different shape, without copying the underlying data.
469#[derive(Debug)]
470pub struct QuantizedTensorView<'a> {
471    /// Reference to the original data
472    pub data: &'a [u8],
473    /// View shape (may differ from original)
474    pub shape: Vec<usize>,
475    /// Reference to quantization parameters
476    pub params: &'a QuantizationParams,
477    /// Reference to device information
478    pub device: &'a Device,
479}
480
481impl<'a> QuantizedTensorView<'a> {
482    /// Get the number of elements in the view
483    pub fn num_elements(&self) -> usize {
484        self.shape.iter().product()
485    }
486
487    /// Get the memory usage in bytes
488    pub fn memory_usage(&self) -> usize {
489        self.data.len()
490    }
491
492    /// Get the shape of the view
493    pub fn shape(&self) -> &[usize] {
494        &self.shape
495    }
496
497    /// Get the number of dimensions
498    pub fn ndim(&self) -> usize {
499        self.shape.len()
500    }
501
502    /// Convert view to owned tensor
503    pub fn to_owned(&self) -> QuantizedTensor {
504        QuantizedTensor {
505            data: self.data.to_vec(),
506            shape: self.shape.clone(),
507            params: self.params.clone(),
508            device: self.device.clone(),
509        }
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::quantization::QuantizationParams;
517
518    #[test]
519    fn test_tensor_creation() {
520        let shape = vec![2, 3, 4];
521        let params = QuantizationParams::int8_symmetric();
522        let device = Device::cpu().unwrap();
523        let tensor = QuantizedTensor::new(shape.clone(), params, device.clone());
524
525        assert_eq!(tensor.shape(), &shape);
526        assert_eq!(tensor.num_elements(), 24);
527        assert_eq!(tensor.memory_usage(), 24); // 1 byte per element for Int8
528        assert_eq!(tensor.device, device);
529    }
530
531    #[test]
532    fn test_int4_tensor_size() {
533        let shape = vec![8];
534        let params = QuantizationParams::int4_symmetric();
535        let tensor = QuantizedTensor::new(shape, params, Device::cpu().unwrap());
536
537        assert_eq!(tensor.num_elements(), 8);
538        assert_eq!(tensor.memory_usage(), 4); // 2 elements per byte for Int4
539    }
540
541    #[test]
542    fn test_binary_tensor_size() {
543        let shape = vec![16];
544        let mut params = QuantizationParams::default();
545        params.dtype = QuantizedDType::Binary;
546        let tensor = QuantizedTensor::new(shape, params, Device::cpu().unwrap());
547
548        assert_eq!(tensor.num_elements(), 16);
549        assert_eq!(tensor.memory_usage(), 2); // 8 elements per byte for Binary
550    }
551
552    #[test]
553    fn test_tensor_from_data() {
554        let data = vec![1, 2, 3, 4];
555        let shape = vec![4];
556        let params = QuantizationParams::int8_symmetric();
557        let tensor =
558            QuantizedTensor::from_data(data.clone(), shape, params, Device::cpu().unwrap())
559                .unwrap();
560
561        assert_eq!(tensor.data, data);
562        assert_eq!(tensor.num_elements(), 4);
563    }
564
565    #[test]
566    fn test_tensor_from_data_size_mismatch() {
567        let data = vec![1, 2, 3]; // 3 bytes
568        let shape = vec![4]; // Expects 4 bytes for Int8
569        let params = QuantizationParams::int8_symmetric();
570        let result = QuantizedTensor::from_data(data, shape, params, Device::cpu().unwrap());
571
572        assert!(result.is_err());
573    }
574
575    #[test]
576    fn test_tensor_reshape() {
577        let tensor = QuantizedTensor::new(
578            vec![2, 6],
579            QuantizationParams::default(),
580            Device::cpu().unwrap(),
581        );
582        let reshaped = tensor.reshape(vec![3, 4]).unwrap();
583
584        assert_eq!(reshaped.shape(), &[3, 4]);
585        assert_eq!(reshaped.num_elements(), 12);
586        assert_eq!(reshaped.data.len(), tensor.data.len());
587    }
588
589    #[test]
590    fn test_tensor_reshape_invalid() {
591        let tensor = QuantizedTensor::new(
592            vec![2, 6],
593            QuantizationParams::default(),
594            Device::cpu().unwrap(),
595        );
596        let result = tensor.reshape(vec![3, 5]); // 15 elements != 12
597
598        assert!(result.is_err());
599    }
600
601    #[test]
602    fn test_tensor_view() {
603        let tensor = QuantizedTensor::new(
604            vec![2, 6],
605            QuantizationParams::default(),
606            Device::cpu().unwrap(),
607        );
608        let view = tensor.view(vec![4, 3]).unwrap();
609
610        assert_eq!(view.shape(), &[4, 3]);
611        assert_eq!(view.num_elements(), 12);
612        assert_eq!(view.data.len(), tensor.data.len());
613    }
614
615    #[test]
616    fn test_storage_efficiency() {
617        let int8_tensor = QuantizedTensor::new(
618            vec![10],
619            QuantizationParams::int8_symmetric(),
620            Device::cpu().unwrap(),
621        );
622        assert_eq!(int8_tensor.storage_efficiency(), 0.25); // 1 byte vs 4 bytes per element
623
624        let int4_tensor = QuantizedTensor::new(
625            vec![10],
626            QuantizationParams::int4_symmetric(),
627            Device::cpu().unwrap(),
628        );
629        assert_eq!(int4_tensor.storage_efficiency(), 0.125); // 0.5 bytes vs 4 bytes per element
630    }
631
632    #[test]
633    fn test_compression_ratio() {
634        let int8_tensor = QuantizedTensor::new(
635            vec![10],
636            QuantizationParams::int8_symmetric(),
637            Device::cpu().unwrap(),
638        );
639        assert_eq!(int8_tensor.compression_ratio(), 4.0); // 4x compression
640
641        let int4_tensor = QuantizedTensor::new(
642            vec![10],
643            QuantizationParams::int4_symmetric(),
644            Device::cpu().unwrap(),
645        );
646        assert_eq!(int4_tensor.compression_ratio(), 8.0); // 8x compression
647    }
648
649    #[test]
650    fn test_data_slice() {
651        let tensor = QuantizedTensor::new(
652            vec![4],
653            QuantizationParams::int8_symmetric(),
654            Device::cpu().unwrap(),
655        );
656        let slice = tensor.data_slice(1, 2).unwrap();
657        assert_eq!(slice.len(), 2);
658
659        // Out of bounds should fail
660        assert!(tensor.data_slice(3, 3).is_err());
661    }
662
663    #[test]
664    fn test_tensor_validation() {
665        let tensor = QuantizedTensor::new(
666            vec![2, 3],
667            QuantizationParams::default(),
668            Device::cpu().unwrap(),
669        );
670        assert!(tensor.validate().is_ok());
671
672        // Test with inconsistent data
673        let mut bad_tensor = tensor.clone();
674        bad_tensor.data.truncate(1); // Make data too small
675        assert!(bad_tensor.validate().is_err());
676    }
677
678    #[test]
679    fn test_tensor_properties() {
680        let tensor = QuantizedTensor::new(
681            vec![2, 3, 4],
682            QuantizationParams::default(),
683            Device::cpu().unwrap(),
684        );
685
686        assert_eq!(tensor.ndim(), 3);
687        assert_eq!(tensor.size(0).unwrap(), 2);
688        assert_eq!(tensor.size(1).unwrap(), 3);
689        assert_eq!(tensor.size(2).unwrap(), 4);
690        assert!(tensor.size(3).is_err()); // Out of bounds
691
692        assert!(!tensor.is_empty());
693    }
694
695    #[test]
696    fn test_empty_tensor() {
697        let tensor = QuantizedTensor::new(
698            vec![0],
699            QuantizationParams::default(),
700            Device::cpu().unwrap(),
701        );
702        assert!(tensor.validate().is_err()); // Zero dimension should be invalid
703    }
704}