Skip to main content

sci_form/transport/
arrow.rs

1//! Arrow-compatible columnar memory layouts for zero-copy data transfer.
2//!
3//! Provides flat, typed-array compatible buffers that can be directly
4//! transferred to JavaScript TypedArrays (Float64Array, Int32Array, etc.)
5//! without serialization overhead.
6//!
7//! **Endianness**: All buffers use **native byte order** (little-endian on x86/ARM,
8//! the dominant platforms for WASM and server targets). This is consistent with
9//! the Arrow IPC specification which requires little-endian. Big-endian hosts
10//! would need byte-swapping before interop with Arrow consumers.
11//!
12//! Column organization is Arrow-like for interoperability at the schema level,
13//! but this module does not emit raw Arrow IPC byte streams.
14//! - Values buffer: contiguous typed array
15//! - Offsets buffer: for variable-length data (strings, nested arrays)
16//! - Null bitmap: optional validity buffer
17
18use serde::{Deserialize, Serialize};
19
20/// A columnar buffer of f64 values with shape metadata.
21///
22/// Designed for zero-copy transfer to JavaScript `Float64Array`.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Float64Column {
25    /// Column name/label.
26    pub name: String,
27    /// Flat f64 values.
28    pub values: Vec<f64>,
29    /// Shape of the data: e.g., \[n_rows\] for 1D, \[n_rows, n_cols\] for 2D.
30    pub shape: Vec<usize>,
31}
32
33/// A columnar buffer of i32 values.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Int32Column {
36    /// Column name/label.
37    pub name: String,
38    /// Flat i32 values.
39    pub values: Vec<i32>,
40    /// Shape of the data.
41    pub shape: Vec<usize>,
42}
43
44/// A columnar buffer of u8 values (for element arrays, flags, etc.).
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Uint8Column {
47    /// Column name/label.
48    pub name: String,
49    /// Flat u8 values.
50    pub values: Vec<u8>,
51    /// Shape of the data.
52    pub shape: Vec<usize>,
53}
54
55/// A record batch: a named collection of typed columns.
56///
57/// This is the primary unit for zero-copy data transfer.
58/// Analogous to Arrow RecordBatch.
59#[derive(Debug, Clone, Serialize, Deserialize, Default)]
60pub struct RecordBatch {
61    /// Number of rows in the batch.
62    pub num_rows: usize,
63    /// Schema: column names and types.
64    pub schema: Vec<ColumnSchema>,
65    /// Float64 columns.
66    pub float_columns: Vec<Float64Column>,
67    /// Int32 columns.
68    pub int_columns: Vec<Int32Column>,
69    /// Uint8 columns.
70    pub uint8_columns: Vec<Uint8Column>,
71}
72
73/// Schema entry for a column.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ColumnSchema {
76    pub name: String,
77    pub dtype: DataType,
78    pub shape: Vec<usize>,
79}
80
81/// Supported data types.
82#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
83pub enum DataType {
84    Float64,
85    Int32,
86    Uint8,
87}
88
89impl RecordBatch {
90    /// Create an empty batch.
91    pub fn new() -> Self {
92        Self {
93            num_rows: 0,
94            schema: Vec::new(),
95            float_columns: Vec::new(),
96            int_columns: Vec::new(),
97            uint8_columns: Vec::new(),
98        }
99    }
100
101    /// Add a Float64 column.
102    pub fn add_float64(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
103        if !shape.is_empty() {
104            self.num_rows = shape[0];
105        }
106        self.schema.push(ColumnSchema {
107            name: name.to_string(),
108            dtype: DataType::Float64,
109            shape: shape.clone(),
110        });
111        self.float_columns.push(Float64Column {
112            name: name.to_string(),
113            values,
114            shape,
115        });
116    }
117
118    /// Add an Int32 column.
119    pub fn add_int32(&mut self, name: &str, values: Vec<i32>, shape: Vec<usize>) {
120        if !shape.is_empty() {
121            self.num_rows = shape[0];
122        }
123        self.schema.push(ColumnSchema {
124            name: name.to_string(),
125            dtype: DataType::Int32,
126            shape: shape.clone(),
127        });
128        self.int_columns.push(Int32Column {
129            name: name.to_string(),
130            values,
131            shape,
132        });
133    }
134
135    /// Add a Uint8 column.
136    pub fn add_uint8(&mut self, name: &str, values: Vec<u8>, shape: Vec<usize>) {
137        if !shape.is_empty() {
138            self.num_rows = shape[0];
139        }
140        self.schema.push(ColumnSchema {
141            name: name.to_string(),
142            dtype: DataType::Uint8,
143            shape: shape.clone(),
144        });
145        self.uint8_columns.push(Uint8Column {
146            name: name.to_string(),
147            values,
148            shape,
149        });
150    }
151
152    /// Total byte size of all buffers (for transfer cost estimation).
153    pub fn byte_size(&self) -> usize {
154        let f64_bytes: usize = self.float_columns.iter().map(|c| c.values.len() * 8).sum();
155        let i32_bytes: usize = self.int_columns.iter().map(|c| c.values.len() * 4).sum();
156        let u8_bytes: usize = self.uint8_columns.iter().map(|c| c.values.len()).sum();
157        f64_bytes + i32_bytes + u8_bytes
158    }
159
160    /// Number of columns.
161    pub fn num_columns(&self) -> usize {
162        self.schema.len()
163    }
164}
165
166/// Pack conformer results into an Arrow-compatible batch.
167///
168/// Columns: elements (u8), coords_x/y/z (f64), success (i32), time_ms (f64)
169pub fn pack_conformers(results: &[crate::ConformerResult]) -> RecordBatch {
170    // Validate shape consistency: coords.len() == num_atoms * 3, elements.len() == num_atoms
171    let results: Vec<&crate::ConformerResult> = results
172        .iter()
173        .filter(|r| r.coords.len() == r.num_atoms * 3 && r.elements.len() == r.num_atoms)
174        .collect();
175
176    let n = results.len();
177    let mut batch = RecordBatch::new();
178
179    // Success flag (1=ok, 0=failed)
180    let success: Vec<i32> = results
181        .iter()
182        .map(|r| if r.error.is_none() { 1 } else { 0 })
183        .collect();
184    batch.add_int32("success", success, vec![n]);
185
186    // Number of atoms per molecule
187    let num_atoms: Vec<i32> = results.iter().map(|r| r.num_atoms as i32).collect();
188    batch.add_int32("num_atoms", num_atoms, vec![n]);
189
190    // Timing
191    let times: Vec<f64> = results.iter().map(|r| r.time_ms).collect();
192    batch.add_float64("time_ms", times, vec![n]);
193
194    // All coordinates flattened (for successful molecules)
195    let all_coords: Vec<f64> = results
196        .iter()
197        .flat_map(|r| r.coords.iter().copied())
198        .collect();
199    let total_atoms: usize = results.iter().map(|r| r.num_atoms).sum();
200    batch.add_float64("coords", all_coords, vec![total_atoms, 3]);
201
202    // All elements flattened
203    let all_elements: Vec<u8> = results
204        .iter()
205        .flat_map(|r| r.elements.iter().copied())
206        .collect();
207    batch.add_uint8("elements", all_elements, vec![total_atoms]);
208
209    batch
210}
211
212/// Pack an ESP grid into an Arrow-compatible batch.
213pub fn pack_esp_grid(grid: &crate::esp::EspGrid) -> RecordBatch {
214    let mut batch = RecordBatch::new();
215
216    batch.add_float64(
217        "values",
218        grid.values.clone(),
219        vec![grid.dims[0], grid.dims[1], grid.dims[2]],
220    );
221    batch.add_float64("origin", grid.origin.to_vec(), vec![3]);
222    batch.add_int32(
223        "dims",
224        grid.dims.iter().map(|&d| d as i32).collect(),
225        vec![3],
226    );
227
228    batch
229}
230
231/// Pack DOS data into an Arrow-compatible batch.
232pub fn pack_dos(dos: &crate::dos::DosResult) -> RecordBatch {
233    let n_points = dos.energies.len();
234    let n_atoms = dos.pdos.len();
235    let mut batch = RecordBatch::new();
236
237    batch.add_float64("energies", dos.energies.clone(), vec![n_points]);
238    batch.add_float64("total_dos", dos.total_dos.clone(), vec![n_points]);
239
240    // PDOS flattened: [atom0_point0, atom0_point1, ..., atom1_point0, ...]
241    let flat_pdos: Vec<f64> = dos.pdos.iter().flat_map(|p| p.iter().copied()).collect();
242    batch.add_float64("pdos", flat_pdos, vec![n_atoms, n_points]);
243
244    batch
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_record_batch_empty() {
253        let batch = RecordBatch::new();
254        assert_eq!(batch.num_rows, 0);
255        assert_eq!(batch.num_columns(), 0);
256        assert_eq!(batch.byte_size(), 0);
257    }
258
259    #[test]
260    fn test_record_batch_add_columns() {
261        let mut batch = RecordBatch::new();
262        batch.add_float64("x", vec![1.0, 2.0, 3.0], vec![3]);
263        batch.add_int32("id", vec![0, 1, 2], vec![3]);
264        batch.add_uint8("flags", vec![1, 0, 1], vec![3]);
265
266        assert_eq!(batch.num_rows, 3);
267        assert_eq!(batch.num_columns(), 3);
268        assert_eq!(batch.byte_size(), 3 * 8 + 3 * 4 + 3); // 24 + 12 + 3 = 39
269    }
270
271    #[test]
272    fn test_pack_conformers() {
273        let results = vec![crate::ConformerResult {
274            smiles: "C".to_string(),
275            num_atoms: 5,
276            coords: vec![0.0; 15],
277            elements: vec![6, 1, 1, 1, 1],
278            bonds: vec![],
279            error: None,
280            time_ms: 1.5,
281        }];
282        let batch = pack_conformers(&results);
283        assert_eq!(batch.num_columns(), 5);
284        assert!(batch.byte_size() > 0);
285    }
286
287    #[test]
288    fn test_pack_esp_grid() {
289        let grid = crate::esp::EspGrid {
290            origin: [0.0, 0.0, 0.0],
291            spacing: 0.5,
292            dims: [3, 3, 3],
293            values: vec![0.1; 27],
294        };
295        let batch = pack_esp_grid(&grid);
296        assert_eq!(batch.float_columns[0].values.len(), 27);
297    }
298
299    #[test]
300    fn test_column_schema() {
301        let mut batch = RecordBatch::new();
302        batch.add_float64("coords", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
303        assert_eq!(batch.schema[0].dtype, DataType::Float64);
304        assert_eq!(batch.schema[0].shape, vec![2, 3]);
305    }
306}