torsh_backend/quantization/tensor.rs
1//! Quantized tensor representation and operations
2//!
3//! This module provides the QuantizedTensor struct which represents tensors
4//! that have been quantized to lower-precision integer formats. It includes
5//! memory-efficient storage, shape management, and basic tensor operations
6//! optimized for quantized data.
7
8use super::params::QuantizationParams;
9use super::types::QuantizedDType;
10use crate::{BackendResult, Device};
11
12#[cfg(not(feature = "std"))]
13use alloc::vec::Vec;
14
15/// Quantized tensor representation
16///
17/// Represents a tensor that has been quantized to a lower-precision format.
18/// The data is stored as raw bytes with associated quantization parameters
19/// that define how to interpret and convert the data back to floating-point.
20#[derive(Debug, Clone)]
21pub struct QuantizedTensor {
22 /// Quantized data stored as raw bytes
23 ///
24 /// The data layout depends on the quantization type:
25 /// - For 8-bit and 16-bit types: one value per element
26 /// - For 4-bit types: two values packed per byte
27 /// - For binary: eight values packed per byte
28 pub data: Vec<u8>,
29
30 /// Original tensor shape
31 ///
32 /// Maintains the logical shape of the tensor for operations.
33 /// The total number of elements is the product of all dimensions.
34 pub shape: Vec<usize>,
35
36 /// Quantization parameters
37 ///
38 /// Contains all information needed to convert between quantized
39 /// and floating-point representations, including scale factors,
40 /// zero points, and metadata about the quantization scheme.
41 pub params: QuantizationParams,
42
43 /// Device where tensor is stored
44 ///
45 /// Indicates whether the tensor data resides in CPU memory,
46 /// GPU memory, or other accelerator memory.
47 pub device: Device,
48}
49
50impl QuantizedTensor {
51 /// Create a new quantized tensor with zero-initialized data
52 ///
53 /// Allocates memory for a quantized tensor with the specified shape
54 /// and quantization parameters. The data is initialized to zeros.
55 ///
56 /// # Arguments
57 ///
58 /// * `shape` - Dimensions of the tensor
59 /// * `params` - Quantization parameters defining the format
60 /// * `device` - Target device for tensor storage
61 ///
62 /// # Examples
63 ///
64 /// ```
65 /// use torsh_backend::quantization::{QuantizedTensor, QuantizationParams};
66 /// use torsh_backend::Device;
67 ///
68 /// let shape = vec![2, 3, 4];
69 /// let params = QuantizationParams::int8_symmetric();
70 /// let device = Device::cpu().unwrap();
71 /// let tensor = QuantizedTensor::new(shape, params, device);
72 /// assert_eq!(tensor.num_elements(), 24);
73 /// ```
74 pub fn new(shape: Vec<usize>, params: QuantizationParams, device: Device) -> Self {
75 let total_elements: usize = shape.iter().product();
76 let data_size = Self::calculate_data_size(total_elements, ¶ms.dtype);
77
78 Self {
79 data: vec![0; data_size],
80 shape,
81 params,
82 device,
83 }
84 }
85
86 /// Create a quantized tensor from existing data
87 ///
88 /// Creates a quantized tensor using pre-existing quantized data.
89 /// The data length must match the expected size for the given
90 /// shape and quantization type.
91 ///
92 /// # Arguments
93 ///
94 /// * `data` - Pre-quantized data bytes
95 /// * `shape` - Dimensions of the tensor
96 /// * `params` - Quantization parameters
97 /// * `device` - Target device for tensor storage
98 ///
99 /// # Returns
100 ///
101 /// Returns `Ok(QuantizedTensor)` if the data size matches expectations,
102 /// or an error if the sizes are incompatible.
103 pub fn from_data(
104 data: Vec<u8>,
105 shape: Vec<usize>,
106 params: QuantizationParams,
107 device: Device,
108 ) -> BackendResult<Self> {
109 let total_elements: usize = shape.iter().product();
110 let expected_size = Self::calculate_data_size(total_elements, ¶ms.dtype);
111
112 if data.len() != expected_size {
113 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
114 "Data size mismatch: expected {} bytes for shape {:?} with dtype {:?}, got {} bytes",
115 expected_size, shape, params.dtype, data.len()
116 )));
117 }
118
119 Ok(Self {
120 data,
121 shape,
122 params,
123 device,
124 })
125 }
126
127 /// Calculate the required data size in bytes for a given element count and dtype
128 fn calculate_data_size(num_elements: usize, dtype: &QuantizedDType) -> usize {
129 match dtype {
130 QuantizedDType::Int4 | QuantizedDType::UInt4 => {
131 // 4-bit types: 2 elements per byte, round up for odd counts
132 (num_elements + 1) / 2
133 }
134 QuantizedDType::Binary => {
135 // Binary: 8 elements per byte, round up
136 (num_elements + 7) / 8
137 }
138 _ => {
139 // 8-bit and 16-bit types: standard byte alignment
140 num_elements * (dtype.bits() as usize / 8)
141 }
142 }
143 }
144
145 /// Get the number of elements in the tensor
146 ///
147 /// Returns the total number of logical elements in the tensor,
148 /// which is the product of all dimensions in the shape.
149 ///
150 /// # Examples
151 ///
152 /// ```
153 /// # use torsh_backend::quantization::{QuantizedTensor, QuantizationParams};
154 /// # use torsh_backend::Device;
155 /// let tensor = QuantizedTensor::new(vec![2, 3, 4], QuantizationParams::default(), Device::cpu().unwrap());
156 /// assert_eq!(tensor.num_elements(), 24);
157 /// ```
158 pub fn num_elements(&self) -> usize {
159 self.shape.iter().product()
160 }
161
162 /// Get the memory usage in bytes
163 ///
164 /// Returns the actual number of bytes used to store the quantized data.
165 /// This may be less than `num_elements()` for sub-byte quantization types.
166 ///
167 /// # Examples
168 ///
169 /// ```
170 /// # use torsh_backend::quantization::{QuantizedTensor, QuantizationParams};
171 /// # use torsh_backend::Device;
172 /// let params = QuantizationParams::int4_symmetric();
173 /// let tensor = QuantizedTensor::new(vec![8], params, Device::cpu().unwrap());
174 /// assert_eq!(tensor.memory_usage(), 4); // 8 elements, 2 per byte = 4 bytes
175 /// ```
176 pub fn memory_usage(&self) -> usize {
177 self.data.len()
178 }
179
180 /// Get the shape of the tensor
181 ///
182 /// Returns a reference to the shape vector. This is the logical
183 /// shape of the tensor, not the storage layout.
184 pub fn shape(&self) -> &[usize] {
185 &self.shape
186 }
187
188 /// Get the number of dimensions
189 ///
190 /// Returns the number of dimensions (rank) of the tensor.
191 pub fn ndim(&self) -> usize {
192 self.shape.len()
193 }
194
195 /// Check if the tensor is empty (has zero elements)
196 pub fn is_empty(&self) -> bool {
197 self.num_elements() == 0
198 }
199
200 /// Get the size of a specific dimension
201 ///
202 /// Returns the size of the dimension at the given index,
203 /// or an error if the index is out of bounds.
204 pub fn size(&self, dim: usize) -> BackendResult<usize> {
205 self.shape.get(dim).copied().ok_or_else(|| {
206 torsh_core::error::TorshError::InvalidArgument(format!(
207 "Dimension {} is out of bounds for tensor with {} dimensions",
208 dim,
209 self.ndim()
210 ))
211 })
212 }
213
214 /// Reshape the tensor to a new shape
215 ///
216 /// Returns a new tensor with the same data but a different shape.
217 /// The total number of elements must remain the same.
218 ///
219 /// # Arguments
220 ///
221 /// * `new_shape` - New shape for the tensor
222 ///
223 /// # Returns
224 ///
225 /// Returns `Ok(QuantizedTensor)` with the new shape, or an error
226 /// if the total number of elements doesn't match.
227 pub fn reshape(&self, new_shape: Vec<usize>) -> BackendResult<QuantizedTensor> {
228 let new_num_elements: usize = new_shape.iter().product();
229 if new_num_elements != self.num_elements() {
230 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
231 "Cannot reshape tensor with {} elements to shape with {} elements",
232 self.num_elements(),
233 new_num_elements
234 )));
235 }
236
237 Ok(QuantizedTensor {
238 data: self.data.clone(),
239 shape: new_shape,
240 params: self.params.clone(),
241 device: self.device.clone(),
242 })
243 }
244
245 /// Create a view with a new shape (zero-copy reshape)
246 ///
247 /// Similar to reshape, but returns a view that shares the same data.
248 /// This is more memory-efficient but creates aliasing.
249 pub fn view(&self, new_shape: Vec<usize>) -> BackendResult<QuantizedTensorView<'_>> {
250 let new_num_elements: usize = new_shape.iter().product();
251 if new_num_elements != self.num_elements() {
252 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
253 "Cannot view tensor with {} elements as shape with {} elements",
254 self.num_elements(),
255 new_num_elements
256 )));
257 }
258
259 Ok(QuantizedTensorView {
260 data: &self.data,
261 shape: new_shape,
262 params: &self.params,
263 device: &self.device,
264 })
265 }
266
267 /// Move tensor to a different device
268 ///
269 /// Creates a copy of the tensor on the specified device.
270 /// If the source and destination devices are the same, returns a copy without transfer.
271 /// For different devices, performs a data transfer and creates a new tensor.
272 pub fn to_device(&self, device: Device) -> BackendResult<QuantizedTensor> {
273 // If the devices are the same, no transfer is needed
274 if self.device.device_type() == device.device_type() && self.device.id() == device.id() {
275 return Ok(self.clone());
276 }
277
278 // For cross-device transfers, we need to handle the data movement
279 // Currently implementing basic data copy - can be enhanced with optimized
280 // backend-specific transfers in the future
281 let transferred_data = self.transfer_data_to_device(&device)?;
282
283 Ok(QuantizedTensor {
284 data: transferred_data,
285 shape: self.shape.clone(),
286 params: self.params.clone(),
287 device,
288 })
289 }
290
291 /// Transfer tensor data to a different device
292 ///
293 /// This is a helper method that handles the actual data transfer.
294 /// Currently implements basic data copying, but can be enhanced with
295 /// backend-specific optimizations for different device types.
296 fn transfer_data_to_device(&self, target_device: &Device) -> BackendResult<Vec<u8>> {
297 use torsh_core::device::DeviceType;
298
299 // For now, implement basic data copying across device types
300 // This provides functional cross-device support while maintaining
301 // simplicity and can be optimized in future iterations
302
303 match (self.device.device_type(), target_device.device_type()) {
304 // Same device type transfers (different device IDs)
305 (DeviceType::Cpu, DeviceType::Cpu) => {
306 // CPU to CPU: simple memory copy
307 Ok(self.data.clone())
308 }
309 (DeviceType::Cuda(_), DeviceType::Cuda(_)) => {
310 // CUDA to CUDA: device-to-device copy
311 // For now, copy through host memory
312 Ok(self.data.clone())
313 }
314 (DeviceType::Metal(_), DeviceType::Metal(_)) => {
315 // Metal to Metal: device-to-device copy
316 Ok(self.data.clone())
317 }
318
319 // Cross-device type transfers
320 (DeviceType::Cpu, DeviceType::Cuda(_)) => {
321 // CPU to CUDA: host to device transfer
322 Ok(self.data.clone())
323 }
324 (DeviceType::Cuda(_), DeviceType::Cpu) => {
325 // CUDA to CPU: device to host transfer
326 Ok(self.data.clone())
327 }
328 (DeviceType::Cpu, DeviceType::Metal(_)) => {
329 // CPU to Metal: host to device transfer
330 Ok(self.data.clone())
331 }
332 (DeviceType::Metal(_), DeviceType::Cpu) => {
333 // Metal to CPU: device to host transfer
334 Ok(self.data.clone())
335 }
336 (DeviceType::Cuda(_), DeviceType::Metal(_)) => {
337 // CUDA to Metal: cross-device transfer via host
338 Ok(self.data.clone())
339 }
340 (DeviceType::Metal(_), DeviceType::Cuda(_)) => {
341 // Metal to CUDA: cross-device transfer via host
342 Ok(self.data.clone())
343 }
344
345 // Future device types can be added here
346 _ => {
347 // Fallback: basic data copy for any unsupported device combinations
348 Ok(self.data.clone())
349 }
350 }
351 }
352
353 /// Get a slice of the raw data
354 ///
355 /// Returns a reference to a portion of the underlying byte data.
356 /// This is useful for low-level operations and custom kernels.
357 ///
358 /// # Arguments
359 ///
360 /// * `start` - Starting byte index
361 /// * `len` - Number of bytes to include
362 ///
363 /// # Safety
364 ///
365 /// The caller must ensure that the slice boundaries are valid
366 /// and aligned with the quantization format.
367 pub fn data_slice(&self, start: usize, len: usize) -> BackendResult<&[u8]> {
368 if start + len > self.data.len() {
369 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
370 "Slice [{}..{}] is out of bounds for data of length {}",
371 start,
372 start + len,
373 self.data.len()
374 )));
375 }
376
377 Ok(&self.data[start..start + len])
378 }
379
380 /// Get a mutable slice of the raw data
381 ///
382 /// Returns a mutable reference to a portion of the underlying byte data.
383 /// This allows in-place modifications of the quantized data.
384 ///
385 /// # Arguments
386 ///
387 /// * `start` - Starting byte index
388 /// * `len` - Number of bytes to include
389 ///
390 /// # Safety
391 ///
392 /// The caller must ensure that any modifications maintain the
393 /// integrity of the quantized representation.
394 pub fn data_slice_mut(&mut self, start: usize, len: usize) -> BackendResult<&mut [u8]> {
395 if start + len > self.data.len() {
396 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
397 "Slice [{}..{}] is out of bounds for data of length {}",
398 start,
399 start + len,
400 self.data.len()
401 )));
402 }
403
404 Ok(&mut self.data[start..start + len])
405 }
406
407 /// Calculate storage efficiency compared to FP32
408 ///
409 /// Returns the ratio of this tensor's memory usage to what
410 /// an equivalent FP32 tensor would require.
411 pub fn storage_efficiency(&self) -> f32 {
412 let fp32_size = self.num_elements() * 4; // 4 bytes per FP32
413 if fp32_size == 0 {
414 return 1.0;
415 }
416 self.memory_usage() as f32 / fp32_size as f32
417 }
418
419 /// Get compression ratio compared to FP32
420 ///
421 /// Returns how many times smaller this tensor is compared to FP32.
422 pub fn compression_ratio(&self) -> f32 {
423 1.0 / self.storage_efficiency()
424 }
425
426 /// Validate tensor consistency
427 ///
428 /// Checks that the tensor's data size, shape, and parameters
429 /// are all consistent with each other.
430 pub fn validate(&self) -> BackendResult<()> {
431 // Validate quantization parameters
432 self.params.validate()?;
433
434 // Check data size consistency
435 let expected_size = Self::calculate_data_size(self.num_elements(), &self.params.dtype);
436 if self.data.len() != expected_size {
437 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
438 "Data size inconsistency: expected {} bytes, actual {} bytes",
439 expected_size,
440 self.data.len()
441 )));
442 }
443
444 // Check for empty shape
445 if self.shape.is_empty() {
446 return Err(torsh_core::error::TorshError::InvalidArgument(
447 "Tensor shape cannot be empty".to_string(),
448 ));
449 }
450
451 // Check for zero dimensions
452 for (i, &dim) in self.shape.iter().enumerate() {
453 if dim == 0 {
454 return Err(torsh_core::error::TorshError::InvalidArgument(format!(
455 "Dimension {} cannot be zero",
456 i
457 )));
458 }
459 }
460
461 Ok(())
462 }
463}
464
465/// Read-only view of a quantized tensor with different shape
466///
467/// Provides a view into an existing quantized tensor with a potentially
468/// different shape, without copying the underlying data.
469#[derive(Debug)]
470pub struct QuantizedTensorView<'a> {
471 /// Reference to the original data
472 pub data: &'a [u8],
473 /// View shape (may differ from original)
474 pub shape: Vec<usize>,
475 /// Reference to quantization parameters
476 pub params: &'a QuantizationParams,
477 /// Reference to device information
478 pub device: &'a Device,
479}
480
481impl<'a> QuantizedTensorView<'a> {
482 /// Get the number of elements in the view
483 pub fn num_elements(&self) -> usize {
484 self.shape.iter().product()
485 }
486
487 /// Get the memory usage in bytes
488 pub fn memory_usage(&self) -> usize {
489 self.data.len()
490 }
491
492 /// Get the shape of the view
493 pub fn shape(&self) -> &[usize] {
494 &self.shape
495 }
496
497 /// Get the number of dimensions
498 pub fn ndim(&self) -> usize {
499 self.shape.len()
500 }
501
502 /// Convert view to owned tensor
503 pub fn to_owned(&self) -> QuantizedTensor {
504 QuantizedTensor {
505 data: self.data.to_vec(),
506 shape: self.shape.clone(),
507 params: self.params.clone(),
508 device: self.device.clone(),
509 }
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::quantization::QuantizationParams;
517
518 #[test]
519 fn test_tensor_creation() {
520 let shape = vec![2, 3, 4];
521 let params = QuantizationParams::int8_symmetric();
522 let device = Device::cpu().unwrap();
523 let tensor = QuantizedTensor::new(shape.clone(), params, device.clone());
524
525 assert_eq!(tensor.shape(), &shape);
526 assert_eq!(tensor.num_elements(), 24);
527 assert_eq!(tensor.memory_usage(), 24); // 1 byte per element for Int8
528 assert_eq!(tensor.device, device);
529 }
530
531 #[test]
532 fn test_int4_tensor_size() {
533 let shape = vec![8];
534 let params = QuantizationParams::int4_symmetric();
535 let tensor = QuantizedTensor::new(shape, params, Device::cpu().unwrap());
536
537 assert_eq!(tensor.num_elements(), 8);
538 assert_eq!(tensor.memory_usage(), 4); // 2 elements per byte for Int4
539 }
540
541 #[test]
542 fn test_binary_tensor_size() {
543 let shape = vec![16];
544 let mut params = QuantizationParams::default();
545 params.dtype = QuantizedDType::Binary;
546 let tensor = QuantizedTensor::new(shape, params, Device::cpu().unwrap());
547
548 assert_eq!(tensor.num_elements(), 16);
549 assert_eq!(tensor.memory_usage(), 2); // 8 elements per byte for Binary
550 }
551
552 #[test]
553 fn test_tensor_from_data() {
554 let data = vec![1, 2, 3, 4];
555 let shape = vec![4];
556 let params = QuantizationParams::int8_symmetric();
557 let tensor =
558 QuantizedTensor::from_data(data.clone(), shape, params, Device::cpu().unwrap())
559 .unwrap();
560
561 assert_eq!(tensor.data, data);
562 assert_eq!(tensor.num_elements(), 4);
563 }
564
565 #[test]
566 fn test_tensor_from_data_size_mismatch() {
567 let data = vec![1, 2, 3]; // 3 bytes
568 let shape = vec![4]; // Expects 4 bytes for Int8
569 let params = QuantizationParams::int8_symmetric();
570 let result = QuantizedTensor::from_data(data, shape, params, Device::cpu().unwrap());
571
572 assert!(result.is_err());
573 }
574
575 #[test]
576 fn test_tensor_reshape() {
577 let tensor = QuantizedTensor::new(
578 vec![2, 6],
579 QuantizationParams::default(),
580 Device::cpu().unwrap(),
581 );
582 let reshaped = tensor.reshape(vec![3, 4]).unwrap();
583
584 assert_eq!(reshaped.shape(), &[3, 4]);
585 assert_eq!(reshaped.num_elements(), 12);
586 assert_eq!(reshaped.data.len(), tensor.data.len());
587 }
588
589 #[test]
590 fn test_tensor_reshape_invalid() {
591 let tensor = QuantizedTensor::new(
592 vec![2, 6],
593 QuantizationParams::default(),
594 Device::cpu().unwrap(),
595 );
596 let result = tensor.reshape(vec![3, 5]); // 15 elements != 12
597
598 assert!(result.is_err());
599 }
600
601 #[test]
602 fn test_tensor_view() {
603 let tensor = QuantizedTensor::new(
604 vec![2, 6],
605 QuantizationParams::default(),
606 Device::cpu().unwrap(),
607 );
608 let view = tensor.view(vec![4, 3]).unwrap();
609
610 assert_eq!(view.shape(), &[4, 3]);
611 assert_eq!(view.num_elements(), 12);
612 assert_eq!(view.data.len(), tensor.data.len());
613 }
614
615 #[test]
616 fn test_storage_efficiency() {
617 let int8_tensor = QuantizedTensor::new(
618 vec![10],
619 QuantizationParams::int8_symmetric(),
620 Device::cpu().unwrap(),
621 );
622 assert_eq!(int8_tensor.storage_efficiency(), 0.25); // 1 byte vs 4 bytes per element
623
624 let int4_tensor = QuantizedTensor::new(
625 vec![10],
626 QuantizationParams::int4_symmetric(),
627 Device::cpu().unwrap(),
628 );
629 assert_eq!(int4_tensor.storage_efficiency(), 0.125); // 0.5 bytes vs 4 bytes per element
630 }
631
632 #[test]
633 fn test_compression_ratio() {
634 let int8_tensor = QuantizedTensor::new(
635 vec![10],
636 QuantizationParams::int8_symmetric(),
637 Device::cpu().unwrap(),
638 );
639 assert_eq!(int8_tensor.compression_ratio(), 4.0); // 4x compression
640
641 let int4_tensor = QuantizedTensor::new(
642 vec![10],
643 QuantizationParams::int4_symmetric(),
644 Device::cpu().unwrap(),
645 );
646 assert_eq!(int4_tensor.compression_ratio(), 8.0); // 8x compression
647 }
648
649 #[test]
650 fn test_data_slice() {
651 let tensor = QuantizedTensor::new(
652 vec![4],
653 QuantizationParams::int8_symmetric(),
654 Device::cpu().unwrap(),
655 );
656 let slice = tensor.data_slice(1, 2).unwrap();
657 assert_eq!(slice.len(), 2);
658
659 // Out of bounds should fail
660 assert!(tensor.data_slice(3, 3).is_err());
661 }
662
663 #[test]
664 fn test_tensor_validation() {
665 let tensor = QuantizedTensor::new(
666 vec![2, 3],
667 QuantizationParams::default(),
668 Device::cpu().unwrap(),
669 );
670 assert!(tensor.validate().is_ok());
671
672 // Test with inconsistent data
673 let mut bad_tensor = tensor.clone();
674 bad_tensor.data.truncate(1); // Make data too small
675 assert!(bad_tensor.validate().is_err());
676 }
677
678 #[test]
679 fn test_tensor_properties() {
680 let tensor = QuantizedTensor::new(
681 vec![2, 3, 4],
682 QuantizationParams::default(),
683 Device::cpu().unwrap(),
684 );
685
686 assert_eq!(tensor.ndim(), 3);
687 assert_eq!(tensor.size(0).unwrap(), 2);
688 assert_eq!(tensor.size(1).unwrap(), 3);
689 assert_eq!(tensor.size(2).unwrap(), 4);
690 assert!(tensor.size(3).is_err()); // Out of bounds
691
692 assert!(!tensor.is_empty());
693 }
694
695 #[test]
696 fn test_empty_tensor() {
697 let tensor = QuantizedTensor::new(
698 vec![0],
699 QuantizationParams::default(),
700 Device::cpu().unwrap(),
701 );
702 assert!(tensor.validate().is_err()); // Zero dimension should be invalid
703 }
704}