1use std::sync::Arc;
8
9use arrow_array::{
10 ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, TimestampMicrosecondArray,
11};
12use arrow_schema::{DataType, Field, TimeUnit};
13use shape_abi_v1::binary_format::{
14 BinaryDataHeader, BinaryFormatError, ColumnDescriptor, ColumnType, StringEntry,
15};
16use shape_ast::error::{Result, ShapeError};
17use shape_value::DataTableBuilder;
18
19pub fn read_binary_to_datatable(data: &[u8]) -> Result<shape_value::DataTable> {
21 let header = BinaryDataHeader::from_bytes(data).map_err(format_error)?;
22 let column_count = header.get_column_count() as usize;
23 let row_count = header.get_row_count() as usize;
24
25 let desc_start = BinaryDataHeader::SIZE;
27 let desc_end = desc_start + column_count * ColumnDescriptor::SIZE;
28 if data.len() < desc_end {
29 return Err(format_error(BinaryFormatError::InsufficientData {
30 expected: desc_end,
31 actual: data.len(),
32 }));
33 }
34
35 let mut descriptors = Vec::with_capacity(column_count);
36 for i in 0..column_count {
37 let offset = desc_start + i * ColumnDescriptor::SIZE;
38 let desc =
39 unsafe { std::ptr::read_unaligned(data[offset..].as_ptr() as *const ColumnDescriptor) };
40 descriptors.push(desc);
41 }
42
43 let string_table_start = desc_end;
45
46 let mut names = Vec::with_capacity(column_count);
48 for desc in &descriptors {
49 let name_off = desc.name_offset as usize;
50 let abs_off = string_table_start + name_off;
51 let end = data[abs_off..]
53 .iter()
54 .position(|&b| b == 0)
55 .ok_or_else(|| {
56 format_error(BinaryFormatError::ColumnNameNotFound {
57 offset: desc.name_offset,
58 })
59 })?;
60 let name = std::str::from_utf8(&data[abs_off..abs_off + end])
61 .map_err(|_| format_error(BinaryFormatError::InvalidUtf8))?
62 .to_string();
63 names.push(name);
64 }
65
66 let mut fields = Vec::with_capacity(column_count);
68 let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_count);
69
70 for (i, desc) in descriptors.iter().enumerate() {
71 let col_type = desc
72 .column_type()
73 .ok_or_else(|| format_error(BinaryFormatError::InvalidColumnType(desc.data_type)))?;
74 let data_offset = desc.data_offset as usize;
75 let data_len = desc.data_len as usize;
76 let col_data = &data[data_offset..data_offset + data_len];
77
78 match col_type {
79 ColumnType::Float64 => {
80 fields.push(Field::new(&names[i], DataType::Float64, desc.is_nullable()));
81 let values = read_f64_column(col_data, row_count)?;
82 arrays.push(Arc::new(Float64Array::from(values)) as ArrayRef);
83 }
84 ColumnType::Int64 => {
85 fields.push(Field::new(&names[i], DataType::Int64, desc.is_nullable()));
86 let values = read_i64_column(col_data, row_count)?;
87 arrays.push(Arc::new(Int64Array::from(values)) as ArrayRef);
88 }
89 ColumnType::String => {
90 fields.push(Field::new(&names[i], DataType::Utf8, desc.is_nullable()));
91 let values = read_string_column(col_data, row_count)?;
92 let refs: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
93 arrays.push(Arc::new(StringArray::from(refs)) as ArrayRef);
94 }
95 ColumnType::Bool => {
96 fields.push(Field::new(&names[i], DataType::Boolean, desc.is_nullable()));
97 let values = read_bool_column(col_data, row_count)?;
98 arrays.push(Arc::new(BooleanArray::from(values)) as ArrayRef);
99 }
100 ColumnType::Timestamp => {
101 fields.push(Field::new(
102 &names[i],
103 DataType::Timestamp(TimeUnit::Microsecond, None),
104 desc.is_nullable(),
105 ));
106 let values = read_i64_column(col_data, row_count)?;
107 arrays.push(Arc::new(TimestampMicrosecondArray::from(values)) as ArrayRef);
108 }
109 }
110 }
111
112 let mut builder = DataTableBuilder::with_fields(fields);
113 for array in arrays {
114 builder.add_column(array);
115 }
116 builder.finish().map_err(|e| ShapeError::RuntimeError {
117 message: format!("Failed to build DataTable: {}", e),
118 location: None,
119 })
120}
121
122fn read_f64_column(data: &[u8], count: usize) -> Result<Vec<f64>> {
124 let expected_size = count * 8;
125 if data.len() < expected_size {
126 return Err(format_error(BinaryFormatError::InsufficientData {
127 expected: expected_size,
128 actual: data.len(),
129 }));
130 }
131
132 let mut values = Vec::with_capacity(count);
133 for i in 0..count {
134 let offset = i * 8;
135 let value = f64::from_le_bytes([
136 data[offset],
137 data[offset + 1],
138 data[offset + 2],
139 data[offset + 3],
140 data[offset + 4],
141 data[offset + 5],
142 data[offset + 6],
143 data[offset + 7],
144 ]);
145 values.push(value);
146 }
147 Ok(values)
148}
149
150fn read_i64_column(data: &[u8], count: usize) -> Result<Vec<i64>> {
152 let expected_size = count * 8;
153 if data.len() < expected_size {
154 return Err(format_error(BinaryFormatError::InsufficientData {
155 expected: expected_size,
156 actual: data.len(),
157 }));
158 }
159
160 let mut values = Vec::with_capacity(count);
161 for i in 0..count {
162 let offset = i * 8;
163 let value = i64::from_le_bytes([
164 data[offset],
165 data[offset + 1],
166 data[offset + 2],
167 data[offset + 3],
168 data[offset + 4],
169 data[offset + 5],
170 data[offset + 6],
171 data[offset + 7],
172 ]);
173 values.push(value);
174 }
175 Ok(values)
176}
177
178fn read_string_column(data: &[u8], count: usize) -> Result<Vec<String>> {
180 let entries_size = count * StringEntry::SIZE;
184 if data.len() < entries_size {
185 return Err(format_error(BinaryFormatError::InsufficientData {
186 expected: entries_size,
187 actual: data.len(),
188 }));
189 }
190
191 let pool_start = entries_size;
192 let pool = &data[pool_start..];
193
194 let mut values = Vec::with_capacity(count);
195 for i in 0..count {
196 let entry_offset = i * StringEntry::SIZE;
197 let entry = unsafe {
198 std::ptr::read_unaligned(data[entry_offset..].as_ptr() as *const StringEntry)
199 };
200
201 let str_start = entry.offset as usize;
202 let str_end = str_start + entry.length as usize;
203
204 if str_end > pool.len() {
205 return Err(format_error(BinaryFormatError::InsufficientData {
206 expected: str_end,
207 actual: pool.len(),
208 }));
209 }
210
211 let s = std::str::from_utf8(&pool[str_start..str_end])
212 .map_err(|_| format_error(BinaryFormatError::InvalidUtf8))?
213 .to_string();
214
215 values.push(s);
216 }
217
218 Ok(values)
219}
220
221fn read_bool_column(data: &[u8], count: usize) -> Result<Vec<bool>> {
223 if data.len() < count {
224 return Err(format_error(BinaryFormatError::InsufficientData {
225 expected: count,
226 actual: data.len(),
227 }));
228 }
229
230 let values: Vec<bool> = data[..count].iter().map(|&b| b != 0).collect();
231 Ok(values)
232}
233
234fn format_error(e: BinaryFormatError) -> ShapeError {
236 ShapeError::RuntimeError {
237 message: format!("Binary format error: {}", e),
238 location: None,
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_read_f64_column() {
248 let data: Vec<u8> = [1.0_f64, 2.0, 3.0]
249 .iter()
250 .flat_map(|v| v.to_le_bytes())
251 .collect();
252 let values = read_f64_column(&data, 3).unwrap();
253 assert_eq!(values, vec![1.0, 2.0, 3.0]);
254 }
255
256 #[test]
257 fn test_read_i64_column() {
258 let data: Vec<u8> = [100_i64, 200, 300]
259 .iter()
260 .flat_map(|v| v.to_le_bytes())
261 .collect();
262 let values = read_i64_column(&data, 3).unwrap();
263 assert_eq!(values, vec![100, 200, 300]);
264 }
265
266 #[test]
267 fn test_read_bool_column() {
268 let data = vec![1u8, 0, 1, 1, 0];
269 let values = read_bool_column(&data, 5).unwrap();
270 assert_eq!(values, vec![true, false, true, true, false]);
271 }
272
273 #[test]
274 fn test_read_binary_to_datatable() {
275 use shape_abi_v1::binary_format::{
276 BinaryDataHeader, ColumnDescriptor, ColumnType, DATA_ALIGNMENT, align_up,
277 };
278
279 let row_count: u64 = 3;
280 let column_count: u16 = 2;
281
282 let string_table = b"price\0volume\0";
284 let name_offset_price: u16 = 0;
285 let name_offset_volume: u16 = 6;
286
287 let desc_start = BinaryDataHeader::SIZE;
289 let desc_end = desc_start + (column_count as usize) * ColumnDescriptor::SIZE;
290 let string_table_end = desc_end + string_table.len();
291 let data_section_start = align_up(string_table_end, DATA_ALIGNMENT);
292
293 let price_data_offset = data_section_start;
295 let price_data_len = 3 * 8;
296 let volume_data_offset = price_data_offset + price_data_len;
298 let volume_data_len = 3 * 8;
299
300 let total_size = volume_data_offset + volume_data_len;
301 let mut blob = vec![0u8; total_size];
302
303 let header = BinaryDataHeader::new(column_count, row_count, false, false);
305 blob[..BinaryDataHeader::SIZE].copy_from_slice(&header.to_bytes());
306
307 let desc0 = ColumnDescriptor::new(
309 name_offset_price,
310 ColumnType::Float64,
311 price_data_offset as u64,
312 price_data_len as u64,
313 false,
314 );
315 let desc1 = ColumnDescriptor::new(
316 name_offset_volume,
317 ColumnType::Int64,
318 volume_data_offset as u64,
319 volume_data_len as u64,
320 false,
321 );
322 unsafe {
323 let p0 = blob[desc_start..].as_mut_ptr() as *mut ColumnDescriptor;
324 std::ptr::write_unaligned(p0, desc0);
325 let p1 =
326 blob[desc_start + ColumnDescriptor::SIZE..].as_mut_ptr() as *mut ColumnDescriptor;
327 std::ptr::write_unaligned(p1, desc1);
328 }
329
330 blob[desc_end..desc_end + string_table.len()].copy_from_slice(string_table);
332
333 let prices = [100.5_f64, 200.75, 300.0];
335 for (i, v) in prices.iter().enumerate() {
336 let off = price_data_offset + i * 8;
337 blob[off..off + 8].copy_from_slice(&v.to_le_bytes());
338 }
339
340 let volumes = [1000_i64, 2000, 3000];
342 for (i, v) in volumes.iter().enumerate() {
343 let off = volume_data_offset + i * 8;
344 blob[off..off + 8].copy_from_slice(&v.to_le_bytes());
345 }
346
347 let dt = read_binary_to_datatable(&blob).unwrap();
349 assert_eq!(dt.row_count(), 3);
350 assert_eq!(dt.column_names(), vec!["price", "volume"]);
351
352 let price_col = dt.get_f64_column("price").unwrap();
353 assert_eq!(price_col.value(0), 100.5);
354 assert_eq!(price_col.value(1), 200.75);
355 assert_eq!(price_col.value(2), 300.0);
356
357 let volume_col = dt.get_i64_column("volume").unwrap();
358 assert_eq!(volume_col.value(0), 1000);
359 assert_eq!(volume_col.value(1), 2000);
360 assert_eq!(volume_col.value(2), 3000);
361 }
362}