Skip to main content

vsf/types/
tensor.rs

1//! Tensor types for VSF: contiguous, strided, and bitpacked
2
3/// Layout order for tensor data
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum LayoutOrder {
6    RowMajor,    // C-style (default for Tensor)
7    ColumnMajor, // Fortran-style
8}
9
10/// Contiguous tensor with row-major layout (no stride stored)
11///
12/// Binary format: `[t][dim_count][type][shape...][data...]`
13/// - Always row-major, contiguous in memory
14/// - No stride information stored (implicitly computed)
15/// - 95% of use cases (normal arrays, images, ML tensors)
16/// - Dynamic dimensionality (1D, 2D, 3D, 4D, or more)
17///
18/// # Examples
19/// ```
20/// use vsf::Tensor;
21///
22/// // 2D image: 1920×1080 u16 pixels
23/// let img = Tensor::new(
24///     vec![1920, 1080],
25///     vec![0u16; 1920 * 1080]
26/// );
27///
28/// // 3D tensor: 100×200×3 RGB
29/// let rgb = Tensor::new(
30///     vec![100, 200, 3],
31///     vec![0u8; 100 * 200 * 3]
32/// );
33/// ```
34#[derive(Debug, Clone, PartialEq)]
35pub struct Tensor<T> {
36    pub shape: Vec<usize>,
37    pub data: Vec<T>,
38}
39
40/// 1D tensor (vector)
41///
42/// Binary format: `[t][n][count][type][data...]`
43///
44/// Where:
45/// - `t` = tensor/vector marker
46/// - `n` = count marker (indicates 1D vector vs multi-dimensional tensor)
47/// - `count` = number of elements (encoded number)
48/// - `type` = element type marker (e.g., `u3`, `i4`, `f5`, `j6`)
49/// - `data` = raw element bytes
50///
51/// Examples:
52/// - Vector<u8> with 10 elements: `[t][n][0x0A][u][3][bytes...]`
53/// - IPv4 address: `[t][n][0x04][u][3][4 bytes]`
54/// - IPv6 address: `[t][n][0x10][u][3][16 bytes]`
55#[derive(Debug, Clone, PartialEq)]
56pub struct Vector<T> {
57    pub data: Vec<T>,
58}
59
60/// Tensor with explicit stride for non-contiguous layouts
61///
62/// Binary format: `[q][dim_count][type][shape...][stride...][data...]`
63/// - Supports arbitrary memory layouts (column-major, slices, views)
64/// - Stores explicit stride information
65/// - Use for: slices, transposed views, column-major matrices
66///
67/// # Examples
68/// ```
69/// use vsf::StridedTensor;
70///
71/// // Column-major 1000×1000 matrix
72/// let mat = StridedTensor::new(
73///     vec![1000, 1000],
74///     vec![1, 1000],  // Column-major stride
75///     vec![0.0f64; 1_000_000]
76/// );
77///
78/// // 2D slice with custom stride
79/// let slice = StridedTensor::new(
80///     vec![100, 50],
81///     vec![200, 2],  // Every other element
82///     vec![0u8; 10_000]
83/// );
84/// ```
85#[derive(Debug, Clone, PartialEq)]
86pub struct StridedTensor<T> {
87    pub shape: Vec<usize>,
88    pub stride: Vec<usize>,
89    pub data: Vec<T>,
90}
91
92impl<T> Tensor<T> {
93    /// Create a new contiguous tensor with given shape and data
94    pub fn new(shape: Vec<usize>, data: Vec<T>) -> Self {
95        let expected_len: usize = shape.iter().product();
96        assert_eq!(
97            data.len(),
98            expected_len,
99            "Data length {} doesn't match shape {:?} (expected {})",
100            data.len(),
101            shape,
102            expected_len
103        );
104        Tensor { shape, data }
105    }
106
107    /// Get number of dimensions
108    pub fn ndim(&self) -> usize {
109        self.shape.len()
110    }
111
112    /// Calculate total number of elements
113    pub fn len(&self) -> usize {
114        self.shape.iter().product()
115    }
116
117    /// Check if tensor is empty
118    pub fn is_empty(&self) -> bool {
119        self.len() == 0
120    }
121}
122
123impl<T> StridedTensor<T> {
124    /// Create a new strided tensor with given shape, stride, and data
125    pub fn new(shape: Vec<usize>, stride: Vec<usize>, data: Vec<T>) -> Self {
126        assert_eq!(
127            shape.len(),
128            stride.len(),
129            "Shape and stride must have same number of dimensions"
130        );
131        StridedTensor {
132            shape,
133            stride,
134            data,
135        }
136    }
137
138    /// Get number of dimensions
139    pub fn ndim(&self) -> usize {
140        self.shape.len()
141    }
142
143    /// Calculate total number of elements
144    pub fn len(&self) -> usize {
145        self.shape.iter().product()
146    }
147
148    /// Check if tensor is empty
149    pub fn is_empty(&self) -> bool {
150        self.len() == 0
151    }
152
153    /// Check if tensor is contiguous (row-major)
154    pub fn is_contiguous(&self) -> bool {
155        let ndim = self.ndim();
156        let mut expected_stride = 1;
157        for i in (0..ndim).rev() {
158            if self.stride[i] != expected_stride {
159                return false;
160            }
161            expected_stride *= self.shape[i];
162        }
163        true
164    }
165}
166
167/// Bitpacked tensor for arbitrary bit depths (1-128 bits per sample)
168///
169/// Binary format: `[p][dim_count][bit_depth][shape...][packed_data...]`
170/// - bit_depth stored as u8: 0x01-0x80 (1-128 bits)
171/// - Samples packed MSB-first, big-endian across byte boundaries
172/// - Only low `bit_depth` bits of input values are packed; high bits ignored
173/// - Final byte zero-padded to align to 8-bit boundary
174/// - Row-major storage (like Tensor<T>)
175///
176/// # Use Cases
177/// - Camera RAW data (10-bit, 12-bit, 14-bit sensors)
178/// - Compressed representations
179/// - Scientific instruments with non-standard bit depths
180///
181/// # Examples
182/// ```
183/// use vsf::BitPackedTensor;
184///
185/// // Option 1: Generic "just works" (most common)
186/// let samples: Vec<u16> = vec![2048; 1920 * 1080];  // 12-bit values (0-4095)
187/// let tensor = BitPackedTensor::pack(12, vec![1920, 1080], &samples);
188/// let unpacked = tensor.unpack().into_u64();  // Auto-sized, then promoted
189///
190/// // Option 2: Explicit type control (when you need guarantees)
191/// let tensor = BitPackedTensor::pack_u16(12, vec![1920, 1080], &samples);
192/// let unpacked: Vec<u16> = tensor.unpack_u16();  // Explicit, no enum
193/// ```
194#[derive(Debug, Clone, PartialEq)]
195pub struct BitPackedTensor {
196    /// Bits per sample (1-128): 0x01-0x80
197    pub bit_depth: u8,
198    /// Tensor dimensions (row-major)
199    pub shape: Vec<usize>,
200    /// Packed bytes: (total_elements * bit_depth + 7) / 8 bytes
201    pub data: Vec<u8>,
202}
203
204/// Trait for types that can be packed into a BitPackedTensor
205///
206/// This exists solely to enable the generic pack() convenience method.
207/// For explicit type handling, use pack_u8(), pack_u16(), etc. directly.
208pub trait PackableUnsigned: Copy {
209    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor;
210}
211
212// Macro to generate explicit pack_u* methods on BitPackedTensor
213macro_rules! impl_bitpack {
214    ($fn_name:ident, $t:ty, $work_t:ty) => {
215        impl BitPackedTensor {
216            #[doc = concat!("Pack ", stringify!($t), " samples into bitpacked tensor\n\n")]
217            #[doc = "# Arguments\n"]
218            #[doc = "* `bit_depth` - Bits per sample (1-128 supported, 0 reserved for future 256-bit)\n"]
219            #[doc = "* `shape` - Tensor dimensions\n"]
220            #[doc = "* `samples` - Sample values (only low `bit_depth` bits are packed, high bits ignored)\n\n"]
221            #[doc = "# Panics\n"]
222            #[doc = "* If bit_depth exceeds the type's bit width\n"]
223            #[doc = "* If bit_depth > 128 (256-bit not yet supported)\n"]
224            #[doc = "* If samples.len() doesn't match shape product\n"]
225            pub fn $fn_name(bit_depth: u8, shape: Vec<usize>, samples: &[$t]) -> Self {
226                let total_elements: usize = shape.iter().product();
227                assert_eq!(
228                    samples.len(),
229                    total_elements,
230                    "Sample count {} doesn't match shape {:?} (expected {})",
231                    samples.len(),
232                    shape,
233                    total_elements
234                );
235
236                let bits_per_sample = if bit_depth == 0 {
237                    panic!("bit_depth=0 (256-bit) not yet supported - use 1-128");
238                } else {
239                    bit_depth as usize
240                };
241
242                // Reject >128 bit depths until hardware support
243                if bits_per_sample > 128 {
244                    panic!("bit_depth > 128 not yet supported (waiting for native u256 support)");
245                }
246
247                // Type-level check: can this type hold bit_depth bits?
248                if bits_per_sample > <$t>::BITS as usize {
249                    panic!(
250                        "Cannot pack {}-bit values into {}-bit type {}",
251                        bits_per_sample,
252                        <$t>::BITS,
253                        std::any::type_name::<$t>()
254                    );
255                }
256
257                // Calculate total bits and bytes needed
258                let total_bits = total_elements * bits_per_sample;
259                let byte_count = (total_bits + 7) / 8;
260                let mut data = vec![0u8; byte_count];
261
262                // Pack samples MSB-first, big-endian
263                // Only low bit_depth bits are read; high bits are ignored
264                let mut bit_offset = 0;
265                for &sample in samples {
266                    let value = sample as $work_t;
267                    for bit_idx in (0..bits_per_sample).rev() {
268                        let bit = if (value >> bit_idx) & 1 == 1 { 1u8 } else { 0u8 };
269                        let byte_idx = bit_offset / 8;
270                        let bit_pos = 7 - (bit_offset % 8);
271                        data[byte_idx] |= bit << bit_pos;
272                        bit_offset += 1;
273                    }
274                }
275
276                BitPackedTensor {
277                    bit_depth,
278                    shape,
279                    data,
280                }
281            }
282        }
283    };
284}
285
286// Generate pack_* methods for each unsigned type
287// Use u64 as work type for u8/u16/u32/u64 (native 64-bit ops)
288// Use u128 for u128 (unavoidably emulated until hardware u256 support)
289impl_bitpack!(pack_u8, u8, u64);
290impl_bitpack!(pack_u16, u16, u64);
291impl_bitpack!(pack_u32, u32, u64);
292impl_bitpack!(pack_u64, u64, u64);
293impl_bitpack!(pack_u128, u128, u128);
294impl_bitpack!(pack_usize, usize, u64);
295
296// Trait implementations that delegate to explicit pack_u* methods
297impl PackableUnsigned for u8 {
298    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
299        BitPackedTensor::pack_u8(bit_depth, shape, samples)
300    }
301}
302
303impl PackableUnsigned for u16 {
304    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
305        BitPackedTensor::pack_u16(bit_depth, shape, samples)
306    }
307}
308
309impl PackableUnsigned for u32 {
310    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
311        BitPackedTensor::pack_u32(bit_depth, shape, samples)
312    }
313}
314
315impl PackableUnsigned for u64 {
316    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
317        BitPackedTensor::pack_u64(bit_depth, shape, samples)
318    }
319}
320
321impl PackableUnsigned for u128 {
322    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
323        BitPackedTensor::pack_u128(bit_depth, shape, samples)
324    }
325}
326
327impl PackableUnsigned for usize {
328    fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
329        BitPackedTensor::pack_usize(bit_depth, shape, samples)
330    }
331}
332
333/// Unpacked samples from a BitPackedTensor
334///
335/// The enum variant matches the minimal type needed for the bit depth.
336/// For explicit type control, use unpack_u8(), unpack_u16(), etc. directly.
337#[derive(Debug, Clone, PartialEq)]
338pub enum UnpackedSamples {
339    U8(Vec<u8>),     // 1-8 bit depths
340    U16(Vec<u16>),   // 9-16 bit depths
341    U32(Vec<u32>),   // 17-32 bit depths
342    U64(Vec<u64>),   // 33-64 bit depths
343    U128(Vec<u128>), // 65-128 bit depths
344}
345
346impl UnpackedSamples {
347    /// Convert to u64, promoting smaller types (panics for >64 bit)
348    pub fn into_u64(self) -> Vec<u64> {
349        match self {
350            UnpackedSamples::U8(v) => v.into_iter().map(|x| x as u64).collect(),
351            UnpackedSamples::U16(v) => v.into_iter().map(|x| x as u64).collect(),
352            UnpackedSamples::U32(v) => v.into_iter().map(|x| x as u64).collect(),
353            UnpackedSamples::U64(v) => v,
354            UnpackedSamples::U128(_) => {
355                panic!("Cannot convert >64 bit samples to u64 (would truncate)")
356            }
357        }
358    }
359
360    /// Convert to u128, promoting all types
361    pub fn into_u128(self) -> Vec<u128> {
362        match self {
363            UnpackedSamples::U8(v) => v.into_iter().map(|x| x as u128).collect(),
364            UnpackedSamples::U16(v) => v.into_iter().map(|x| x as u128).collect(),
365            UnpackedSamples::U32(v) => v.into_iter().map(|x| x as u128).collect(),
366            UnpackedSamples::U64(v) => v.into_iter().map(|x| x as u128).collect(),
367            UnpackedSamples::U128(v) => v,
368        }
369    }
370
371    /// Get the number of samples
372    pub fn len(&self) -> usize {
373        match self {
374            UnpackedSamples::U8(v) => v.len(),
375            UnpackedSamples::U16(v) => v.len(),
376            UnpackedSamples::U32(v) => v.len(),
377            UnpackedSamples::U64(v) => v.len(),
378            UnpackedSamples::U128(v) => v.len(),
379        }
380    }
381
382    /// Check if empty
383    pub fn is_empty(&self) -> bool {
384        self.len() == 0
385    }
386}
387
388impl BitPackedTensor {
389    /// Pack samples with automatic type dispatch (convenience wrapper)
390    ///
391    /// Calls the appropriate pack_u* method based on the input type.
392    /// Use explicit pack_u8(), pack_u16() etc. methods if you need compile-time guarantees.
393    ///
394    /// # Examples
395    /// ```ignore
396    /// // Generic - works with any unsigned type
397    /// let samples: Vec<u16> = vec![2048; 1920 * 1080];
398    /// let tensor = BitPackedTensor::pack(12, vec![1920, 1080], &samples);
399    /// ```
400    pub fn pack<T: PackableUnsigned>(bit_depth: u8, shape: Vec<usize>, samples: &[T]) -> Self {
401        T::pack_samples(bit_depth, shape, samples)
402    }
403
404    /// Unpack to the minimal type that fits bit_depth
405    ///
406    /// Returns:
407    /// - Vec<u8> for 1-8 bit depths
408    /// - Vec<u16> for 9-16 bit depths
409    /// - Vec<u32> for 17-32 bit depths
410    /// - Vec<u64> for 33-64 bit depths
411    /// - Vec<u128> for 65-128 bit depths
412    ///
413    /// Use explicit unpack_u8(), unpack_u16() etc. if you need a specific type.
414    ///
415    /// # Examples
416    /// ```ignore
417    /// let tensor = BitPackedTensor::pack(12, vec![100, 100], &samples);
418    /// // Auto-sized, then promoted to u64
419    /// let unpacked = tensor.unpack().into_u64();
420    /// ```
421    pub fn unpack(&self) -> UnpackedSamples {
422        let bits = self.bit_depth as usize;
423        match bits {
424            1..=8 => UnpackedSamples::U8(self.unpack_to_u8()),
425            9..=16 => UnpackedSamples::U16(self.unpack_to_u16()),
426            17..=32 => UnpackedSamples::U32(self.unpack_to_u32()),
427            33..=64 => UnpackedSamples::U64(self.unpack_to_u64()),
428            65..=128 => UnpackedSamples::U128(self.unpack_to_u128()),
429            _ => panic!("bit_depth {} not supported (max 128)", self.bit_depth),
430        }
431    }
432
433    /// Unpack to u8 samples
434    ///
435    /// # Panics
436    /// Panics if bit_depth > 8 (data wouldn't fit)
437    pub fn unpack_u8(&self) -> Vec<u8> {
438        if self.bit_depth > 8 {
439            panic!(
440                "Cannot unpack {}-bit data into u8 (would truncate)",
441                self.bit_depth
442            );
443        }
444        self.unpack_to_u8()
445    }
446
447    /// Unpack to u16 samples
448    ///
449    /// # Panics
450    /// Panics if bit_depth > 16 (data wouldn't fit)
451    pub fn unpack_u16(&self) -> Vec<u16> {
452        if self.bit_depth > 16 {
453            panic!(
454                "Cannot unpack {}-bit data into u16 (would truncate)",
455                self.bit_depth
456            );
457        }
458        self.unpack_to_u16()
459    }
460
461    /// Unpack to u32 samples
462    ///
463    /// # Panics
464    /// Panics if bit_depth > 32 (data wouldn't fit)
465    pub fn unpack_u32(&self) -> Vec<u32> {
466        if self.bit_depth > 32 {
467            panic!(
468                "Cannot unpack {}-bit data into u32 (would truncate)",
469                self.bit_depth
470            );
471        }
472        self.unpack_to_u32()
473    }
474
475    /// Unpack to u64 samples
476    ///
477    /// # Panics
478    /// Panics if bit_depth > 64 (data wouldn't fit)
479    pub fn unpack_u64(&self) -> Vec<u64> {
480        if self.bit_depth > 64 {
481            panic!(
482                "Cannot unpack {}-bit data into u64 (would truncate)",
483                self.bit_depth
484            );
485        }
486        self.unpack_to_u64()
487    }
488
489    /// Unpack to u128 samples (works for all current bit depths 1-128)
490    pub fn unpack_u128(&self) -> Vec<u128> {
491        self.unpack_to_u128()
492    }
493
494    // Private unpack helpers for each type
495    fn unpack_to_u8(&self) -> Vec<u8> {
496        let total_elements: usize = self.shape.iter().product();
497        let bits_per_sample = self.bit_depth as usize;
498        let mut samples = Vec::with_capacity(total_elements);
499
500        let mut bit_offset = 0;
501        for _ in 0..total_elements {
502            let mut sample = 0u8;
503            for _ in 0..bits_per_sample {
504                let byte_idx = bit_offset / 8;
505                let bit_pos = 7 - (bit_offset % 8);
506                let bit = (self.data[byte_idx] >> bit_pos) & 1;
507                sample = (sample << 1) | bit;
508                bit_offset += 1;
509            }
510            samples.push(sample);
511        }
512        samples
513    }
514
515    fn unpack_to_u16(&self) -> Vec<u16> {
516        let total_elements: usize = self.shape.iter().product();
517        let bits_per_sample = self.bit_depth as usize;
518        let mut samples = Vec::with_capacity(total_elements);
519
520        let mut bit_offset = 0;
521        for _ in 0..total_elements {
522            let mut sample = 0u16;
523            for _ in 0..bits_per_sample {
524                let byte_idx = bit_offset / 8;
525                let bit_pos = 7 - (bit_offset % 8);
526                let bit = (self.data[byte_idx] >> bit_pos) & 1;
527                sample = (sample << 1) | (bit as u16);
528                bit_offset += 1;
529            }
530            samples.push(sample);
531        }
532        samples
533    }
534
535    fn unpack_to_u32(&self) -> Vec<u32> {
536        let total_elements: usize = self.shape.iter().product();
537        let bits_per_sample = self.bit_depth as usize;
538        let mut samples = Vec::with_capacity(total_elements);
539
540        let mut bit_offset = 0;
541        for _ in 0..total_elements {
542            let mut sample = 0u32;
543            for _ in 0..bits_per_sample {
544                let byte_idx = bit_offset / 8;
545                let bit_pos = 7 - (bit_offset % 8);
546                let bit = (self.data[byte_idx] >> bit_pos) & 1;
547                sample = (sample << 1) | (bit as u32);
548                bit_offset += 1;
549            }
550            samples.push(sample);
551        }
552        samples
553    }
554
555    fn unpack_to_u64(&self) -> Vec<u64> {
556        let total_elements: usize = self.shape.iter().product();
557        let bits_per_sample = self.bit_depth as usize;
558        let mut samples = Vec::with_capacity(total_elements);
559
560        let mut bit_offset = 0;
561        for _ in 0..total_elements {
562            let mut sample = 0u64;
563            for _ in 0..bits_per_sample {
564                let byte_idx = bit_offset / 8;
565                let bit_pos = 7 - (bit_offset % 8);
566                let bit = (self.data[byte_idx] >> bit_pos) & 1;
567                sample = (sample << 1) | (bit as u64);
568                bit_offset += 1;
569            }
570            samples.push(sample);
571        }
572        samples
573    }
574
575    fn unpack_to_u128(&self) -> Vec<u128> {
576        let total_elements: usize = self.shape.iter().product();
577        let bits_per_sample = self.bit_depth as usize;
578        let mut samples = Vec::with_capacity(total_elements);
579
580        let mut bit_offset = 0;
581        for _ in 0..total_elements {
582            let mut sample = 0u128;
583            for _ in 0..bits_per_sample {
584                let byte_idx = bit_offset / 8;
585                let bit_pos = 7 - (bit_offset % 8);
586                let bit = (self.data[byte_idx] >> bit_pos) & 1;
587                sample = (sample << 1) | (bit as u128);
588                bit_offset += 1;
589            }
590            samples.push(sample);
591        }
592        samples
593    }
594
595    /// Get number of dimensions
596    pub fn ndim(&self) -> usize {
597        self.shape.len()
598    }
599
600    /// Calculate total number of elements
601    pub fn len(&self) -> usize {
602        self.shape.iter().product()
603    }
604
605    /// Check if tensor is empty
606    pub fn is_empty(&self) -> bool {
607        self.len() == 0
608    }
609}