Skip to main content

shape_value/
datatable.rs

1//! Columnar DataTable backed by Arrow RecordBatch.
2//!
3//! DataTable is a high-performance columnar data structure wrapping Arrow's `RecordBatch`.
4//! It provides zero-copy slicing, typed column access, and efficient batch operations.
5
6use arrow_array::{
7    Array, ArrayRef, BooleanArray, Float64Array, Int64Array, RecordBatch, StringArray,
8    TimestampMicrosecondArray,
9};
10use arrow_schema::{DataType, Field, Schema};
11use std::sync::Arc;
12
13
14/// Raw pointers to Arrow column buffers for zero-cost field access.
15///
16/// These pointers are derived from the underlying Arrow arrays and remain
17/// valid as long as the parent `DataTable` (and its `RecordBatch`) is alive.
18#[derive(Debug, Clone)]
19pub struct ColumnPtrs {
20    /// Pointer to the values buffer (f64, i64, bool bytes, etc.)
21    pub values_ptr: *const u8,
22    /// Pointer to the offsets buffer (for variable-length types like Utf8)
23    pub offsets_ptr: *const u8,
24    /// Pointer to the validity bitmap (null tracking)
25    pub validity_ptr: *const u8,
26    /// Stride in bytes between consecutive values (0 for variable-length)
27    pub stride: usize,
28    /// Arrow data type for this column
29    pub data_type: DataType,
30}
31
32// SAFETY: ColumnPtrs are derived from Arc<RecordBatch> which is Send+Sync.
33// The pointers remain valid as long as the DataTable lives.
34unsafe impl Send for ColumnPtrs {}
35unsafe impl Sync for ColumnPtrs {}
36
37impl ColumnPtrs {
38    /// Build ColumnPtrs from an Arrow ArrayRef.
39    fn from_array(array: &ArrayRef) -> Self {
40        let data = array.to_data();
41        let data_type = data.data_type().clone();
42
43        // Get values buffer pointer and stride
44        let (values_ptr, stride) = match &data_type {
45            DataType::Float64 => {
46                let ptr = if !data.buffers().is_empty() {
47                    data.buffers()[0].as_ptr().wrapping_add(data.offset() * 8)
48                } else {
49                    std::ptr::null()
50                };
51                (ptr, 8)
52            }
53            DataType::Int64 | DataType::Timestamp(_, _) => {
54                let ptr = if !data.buffers().is_empty() {
55                    data.buffers()[0].as_ptr().wrapping_add(data.offset() * 8)
56                } else {
57                    std::ptr::null()
58                };
59                (ptr, 8)
60            }
61            DataType::Int32 | DataType::Float32 => {
62                let ptr = if !data.buffers().is_empty() {
63                    data.buffers()[0].as_ptr().wrapping_add(data.offset() * 4)
64                } else {
65                    std::ptr::null()
66                };
67                (ptr, 4)
68            }
69            DataType::Boolean => {
70                // Boolean uses bit-packed storage; stride=0 signals bit access
71                let ptr = if !data.buffers().is_empty() {
72                    data.buffers()[0].as_ptr()
73                } else {
74                    std::ptr::null()
75                };
76                (ptr, 0)
77            }
78            DataType::Utf8 => {
79                // Utf8 has offsets buffer[0] and values buffer[1]
80                let ptr = if data.buffers().len() > 1 {
81                    data.buffers()[1].as_ptr()
82                } else {
83                    std::ptr::null()
84                };
85                (ptr, 0) // Variable-length
86            }
87            _ => (std::ptr::null(), 0),
88        };
89
90        // Get offsets buffer for variable-length types
91        let offsets_ptr = match &data_type {
92            DataType::Utf8 => {
93                if !data.buffers().is_empty() {
94                    data.buffers()[0].as_ptr().wrapping_add(data.offset() * 4)
95                } else {
96                    std::ptr::null()
97                }
98            }
99            _ => std::ptr::null(),
100        };
101
102        // Get validity bitmap
103        let validity_ptr = data
104            .nulls()
105            .map(|nulls| nulls.buffer().as_ptr())
106            .unwrap_or(std::ptr::null());
107
108        ColumnPtrs {
109            values_ptr,
110            offsets_ptr,
111            validity_ptr,
112            stride,
113            data_type,
114        }
115    }
116}
117
118/// A columnar data table backed by Arrow RecordBatch.
119///
120/// DataTable wraps an Arrow `RecordBatch` and provides typed column access,
121/// zero-copy slicing, and interop with the Shape type system.
122#[derive(Debug, Clone)]
123pub struct DataTable {
124    batch: RecordBatch,
125    /// Optional type name for Shape type system integration
126    type_name: Option<String>,
127    /// Optional schema ID for typed tables (Table<T>)
128    schema_id: Option<u32>,
129    /// Pre-computed column pointers for zero-cost access
130    column_ptrs: Vec<ColumnPtrs>,
131    /// Index column name (set by index_by(), preserved across operations)
132    index_col: Option<String>,
133}
134
135impl DataTable {
136    /// Build column pointer table from a RecordBatch.
137    fn build_column_ptrs(batch: &RecordBatch) -> Vec<ColumnPtrs> {
138        (0..batch.num_columns())
139            .map(|i| ColumnPtrs::from_array(batch.column(i)))
140            .collect()
141    }
142
143    /// Create a new DataTable from an Arrow RecordBatch.
144    pub fn new(batch: RecordBatch) -> Self {
145        let column_ptrs = Self::build_column_ptrs(&batch);
146        Self {
147            batch,
148            type_name: None,
149            schema_id: None,
150            column_ptrs,
151            index_col: None,
152        }
153    }
154
155    /// Create a new DataTable with an associated type name.
156    pub fn with_type_name(batch: RecordBatch, type_name: String) -> Self {
157        let column_ptrs = Self::build_column_ptrs(&batch);
158        Self {
159            batch,
160            type_name: Some(type_name),
161            schema_id: None,
162            column_ptrs,
163            index_col: None,
164        }
165    }
166
167    /// Set the schema ID for typed table access.
168    pub fn with_schema_id(mut self, schema_id: u32) -> Self {
169        self.schema_id = Some(schema_id);
170        self
171    }
172
173    /// Set the index column name (from index_by()).
174    pub fn with_index_col(mut self, name: String) -> Self {
175        self.index_col = Some(name);
176        self
177    }
178
179    /// Get the schema ID if this is a typed table.
180    pub fn schema_id(&self) -> Option<u32> {
181        self.schema_id
182    }
183
184    /// Get the index column name if set.
185    pub fn index_col(&self) -> Option<&str> {
186        self.index_col.as_deref()
187    }
188
189    /// Get column pointers for a column by index.
190    pub fn column_ptr(&self, index: usize) -> Option<&ColumnPtrs> {
191        self.column_ptrs.get(index)
192    }
193
194    /// Get all column pointers.
195    pub fn column_ptrs(&self) -> &[ColumnPtrs] {
196        &self.column_ptrs
197    }
198
199    /// Number of rows in the table.
200    pub fn row_count(&self) -> usize {
201        self.batch.num_rows()
202    }
203
204    /// Number of columns in the table.
205    pub fn column_count(&self) -> usize {
206        self.batch.num_columns()
207    }
208
209    /// Column names in order.
210    pub fn column_names(&self) -> Vec<String> {
211        self.batch
212            .schema()
213            .fields()
214            .iter()
215            .map(|f| f.name().clone())
216            .collect()
217    }
218
219    /// The Arrow schema.
220    pub fn schema(&self) -> Arc<Schema> {
221        self.batch.schema()
222    }
223
224    /// The optional Shape type name.
225    pub fn type_name(&self) -> Option<&str> {
226        self.type_name.as_deref()
227    }
228
229    /// Get a column by name as a generic ArrayRef.
230    pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
231        let idx = self.batch.schema().index_of(name).ok()?;
232        Some(self.batch.column(idx))
233    }
234
235    /// Get a Float64 column by name.
236    pub fn get_f64_column(&self, name: &str) -> Option<&Float64Array> {
237        self.column_by_name(name)?
238            .as_any()
239            .downcast_ref::<Float64Array>()
240    }
241
242    /// Get an Int64 column by name.
243    pub fn get_i64_column(&self, name: &str) -> Option<&Int64Array> {
244        self.column_by_name(name)?
245            .as_any()
246            .downcast_ref::<Int64Array>()
247    }
248
249    /// Get a String (Utf8) column by name.
250    pub fn get_string_column(&self, name: &str) -> Option<&StringArray> {
251        self.column_by_name(name)?
252            .as_any()
253            .downcast_ref::<StringArray>()
254    }
255
256    /// Get a Boolean column by name.
257    pub fn get_bool_column(&self, name: &str) -> Option<&BooleanArray> {
258        self.column_by_name(name)?
259            .as_any()
260            .downcast_ref::<BooleanArray>()
261    }
262
263    /// Get a TimestampMicrosecond column by name.
264    pub fn get_timestamp_column(&self, name: &str) -> Option<&TimestampMicrosecondArray> {
265        self.column_by_name(name)?
266            .as_any()
267            .downcast_ref::<TimestampMicrosecondArray>()
268    }
269
270    /// Zero-copy slice of the DataTable.
271    pub fn slice(&self, offset: usize, length: usize) -> Self {
272        let sliced = self.batch.slice(offset, length);
273        let column_ptrs = Self::build_column_ptrs(&sliced);
274        Self {
275            batch: sliced,
276            type_name: self.type_name.clone(),
277            schema_id: self.schema_id,
278            column_ptrs,
279            index_col: self.index_col.clone(),
280        }
281    }
282
283    /// Borrow the inner RecordBatch.
284    pub fn inner(&self) -> &RecordBatch {
285        &self.batch
286    }
287
288    /// Consume and return the inner RecordBatch.
289    pub fn into_inner(self) -> RecordBatch {
290        self.batch
291    }
292
293    /// Check if the table is empty.
294    pub fn is_empty(&self) -> bool {
295        self.batch.num_rows() == 0
296    }
297}
298
299impl std::fmt::Display for DataTable {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        let name = self.type_name.as_deref().unwrap_or("DataTable");
302        write!(
303            f,
304            "{}({} rows x {} cols: [{}])",
305            name,
306            self.row_count(),
307            self.column_count(),
308            self.column_names().join(", "),
309        )
310    }
311}
312
313impl PartialEq for DataTable {
314    fn eq(&self, other: &Self) -> bool {
315        self.batch == other.batch
316    }
317}
318
319/// Builder for constructing a DataTable column-by-column.
320///
321/// Collects columns (as Arrow arrays) and a schema, then builds a RecordBatch.
322pub struct DataTableBuilder {
323    schema: Schema,
324    columns: Vec<ArrayRef>,
325}
326
327impl DataTableBuilder {
328    /// Create a builder from an Arrow schema.
329    pub fn new(schema: Schema) -> Self {
330        Self {
331            schema,
332            columns: Vec::new(),
333        }
334    }
335
336    /// Create a builder with just field definitions (convenience).
337    pub fn with_fields(fields: Vec<Field>) -> Self {
338        Self {
339            schema: Schema::new(fields),
340            columns: Vec::new(),
341        }
342    }
343
344    /// Add a Float64 column.
345    pub fn add_f64_column(&mut self, values: Vec<f64>) -> &mut Self {
346        self.columns
347            .push(Arc::new(Float64Array::from(values)) as ArrayRef);
348        self
349    }
350
351    /// Add an Int64 column.
352    pub fn add_i64_column(&mut self, values: Vec<i64>) -> &mut Self {
353        self.columns
354            .push(Arc::new(Int64Array::from(values)) as ArrayRef);
355        self
356    }
357
358    /// Add a String column.
359    pub fn add_string_column(&mut self, values: Vec<&str>) -> &mut Self {
360        self.columns
361            .push(Arc::new(StringArray::from(values)) as ArrayRef);
362        self
363    }
364
365    /// Add a Boolean column.
366    pub fn add_bool_column(&mut self, values: Vec<bool>) -> &mut Self {
367        self.columns
368            .push(Arc::new(BooleanArray::from(values)) as ArrayRef);
369        self
370    }
371
372    /// Add a TimestampMicrosecond column.
373    pub fn add_timestamp_column(&mut self, values: Vec<i64>) -> &mut Self {
374        self.columns
375            .push(Arc::new(TimestampMicrosecondArray::from(values)) as ArrayRef);
376        self
377    }
378
379    /// Add a pre-built Arrow array column.
380    pub fn add_column(&mut self, array: ArrayRef) -> &mut Self {
381        self.columns.push(array);
382        self
383    }
384
385    /// Build the DataTable. Returns an error if schema/column mismatch.
386    pub fn finish(self) -> Result<DataTable, arrow_schema::ArrowError> {
387        let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
388        Ok(DataTable::new(batch))
389    }
390
391    /// Build a DataTable with an associated type name.
392    pub fn finish_with_type_name(
393        self,
394        type_name: String,
395    ) -> Result<DataTable, arrow_schema::ArrowError> {
396        let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
397        Ok(DataTable::with_type_name(batch, type_name))
398    }
399
400    /// Build a DataTable with schema ID for typed tables.
401    pub fn finish_with_schema_id(
402        self,
403        schema_id: u32,
404    ) -> Result<DataTable, arrow_schema::ArrowError> {
405        let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
406        Ok(DataTable::new(batch).with_schema_id(schema_id))
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use arrow_schema::{DataType, TimeUnit};
414
415    fn sample_schema() -> Schema {
416        Schema::new(vec![
417            Field::new("price", DataType::Float64, false),
418            Field::new("volume", DataType::Int64, false),
419            Field::new("symbol", DataType::Utf8, false),
420        ])
421    }
422
423    fn sample_datatable() -> DataTable {
424        let mut builder = DataTableBuilder::new(sample_schema());
425        builder
426            .add_f64_column(vec![100.0, 101.5, 99.8])
427            .add_i64_column(vec![1000, 2000, 1500])
428            .add_string_column(vec!["AAPL", "AAPL", "AAPL"]);
429        builder.finish().unwrap()
430    }
431
432    #[test]
433    fn test_creation_and_basic_accessors() {
434        let dt = sample_datatable();
435        assert_eq!(dt.row_count(), 3);
436        assert_eq!(dt.column_count(), 3);
437        assert_eq!(dt.column_names(), vec!["price", "volume", "symbol"]);
438        assert!(!dt.is_empty());
439    }
440
441    #[test]
442    fn test_typed_column_access() {
443        let dt = sample_datatable();
444
445        let prices = dt.get_f64_column("price").unwrap();
446        assert_eq!(prices.value(0), 100.0);
447        assert_eq!(prices.value(2), 99.8);
448
449        let volumes = dt.get_i64_column("volume").unwrap();
450        assert_eq!(volumes.value(1), 2000);
451
452        let symbols = dt.get_string_column("symbol").unwrap();
453        assert_eq!(symbols.value(0), "AAPL");
454
455        // Wrong type returns None
456        assert!(dt.get_f64_column("symbol").is_none());
457        // Missing column returns None
458        assert!(dt.get_f64_column("nonexistent").is_none());
459    }
460
461    #[test]
462    fn test_bool_column() {
463        let schema = Schema::new(vec![Field::new("flag", DataType::Boolean, false)]);
464        let mut builder = DataTableBuilder::new(schema);
465        builder.add_bool_column(vec![true, false, true]);
466        let dt = builder.finish().unwrap();
467
468        let flags = dt.get_bool_column("flag").unwrap();
469        assert!(flags.value(0));
470        assert!(!flags.value(1));
471    }
472
473    #[test]
474    fn test_timestamp_column() {
475        let schema = Schema::new(vec![Field::new(
476            "ts",
477            DataType::Timestamp(TimeUnit::Microsecond, None),
478            false,
479        )]);
480        let mut builder = DataTableBuilder::new(schema);
481        builder.add_timestamp_column(vec![1_000_000, 2_000_000, 3_000_000]);
482        let dt = builder.finish().unwrap();
483
484        let ts = dt.get_timestamp_column("ts").unwrap();
485        assert_eq!(ts.value(0), 1_000_000);
486        assert_eq!(ts.value(2), 3_000_000);
487    }
488
489    #[test]
490    fn test_zero_copy_slice() {
491        let dt = sample_datatable();
492        let sliced = dt.slice(1, 2);
493
494        assert_eq!(sliced.row_count(), 2);
495        assert_eq!(sliced.column_count(), 3);
496
497        let prices = sliced.get_f64_column("price").unwrap();
498        assert_eq!(prices.value(0), 101.5);
499        assert_eq!(prices.value(1), 99.8);
500    }
501
502    #[test]
503    fn test_empty_datatable() {
504        let schema = Schema::new(vec![Field::new("x", DataType::Float64, false)]);
505        let mut builder = DataTableBuilder::new(schema);
506        builder.add_f64_column(vec![]);
507        let dt = builder.finish().unwrap();
508
509        assert!(dt.is_empty());
510        assert_eq!(dt.row_count(), 0);
511    }
512
513    #[test]
514    fn test_display() {
515        let dt = sample_datatable();
516        let s = format!("{}", dt);
517        assert!(s.contains("DataTable"));
518        assert!(s.contains("3 rows"));
519        assert!(s.contains("price"));
520    }
521
522    #[test]
523    fn test_type_name() {
524        let dt = sample_datatable();
525        assert!(dt.type_name().is_none());
526
527        let schema = sample_schema();
528        let mut builder = DataTableBuilder::new(schema);
529        builder
530            .add_f64_column(vec![1.0])
531            .add_i64_column(vec![10])
532            .add_string_column(vec!["X"]);
533        let dt = builder.finish_with_type_name("Candle".to_string()).unwrap();
534        assert_eq!(dt.type_name(), Some("Candle"));
535        let s = format!("{}", dt);
536        assert!(s.starts_with("Candle("));
537    }
538
539    #[test]
540    fn test_builder_schema_mismatch_errors() {
541        let schema = Schema::new(vec![
542            Field::new("a", DataType::Float64, false),
543            Field::new("b", DataType::Int64, false),
544        ]);
545        let mut builder = DataTableBuilder::new(schema);
546        // Only add one column instead of two
547        builder.add_f64_column(vec![1.0]);
548        assert!(builder.finish().is_err());
549    }
550
551    #[test]
552    fn test_inner_and_into_inner() {
553        let dt = sample_datatable();
554        let batch_ref = dt.inner();
555        assert_eq!(batch_ref.num_rows(), 3);
556
557        let dt2 = sample_datatable();
558        let batch = dt2.into_inner();
559        assert_eq!(batch.num_rows(), 3);
560    }
561
562    #[test]
563    fn test_partial_eq() {
564        let dt1 = sample_datatable();
565        let dt2 = sample_datatable();
566        assert_eq!(dt1, dt2);
567
568        let sliced = dt1.slice(0, 2);
569        assert_ne!(sliced, dt2);
570    }
571
572    #[test]
573    fn test_column_by_name() {
574        let dt = sample_datatable();
575        assert!(dt.column_by_name("price").is_some());
576        assert!(dt.column_by_name("missing").is_none());
577    }
578
579    #[test]
580    fn test_column_ptrs_constructed() {
581        let dt = sample_datatable();
582        // Should have 3 column pointer entries
583        assert_eq!(dt.column_ptrs().len(), 3);
584
585        // Price column (Float64) should have stride 8
586        let price_ptrs = dt.column_ptr(0).unwrap();
587        assert_eq!(price_ptrs.stride, 8);
588        assert!(matches!(price_ptrs.data_type, DataType::Float64));
589        assert!(!price_ptrs.values_ptr.is_null());
590
591        // Volume column (Int64) should have stride 8
592        let vol_ptrs = dt.column_ptr(1).unwrap();
593        assert_eq!(vol_ptrs.stride, 8);
594        assert!(matches!(vol_ptrs.data_type, DataType::Int64));
595
596        // Symbol column (Utf8) should have stride 0 (variable-length)
597        let sym_ptrs = dt.column_ptr(2).unwrap();
598        assert_eq!(sym_ptrs.stride, 0);
599        assert!(matches!(sym_ptrs.data_type, DataType::Utf8));
600        assert!(!sym_ptrs.offsets_ptr.is_null());
601    }
602
603    #[test]
604    fn test_column_ptrs_f64_read() {
605        let dt = sample_datatable();
606        let ptrs = dt.column_ptr(0).unwrap();
607
608        // Read f64 values through raw pointer
609        unsafe {
610            let f64_ptr = ptrs.values_ptr as *const f64;
611            assert_eq!(*f64_ptr, 100.0);
612            assert_eq!(*f64_ptr.add(1), 101.5);
613            assert_eq!(*f64_ptr.add(2), 99.8);
614        }
615    }
616
617    #[test]
618    fn test_column_ptrs_i64_read() {
619        let dt = sample_datatable();
620        let ptrs = dt.column_ptr(1).unwrap();
621
622        // Read i64 values through raw pointer
623        unsafe {
624            let i64_ptr = ptrs.values_ptr as *const i64;
625            assert_eq!(*i64_ptr, 1000);
626            assert_eq!(*i64_ptr.add(1), 2000);
627            assert_eq!(*i64_ptr.add(2), 1500);
628        }
629    }
630
631    #[test]
632    fn test_schema_id() {
633        let dt = sample_datatable();
634        assert!(dt.schema_id().is_none());
635
636        let dt_typed = sample_datatable().with_schema_id(42);
637        assert_eq!(dt_typed.schema_id(), Some(42));
638    }
639
640    #[test]
641    fn test_column_ptr_out_of_bounds() {
642        let dt = sample_datatable();
643        assert!(dt.column_ptr(99).is_none());
644    }
645}