Skip to main content

sci_form/transport/
arrow.rs

1//! Arrow-compatible columnar memory layouts for zero-copy data transfer.
2//!
3//! Provides flat, typed-array compatible buffers that can be directly
4//! transferred to JavaScript TypedArrays (Float64Array, Int32Array, etc.)
5//! without serialization overhead.
6//!
7//! Memory layout matches Apache Arrow IPC format for interoperability:
8//! - Values buffer: contiguous typed array
9//! - Offsets buffer: for variable-length data (strings, nested arrays)
10//! - Null bitmap: optional validity buffer
11
12use serde::{Deserialize, Serialize};
13
14/// A columnar buffer of f64 values with shape metadata.
15///
16/// Designed for zero-copy transfer to JavaScript `Float64Array`.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Float64Column {
19    /// Column name/label.
20    pub name: String,
21    /// Flat f64 values.
22    pub values: Vec<f64>,
23    /// Shape of the data: e.g., \[n_rows\] for 1D, \[n_rows, n_cols\] for 2D.
24    pub shape: Vec<usize>,
25}
26
27/// A columnar buffer of i32 values.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct Int32Column {
30    /// Column name/label.
31    pub name: String,
32    /// Flat i32 values.
33    pub values: Vec<i32>,
34    /// Shape of the data.
35    pub shape: Vec<usize>,
36}
37
38/// A columnar buffer of u8 values (for element arrays, flags, etc.).
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Uint8Column {
41    /// Column name/label.
42    pub name: String,
43    /// Flat u8 values.
44    pub values: Vec<u8>,
45    /// Shape of the data.
46    pub shape: Vec<usize>,
47}
48
49/// A record batch: a named collection of typed columns.
50///
51/// This is the primary unit for zero-copy data transfer.
52/// Analogous to Arrow RecordBatch.
53#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct RecordBatch {
55    /// Number of rows in the batch.
56    pub num_rows: usize,
57    /// Schema: column names and types.
58    pub schema: Vec<ColumnSchema>,
59    /// Float64 columns.
60    pub float_columns: Vec<Float64Column>,
61    /// Int32 columns.
62    pub int_columns: Vec<Int32Column>,
63    /// Uint8 columns.
64    pub uint8_columns: Vec<Uint8Column>,
65}
66
67/// Schema entry for a column.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ColumnSchema {
70    pub name: String,
71    pub dtype: DataType,
72    pub shape: Vec<usize>,
73}
74
75/// Supported data types.
76#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
77pub enum DataType {
78    Float64,
79    Int32,
80    Uint8,
81}
82
83impl RecordBatch {
84    /// Create an empty batch.
85    pub fn new() -> Self {
86        Self {
87            num_rows: 0,
88            schema: Vec::new(),
89            float_columns: Vec::new(),
90            int_columns: Vec::new(),
91            uint8_columns: Vec::new(),
92        }
93    }
94
95    /// Add a Float64 column.
96    pub fn add_float64(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
97        if !shape.is_empty() {
98            self.num_rows = shape[0];
99        }
100        self.schema.push(ColumnSchema {
101            name: name.to_string(),
102            dtype: DataType::Float64,
103            shape: shape.clone(),
104        });
105        self.float_columns.push(Float64Column {
106            name: name.to_string(),
107            values,
108            shape,
109        });
110    }
111
112    /// Add an Int32 column.
113    pub fn add_int32(&mut self, name: &str, values: Vec<i32>, shape: Vec<usize>) {
114        if !shape.is_empty() {
115            self.num_rows = shape[0];
116        }
117        self.schema.push(ColumnSchema {
118            name: name.to_string(),
119            dtype: DataType::Int32,
120            shape: shape.clone(),
121        });
122        self.int_columns.push(Int32Column {
123            name: name.to_string(),
124            values,
125            shape,
126        });
127    }
128
129    /// Add a Uint8 column.
130    pub fn add_uint8(&mut self, name: &str, values: Vec<u8>, shape: Vec<usize>) {
131        if !shape.is_empty() {
132            self.num_rows = shape[0];
133        }
134        self.schema.push(ColumnSchema {
135            name: name.to_string(),
136            dtype: DataType::Uint8,
137            shape: shape.clone(),
138        });
139        self.uint8_columns.push(Uint8Column {
140            name: name.to_string(),
141            values,
142            shape,
143        });
144    }
145
146    /// Total byte size of all buffers (for transfer cost estimation).
147    pub fn byte_size(&self) -> usize {
148        let f64_bytes: usize = self.float_columns.iter().map(|c| c.values.len() * 8).sum();
149        let i32_bytes: usize = self.int_columns.iter().map(|c| c.values.len() * 4).sum();
150        let u8_bytes: usize = self.uint8_columns.iter().map(|c| c.values.len()).sum();
151        f64_bytes + i32_bytes + u8_bytes
152    }
153
154    /// Number of columns.
155    pub fn num_columns(&self) -> usize {
156        self.schema.len()
157    }
158}
159
160/// Pack conformer results into an Arrow-compatible batch.
161///
162/// Columns: elements (u8), coords_x/y/z (f64), success (i32), time_ms (f64)
163pub fn pack_conformers(results: &[crate::ConformerResult]) -> RecordBatch {
164    let n = results.len();
165    let mut batch = RecordBatch::new();
166
167    // Success flag (1=ok, 0=failed)
168    let success: Vec<i32> = results
169        .iter()
170        .map(|r| if r.error.is_none() { 1 } else { 0 })
171        .collect();
172    batch.add_int32("success", success, vec![n]);
173
174    // Number of atoms per molecule
175    let num_atoms: Vec<i32> = results.iter().map(|r| r.num_atoms as i32).collect();
176    batch.add_int32("num_atoms", num_atoms, vec![n]);
177
178    // Timing
179    let times: Vec<f64> = results.iter().map(|r| r.time_ms).collect();
180    batch.add_float64("time_ms", times, vec![n]);
181
182    // All coordinates flattened (for successful molecules)
183    let all_coords: Vec<f64> = results
184        .iter()
185        .flat_map(|r| r.coords.iter().copied())
186        .collect();
187    let total_atoms: usize = results.iter().map(|r| r.num_atoms).sum();
188    batch.add_float64("coords", all_coords, vec![total_atoms, 3]);
189
190    // All elements flattened
191    let all_elements: Vec<u8> = results
192        .iter()
193        .flat_map(|r| r.elements.iter().copied())
194        .collect();
195    batch.add_uint8("elements", all_elements, vec![total_atoms]);
196
197    batch
198}
199
200/// Pack an ESP grid into an Arrow-compatible batch.
201pub fn pack_esp_grid(grid: &crate::esp::EspGrid) -> RecordBatch {
202    let mut batch = RecordBatch::new();
203
204    batch.add_float64(
205        "values",
206        grid.values.clone(),
207        vec![grid.dims[0], grid.dims[1], grid.dims[2]],
208    );
209    batch.add_float64("origin", grid.origin.to_vec(), vec![3]);
210    batch.add_int32(
211        "dims",
212        grid.dims.iter().map(|&d| d as i32).collect(),
213        vec![3],
214    );
215
216    batch
217}
218
219/// Pack DOS data into an Arrow-compatible batch.
220pub fn pack_dos(dos: &crate::dos::DosResult) -> RecordBatch {
221    let n_points = dos.energies.len();
222    let n_atoms = dos.pdos.len();
223    let mut batch = RecordBatch::new();
224
225    batch.add_float64("energies", dos.energies.clone(), vec![n_points]);
226    batch.add_float64("total_dos", dos.total_dos.clone(), vec![n_points]);
227
228    // PDOS flattened: [atom0_point0, atom0_point1, ..., atom1_point0, ...]
229    let flat_pdos: Vec<f64> = dos.pdos.iter().flat_map(|p| p.iter().copied()).collect();
230    batch.add_float64("pdos", flat_pdos, vec![n_atoms, n_points]);
231
232    batch
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_record_batch_empty() {
241        let batch = RecordBatch::new();
242        assert_eq!(batch.num_rows, 0);
243        assert_eq!(batch.num_columns(), 0);
244        assert_eq!(batch.byte_size(), 0);
245    }
246
247    #[test]
248    fn test_record_batch_add_columns() {
249        let mut batch = RecordBatch::new();
250        batch.add_float64("x", vec![1.0, 2.0, 3.0], vec![3]);
251        batch.add_int32("id", vec![0, 1, 2], vec![3]);
252        batch.add_uint8("flags", vec![1, 0, 1], vec![3]);
253
254        assert_eq!(batch.num_rows, 3);
255        assert_eq!(batch.num_columns(), 3);
256        assert_eq!(batch.byte_size(), 3 * 8 + 3 * 4 + 3); // 24 + 12 + 3 = 39
257    }
258
259    #[test]
260    fn test_pack_conformers() {
261        let results = vec![crate::ConformerResult {
262            smiles: "C".to_string(),
263            num_atoms: 5,
264            coords: vec![0.0; 15],
265            elements: vec![6, 1, 1, 1, 1],
266            bonds: vec![],
267            error: None,
268            time_ms: 1.5,
269        }];
270        let batch = pack_conformers(&results);
271        assert_eq!(batch.num_columns(), 5);
272        assert!(batch.byte_size() > 0);
273    }
274
275    #[test]
276    fn test_pack_esp_grid() {
277        let grid = crate::esp::EspGrid {
278            origin: [0.0, 0.0, 0.0],
279            spacing: 0.5,
280            dims: [3, 3, 3],
281            values: vec![0.1; 27],
282        };
283        let batch = pack_esp_grid(&grid);
284        assert_eq!(batch.float_columns[0].values.len(), 27);
285    }
286
287    #[test]
288    fn test_column_schema() {
289        let mut batch = RecordBatch::new();
290        batch.add_float64("coords", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
291        assert_eq!(batch.schema[0].dtype, DataType::Float64);
292        assert_eq!(batch.schema[0].shape, vec![2, 3]);
293    }
294}