1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum LayoutOrder {
6 RowMajor, ColumnMajor, }
9
10#[derive(Debug, Clone, PartialEq)]
35pub struct Tensor<T> {
36 pub shape: Vec<usize>,
37 pub data: Vec<T>,
38}
39
40#[derive(Debug, Clone, PartialEq)]
56pub struct Vector<T> {
57 pub data: Vec<T>,
58}
59
60#[derive(Debug, Clone, PartialEq)]
86pub struct StridedTensor<T> {
87 pub shape: Vec<usize>,
88 pub stride: Vec<usize>,
89 pub data: Vec<T>,
90}
91
92impl<T> Tensor<T> {
93 pub fn new(shape: Vec<usize>, data: Vec<T>) -> Self {
95 let expected_len: usize = shape.iter().product();
96 assert_eq!(
97 data.len(),
98 expected_len,
99 "Data length {} doesn't match shape {:?} (expected {})",
100 data.len(),
101 shape,
102 expected_len
103 );
104 Tensor { shape, data }
105 }
106
107 pub fn ndim(&self) -> usize {
109 self.shape.len()
110 }
111
112 pub fn len(&self) -> usize {
114 self.shape.iter().product()
115 }
116
117 pub fn is_empty(&self) -> bool {
119 self.len() == 0
120 }
121}
122
123impl<T> StridedTensor<T> {
124 pub fn new(shape: Vec<usize>, stride: Vec<usize>, data: Vec<T>) -> Self {
126 assert_eq!(
127 shape.len(),
128 stride.len(),
129 "Shape and stride must have same number of dimensions"
130 );
131 StridedTensor {
132 shape,
133 stride,
134 data,
135 }
136 }
137
138 pub fn ndim(&self) -> usize {
140 self.shape.len()
141 }
142
143 pub fn len(&self) -> usize {
145 self.shape.iter().product()
146 }
147
148 pub fn is_empty(&self) -> bool {
150 self.len() == 0
151 }
152
153 pub fn is_contiguous(&self) -> bool {
155 let ndim = self.ndim();
156 let mut expected_stride = 1;
157 for i in (0..ndim).rev() {
158 if self.stride[i] != expected_stride {
159 return false;
160 }
161 expected_stride *= self.shape[i];
162 }
163 true
164 }
165}
166
167#[derive(Debug, Clone, PartialEq)]
195pub struct BitPackedTensor {
196 pub bit_depth: u8,
198 pub shape: Vec<usize>,
200 pub data: Vec<u8>,
202}
203
204pub trait PackableUnsigned: Copy {
209 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor;
210}
211
212macro_rules! impl_bitpack {
214 ($fn_name:ident, $t:ty, $work_t:ty) => {
215 impl BitPackedTensor {
216 #[doc = concat!("Pack ", stringify!($t), " samples into bitpacked tensor\n\n")]
217 #[doc = "# Arguments\n"]
218 #[doc = "* `bit_depth` - Bits per sample (1-128 supported, 0 reserved for future 256-bit)\n"]
219 #[doc = "* `shape` - Tensor dimensions\n"]
220 #[doc = "* `samples` - Sample values (only low `bit_depth` bits are packed, high bits ignored)\n\n"]
221 #[doc = "# Panics\n"]
222 #[doc = "* If bit_depth exceeds the type's bit width\n"]
223 #[doc = "* If bit_depth > 128 (256-bit not yet supported)\n"]
224 #[doc = "* If samples.len() doesn't match shape product\n"]
225 pub fn $fn_name(bit_depth: u8, shape: Vec<usize>, samples: &[$t]) -> Self {
226 let total_elements: usize = shape.iter().product();
227 assert_eq!(
228 samples.len(),
229 total_elements,
230 "Sample count {} doesn't match shape {:?} (expected {})",
231 samples.len(),
232 shape,
233 total_elements
234 );
235
236 let bits_per_sample = if bit_depth == 0 {
237 panic!("bit_depth=0 (256-bit) not yet supported - use 1-128");
238 } else {
239 bit_depth as usize
240 };
241
242 if bits_per_sample > 128 {
244 panic!("bit_depth > 128 not yet supported (waiting for native u256 support)");
245 }
246
247 if bits_per_sample > <$t>::BITS as usize {
249 panic!(
250 "Cannot pack {}-bit values into {}-bit type {}",
251 bits_per_sample,
252 <$t>::BITS,
253 std::any::type_name::<$t>()
254 );
255 }
256
257 let total_bits = total_elements * bits_per_sample;
259 let byte_count = (total_bits + 7) / 8;
260 let mut data = vec![0u8; byte_count];
261
262 let mut bit_offset = 0;
265 for &sample in samples {
266 let value = sample as $work_t;
267 for bit_idx in (0..bits_per_sample).rev() {
268 let bit = if (value >> bit_idx) & 1 == 1 { 1u8 } else { 0u8 };
269 let byte_idx = bit_offset / 8;
270 let bit_pos = 7 - (bit_offset % 8);
271 data[byte_idx] |= bit << bit_pos;
272 bit_offset += 1;
273 }
274 }
275
276 BitPackedTensor {
277 bit_depth,
278 shape,
279 data,
280 }
281 }
282 }
283 };
284}
285
286impl_bitpack!(pack_u8, u8, u64);
290impl_bitpack!(pack_u16, u16, u64);
291impl_bitpack!(pack_u32, u32, u64);
292impl_bitpack!(pack_u64, u64, u64);
293impl_bitpack!(pack_u128, u128, u128);
294impl_bitpack!(pack_usize, usize, u64);
295
296impl PackableUnsigned for u8 {
298 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
299 BitPackedTensor::pack_u8(bit_depth, shape, samples)
300 }
301}
302
303impl PackableUnsigned for u16 {
304 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
305 BitPackedTensor::pack_u16(bit_depth, shape, samples)
306 }
307}
308
309impl PackableUnsigned for u32 {
310 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
311 BitPackedTensor::pack_u32(bit_depth, shape, samples)
312 }
313}
314
315impl PackableUnsigned for u64 {
316 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
317 BitPackedTensor::pack_u64(bit_depth, shape, samples)
318 }
319}
320
321impl PackableUnsigned for u128 {
322 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
323 BitPackedTensor::pack_u128(bit_depth, shape, samples)
324 }
325}
326
327impl PackableUnsigned for usize {
328 fn pack_samples(bit_depth: u8, shape: Vec<usize>, samples: &[Self]) -> BitPackedTensor {
329 BitPackedTensor::pack_usize(bit_depth, shape, samples)
330 }
331}
332
333#[derive(Debug, Clone, PartialEq)]
338pub enum UnpackedSamples {
339 U8(Vec<u8>), U16(Vec<u16>), U32(Vec<u32>), U64(Vec<u64>), U128(Vec<u128>), }
345
346impl UnpackedSamples {
347 pub fn into_u64(self) -> Vec<u64> {
349 match self {
350 UnpackedSamples::U8(v) => v.into_iter().map(|x| x as u64).collect(),
351 UnpackedSamples::U16(v) => v.into_iter().map(|x| x as u64).collect(),
352 UnpackedSamples::U32(v) => v.into_iter().map(|x| x as u64).collect(),
353 UnpackedSamples::U64(v) => v,
354 UnpackedSamples::U128(_) => {
355 panic!("Cannot convert >64 bit samples to u64 (would truncate)")
356 }
357 }
358 }
359
360 pub fn into_u128(self) -> Vec<u128> {
362 match self {
363 UnpackedSamples::U8(v) => v.into_iter().map(|x| x as u128).collect(),
364 UnpackedSamples::U16(v) => v.into_iter().map(|x| x as u128).collect(),
365 UnpackedSamples::U32(v) => v.into_iter().map(|x| x as u128).collect(),
366 UnpackedSamples::U64(v) => v.into_iter().map(|x| x as u128).collect(),
367 UnpackedSamples::U128(v) => v,
368 }
369 }
370
371 pub fn len(&self) -> usize {
373 match self {
374 UnpackedSamples::U8(v) => v.len(),
375 UnpackedSamples::U16(v) => v.len(),
376 UnpackedSamples::U32(v) => v.len(),
377 UnpackedSamples::U64(v) => v.len(),
378 UnpackedSamples::U128(v) => v.len(),
379 }
380 }
381
382 pub fn is_empty(&self) -> bool {
384 self.len() == 0
385 }
386}
387
388impl BitPackedTensor {
389 pub fn pack<T: PackableUnsigned>(bit_depth: u8, shape: Vec<usize>, samples: &[T]) -> Self {
401 T::pack_samples(bit_depth, shape, samples)
402 }
403
404 pub fn unpack(&self) -> UnpackedSamples {
422 let bits = self.bit_depth as usize;
423 match bits {
424 1..=8 => UnpackedSamples::U8(self.unpack_to_u8()),
425 9..=16 => UnpackedSamples::U16(self.unpack_to_u16()),
426 17..=32 => UnpackedSamples::U32(self.unpack_to_u32()),
427 33..=64 => UnpackedSamples::U64(self.unpack_to_u64()),
428 65..=128 => UnpackedSamples::U128(self.unpack_to_u128()),
429 _ => panic!("bit_depth {} not supported (max 128)", self.bit_depth),
430 }
431 }
432
433 pub fn unpack_u8(&self) -> Vec<u8> {
438 if self.bit_depth > 8 {
439 panic!(
440 "Cannot unpack {}-bit data into u8 (would truncate)",
441 self.bit_depth
442 );
443 }
444 self.unpack_to_u8()
445 }
446
447 pub fn unpack_u16(&self) -> Vec<u16> {
452 if self.bit_depth > 16 {
453 panic!(
454 "Cannot unpack {}-bit data into u16 (would truncate)",
455 self.bit_depth
456 );
457 }
458 self.unpack_to_u16()
459 }
460
461 pub fn unpack_u32(&self) -> Vec<u32> {
466 if self.bit_depth > 32 {
467 panic!(
468 "Cannot unpack {}-bit data into u32 (would truncate)",
469 self.bit_depth
470 );
471 }
472 self.unpack_to_u32()
473 }
474
475 pub fn unpack_u64(&self) -> Vec<u64> {
480 if self.bit_depth > 64 {
481 panic!(
482 "Cannot unpack {}-bit data into u64 (would truncate)",
483 self.bit_depth
484 );
485 }
486 self.unpack_to_u64()
487 }
488
489 pub fn unpack_u128(&self) -> Vec<u128> {
491 self.unpack_to_u128()
492 }
493
494 fn unpack_to_u8(&self) -> Vec<u8> {
496 let total_elements: usize = self.shape.iter().product();
497 let bits_per_sample = self.bit_depth as usize;
498 let mut samples = Vec::with_capacity(total_elements);
499
500 let mut bit_offset = 0;
501 for _ in 0..total_elements {
502 let mut sample = 0u8;
503 for _ in 0..bits_per_sample {
504 let byte_idx = bit_offset / 8;
505 let bit_pos = 7 - (bit_offset % 8);
506 let bit = (self.data[byte_idx] >> bit_pos) & 1;
507 sample = (sample << 1) | bit;
508 bit_offset += 1;
509 }
510 samples.push(sample);
511 }
512 samples
513 }
514
515 fn unpack_to_u16(&self) -> Vec<u16> {
516 let total_elements: usize = self.shape.iter().product();
517 let bits_per_sample = self.bit_depth as usize;
518 let mut samples = Vec::with_capacity(total_elements);
519
520 let mut bit_offset = 0;
521 for _ in 0..total_elements {
522 let mut sample = 0u16;
523 for _ in 0..bits_per_sample {
524 let byte_idx = bit_offset / 8;
525 let bit_pos = 7 - (bit_offset % 8);
526 let bit = (self.data[byte_idx] >> bit_pos) & 1;
527 sample = (sample << 1) | (bit as u16);
528 bit_offset += 1;
529 }
530 samples.push(sample);
531 }
532 samples
533 }
534
535 fn unpack_to_u32(&self) -> Vec<u32> {
536 let total_elements: usize = self.shape.iter().product();
537 let bits_per_sample = self.bit_depth as usize;
538 let mut samples = Vec::with_capacity(total_elements);
539
540 let mut bit_offset = 0;
541 for _ in 0..total_elements {
542 let mut sample = 0u32;
543 for _ in 0..bits_per_sample {
544 let byte_idx = bit_offset / 8;
545 let bit_pos = 7 - (bit_offset % 8);
546 let bit = (self.data[byte_idx] >> bit_pos) & 1;
547 sample = (sample << 1) | (bit as u32);
548 bit_offset += 1;
549 }
550 samples.push(sample);
551 }
552 samples
553 }
554
555 fn unpack_to_u64(&self) -> Vec<u64> {
556 let total_elements: usize = self.shape.iter().product();
557 let bits_per_sample = self.bit_depth as usize;
558 let mut samples = Vec::with_capacity(total_elements);
559
560 let mut bit_offset = 0;
561 for _ in 0..total_elements {
562 let mut sample = 0u64;
563 for _ in 0..bits_per_sample {
564 let byte_idx = bit_offset / 8;
565 let bit_pos = 7 - (bit_offset % 8);
566 let bit = (self.data[byte_idx] >> bit_pos) & 1;
567 sample = (sample << 1) | (bit as u64);
568 bit_offset += 1;
569 }
570 samples.push(sample);
571 }
572 samples
573 }
574
575 fn unpack_to_u128(&self) -> Vec<u128> {
576 let total_elements: usize = self.shape.iter().product();
577 let bits_per_sample = self.bit_depth as usize;
578 let mut samples = Vec::with_capacity(total_elements);
579
580 let mut bit_offset = 0;
581 for _ in 0..total_elements {
582 let mut sample = 0u128;
583 for _ in 0..bits_per_sample {
584 let byte_idx = bit_offset / 8;
585 let bit_pos = 7 - (bit_offset % 8);
586 let bit = (self.data[byte_idx] >> bit_pos) & 1;
587 sample = (sample << 1) | (bit as u128);
588 bit_offset += 1;
589 }
590 samples.push(sample);
591 }
592 samples
593 }
594
595 pub fn ndim(&self) -> usize {
597 self.shape.len()
598 }
599
600 pub fn len(&self) -> usize {
602 self.shape.iter().product()
603 }
604
605 pub fn is_empty(&self) -> bool {
607 self.len() == 0
608 }
609}