1use crate::{DType, Result, Shape, TensorError};
25use std::collections::HashMap;
26use std::fmt;
27
28#[cfg(feature = "serde")]
29use serde::{Deserialize, Serialize};
30
31#[derive(Debug, Clone)]
33#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
34pub struct FieldDescriptor {
35 pub name: String,
37 pub dtype: DType,
39 pub size: Option<usize>,
41 pub offset: usize,
43}
44
45impl FieldDescriptor {
46 pub fn new(name: impl Into<String>, dtype: DType, size: Option<usize>) -> Self {
48 Self {
49 name: name.into(),
50 dtype,
51 size,
52 offset: 0, }
54 }
55
56 pub fn byte_size(&self) -> usize {
58 match self.dtype {
59 DType::Float32 => 4,
60 DType::Float64 => 8,
61 DType::Int32 => 4,
62 DType::Int64 => 8,
63 DType::UInt32 => 4,
64 DType::UInt64 => 8,
65 DType::Int16 => 2,
66 DType::UInt16 => 2,
67 DType::Int8 => 1,
68 DType::UInt8 => 1,
69 DType::Bool => 1,
70 DType::String => self.size.unwrap_or(64), _ => 8, }
73 }
74}
75
76#[derive(Debug, Clone)]
78#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
79pub enum FieldValue {
80 Float32(f32),
81 Float64(f64),
82 Int32(i32),
83 Int64(i64),
84 UInt32(u32),
85 UInt64(u64),
86 Int16(i16),
87 UInt16(u16),
88 Int8(i8),
89 UInt8(u8),
90 Bool(bool),
91 String(String),
92 Bytes(Vec<u8>),
93}
94
95impl FieldValue {
96 pub fn dtype(&self) -> DType {
98 match self {
99 FieldValue::Float32(_) => DType::Float32,
100 FieldValue::Float64(_) => DType::Float64,
101 FieldValue::Int32(_) => DType::Int32,
102 FieldValue::Int64(_) => DType::Int64,
103 FieldValue::UInt32(_) => DType::UInt32,
104 FieldValue::UInt64(_) => DType::UInt64,
105 FieldValue::Int16(_) => DType::Int16,
106 FieldValue::UInt16(_) => DType::UInt16,
107 FieldValue::Int8(_) => DType::Int8,
108 FieldValue::UInt8(_) => DType::UInt8,
109 FieldValue::Bool(_) => DType::Bool,
110 FieldValue::String(_) => DType::String,
111 FieldValue::Bytes(_) => DType::UInt8, }
113 }
114
115 pub fn to_bytes(&self, expected_size: usize) -> Vec<u8> {
117 match self {
118 FieldValue::Float32(v) => v.to_le_bytes().to_vec(),
119 FieldValue::Float64(v) => v.to_le_bytes().to_vec(),
120 FieldValue::Int32(v) => v.to_le_bytes().to_vec(),
121 FieldValue::Int64(v) => v.to_le_bytes().to_vec(),
122 FieldValue::UInt32(v) => v.to_le_bytes().to_vec(),
123 FieldValue::UInt64(v) => v.to_le_bytes().to_vec(),
124 FieldValue::Int16(v) => v.to_le_bytes().to_vec(),
125 FieldValue::UInt16(v) => v.to_le_bytes().to_vec(),
126 FieldValue::Int8(v) => vec![*v as u8],
127 FieldValue::UInt8(v) => vec![*v],
128 FieldValue::Bool(v) => vec![if *v { 1 } else { 0 }],
129 FieldValue::String(s) => {
130 let mut bytes = s.as_bytes().to_vec();
131 bytes.resize(expected_size, 0); bytes
133 }
134 FieldValue::Bytes(b) => {
135 let mut bytes = b.clone();
136 bytes.resize(expected_size, 0); bytes
138 }
139 }
140 }
141
142 pub fn from_bytes(bytes: &[u8], dtype: DType) -> Result<Self> {
144 match dtype {
145 DType::Float32 => {
146 if bytes.len() >= 4 {
147 Ok(FieldValue::Float32(f32::from_le_bytes([
148 bytes[0], bytes[1], bytes[2], bytes[3],
149 ])))
150 } else {
151 Err(TensorError::invalid_argument(
152 "Insufficient bytes for f32".to_string(),
153 ))
154 }
155 }
156 DType::Float64 => {
157 if bytes.len() >= 8 {
158 Ok(FieldValue::Float64(f64::from_le_bytes([
159 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
160 bytes[7],
161 ])))
162 } else {
163 Err(TensorError::invalid_argument(
164 "Insufficient bytes for f64".to_string(),
165 ))
166 }
167 }
168 DType::Int32 => {
169 if bytes.len() >= 4 {
170 Ok(FieldValue::Int32(i32::from_le_bytes([
171 bytes[0], bytes[1], bytes[2], bytes[3],
172 ])))
173 } else {
174 Err(TensorError::invalid_argument(
175 "Insufficient bytes for i32".to_string(),
176 ))
177 }
178 }
179 DType::Int64 => {
180 if bytes.len() >= 8 {
181 Ok(FieldValue::Int64(i64::from_le_bytes([
182 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
183 bytes[7],
184 ])))
185 } else {
186 Err(TensorError::invalid_argument(
187 "Insufficient bytes for i64".to_string(),
188 ))
189 }
190 }
191 DType::UInt32 => {
192 if bytes.len() >= 4 {
193 Ok(FieldValue::UInt32(u32::from_le_bytes([
194 bytes[0], bytes[1], bytes[2], bytes[3],
195 ])))
196 } else {
197 Err(TensorError::invalid_argument(
198 "Insufficient bytes for u32".to_string(),
199 ))
200 }
201 }
202 DType::UInt64 => {
203 if bytes.len() >= 8 {
204 Ok(FieldValue::UInt64(u64::from_le_bytes([
205 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
206 bytes[7],
207 ])))
208 } else {
209 Err(TensorError::invalid_argument(
210 "Insufficient bytes for u64".to_string(),
211 ))
212 }
213 }
214 DType::Int16 => {
215 if bytes.len() >= 2 {
216 Ok(FieldValue::Int16(i16::from_le_bytes([bytes[0], bytes[1]])))
217 } else {
218 Err(TensorError::invalid_argument(
219 "Insufficient bytes for i16".to_string(),
220 ))
221 }
222 }
223 DType::UInt16 => {
224 if bytes.len() >= 2 {
225 Ok(FieldValue::UInt16(u16::from_le_bytes([bytes[0], bytes[1]])))
226 } else {
227 Err(TensorError::invalid_argument(
228 "Insufficient bytes for u16".to_string(),
229 ))
230 }
231 }
232 DType::Int8 => {
233 if !bytes.is_empty() {
234 Ok(FieldValue::Int8(bytes[0] as i8))
235 } else {
236 Err(TensorError::invalid_argument(
237 "Insufficient bytes for i8".to_string(),
238 ))
239 }
240 }
241 DType::UInt8 => {
242 if !bytes.is_empty() {
243 Ok(FieldValue::UInt8(bytes[0]))
244 } else {
245 Err(TensorError::invalid_argument(
246 "Insufficient bytes for u8".to_string(),
247 ))
248 }
249 }
250 DType::Bool => {
251 if !bytes.is_empty() {
252 Ok(FieldValue::Bool(bytes[0] != 0))
253 } else {
254 Err(TensorError::invalid_argument(
255 "Insufficient bytes for bool".to_string(),
256 ))
257 }
258 }
259 DType::String => {
260 let null_pos = bytes.iter().position(|&b| b == 0).unwrap_or(bytes.len());
261 let string_bytes = &bytes[..null_pos];
262 let s = String::from_utf8_lossy(string_bytes).to_string();
263 Ok(FieldValue::String(s))
264 }
265 _ => Err(TensorError::not_implemented_simple(
266 "Unsupported dtype for structured arrays".to_string(),
267 )),
268 }
269 }
270}
271
272impl From<f32> for FieldValue {
273 fn from(v: f32) -> Self {
274 FieldValue::Float32(v)
275 }
276}
277
278impl From<f64> for FieldValue {
279 fn from(v: f64) -> Self {
280 FieldValue::Float64(v)
281 }
282}
283
284impl From<i32> for FieldValue {
285 fn from(v: i32) -> Self {
286 FieldValue::Int32(v)
287 }
288}
289
290impl From<i64> for FieldValue {
291 fn from(v: i64) -> Self {
292 FieldValue::Int64(v)
293 }
294}
295
296impl From<u32> for FieldValue {
297 fn from(v: u32) -> Self {
298 FieldValue::UInt32(v)
299 }
300}
301
302impl From<u64> for FieldValue {
303 fn from(v: u64) -> Self {
304 FieldValue::UInt64(v)
305 }
306}
307
308impl From<i16> for FieldValue {
309 fn from(v: i16) -> Self {
310 FieldValue::Int16(v)
311 }
312}
313
314impl From<u16> for FieldValue {
315 fn from(v: u16) -> Self {
316 FieldValue::UInt16(v)
317 }
318}
319
320impl From<i8> for FieldValue {
321 fn from(v: i8) -> Self {
322 FieldValue::Int8(v)
323 }
324}
325
326impl From<u8> for FieldValue {
327 fn from(v: u8) -> Self {
328 FieldValue::UInt8(v)
329 }
330}
331
332impl From<bool> for FieldValue {
333 fn from(v: bool) -> Self {
334 FieldValue::Bool(v)
335 }
336}
337
338impl From<String> for FieldValue {
339 fn from(v: String) -> Self {
340 FieldValue::String(v)
341 }
342}
343
344impl From<&str> for FieldValue {
345 fn from(v: &str) -> Self {
346 FieldValue::String(v.to_string())
347 }
348}
349
350impl From<Vec<u8>> for FieldValue {
351 fn from(v: Vec<u8>) -> Self {
352 FieldValue::Bytes(v)
353 }
354}
355
356#[derive(Debug, Clone)]
358#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
359pub struct StructuredArray {
360 fields: Vec<FieldDescriptor>,
362 field_map: HashMap<String, usize>,
364 record_size: usize,
366 data: Vec<u8>,
368 len: usize,
370 shape: Shape,
372}
373
374impl StructuredArray {
375 pub fn new(mut fields: Vec<FieldDescriptor>, len: usize) -> Self {
377 let mut offset = 0;
379 for field in &mut fields {
380 field.offset = offset;
381 offset += field.byte_size();
382 }
383 let record_size = offset;
384
385 let field_map: HashMap<String, usize> = fields
387 .iter()
388 .enumerate()
389 .map(|(i, field)| (field.name.clone(), i))
390 .collect();
391
392 let data = vec![0u8; record_size * len];
394
395 Self {
396 fields,
397 field_map,
398 record_size,
399 data,
400 len,
401 shape: Shape::from_slice(&[len]),
402 }
403 }
404
405 pub fn with_shape(fields: Vec<FieldDescriptor>, shape: &[usize]) -> Self {
407 let total_len = shape.iter().product::<usize>();
408 let mut array = Self::new(fields, total_len);
409 array.shape = Shape::from_slice(shape);
410 array
411 }
412
413 pub fn len(&self) -> usize {
415 self.len
416 }
417
418 pub fn is_empty(&self) -> bool {
420 self.len == 0
421 }
422
423 pub fn shape(&self) -> &Shape {
425 &self.shape
426 }
427
428 pub fn fields(&self) -> &[FieldDescriptor] {
430 &self.fields
431 }
432
433 pub fn field(&self, name: &str) -> Option<&FieldDescriptor> {
435 self.field_map.get(name).map(|&i| &self.fields[i])
436 }
437
438 pub fn field_names(&self) -> Vec<&str> {
440 self.fields.iter().map(|f| f.name.as_str()).collect()
441 }
442
443 pub fn set_field_value(
445 &mut self,
446 record_idx: usize,
447 field_name: &str,
448 value: FieldValue,
449 ) -> Result<()> {
450 if record_idx >= self.len {
451 return Err(TensorError::invalid_argument(format!(
452 "Record index {record_idx} out of bounds"
453 )));
454 }
455
456 let field_idx = self
457 .field_map
458 .get(field_name)
459 .ok_or_else(|| TensorError::invalid_argument(format!("Unknown field: {field_name}")))?;
460
461 let field = &self.fields[*field_idx];
462
463 if value.dtype() != field.dtype && field.dtype != DType::String {
465 return Err(TensorError::invalid_argument(format!(
466 "Type mismatch: expected {:?}, got {:?}",
467 field.dtype,
468 value.dtype()
469 )));
470 }
471
472 let value_bytes = value.to_bytes(field.byte_size());
474 let record_start = record_idx * self.record_size;
475 let field_start = record_start + field.offset;
476 let field_end = field_start + field.byte_size();
477
478 self.data[field_start..field_end].copy_from_slice(&value_bytes);
479 Ok(())
480 }
481
482 pub fn get_field_value(&self, record_idx: usize, field_name: &str) -> Result<FieldValue> {
484 if record_idx >= self.len {
485 return Err(TensorError::invalid_argument(format!(
486 "Record index {record_idx} out of bounds"
487 )));
488 }
489
490 let field_idx = self
491 .field_map
492 .get(field_name)
493 .ok_or_else(|| TensorError::invalid_argument(format!("Unknown field: {field_name}")))?;
494
495 let field = &self.fields[*field_idx];
496 let record_start = record_idx * self.record_size;
497 let field_start = record_start + field.offset;
498 let field_end = field_start + field.byte_size();
499
500 let field_bytes = &self.data[field_start..field_end];
501 FieldValue::from_bytes(field_bytes, field.dtype)
502 }
503
504 pub fn get_record(&self, record_idx: usize) -> Result<HashMap<String, FieldValue>> {
506 if record_idx >= self.len {
507 return Err(TensorError::invalid_argument(format!(
508 "Record index {record_idx} out of bounds"
509 )));
510 }
511
512 let mut record = HashMap::new();
513 for field in &self.fields {
514 let value = self.get_field_value(record_idx, &field.name)?;
515 record.insert(field.name.clone(), value);
516 }
517 Ok(record)
518 }
519
520 pub fn set_record(
522 &mut self,
523 record_idx: usize,
524 values: HashMap<String, FieldValue>,
525 ) -> Result<()> {
526 for (field_name, value) in values {
527 self.set_field_value(record_idx, &field_name, value)?;
528 }
529 Ok(())
530 }
531
532 pub fn get_column(&self, field_name: &str) -> Result<Vec<FieldValue>> {
534 let mut values = Vec::with_capacity(self.len);
535 for i in 0..self.len {
536 values.push(self.get_field_value(i, field_name)?);
537 }
538 Ok(values)
539 }
540
541 pub fn slice(&self, start: usize, end: usize) -> Result<StructuredArray> {
543 if start >= self.len || end > self.len || start >= end {
544 return Err(TensorError::invalid_argument(
545 "Invalid slice range".to_string(),
546 ));
547 }
548
549 let slice_len = end - start;
550 let mut sliced = StructuredArray::new(self.fields.clone(), slice_len);
551
552 let start_byte = start * self.record_size;
553 let end_byte = end * self.record_size;
554 sliced
555 .data
556 .copy_from_slice(&self.data[start_byte..end_byte]);
557
558 Ok(sliced)
559 }
560
561 pub fn resize(&mut self, new_len: usize) {
563 if new_len != self.len {
564 self.data.resize(new_len * self.record_size, 0);
565 self.len = new_len;
566 self.shape = Shape::from_slice(&[new_len]);
567 }
568 }
569
570 pub fn as_bytes(&self) -> &[u8] {
572 &self.data
573 }
574
575 pub fn from_bytes(fields: Vec<FieldDescriptor>, data: Vec<u8>, len: usize) -> Result<Self> {
577 let mut field_map = HashMap::new();
578 let mut offset = 0;
579
580 let mut corrected_fields = fields;
581 for (i, field) in corrected_fields.iter_mut().enumerate() {
582 field.offset = offset;
583 offset += field.byte_size();
584 field_map.insert(field.name.clone(), i);
585 }
586
587 let record_size = offset;
588
589 if data.len() != record_size * len {
590 return Err(TensorError::invalid_argument(
591 "Data size doesn't match expected record structure".to_string(),
592 ));
593 }
594
595 Ok(Self {
596 fields: corrected_fields,
597 field_map,
598 record_size,
599 data,
600 len,
601 shape: Shape::from_slice(&[len]),
602 })
603 }
604}
605
606impl fmt::Display for StructuredArray {
607 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
608 writeln!(
609 f,
610 "StructuredArray(len={}, fields=[{}])",
611 self.len,
612 self.fields
613 .iter()
614 .map(|f| format!("{}:{:?}", f.name, f.dtype))
615 .collect::<Vec<_>>()
616 .join(", ")
617 )?;
618
619 let show_count = std::cmp::min(5, self.len);
621 for i in 0..show_count {
622 if let Ok(record) = self.get_record(i) {
623 write!(f, " [{i}]: ")?;
624 let field_strs: Vec<String> = self
625 .fields
626 .iter()
627 .map(|field| {
628 if let Some(value) = record.get(&field.name) {
629 format!("{}={:?}", field.name, value)
630 } else {
631 format!("{}=<missing>", field.name)
632 }
633 })
634 .collect();
635 writeln!(f, "{{{}}}", field_strs.join(", "))?;
636 }
637 }
638
639 if self.len > show_count {
640 writeln!(f, " ... ({} more records)", self.len - show_count)?;
641 }
642
643 Ok(())
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
652 fn test_field_descriptor() {
653 let field = FieldDescriptor::new("test", DType::Float32, None);
654 assert_eq!(field.name, "test");
655 assert_eq!(field.dtype, DType::Float32);
656 assert_eq!(field.byte_size(), 4);
657 }
658
659 #[test]
660 fn test_field_value_conversions() {
661 let value = FieldValue::Float32(3.15);
662 assert_eq!(value.dtype(), DType::Float32);
663
664 let bytes = value.to_bytes(4);
665 assert_eq!(bytes.len(), 4);
666
667 let recovered = FieldValue::from_bytes(&bytes, DType::Float32)
668 .expect("test: from_bytes should succeed");
669 if let FieldValue::Float32(v) = recovered {
670 assert!((v - 3.15).abs() < 1e-6);
671 } else {
672 panic!("Wrong type recovered");
673 }
674 }
675
676 #[test]
677 fn test_structured_array_creation() {
678 let fields = vec![
679 FieldDescriptor::new("id", DType::Int32, None),
680 FieldDescriptor::new("score", DType::Float32, None),
681 FieldDescriptor::new("name", DType::String, Some(16)),
682 ];
683
684 let array = StructuredArray::new(fields, 10);
685 assert_eq!(array.len(), 10);
686 assert_eq!(array.fields().len(), 3);
687 assert!(array.field("id").is_some());
688 assert!(array.field("unknown").is_none());
689 }
690
691 #[test]
692 fn test_field_operations() {
693 let fields = vec![
694 FieldDescriptor::new("id", DType::Int32, None),
695 FieldDescriptor::new("score", DType::Float32, None),
696 FieldDescriptor::new("name", DType::String, Some(16)),
697 ];
698
699 let mut array = StructuredArray::new(fields, 2);
700
701 array
703 .set_field_value(0, "id", 42i32.into())
704 .expect("test: operation should succeed");
705 array
706 .set_field_value(0, "score", 95.5f32.into())
707 .expect("test: operation should succeed");
708 array
709 .set_field_value(0, "name", "Alice".into())
710 .expect("test: operation should succeed");
711
712 array
713 .set_field_value(1, "id", 43i32.into())
714 .expect("test: operation should succeed");
715 array
716 .set_field_value(1, "score", 87.2f32.into())
717 .expect("test: operation should succeed");
718 array
719 .set_field_value(1, "name", "Bob".into())
720 .expect("test: operation should succeed");
721
722 let id0 = array
724 .get_field_value(0, "id")
725 .expect("test: get_field_value should succeed");
726 if let FieldValue::Int32(v) = id0 {
727 assert_eq!(v, 42);
728 } else {
729 panic!("Wrong type");
730 }
731
732 let name1 = array
733 .get_field_value(1, "name")
734 .expect("test: get_field_value should succeed");
735 if let FieldValue::String(s) = name1 {
736 assert_eq!(s, "Bob");
737 } else {
738 panic!("Wrong type");
739 }
740 }
741
742 #[test]
743 fn test_record_operations() {
744 let fields = vec![
745 FieldDescriptor::new("x", DType::Float32, None),
746 FieldDescriptor::new("y", DType::Float32, None),
747 ];
748
749 let mut array = StructuredArray::new(fields, 1);
750
751 let mut record = HashMap::new();
752 record.insert("x".to_string(), 1.0f32.into());
753 record.insert("y".to_string(), 2.0f32.into());
754
755 array
756 .set_record(0, record)
757 .expect("test: set_record should succeed");
758
759 let retrieved = array
760 .get_record(0)
761 .expect("test: get_record should succeed");
762 assert_eq!(retrieved.len(), 2);
763
764 if let Some(FieldValue::Float32(x)) = retrieved.get("x") {
765 assert_eq!(*x, 1.0);
766 } else {
767 panic!("Wrong value for x");
768 }
769 }
770
771 #[test]
772 fn test_column_extraction() {
773 let fields = vec![FieldDescriptor::new("values", DType::Float32, None)];
774
775 let mut array = StructuredArray::new(fields, 3);
776
777 array
778 .set_field_value(0, "values", 1.0f32.into())
779 .expect("test: operation should succeed");
780 array
781 .set_field_value(1, "values", 2.0f32.into())
782 .expect("test: operation should succeed");
783 array
784 .set_field_value(2, "values", 3.0f32.into())
785 .expect("test: operation should succeed");
786
787 let column = array
788 .get_column("values")
789 .expect("test: get_column should succeed");
790 assert_eq!(column.len(), 3);
791
792 if let FieldValue::Float32(v) = &column[1] {
793 assert_eq!(*v, 2.0);
794 } else {
795 panic!("Wrong type");
796 }
797 }
798
799 #[test]
800 fn test_array_slice() {
801 let fields = vec![FieldDescriptor::new("id", DType::Int32, None)];
802
803 let mut array = StructuredArray::new(fields, 5);
804
805 for i in 0..5 {
806 array
807 .set_field_value(i, "id", (i as i32).into())
808 .expect("test: operation should succeed");
809 }
810
811 let slice = array.slice(1, 4).expect("test: slice should succeed");
812 assert_eq!(slice.len(), 3);
813
814 let id = slice
815 .get_field_value(0, "id")
816 .expect("test: get_field_value should succeed");
817 if let FieldValue::Int32(v) = id {
818 assert_eq!(v, 1); } else {
820 panic!("Wrong type");
821 }
822 }
823}