1use serde::{Deserialize, Serialize};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Float64Column {
25 pub name: String,
27 pub values: Vec<f64>,
29 pub shape: Vec<usize>,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Int32Column {
36 pub name: String,
38 pub values: Vec<i32>,
40 pub shape: Vec<usize>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct Uint8Column {
47 pub name: String,
49 pub values: Vec<u8>,
51 pub shape: Vec<usize>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, Default)]
60pub struct RecordBatch {
61 pub num_rows: usize,
63 pub schema: Vec<ColumnSchema>,
65 pub float_columns: Vec<Float64Column>,
67 pub int_columns: Vec<Int32Column>,
69 pub uint8_columns: Vec<Uint8Column>,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ColumnSchema {
76 pub name: String,
77 pub dtype: DataType,
78 pub shape: Vec<usize>,
79}
80
81#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
83pub enum DataType {
84 Float64,
85 Int32,
86 Uint8,
87}
88
89impl RecordBatch {
90 pub fn new() -> Self {
92 Self {
93 num_rows: 0,
94 schema: Vec::new(),
95 float_columns: Vec::new(),
96 int_columns: Vec::new(),
97 uint8_columns: Vec::new(),
98 }
99 }
100
101 pub fn add_float64(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
103 if !shape.is_empty() {
104 self.num_rows = shape[0];
105 }
106 self.schema.push(ColumnSchema {
107 name: name.to_string(),
108 dtype: DataType::Float64,
109 shape: shape.clone(),
110 });
111 self.float_columns.push(Float64Column {
112 name: name.to_string(),
113 values,
114 shape,
115 });
116 }
117
118 pub fn add_int32(&mut self, name: &str, values: Vec<i32>, shape: Vec<usize>) {
120 if !shape.is_empty() {
121 self.num_rows = shape[0];
122 }
123 self.schema.push(ColumnSchema {
124 name: name.to_string(),
125 dtype: DataType::Int32,
126 shape: shape.clone(),
127 });
128 self.int_columns.push(Int32Column {
129 name: name.to_string(),
130 values,
131 shape,
132 });
133 }
134
135 pub fn add_uint8(&mut self, name: &str, values: Vec<u8>, shape: Vec<usize>) {
137 if !shape.is_empty() {
138 self.num_rows = shape[0];
139 }
140 self.schema.push(ColumnSchema {
141 name: name.to_string(),
142 dtype: DataType::Uint8,
143 shape: shape.clone(),
144 });
145 self.uint8_columns.push(Uint8Column {
146 name: name.to_string(),
147 values,
148 shape,
149 });
150 }
151
152 pub fn byte_size(&self) -> usize {
154 let f64_bytes: usize = self.float_columns.iter().map(|c| c.values.len() * 8).sum();
155 let i32_bytes: usize = self.int_columns.iter().map(|c| c.values.len() * 4).sum();
156 let u8_bytes: usize = self.uint8_columns.iter().map(|c| c.values.len()).sum();
157 f64_bytes + i32_bytes + u8_bytes
158 }
159
160 pub fn num_columns(&self) -> usize {
162 self.schema.len()
163 }
164}
165
166pub fn pack_conformers(results: &[crate::ConformerResult]) -> RecordBatch {
170 let results: Vec<&crate::ConformerResult> = results
172 .iter()
173 .filter(|r| r.coords.len() == r.num_atoms * 3 && r.elements.len() == r.num_atoms)
174 .collect();
175
176 let n = results.len();
177 let mut batch = RecordBatch::new();
178
179 let success: Vec<i32> = results
181 .iter()
182 .map(|r| if r.error.is_none() { 1 } else { 0 })
183 .collect();
184 batch.add_int32("success", success, vec![n]);
185
186 let num_atoms: Vec<i32> = results.iter().map(|r| r.num_atoms as i32).collect();
188 batch.add_int32("num_atoms", num_atoms, vec![n]);
189
190 let times: Vec<f64> = results.iter().map(|r| r.time_ms).collect();
192 batch.add_float64("time_ms", times, vec![n]);
193
194 let all_coords: Vec<f64> = results
196 .iter()
197 .flat_map(|r| r.coords.iter().copied())
198 .collect();
199 let total_atoms: usize = results.iter().map(|r| r.num_atoms).sum();
200 batch.add_float64("coords", all_coords, vec![total_atoms, 3]);
201
202 let all_elements: Vec<u8> = results
204 .iter()
205 .flat_map(|r| r.elements.iter().copied())
206 .collect();
207 batch.add_uint8("elements", all_elements, vec![total_atoms]);
208
209 batch
210}
211
212pub fn pack_esp_grid(grid: &crate::esp::EspGrid) -> RecordBatch {
214 let mut batch = RecordBatch::new();
215
216 batch.add_float64(
217 "values",
218 grid.values.clone(),
219 vec![grid.dims[0], grid.dims[1], grid.dims[2]],
220 );
221 batch.add_float64("origin", grid.origin.to_vec(), vec![3]);
222 batch.add_int32(
223 "dims",
224 grid.dims.iter().map(|&d| d as i32).collect(),
225 vec![3],
226 );
227
228 batch
229}
230
231pub fn pack_dos(dos: &crate::dos::DosResult) -> RecordBatch {
233 let n_points = dos.energies.len();
234 let n_atoms = dos.pdos.len();
235 let mut batch = RecordBatch::new();
236
237 batch.add_float64("energies", dos.energies.clone(), vec![n_points]);
238 batch.add_float64("total_dos", dos.total_dos.clone(), vec![n_points]);
239
240 let flat_pdos: Vec<f64> = dos.pdos.iter().flat_map(|p| p.iter().copied()).collect();
242 batch.add_float64("pdos", flat_pdos, vec![n_atoms, n_points]);
243
244 batch
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_record_batch_empty() {
253 let batch = RecordBatch::new();
254 assert_eq!(batch.num_rows, 0);
255 assert_eq!(batch.num_columns(), 0);
256 assert_eq!(batch.byte_size(), 0);
257 }
258
259 #[test]
260 fn test_record_batch_add_columns() {
261 let mut batch = RecordBatch::new();
262 batch.add_float64("x", vec![1.0, 2.0, 3.0], vec![3]);
263 batch.add_int32("id", vec![0, 1, 2], vec![3]);
264 batch.add_uint8("flags", vec![1, 0, 1], vec![3]);
265
266 assert_eq!(batch.num_rows, 3);
267 assert_eq!(batch.num_columns(), 3);
268 assert_eq!(batch.byte_size(), 3 * 8 + 3 * 4 + 3); }
270
271 #[test]
272 fn test_pack_conformers() {
273 let results = vec![crate::ConformerResult {
274 smiles: "C".to_string(),
275 num_atoms: 5,
276 coords: vec![0.0; 15],
277 elements: vec![6, 1, 1, 1, 1],
278 bonds: vec![],
279 error: None,
280 time_ms: 1.5,
281 }];
282 let batch = pack_conformers(&results);
283 assert_eq!(batch.num_columns(), 5);
284 assert!(batch.byte_size() > 0);
285 }
286
287 #[test]
288 fn test_pack_esp_grid() {
289 let grid = crate::esp::EspGrid {
290 origin: [0.0, 0.0, 0.0],
291 spacing: 0.5,
292 dims: [3, 3, 3],
293 values: vec![0.1; 27],
294 };
295 let batch = pack_esp_grid(&grid);
296 assert_eq!(batch.float_columns[0].values.len(), 27);
297 }
298
299 #[test]
300 fn test_column_schema() {
301 let mut batch = RecordBatch::new();
302 batch.add_float64("coords", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
303 assert_eq!(batch.schema[0].dtype, DataType::Float64);
304 assert_eq!(batch.schema[0].shape, vec![2, 3]);
305 }
306}