1use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Float64Column {
19 pub name: String,
21 pub values: Vec<f64>,
23 pub shape: Vec<usize>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct Int32Column {
30 pub name: String,
32 pub values: Vec<i32>,
34 pub shape: Vec<usize>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Uint8Column {
41 pub name: String,
43 pub values: Vec<u8>,
45 pub shape: Vec<usize>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, Default)]
54pub struct RecordBatch {
55 pub num_rows: usize,
57 pub schema: Vec<ColumnSchema>,
59 pub float_columns: Vec<Float64Column>,
61 pub int_columns: Vec<Int32Column>,
63 pub uint8_columns: Vec<Uint8Column>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ColumnSchema {
70 pub name: String,
71 pub dtype: DataType,
72 pub shape: Vec<usize>,
73}
74
75#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
77pub enum DataType {
78 Float64,
79 Int32,
80 Uint8,
81}
82
83impl RecordBatch {
84 pub fn new() -> Self {
86 Self {
87 num_rows: 0,
88 schema: Vec::new(),
89 float_columns: Vec::new(),
90 int_columns: Vec::new(),
91 uint8_columns: Vec::new(),
92 }
93 }
94
95 pub fn add_float64(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
97 if !shape.is_empty() {
98 self.num_rows = shape[0];
99 }
100 self.schema.push(ColumnSchema {
101 name: name.to_string(),
102 dtype: DataType::Float64,
103 shape: shape.clone(),
104 });
105 self.float_columns.push(Float64Column {
106 name: name.to_string(),
107 values,
108 shape,
109 });
110 }
111
112 pub fn add_int32(&mut self, name: &str, values: Vec<i32>, shape: Vec<usize>) {
114 if !shape.is_empty() {
115 self.num_rows = shape[0];
116 }
117 self.schema.push(ColumnSchema {
118 name: name.to_string(),
119 dtype: DataType::Int32,
120 shape: shape.clone(),
121 });
122 self.int_columns.push(Int32Column {
123 name: name.to_string(),
124 values,
125 shape,
126 });
127 }
128
129 pub fn add_uint8(&mut self, name: &str, values: Vec<u8>, shape: Vec<usize>) {
131 if !shape.is_empty() {
132 self.num_rows = shape[0];
133 }
134 self.schema.push(ColumnSchema {
135 name: name.to_string(),
136 dtype: DataType::Uint8,
137 shape: shape.clone(),
138 });
139 self.uint8_columns.push(Uint8Column {
140 name: name.to_string(),
141 values,
142 shape,
143 });
144 }
145
146 pub fn byte_size(&self) -> usize {
148 let f64_bytes: usize = self.float_columns.iter().map(|c| c.values.len() * 8).sum();
149 let i32_bytes: usize = self.int_columns.iter().map(|c| c.values.len() * 4).sum();
150 let u8_bytes: usize = self.uint8_columns.iter().map(|c| c.values.len()).sum();
151 f64_bytes + i32_bytes + u8_bytes
152 }
153
154 pub fn num_columns(&self) -> usize {
156 self.schema.len()
157 }
158}
159
160pub fn pack_conformers(results: &[crate::ConformerResult]) -> RecordBatch {
164 let n = results.len();
165 let mut batch = RecordBatch::new();
166
167 let success: Vec<i32> = results
169 .iter()
170 .map(|r| if r.error.is_none() { 1 } else { 0 })
171 .collect();
172 batch.add_int32("success", success, vec![n]);
173
174 let num_atoms: Vec<i32> = results.iter().map(|r| r.num_atoms as i32).collect();
176 batch.add_int32("num_atoms", num_atoms, vec![n]);
177
178 let times: Vec<f64> = results.iter().map(|r| r.time_ms).collect();
180 batch.add_float64("time_ms", times, vec![n]);
181
182 let all_coords: Vec<f64> = results
184 .iter()
185 .flat_map(|r| r.coords.iter().copied())
186 .collect();
187 let total_atoms: usize = results.iter().map(|r| r.num_atoms).sum();
188 batch.add_float64("coords", all_coords, vec![total_atoms, 3]);
189
190 let all_elements: Vec<u8> = results
192 .iter()
193 .flat_map(|r| r.elements.iter().copied())
194 .collect();
195 batch.add_uint8("elements", all_elements, vec![total_atoms]);
196
197 batch
198}
199
200pub fn pack_esp_grid(grid: &crate::esp::EspGrid) -> RecordBatch {
202 let mut batch = RecordBatch::new();
203
204 batch.add_float64(
205 "values",
206 grid.values.clone(),
207 vec![grid.dims[0], grid.dims[1], grid.dims[2]],
208 );
209 batch.add_float64("origin", grid.origin.to_vec(), vec![3]);
210 batch.add_int32(
211 "dims",
212 grid.dims.iter().map(|&d| d as i32).collect(),
213 vec![3],
214 );
215
216 batch
217}
218
219pub fn pack_dos(dos: &crate::dos::DosResult) -> RecordBatch {
221 let n_points = dos.energies.len();
222 let n_atoms = dos.pdos.len();
223 let mut batch = RecordBatch::new();
224
225 batch.add_float64("energies", dos.energies.clone(), vec![n_points]);
226 batch.add_float64("total_dos", dos.total_dos.clone(), vec![n_points]);
227
228 let flat_pdos: Vec<f64> = dos.pdos.iter().flat_map(|p| p.iter().copied()).collect();
230 batch.add_float64("pdos", flat_pdos, vec![n_atoms, n_points]);
231
232 batch
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_record_batch_empty() {
241 let batch = RecordBatch::new();
242 assert_eq!(batch.num_rows, 0);
243 assert_eq!(batch.num_columns(), 0);
244 assert_eq!(batch.byte_size(), 0);
245 }
246
247 #[test]
248 fn test_record_batch_add_columns() {
249 let mut batch = RecordBatch::new();
250 batch.add_float64("x", vec![1.0, 2.0, 3.0], vec![3]);
251 batch.add_int32("id", vec![0, 1, 2], vec![3]);
252 batch.add_uint8("flags", vec![1, 0, 1], vec![3]);
253
254 assert_eq!(batch.num_rows, 3);
255 assert_eq!(batch.num_columns(), 3);
256 assert_eq!(batch.byte_size(), 3 * 8 + 3 * 4 + 3); }
258
259 #[test]
260 fn test_pack_conformers() {
261 let results = vec![crate::ConformerResult {
262 smiles: "C".to_string(),
263 num_atoms: 5,
264 coords: vec![0.0; 15],
265 elements: vec![6, 1, 1, 1, 1],
266 bonds: vec![],
267 error: None,
268 time_ms: 1.5,
269 }];
270 let batch = pack_conformers(&results);
271 assert_eq!(batch.num_columns(), 5);
272 assert!(batch.byte_size() > 0);
273 }
274
275 #[test]
276 fn test_pack_esp_grid() {
277 let grid = crate::esp::EspGrid {
278 origin: [0.0, 0.0, 0.0],
279 spacing: 0.5,
280 dims: [3, 3, 3],
281 values: vec![0.1; 27],
282 };
283 let batch = pack_esp_grid(&grid);
284 assert_eq!(batch.float_columns[0].values.len(), 27);
285 }
286
287 #[test]
288 fn test_column_schema() {
289 let mut batch = RecordBatch::new();
290 batch.add_float64("coords", vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
291 assert_eq!(batch.schema[0].dtype, DataType::Float64);
292 assert_eq!(batch.schema[0].shape, vec![2, 3]);
293 }
294}