1use crate::error::{DataFrameError, DataFrameResult};
4use crate::series::Series;
5use indexmap::IndexMap;
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use xdl_core::{XdlResult, XdlValue};
9
10#[derive(Debug, Clone)]
12pub struct DataFrame {
13 columns: IndexMap<String, Series>,
15 nrows: usize,
17}
18
19impl DataFrame {
20 pub fn new() -> Self {
22 Self {
23 columns: IndexMap::new(),
24 nrows: 0,
25 }
26 }
27
28 pub fn from_columns(columns: IndexMap<String, Series>) -> DataFrameResult<Self> {
30 if columns.is_empty() {
31 return Ok(Self::new());
32 }
33
34 let nrows = columns.values().next().unwrap().len();
36 for (name, series) in &columns {
37 if series.len() != nrows {
38 return Err(DataFrameError::DimensionMismatch(format!(
39 "Column '{}' has length {} but expected {}",
40 name,
41 series.len(),
42 nrows
43 )));
44 }
45 }
46
47 Ok(Self { columns, nrows })
48 }
49
50 pub fn from_map(data: HashMap<String, Vec<XdlValue>>) -> DataFrameResult<Self> {
52 let mut columns = IndexMap::new();
53
54 for (name, values) in data {
55 columns.insert(name, Series::from_vec(values)?);
56 }
57
58 Self::from_columns(columns)
59 }
60
61 pub fn nrows(&self) -> usize {
63 self.nrows
64 }
65
66 pub fn ncols(&self) -> usize {
68 self.columns.len()
69 }
70
71 pub fn column_names(&self) -> Vec<String> {
73 self.columns.keys().cloned().collect()
74 }
75
76 pub fn column(&self, name: &str) -> DataFrameResult<&Series> {
78 self.columns
79 .get(name)
80 .ok_or_else(|| DataFrameError::ColumnNotFound(name.to_string()))
81 }
82
83 pub fn column_mut(&mut self, name: &str) -> DataFrameResult<&mut Series> {
85 self.columns
86 .get_mut(name)
87 .ok_or_else(|| DataFrameError::ColumnNotFound(name.to_string()))
88 }
89
90 pub fn add_column(&mut self, name: String, series: Series) -> DataFrameResult<()> {
92 if !self.columns.is_empty() && series.len() != self.nrows {
93 return Err(DataFrameError::DimensionMismatch(format!(
94 "Series has length {} but DataFrame has {} rows",
95 series.len(),
96 self.nrows
97 )));
98 }
99
100 if self.columns.is_empty() {
101 self.nrows = series.len();
102 }
103
104 self.columns.insert(name, series);
105 Ok(())
106 }
107
108 pub fn remove_column(&mut self, name: &str) -> DataFrameResult<Series> {
110 self.columns
111 .shift_remove(name)
112 .ok_or_else(|| DataFrameError::ColumnNotFound(name.to_string()))
113 }
114
115 pub fn select(&self, column_names: &[&str]) -> DataFrameResult<DataFrame> {
117 let mut new_columns = IndexMap::new();
118
119 for name in column_names {
120 let series = self.column(name)?.clone();
121 new_columns.insert(name.to_string(), series);
122 }
123
124 Self::from_columns(new_columns)
125 }
126
127 pub fn filter<F>(&self, predicate: F) -> DataFrameResult<DataFrame>
129 where
130 F: Fn(usize, &HashMap<String, &XdlValue>) -> bool,
131 {
132 let mut selected_rows = Vec::new();
133
134 for row_idx in 0..self.nrows {
136 let mut row_map = HashMap::new();
137 for (col_name, series) in &self.columns {
138 if let Ok(value) = series.get(row_idx) {
139 row_map.insert(col_name.clone(), value);
140 }
141 }
142
143 if predicate(row_idx, &row_map) {
144 selected_rows.push(row_idx);
145 }
146 }
147
148 let mut new_columns = IndexMap::new();
150 for (col_name, series) in &self.columns {
151 let filtered_values: Vec<XdlValue> = selected_rows
152 .iter()
153 .filter_map(|&idx| series.get(idx).ok().cloned())
154 .collect();
155 new_columns.insert(col_name.clone(), Series::from_vec(filtered_values)?);
156 }
157
158 Self::from_columns(new_columns)
159 }
160
161 pub fn row(&self, index: usize) -> DataFrameResult<HashMap<String, XdlValue>> {
163 if index >= self.nrows {
164 return Err(DataFrameError::IndexOutOfBounds(index, self.nrows));
165 }
166
167 let mut row = HashMap::new();
168 for (col_name, series) in &self.columns {
169 row.insert(col_name.clone(), series.get(index)?.clone());
170 }
171
172 Ok(row)
173 }
174
175 pub fn shape(&self) -> (usize, usize) {
177 (self.nrows, self.ncols())
178 }
179
180 pub fn info(&self) -> String {
182 let mut info = String::new();
183 info.push_str(&format!(
184 "DataFrame: {} rows × {} columns\n",
185 self.nrows,
186 self.ncols()
187 ));
188 info.push_str("\nColumns:\n");
189 for (name, series) in &self.columns {
190 info.push_str(&format!(" {} ({})\n", name, series.dtype()));
191 }
192 info
193 }
194
195 pub fn head(&self, n: usize) -> DataFrameResult<DataFrame> {
197 let n = n.min(self.nrows);
198 let mut new_columns = IndexMap::new();
199
200 for (col_name, series) in &self.columns {
201 new_columns.insert(col_name.clone(), series.head(n)?);
202 }
203
204 Self::from_columns(new_columns)
205 }
206
207 pub fn tail(&self, n: usize) -> DataFrameResult<DataFrame> {
209 let n = n.min(self.nrows);
210 let mut new_columns = IndexMap::new();
211
212 for (col_name, series) in &self.columns {
213 new_columns.insert(col_name.clone(), series.tail(n)?);
214 }
215
216 Self::from_columns(new_columns)
217 }
218
219 pub fn describe(&self) -> DataFrameResult<HashMap<String, HashMap<String, f64>>> {
221 let mut stats = HashMap::new();
222
223 for (col_name, series) in &self.columns {
224 if let Ok(col_stats) = series.describe() {
225 stats.insert(col_name.clone(), col_stats);
226 }
227 }
228
229 Ok(stats)
230 }
231
232 pub fn to_json(&self) -> Vec<JsonValue> {
234 let mut rows = Vec::new();
235
236 for row_idx in 0..self.nrows {
237 let mut row_obj = serde_json::Map::new();
238 for (col_name, series) in &self.columns {
239 if let Ok(value) = series.get(row_idx) {
240 row_obj.insert(col_name.clone(), xdl_value_to_json(value));
241 }
242 }
243 rows.push(JsonValue::Object(row_obj));
244 }
245
246 rows
247 }
248
249 pub fn to_xdl_value(&self) -> XdlResult<XdlValue> {
251 let mut rows = Vec::new();
252
253 for row_idx in 0..self.nrows {
254 let mut row_values = Vec::new();
255 for series in self.columns.values() {
256 if let Ok(value) = series.get(row_idx) {
257 row_values.push(value.clone());
258 }
259 }
260 rows.push(XdlValue::NestedArray(row_values));
261 }
262
263 Ok(XdlValue::NestedArray(rows))
264 }
265
266 pub fn sort_by(&self, column_names: &[&str], ascending: bool) -> DataFrameResult<DataFrame> {
268 if column_names.is_empty() {
269 return Ok(self.clone());
270 }
271
272 let mut indices: Vec<usize> = (0..self.nrows).collect();
274
275 indices.sort_by(|&a, &b| {
277 for &col_name in column_names {
278 if let Ok(series) = self.column(col_name) {
279 if let (Ok(val_a), Ok(val_b)) = (series.get(a), series.get(b)) {
280 let cmp = compare_xdl_values(val_a, val_b);
281 if cmp != std::cmp::Ordering::Equal {
282 return if ascending { cmp } else { cmp.reverse() };
283 }
284 }
285 }
286 }
287 std::cmp::Ordering::Equal
288 });
289
290 let mut new_columns = IndexMap::new();
292 for (col_name, series) in &self.columns {
293 let sorted_values: Vec<XdlValue> = indices
294 .iter()
295 .filter_map(|&idx| series.get(idx).ok().cloned())
296 .collect();
297 new_columns.insert(col_name.clone(), Series::from_vec(sorted_values)?);
298 }
299
300 Self::from_columns(new_columns)
301 }
302
303 pub fn groupby(&self, column_names: &[&str]) -> DataFrameResult<GroupBy> {
305 GroupBy::new(
306 self.clone(),
307 column_names.iter().map(|s| s.to_string()).collect(),
308 )
309 }
310}
311
312impl Default for DataFrame {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318#[derive(Debug, Clone)]
320pub struct GroupBy {
321 dataframe: DataFrame,
322 group_columns: Vec<String>,
323 groups: HashMap<Vec<String>, Vec<usize>>, }
325
326impl GroupBy {
327 fn new(dataframe: DataFrame, group_columns: Vec<String>) -> DataFrameResult<Self> {
328 let mut groups: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
329
330 for row_idx in 0..dataframe.nrows() {
332 let mut key = Vec::new();
333 for col_name in &group_columns {
334 if let Ok(value) = dataframe.column(col_name)?.get(row_idx) {
335 key.push(value.to_string_repr());
336 }
337 }
338
339 groups.entry(key).or_default().push(row_idx);
340 }
341
342 Ok(Self {
343 dataframe,
344 group_columns,
345 groups,
346 })
347 }
348
349 pub fn count(&self) -> DataFrameResult<DataFrame> {
351 let mut columns = IndexMap::new();
352
353 let mut group_keys: Vec<_> = self.groups.keys().collect();
355 group_keys.sort();
356
357 for (i, col_name) in self.group_columns.iter().enumerate() {
358 let values: Vec<XdlValue> = group_keys
359 .iter()
360 .map(|key| XdlValue::String(key[i].clone()))
361 .collect();
362 columns.insert(col_name.clone(), Series::from_vec(values)?);
363 }
364
365 let counts: Vec<XdlValue> = group_keys
367 .iter()
368 .map(|key| XdlValue::Long(self.groups[*key].len() as i32))
369 .collect();
370 columns.insert("count".to_string(), Series::from_vec(counts)?);
371
372 DataFrame::from_columns(columns)
373 }
374
375 pub fn mean(&self) -> DataFrameResult<DataFrame> {
377 self.aggregate("mean", |values| {
378 let nums: Vec<f64> = values.iter().filter_map(|v| v.to_double().ok()).collect();
379 if nums.is_empty() {
380 XdlValue::Undefined
381 } else {
382 XdlValue::Double(nums.iter().sum::<f64>() / nums.len() as f64)
383 }
384 })
385 }
386
387 pub fn sum(&self) -> DataFrameResult<DataFrame> {
389 self.aggregate("sum", |values| {
390 let sum: f64 = values.iter().filter_map(|v| v.to_double().ok()).sum();
391 XdlValue::Double(sum)
392 })
393 }
394
395 fn aggregate<F>(&self, _agg_name: &str, agg_fn: F) -> DataFrameResult<DataFrame>
397 where
398 F: Fn(&[XdlValue]) -> XdlValue,
399 {
400 let mut columns = IndexMap::new();
401 let mut group_keys: Vec<_> = self.groups.keys().collect();
402 group_keys.sort();
403
404 for (i, col_name) in self.group_columns.iter().enumerate() {
406 let values: Vec<XdlValue> = group_keys
407 .iter()
408 .map(|key| XdlValue::String(key[i].clone()))
409 .collect();
410 columns.insert(col_name.clone(), Series::from_vec(values)?);
411 }
412
413 for (col_name, _series) in &self.dataframe.columns {
415 if self.group_columns.contains(col_name) {
416 continue;
417 }
418
419 let values: Vec<XdlValue> = group_keys
420 .iter()
421 .map(|key| {
422 let indices = &self.groups[*key];
423 let col_values: Vec<XdlValue> = indices
424 .iter()
425 .filter_map(|&idx| {
426 self.dataframe.column(col_name).ok()?.get(idx).ok().cloned()
427 })
428 .collect();
429 agg_fn(&col_values)
430 })
431 .collect();
432
433 columns.insert(col_name.clone(), Series::from_vec(values)?);
434 }
435
436 DataFrame::from_columns(columns)
437 }
438}
439
440fn xdl_value_to_json(value: &XdlValue) -> JsonValue {
442 match value {
443 XdlValue::Undefined => JsonValue::Null,
444 XdlValue::Int(i) => JsonValue::from(*i),
445 XdlValue::Long(l) => JsonValue::from(*l),
446 XdlValue::Long64(l) => JsonValue::from(*l),
447 XdlValue::Float(f) => JsonValue::from(*f),
448 XdlValue::Double(d) => JsonValue::from(*d),
449 XdlValue::String(s) => JsonValue::from(s.clone()),
450 XdlValue::NestedArray(arr) => JsonValue::Array(arr.iter().map(xdl_value_to_json).collect()),
451 _ => JsonValue::String(value.to_string_repr()),
452 }
453}
454
455fn compare_xdl_values(a: &XdlValue, b: &XdlValue) -> std::cmp::Ordering {
457 use std::cmp::Ordering;
458
459 match (a, b) {
460 (XdlValue::Int(a), XdlValue::Int(b)) => a.cmp(b),
461 (XdlValue::Long(a), XdlValue::Long(b)) => a.cmp(b),
462 (XdlValue::Long64(a), XdlValue::Long64(b)) => a.cmp(b),
463 (XdlValue::Float(a), XdlValue::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
464 (XdlValue::Double(a), XdlValue::Double(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
465 (XdlValue::String(a), XdlValue::String(b)) => a.cmp(b),
466 _ => {
467 if let (Ok(a_f), Ok(b_f)) = (a.to_double(), b.to_double()) {
469 a_f.partial_cmp(&b_f).unwrap_or(Ordering::Equal)
470 } else {
471 a.to_string_repr().cmp(&b.to_string_repr())
472 }
473 }
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_empty_dataframe() {
483 let df = DataFrame::new();
484 assert_eq!(df.nrows(), 0);
485 assert_eq!(df.ncols(), 0);
486 }
487
488 #[test]
489 fn test_from_map() {
490 let mut data = HashMap::new();
491 data.insert(
492 "col1".to_string(),
493 vec![XdlValue::Long(1), XdlValue::Long(2), XdlValue::Long(3)],
494 );
495 data.insert(
496 "col2".to_string(),
497 vec![
498 XdlValue::String("a".to_string()),
499 XdlValue::String("b".to_string()),
500 XdlValue::String("c".to_string()),
501 ],
502 );
503
504 let df = DataFrame::from_map(data).unwrap();
505 assert_eq!(df.nrows(), 3);
506 assert_eq!(df.ncols(), 2);
507 }
508
509 #[test]
510 fn test_select() {
511 let mut data = HashMap::new();
512 data.insert("col1".to_string(), vec![XdlValue::Long(1)]);
513 data.insert("col2".to_string(), vec![XdlValue::Long(2)]);
514 data.insert("col3".to_string(), vec![XdlValue::Long(3)]);
515
516 let df = DataFrame::from_map(data).unwrap();
517 let selected = df.select(&["col1", "col3"]).unwrap();
518
519 assert_eq!(selected.ncols(), 2);
520 assert!(selected.column("col1").is_ok());
521 assert!(selected.column("col3").is_ok());
522 assert!(selected.column("col2").is_err());
523 }
524}