1use arrow_array::builder::{
8 ArrayBuilder, BooleanBuilder, Float64Builder, Int64Builder, StringBuilder,
9};
10use arrow_array::{ArrayRef, RecordBatch};
11use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
12use std::sync::Arc;
13
14#[derive(Debug, Clone)]
19pub enum ArrowValue {
20 Int64(i64),
22 Float64(f64),
24 Boolean(bool),
26 Utf8(String),
28}
29
30pub struct ArrowBatchBuilder {
37 schema: SchemaRef,
38 builders: Vec<Box<dyn ArrayBuilder>>,
39 len: usize,
40}
41
42impl ArrowBatchBuilder {
43 pub fn new(schema: SchemaRef) -> Result<Self, ArrowError> {
47 let builders = build_builders(&schema)?;
48 Ok(Self {
49 schema,
50 builders,
51 len: 0,
52 })
53 }
54
55 pub fn schema(&self) -> &SchemaRef {
57 &self.schema
58 }
59
60 pub fn len(&self) -> usize {
62 self.len
63 }
64
65 pub fn is_empty(&self) -> bool {
67 self.len == 0
68 }
69
70 pub fn push_row(&mut self, values: &[ArrowValue]) -> Result<(), ArrowError> {
75 if values.len() != self.builders.len() {
76 return Err(ArrowError::SchemaError(
77 "row length does not match schema".to_string(),
78 ));
79 }
80
81 for (value, builder) in values.iter().zip(self.builders.iter_mut()) {
82 append_value(builder.as_mut(), value)?;
83 }
84
85 self.len += 1;
86 Ok(())
87 }
88
89 pub fn finish(&mut self) -> Result<RecordBatch, ArrowError> {
93 let arrays = self
94 .builders
95 .iter_mut()
96 .map(|builder| builder.finish())
97 .collect::<Vec<ArrayRef>>();
98
99 let batch = RecordBatch::try_new(self.schema.clone(), arrays)?;
100 self.builders = build_builders(&self.schema)?;
101 self.len = 0;
102 Ok(batch)
103 }
104}
105
106#[derive(Debug, Default)]
108pub struct RecordBatchCollector {
109 batches: Vec<RecordBatch>,
110}
111
112impl RecordBatchCollector {
113 pub fn new() -> Self {
115 Self {
116 batches: Vec::new(),
117 }
118 }
119
120 pub fn push(&mut self, batch: RecordBatch) {
122 self.batches.push(batch);
123 }
124
125 pub fn take(&mut self) -> Vec<RecordBatch> {
127 std::mem::take(&mut self.batches)
128 }
129
130 pub fn batches(&self) -> &[RecordBatch] {
132 &self.batches
133 }
134}
135
136pub fn schema_from_fields(fields: Vec<Field>) -> SchemaRef {
138 Arc::new(Schema::new(fields))
139}
140
141fn build_builders(schema: &SchemaRef) -> Result<Vec<Box<dyn ArrayBuilder>>, ArrowError> {
142 schema
143 .fields()
144 .iter()
145 .map(|field| builder_for_field(field))
146 .collect()
147}
148
149fn builder_for_field(field: &Field) -> Result<Box<dyn ArrayBuilder>, ArrowError> {
150 match field.data_type() {
151 DataType::Int64 => Ok(Box::new(Int64Builder::new())),
152 DataType::Float64 => Ok(Box::new(Float64Builder::new())),
153 DataType::Boolean => Ok(Box::new(BooleanBuilder::new())),
154 DataType::Utf8 => Ok(Box::new(StringBuilder::new())),
155 other => Err(ArrowError::SchemaError(format!(
156 "unsupported data type {other:?}"
157 ))),
158 }
159}
160
161fn append_value(builder: &mut dyn ArrayBuilder, value: &ArrowValue) -> Result<(), ArrowError> {
162 if let (Some(builder), ArrowValue::Int64(value)) =
163 (builder.as_any_mut().downcast_mut::<Int64Builder>(), value)
164 {
165 builder.append_value(*value);
166 return Ok(());
167 }
168
169 if let (Some(builder), ArrowValue::Float64(value)) =
170 (builder.as_any_mut().downcast_mut::<Float64Builder>(), value)
171 {
172 builder.append_value(*value);
173 return Ok(());
174 }
175
176 if let (Some(builder), ArrowValue::Boolean(value)) =
177 (builder.as_any_mut().downcast_mut::<BooleanBuilder>(), value)
178 {
179 builder.append_value(*value);
180 return Ok(());
181 }
182
183 if let (Some(builder), ArrowValue::Utf8(value)) =
184 (builder.as_any_mut().downcast_mut::<StringBuilder>(), value)
185 {
186 builder.append_value(value);
187 return Ok(());
188 }
189
190 Err(ArrowError::SchemaError(
191 "value does not match builder type".to_string(),
192 ))
193}