1use arrow_array::{
7 Array, ArrayRef, BooleanArray, Float64Array, Int64Array, RecordBatch, StringArray,
8 TimestampMicrosecondArray,
9};
10use arrow_schema::{DataType, Field, Schema};
11use std::sync::Arc;
12
13use crate::ValueWord;
14
15#[derive(Debug, Clone)]
20pub struct ColumnPtrs {
21 pub values_ptr: *const u8,
23 pub offsets_ptr: *const u8,
25 pub validity_ptr: *const u8,
27 pub stride: usize,
29 pub data_type: DataType,
31}
32
33unsafe impl Send for ColumnPtrs {}
36unsafe impl Sync for ColumnPtrs {}
37
38impl ColumnPtrs {
39 fn from_array(array: &ArrayRef) -> Self {
41 let data = array.to_data();
42 let data_type = data.data_type().clone();
43
44 let (values_ptr, stride) = match &data_type {
46 DataType::Float64 => {
47 let ptr = if !data.buffers().is_empty() {
48 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 8)
49 } else {
50 std::ptr::null()
51 };
52 (ptr, 8)
53 }
54 DataType::Int64 | DataType::Timestamp(_, _) => {
55 let ptr = if !data.buffers().is_empty() {
56 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 8)
57 } else {
58 std::ptr::null()
59 };
60 (ptr, 8)
61 }
62 DataType::Int32 | DataType::Float32 => {
63 let ptr = if !data.buffers().is_empty() {
64 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 4)
65 } else {
66 std::ptr::null()
67 };
68 (ptr, 4)
69 }
70 DataType::Boolean => {
71 let ptr = if !data.buffers().is_empty() {
73 data.buffers()[0].as_ptr()
74 } else {
75 std::ptr::null()
76 };
77 (ptr, 0)
78 }
79 DataType::Utf8 => {
80 let ptr = if data.buffers().len() > 1 {
82 data.buffers()[1].as_ptr()
83 } else {
84 std::ptr::null()
85 };
86 (ptr, 0) }
88 _ => (std::ptr::null(), 0),
89 };
90
91 let offsets_ptr = match &data_type {
93 DataType::Utf8 => {
94 if !data.buffers().is_empty() {
95 data.buffers()[0].as_ptr().wrapping_add(data.offset() * 4)
96 } else {
97 std::ptr::null()
98 }
99 }
100 _ => std::ptr::null(),
101 };
102
103 let validity_ptr = data
105 .nulls()
106 .map(|nulls| nulls.buffer().as_ptr())
107 .unwrap_or(std::ptr::null());
108
109 ColumnPtrs {
110 values_ptr,
111 offsets_ptr,
112 validity_ptr,
113 stride,
114 data_type,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
124pub struct DataTable {
125 batch: RecordBatch,
126 type_name: Option<String>,
128 schema_id: Option<u32>,
130 column_ptrs: Vec<ColumnPtrs>,
132 index_col: Option<String>,
134 origin: Option<(ValueWord, ValueWord)>,
136}
137
138impl DataTable {
139 fn build_column_ptrs(batch: &RecordBatch) -> Vec<ColumnPtrs> {
141 (0..batch.num_columns())
142 .map(|i| ColumnPtrs::from_array(batch.column(i)))
143 .collect()
144 }
145
146 pub fn new(batch: RecordBatch) -> Self {
148 let column_ptrs = Self::build_column_ptrs(&batch);
149 Self {
150 batch,
151 type_name: None,
152 schema_id: None,
153 column_ptrs,
154 index_col: None,
155 origin: None,
156 }
157 }
158
159 pub fn with_type_name(batch: RecordBatch, type_name: String) -> Self {
161 let column_ptrs = Self::build_column_ptrs(&batch);
162 Self {
163 batch,
164 type_name: Some(type_name),
165 schema_id: None,
166 column_ptrs,
167 index_col: None,
168 origin: None,
169 }
170 }
171
172 pub fn with_schema_id(mut self, schema_id: u32) -> Self {
174 self.schema_id = Some(schema_id);
175 self
176 }
177
178 pub fn with_index_col(mut self, name: String) -> Self {
180 self.index_col = Some(name);
181 self
182 }
183
184 pub fn set_origin(&mut self, source: ValueWord, params: ValueWord) {
186 self.origin = Some((source, params));
187 }
188
189 pub fn origin(&self) -> ValueWord {
192 use crate::heap_value::HeapValue;
193 use crate::slot::ValueSlot;
194 use std::sync::atomic::{AtomicU64, Ordering};
195 static ORIGIN_SCHEMA_ID: AtomicU64 = AtomicU64::new(0);
196
197 match &self.origin {
198 Some((source, params)) => {
199 let schema_id = ORIGIN_SCHEMA_ID.load(Ordering::Relaxed);
201 let schema_id = if schema_id == 0 {
202 let id = 0xFFFF_FF00_u64;
204 ORIGIN_SCHEMA_ID.store(id, Ordering::Relaxed);
205 id
206 } else {
207 schema_id
208 };
209 let nb_to_slot = |nb: &ValueWord| -> (ValueSlot, bool) {
212 use crate::value_word::NanTag;
213 match nb.tag() {
214 NanTag::Heap => {
215 let hv = nb.as_heap_ref().cloned().unwrap_or_else(|| {
216 HeapValue::String(std::sync::Arc::new(String::new()))
217 });
218 (ValueSlot::from_heap(hv), true)
219 }
220 NanTag::F64 => (ValueSlot::from_number(nb.as_f64().unwrap_or(0.0)), false),
221 NanTag::I48 => (ValueSlot::from_int(nb.as_i64().unwrap_or(0)), false),
222 NanTag::Bool => {
223 (ValueSlot::from_bool(nb.as_bool().unwrap_or(false)), false)
224 }
225 NanTag::None | NanTag::Unit | NanTag::Ref => (ValueSlot::none(), false),
226 NanTag::Function | NanTag::ModuleFunction => {
227 (ValueSlot::from_raw(nb.raw_bits()), false)
228 }
229 }
230 };
231 let (slot0, heap0) = nb_to_slot(source);
232 let (slot1, heap1) = nb_to_slot(params);
233 let heap_mask = (heap0 as u64) | ((heap1 as u64) << 1);
234 let slots = Box::new([slot0, slot1]);
235 ValueWord::from_heap_value(HeapValue::TypedObject {
236 schema_id,
237 slots,
238 heap_mask,
239 })
240 }
241 None => ValueWord::none(),
242 }
243 }
244
245 pub fn schema_id(&self) -> Option<u32> {
247 self.schema_id
248 }
249
250 pub fn index_col(&self) -> Option<&str> {
252 self.index_col.as_deref()
253 }
254
255 pub fn column_ptr(&self, index: usize) -> Option<&ColumnPtrs> {
257 self.column_ptrs.get(index)
258 }
259
260 pub fn column_ptrs(&self) -> &[ColumnPtrs] {
262 &self.column_ptrs
263 }
264
265 pub fn row_count(&self) -> usize {
267 self.batch.num_rows()
268 }
269
270 pub fn column_count(&self) -> usize {
272 self.batch.num_columns()
273 }
274
275 pub fn column_names(&self) -> Vec<String> {
277 self.batch
278 .schema()
279 .fields()
280 .iter()
281 .map(|f| f.name().clone())
282 .collect()
283 }
284
285 pub fn schema(&self) -> Arc<Schema> {
287 self.batch.schema()
288 }
289
290 pub fn type_name(&self) -> Option<&str> {
292 self.type_name.as_deref()
293 }
294
295 pub fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
297 let idx = self.batch.schema().index_of(name).ok()?;
298 Some(self.batch.column(idx))
299 }
300
301 pub fn get_f64_column(&self, name: &str) -> Option<&Float64Array> {
303 self.column_by_name(name)?
304 .as_any()
305 .downcast_ref::<Float64Array>()
306 }
307
308 pub fn get_i64_column(&self, name: &str) -> Option<&Int64Array> {
310 self.column_by_name(name)?
311 .as_any()
312 .downcast_ref::<Int64Array>()
313 }
314
315 pub fn get_string_column(&self, name: &str) -> Option<&StringArray> {
317 self.column_by_name(name)?
318 .as_any()
319 .downcast_ref::<StringArray>()
320 }
321
322 pub fn get_bool_column(&self, name: &str) -> Option<&BooleanArray> {
324 self.column_by_name(name)?
325 .as_any()
326 .downcast_ref::<BooleanArray>()
327 }
328
329 pub fn get_timestamp_column(&self, name: &str) -> Option<&TimestampMicrosecondArray> {
331 self.column_by_name(name)?
332 .as_any()
333 .downcast_ref::<TimestampMicrosecondArray>()
334 }
335
336 pub fn slice(&self, offset: usize, length: usize) -> Self {
338 let sliced = self.batch.slice(offset, length);
339 let column_ptrs = Self::build_column_ptrs(&sliced);
340 Self {
341 batch: sliced,
342 type_name: self.type_name.clone(),
343 schema_id: self.schema_id,
344 column_ptrs,
345 index_col: self.index_col.clone(),
346 origin: self.origin.clone(),
347 }
348 }
349
350 pub fn inner(&self) -> &RecordBatch {
352 &self.batch
353 }
354
355 pub fn into_inner(self) -> RecordBatch {
357 self.batch
358 }
359
360 pub fn is_empty(&self) -> bool {
362 self.batch.num_rows() == 0
363 }
364}
365
366impl std::fmt::Display for DataTable {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 let name = self.type_name.as_deref().unwrap_or("DataTable");
369 write!(
370 f,
371 "{}({} rows x {} cols: [{}])",
372 name,
373 self.row_count(),
374 self.column_count(),
375 self.column_names().join(", "),
376 )
377 }
378}
379
380impl PartialEq for DataTable {
381 fn eq(&self, other: &Self) -> bool {
382 self.batch == other.batch
383 }
384}
385
386pub struct DataTableBuilder {
390 schema: Schema,
391 columns: Vec<ArrayRef>,
392}
393
394impl DataTableBuilder {
395 pub fn new(schema: Schema) -> Self {
397 Self {
398 schema,
399 columns: Vec::new(),
400 }
401 }
402
403 pub fn with_fields(fields: Vec<Field>) -> Self {
405 Self {
406 schema: Schema::new(fields),
407 columns: Vec::new(),
408 }
409 }
410
411 pub fn add_f64_column(&mut self, values: Vec<f64>) -> &mut Self {
413 self.columns
414 .push(Arc::new(Float64Array::from(values)) as ArrayRef);
415 self
416 }
417
418 pub fn add_i64_column(&mut self, values: Vec<i64>) -> &mut Self {
420 self.columns
421 .push(Arc::new(Int64Array::from(values)) as ArrayRef);
422 self
423 }
424
425 pub fn add_string_column(&mut self, values: Vec<&str>) -> &mut Self {
427 self.columns
428 .push(Arc::new(StringArray::from(values)) as ArrayRef);
429 self
430 }
431
432 pub fn add_bool_column(&mut self, values: Vec<bool>) -> &mut Self {
434 self.columns
435 .push(Arc::new(BooleanArray::from(values)) as ArrayRef);
436 self
437 }
438
439 pub fn add_timestamp_column(&mut self, values: Vec<i64>) -> &mut Self {
441 self.columns
442 .push(Arc::new(TimestampMicrosecondArray::from(values)) as ArrayRef);
443 self
444 }
445
446 pub fn add_column(&mut self, array: ArrayRef) -> &mut Self {
448 self.columns.push(array);
449 self
450 }
451
452 pub fn finish(self) -> Result<DataTable, arrow_schema::ArrowError> {
454 let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
455 Ok(DataTable::new(batch))
456 }
457
458 pub fn finish_with_type_name(
460 self,
461 type_name: String,
462 ) -> Result<DataTable, arrow_schema::ArrowError> {
463 let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
464 Ok(DataTable::with_type_name(batch, type_name))
465 }
466
467 pub fn finish_with_schema_id(
469 self,
470 schema_id: u32,
471 ) -> Result<DataTable, arrow_schema::ArrowError> {
472 let batch = RecordBatch::try_new(Arc::new(self.schema), self.columns)?;
473 Ok(DataTable::new(batch).with_schema_id(schema_id))
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use arrow_schema::{DataType, TimeUnit};
481
482 fn sample_schema() -> Schema {
483 Schema::new(vec![
484 Field::new("price", DataType::Float64, false),
485 Field::new("volume", DataType::Int64, false),
486 Field::new("symbol", DataType::Utf8, false),
487 ])
488 }
489
490 fn sample_datatable() -> DataTable {
491 let mut builder = DataTableBuilder::new(sample_schema());
492 builder
493 .add_f64_column(vec![100.0, 101.5, 99.8])
494 .add_i64_column(vec![1000, 2000, 1500])
495 .add_string_column(vec!["AAPL", "AAPL", "AAPL"]);
496 builder.finish().unwrap()
497 }
498
499 #[test]
500 fn test_creation_and_basic_accessors() {
501 let dt = sample_datatable();
502 assert_eq!(dt.row_count(), 3);
503 assert_eq!(dt.column_count(), 3);
504 assert_eq!(dt.column_names(), vec!["price", "volume", "symbol"]);
505 assert!(!dt.is_empty());
506 }
507
508 #[test]
509 fn test_typed_column_access() {
510 let dt = sample_datatable();
511
512 let prices = dt.get_f64_column("price").unwrap();
513 assert_eq!(prices.value(0), 100.0);
514 assert_eq!(prices.value(2), 99.8);
515
516 let volumes = dt.get_i64_column("volume").unwrap();
517 assert_eq!(volumes.value(1), 2000);
518
519 let symbols = dt.get_string_column("symbol").unwrap();
520 assert_eq!(symbols.value(0), "AAPL");
521
522 assert!(dt.get_f64_column("symbol").is_none());
524 assert!(dt.get_f64_column("nonexistent").is_none());
526 }
527
528 #[test]
529 fn test_bool_column() {
530 let schema = Schema::new(vec![Field::new("flag", DataType::Boolean, false)]);
531 let mut builder = DataTableBuilder::new(schema);
532 builder.add_bool_column(vec![true, false, true]);
533 let dt = builder.finish().unwrap();
534
535 let flags = dt.get_bool_column("flag").unwrap();
536 assert!(flags.value(0));
537 assert!(!flags.value(1));
538 }
539
540 #[test]
541 fn test_timestamp_column() {
542 let schema = Schema::new(vec![Field::new(
543 "ts",
544 DataType::Timestamp(TimeUnit::Microsecond, None),
545 false,
546 )]);
547 let mut builder = DataTableBuilder::new(schema);
548 builder.add_timestamp_column(vec![1_000_000, 2_000_000, 3_000_000]);
549 let dt = builder.finish().unwrap();
550
551 let ts = dt.get_timestamp_column("ts").unwrap();
552 assert_eq!(ts.value(0), 1_000_000);
553 assert_eq!(ts.value(2), 3_000_000);
554 }
555
556 #[test]
557 fn test_zero_copy_slice() {
558 let dt = sample_datatable();
559 let sliced = dt.slice(1, 2);
560
561 assert_eq!(sliced.row_count(), 2);
562 assert_eq!(sliced.column_count(), 3);
563
564 let prices = sliced.get_f64_column("price").unwrap();
565 assert_eq!(prices.value(0), 101.5);
566 assert_eq!(prices.value(1), 99.8);
567 }
568
569 #[test]
570 fn test_empty_datatable() {
571 let schema = Schema::new(vec![Field::new("x", DataType::Float64, false)]);
572 let mut builder = DataTableBuilder::new(schema);
573 builder.add_f64_column(vec![]);
574 let dt = builder.finish().unwrap();
575
576 assert!(dt.is_empty());
577 assert_eq!(dt.row_count(), 0);
578 }
579
580 #[test]
581 fn test_display() {
582 let dt = sample_datatable();
583 let s = format!("{}", dt);
584 assert!(s.contains("DataTable"));
585 assert!(s.contains("3 rows"));
586 assert!(s.contains("price"));
587 }
588
589 #[test]
590 fn test_type_name() {
591 let dt = sample_datatable();
592 assert!(dt.type_name().is_none());
593
594 let schema = sample_schema();
595 let mut builder = DataTableBuilder::new(schema);
596 builder
597 .add_f64_column(vec![1.0])
598 .add_i64_column(vec![10])
599 .add_string_column(vec!["X"]);
600 let dt = builder.finish_with_type_name("Candle".to_string()).unwrap();
601 assert_eq!(dt.type_name(), Some("Candle"));
602 let s = format!("{}", dt);
603 assert!(s.starts_with("Candle("));
604 }
605
606 #[test]
607 fn test_builder_schema_mismatch_errors() {
608 let schema = Schema::new(vec![
609 Field::new("a", DataType::Float64, false),
610 Field::new("b", DataType::Int64, false),
611 ]);
612 let mut builder = DataTableBuilder::new(schema);
613 builder.add_f64_column(vec![1.0]);
615 assert!(builder.finish().is_err());
616 }
617
618 #[test]
619 fn test_inner_and_into_inner() {
620 let dt = sample_datatable();
621 let batch_ref = dt.inner();
622 assert_eq!(batch_ref.num_rows(), 3);
623
624 let dt2 = sample_datatable();
625 let batch = dt2.into_inner();
626 assert_eq!(batch.num_rows(), 3);
627 }
628
629 #[test]
630 fn test_partial_eq() {
631 let dt1 = sample_datatable();
632 let dt2 = sample_datatable();
633 assert_eq!(dt1, dt2);
634
635 let sliced = dt1.slice(0, 2);
636 assert_ne!(sliced, dt2);
637 }
638
639 #[test]
640 fn test_column_by_name() {
641 let dt = sample_datatable();
642 assert!(dt.column_by_name("price").is_some());
643 assert!(dt.column_by_name("missing").is_none());
644 }
645
646 #[test]
647 fn test_column_ptrs_constructed() {
648 let dt = sample_datatable();
649 assert_eq!(dt.column_ptrs().len(), 3);
651
652 let price_ptrs = dt.column_ptr(0).unwrap();
654 assert_eq!(price_ptrs.stride, 8);
655 assert!(matches!(price_ptrs.data_type, DataType::Float64));
656 assert!(!price_ptrs.values_ptr.is_null());
657
658 let vol_ptrs = dt.column_ptr(1).unwrap();
660 assert_eq!(vol_ptrs.stride, 8);
661 assert!(matches!(vol_ptrs.data_type, DataType::Int64));
662
663 let sym_ptrs = dt.column_ptr(2).unwrap();
665 assert_eq!(sym_ptrs.stride, 0);
666 assert!(matches!(sym_ptrs.data_type, DataType::Utf8));
667 assert!(!sym_ptrs.offsets_ptr.is_null());
668 }
669
670 #[test]
671 fn test_column_ptrs_f64_read() {
672 let dt = sample_datatable();
673 let ptrs = dt.column_ptr(0).unwrap();
674
675 unsafe {
677 let f64_ptr = ptrs.values_ptr as *const f64;
678 assert_eq!(*f64_ptr, 100.0);
679 assert_eq!(*f64_ptr.add(1), 101.5);
680 assert_eq!(*f64_ptr.add(2), 99.8);
681 }
682 }
683
684 #[test]
685 fn test_column_ptrs_i64_read() {
686 let dt = sample_datatable();
687 let ptrs = dt.column_ptr(1).unwrap();
688
689 unsafe {
691 let i64_ptr = ptrs.values_ptr as *const i64;
692 assert_eq!(*i64_ptr, 1000);
693 assert_eq!(*i64_ptr.add(1), 2000);
694 assert_eq!(*i64_ptr.add(2), 1500);
695 }
696 }
697
698 #[test]
699 fn test_schema_id() {
700 let dt = sample_datatable();
701 assert!(dt.schema_id().is_none());
702
703 let dt_typed = sample_datatable().with_schema_id(42);
704 assert_eq!(dt_typed.schema_id(), Some(42));
705 }
706
707 #[test]
708 fn test_column_ptr_out_of_bounds() {
709 let dt = sample_datatable();
710 assert!(dt.column_ptr(99).is_none());
711 }
712}