Skip to main content

shape_runtime/
binary_reader.rs

1//! Binary reader for converting binary columnar data to DataTable
2//!
3//! Reads binary data produced by plugins and converts it to DataTable.
4//! Helper functions for reading typed columns from binary format are available
5//! for use by the DataTable migration.
6
7use std::sync::Arc;
8
9use arrow_array::{
10    ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, TimestampMicrosecondArray,
11};
12use arrow_schema::{DataType, Field, TimeUnit};
13use shape_abi_v1::binary_format::{
14    BinaryDataHeader, BinaryFormatError, ColumnDescriptor, ColumnType, StringEntry,
15};
16use shape_ast::error::{Result, ShapeError};
17use shape_value::DataTableBuilder;
18
19/// Read binary columnar data and convert to a DataTable.
20pub fn read_binary_to_datatable(data: &[u8]) -> Result<shape_value::DataTable> {
21    let header = BinaryDataHeader::from_bytes(data).map_err(format_error)?;
22    let column_count = header.get_column_count() as usize;
23    let row_count = header.get_row_count() as usize;
24
25    // Read column descriptors (starts after 16-byte header)
26    let desc_start = BinaryDataHeader::SIZE;
27    let desc_end = desc_start + column_count * ColumnDescriptor::SIZE;
28    if data.len() < desc_end {
29        return Err(format_error(BinaryFormatError::InsufficientData {
30            expected: desc_end,
31            actual: data.len(),
32        }));
33    }
34
35    let mut descriptors = Vec::with_capacity(column_count);
36    for i in 0..column_count {
37        let offset = desc_start + i * ColumnDescriptor::SIZE;
38        let desc =
39            unsafe { std::ptr::read_unaligned(data[offset..].as_ptr() as *const ColumnDescriptor) };
40        descriptors.push(desc);
41    }
42
43    // String table starts after descriptors
44    let string_table_start = desc_end;
45
46    // Read column names from string table
47    let mut names = Vec::with_capacity(column_count);
48    for desc in &descriptors {
49        let name_off = desc.name_offset as usize;
50        let abs_off = string_table_start + name_off;
51        // Find null terminator
52        let end = data[abs_off..]
53            .iter()
54            .position(|&b| b == 0)
55            .ok_or_else(|| {
56                format_error(BinaryFormatError::ColumnNameNotFound {
57                    offset: desc.name_offset,
58                })
59            })?;
60        let name = std::str::from_utf8(&data[abs_off..abs_off + end])
61            .map_err(|_| format_error(BinaryFormatError::InvalidUtf8))?
62            .to_string();
63        names.push(name);
64    }
65
66    // Build Arrow fields and arrays
67    let mut fields = Vec::with_capacity(column_count);
68    let mut arrays: Vec<ArrayRef> = Vec::with_capacity(column_count);
69
70    for (i, desc) in descriptors.iter().enumerate() {
71        let col_type = desc
72            .column_type()
73            .ok_or_else(|| format_error(BinaryFormatError::InvalidColumnType(desc.data_type)))?;
74        let data_offset = desc.data_offset as usize;
75        let data_len = desc.data_len as usize;
76        let col_data = &data[data_offset..data_offset + data_len];
77
78        match col_type {
79            ColumnType::Float64 => {
80                fields.push(Field::new(&names[i], DataType::Float64, desc.is_nullable()));
81                let values = read_f64_column(col_data, row_count)?;
82                arrays.push(Arc::new(Float64Array::from(values)) as ArrayRef);
83            }
84            ColumnType::Int64 => {
85                fields.push(Field::new(&names[i], DataType::Int64, desc.is_nullable()));
86                let values = read_i64_column(col_data, row_count)?;
87                arrays.push(Arc::new(Int64Array::from(values)) as ArrayRef);
88            }
89            ColumnType::String => {
90                fields.push(Field::new(&names[i], DataType::Utf8, desc.is_nullable()));
91                let values = read_string_column(col_data, row_count)?;
92                let refs: Vec<&str> = values.iter().map(|s| s.as_str()).collect();
93                arrays.push(Arc::new(StringArray::from(refs)) as ArrayRef);
94            }
95            ColumnType::Bool => {
96                fields.push(Field::new(&names[i], DataType::Boolean, desc.is_nullable()));
97                let values = read_bool_column(col_data, row_count)?;
98                arrays.push(Arc::new(BooleanArray::from(values)) as ArrayRef);
99            }
100            ColumnType::Timestamp => {
101                fields.push(Field::new(
102                    &names[i],
103                    DataType::Timestamp(TimeUnit::Microsecond, None),
104                    desc.is_nullable(),
105                ));
106                let values = read_i64_column(col_data, row_count)?;
107                arrays.push(Arc::new(TimestampMicrosecondArray::from(values)) as ArrayRef);
108            }
109        }
110    }
111
112    let mut builder = DataTableBuilder::with_fields(fields);
113    for array in arrays {
114        builder.add_column(array);
115    }
116    builder.finish().map_err(|e| ShapeError::RuntimeError {
117        message: format!("Failed to build DataTable: {}", e),
118        location: None,
119    })
120}
121
122/// Read f64 column data
123fn read_f64_column(data: &[u8], count: usize) -> Result<Vec<f64>> {
124    let expected_size = count * 8;
125    if data.len() < expected_size {
126        return Err(format_error(BinaryFormatError::InsufficientData {
127            expected: expected_size,
128            actual: data.len(),
129        }));
130    }
131
132    let mut values = Vec::with_capacity(count);
133    for i in 0..count {
134        let offset = i * 8;
135        let value = f64::from_le_bytes([
136            data[offset],
137            data[offset + 1],
138            data[offset + 2],
139            data[offset + 3],
140            data[offset + 4],
141            data[offset + 5],
142            data[offset + 6],
143            data[offset + 7],
144        ]);
145        values.push(value);
146    }
147    Ok(values)
148}
149
150/// Read i64 column data
151fn read_i64_column(data: &[u8], count: usize) -> Result<Vec<i64>> {
152    let expected_size = count * 8;
153    if data.len() < expected_size {
154        return Err(format_error(BinaryFormatError::InsufficientData {
155            expected: expected_size,
156            actual: data.len(),
157        }));
158    }
159
160    let mut values = Vec::with_capacity(count);
161    for i in 0..count {
162        let offset = i * 8;
163        let value = i64::from_le_bytes([
164            data[offset],
165            data[offset + 1],
166            data[offset + 2],
167            data[offset + 3],
168            data[offset + 4],
169            data[offset + 5],
170            data[offset + 6],
171            data[offset + 7],
172        ]);
173        values.push(value);
174    }
175    Ok(values)
176}
177
178/// Read string column data
179fn read_string_column(data: &[u8], count: usize) -> Result<Vec<String>> {
180    // String column format:
181    // [StringEntry array (count * 8 bytes)] [String pool]
182
183    let entries_size = count * StringEntry::SIZE;
184    if data.len() < entries_size {
185        return Err(format_error(BinaryFormatError::InsufficientData {
186            expected: entries_size,
187            actual: data.len(),
188        }));
189    }
190
191    let pool_start = entries_size;
192    let pool = &data[pool_start..];
193
194    let mut values = Vec::with_capacity(count);
195    for i in 0..count {
196        let entry_offset = i * StringEntry::SIZE;
197        let entry = unsafe {
198            std::ptr::read_unaligned(data[entry_offset..].as_ptr() as *const StringEntry)
199        };
200
201        let str_start = entry.offset as usize;
202        let str_end = str_start + entry.length as usize;
203
204        if str_end > pool.len() {
205            return Err(format_error(BinaryFormatError::InsufficientData {
206                expected: str_end,
207                actual: pool.len(),
208            }));
209        }
210
211        let s = std::str::from_utf8(&pool[str_start..str_end])
212            .map_err(|_| format_error(BinaryFormatError::InvalidUtf8))?
213            .to_string();
214
215        values.push(s);
216    }
217
218    Ok(values)
219}
220
221/// Read bool column data
222fn read_bool_column(data: &[u8], count: usize) -> Result<Vec<bool>> {
223    if data.len() < count {
224        return Err(format_error(BinaryFormatError::InsufficientData {
225            expected: count,
226            actual: data.len(),
227        }));
228    }
229
230    let values: Vec<bool> = data[..count].iter().map(|&b| b != 0).collect();
231    Ok(values)
232}
233
234/// Convert BinaryFormatError to ShapeError
235fn format_error(e: BinaryFormatError) -> ShapeError {
236    ShapeError::RuntimeError {
237        message: format!("Binary format error: {}", e),
238        location: None,
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_read_f64_column() {
248        let data: Vec<u8> = [1.0_f64, 2.0, 3.0]
249            .iter()
250            .flat_map(|v| v.to_le_bytes())
251            .collect();
252        let values = read_f64_column(&data, 3).unwrap();
253        assert_eq!(values, vec![1.0, 2.0, 3.0]);
254    }
255
256    #[test]
257    fn test_read_i64_column() {
258        let data: Vec<u8> = [100_i64, 200, 300]
259            .iter()
260            .flat_map(|v| v.to_le_bytes())
261            .collect();
262        let values = read_i64_column(&data, 3).unwrap();
263        assert_eq!(values, vec![100, 200, 300]);
264    }
265
266    #[test]
267    fn test_read_bool_column() {
268        let data = vec![1u8, 0, 1, 1, 0];
269        let values = read_bool_column(&data, 5).unwrap();
270        assert_eq!(values, vec![true, false, true, true, false]);
271    }
272
273    #[test]
274    fn test_read_binary_to_datatable() {
275        use shape_abi_v1::binary_format::{
276            BinaryDataHeader, ColumnDescriptor, ColumnType, DATA_ALIGNMENT, align_up,
277        };
278
279        let row_count: u64 = 3;
280        let column_count: u16 = 2;
281
282        // Build string table: "price\0volume\0"
283        let string_table = b"price\0volume\0";
284        let name_offset_price: u16 = 0;
285        let name_offset_volume: u16 = 6;
286
287        // Calculate offsets
288        let desc_start = BinaryDataHeader::SIZE;
289        let desc_end = desc_start + (column_count as usize) * ColumnDescriptor::SIZE;
290        let string_table_end = desc_end + string_table.len();
291        let data_section_start = align_up(string_table_end, DATA_ALIGNMENT);
292
293        // Column 0: f64 "price" — 3 rows * 8 bytes = 24 bytes
294        let price_data_offset = data_section_start;
295        let price_data_len = 3 * 8;
296        // Column 1: i64 "volume" — 3 rows * 8 bytes = 24 bytes
297        let volume_data_offset = price_data_offset + price_data_len;
298        let volume_data_len = 3 * 8;
299
300        let total_size = volume_data_offset + volume_data_len;
301        let mut blob = vec![0u8; total_size];
302
303        // Write header
304        let header = BinaryDataHeader::new(column_count, row_count, false, false);
305        blob[..BinaryDataHeader::SIZE].copy_from_slice(&header.to_bytes());
306
307        // Write column descriptors
308        let desc0 = ColumnDescriptor::new(
309            name_offset_price,
310            ColumnType::Float64,
311            price_data_offset as u64,
312            price_data_len as u64,
313            false,
314        );
315        let desc1 = ColumnDescriptor::new(
316            name_offset_volume,
317            ColumnType::Int64,
318            volume_data_offset as u64,
319            volume_data_len as u64,
320            false,
321        );
322        unsafe {
323            let p0 = blob[desc_start..].as_mut_ptr() as *mut ColumnDescriptor;
324            std::ptr::write_unaligned(p0, desc0);
325            let p1 =
326                blob[desc_start + ColumnDescriptor::SIZE..].as_mut_ptr() as *mut ColumnDescriptor;
327            std::ptr::write_unaligned(p1, desc1);
328        }
329
330        // Write string table
331        blob[desc_end..desc_end + string_table.len()].copy_from_slice(string_table);
332
333        // Write price column data: [100.5, 200.75, 300.0]
334        let prices = [100.5_f64, 200.75, 300.0];
335        for (i, v) in prices.iter().enumerate() {
336            let off = price_data_offset + i * 8;
337            blob[off..off + 8].copy_from_slice(&v.to_le_bytes());
338        }
339
340        // Write volume column data: [1000, 2000, 3000]
341        let volumes = [1000_i64, 2000, 3000];
342        for (i, v) in volumes.iter().enumerate() {
343            let off = volume_data_offset + i * 8;
344            blob[off..off + 8].copy_from_slice(&v.to_le_bytes());
345        }
346
347        // Parse and verify
348        let dt = read_binary_to_datatable(&blob).unwrap();
349        assert_eq!(dt.row_count(), 3);
350        assert_eq!(dt.column_names(), vec!["price", "volume"]);
351
352        let price_col = dt.get_f64_column("price").unwrap();
353        assert_eq!(price_col.value(0), 100.5);
354        assert_eq!(price_col.value(1), 200.75);
355        assert_eq!(price_col.value(2), 300.0);
356
357        let volume_col = dt.get_i64_column("volume").unwrap();
358        assert_eq!(volume_col.value(0), 1000);
359        assert_eq!(volume_col.value(1), 2000);
360        assert_eq!(volume_col.value(2), 3000);
361    }
362}