xdl_dataframe/
dataframe.rs

1//! DataFrame - pandas/Spark-style data structure for XDL
2
3use crate::error::{DataFrameError, DataFrameResult};
4use crate::series::Series;
5use indexmap::IndexMap;
6use serde_json::Value as JsonValue;
7use std::collections::HashMap;
8use xdl_core::{XdlResult, XdlValue};
9
10/// DataFrame - A tabular data structure with labeled columns
11#[derive(Debug, Clone)]
12pub struct DataFrame {
13    /// Column data (column_name -> Series)
14    columns: IndexMap<String, Series>,
15    /// Number of rows
16    nrows: usize,
17}
18
19impl DataFrame {
20    /// Create a new empty DataFrame
21    pub fn new() -> Self {
22        Self {
23            columns: IndexMap::new(),
24            nrows: 0,
25        }
26    }
27
28    /// Create DataFrame from columns
29    pub fn from_columns(columns: IndexMap<String, Series>) -> DataFrameResult<Self> {
30        if columns.is_empty() {
31            return Ok(Self::new());
32        }
33
34        // Verify all columns have same length
35        let nrows = columns.values().next().unwrap().len();
36        for (name, series) in &columns {
37            if series.len() != nrows {
38                return Err(DataFrameError::DimensionMismatch(format!(
39                    "Column '{}' has length {} but expected {}",
40                    name,
41                    series.len(),
42                    nrows
43                )));
44            }
45        }
46
47        Ok(Self { columns, nrows })
48    }
49
50    /// Create DataFrame from a HashMap of column names to data vectors
51    pub fn from_map(data: HashMap<String, Vec<XdlValue>>) -> DataFrameResult<Self> {
52        let mut columns = IndexMap::new();
53
54        for (name, values) in data {
55            columns.insert(name, Series::from_vec(values)?);
56        }
57
58        Self::from_columns(columns)
59    }
60
61    /// Get number of rows
62    pub fn nrows(&self) -> usize {
63        self.nrows
64    }
65
66    /// Get number of columns
67    pub fn ncols(&self) -> usize {
68        self.columns.len()
69    }
70
71    /// Get column names
72    pub fn column_names(&self) -> Vec<String> {
73        self.columns.keys().cloned().collect()
74    }
75
76    /// Get a column by name
77    pub fn column(&self, name: &str) -> DataFrameResult<&Series> {
78        self.columns
79            .get(name)
80            .ok_or_else(|| DataFrameError::ColumnNotFound(name.to_string()))
81    }
82
83    /// Get a mutable column by name
84    pub fn column_mut(&mut self, name: &str) -> DataFrameResult<&mut Series> {
85        self.columns
86            .get_mut(name)
87            .ok_or_else(|| DataFrameError::ColumnNotFound(name.to_string()))
88    }
89
90    /// Add a new column
91    pub fn add_column(&mut self, name: String, series: Series) -> DataFrameResult<()> {
92        if !self.columns.is_empty() && series.len() != self.nrows {
93            return Err(DataFrameError::DimensionMismatch(format!(
94                "Series has length {} but DataFrame has {} rows",
95                series.len(),
96                self.nrows
97            )));
98        }
99
100        if self.columns.is_empty() {
101            self.nrows = series.len();
102        }
103
104        self.columns.insert(name, series);
105        Ok(())
106    }
107
108    /// Remove a column
109    pub fn remove_column(&mut self, name: &str) -> DataFrameResult<Series> {
110        self.columns
111            .shift_remove(name)
112            .ok_or_else(|| DataFrameError::ColumnNotFound(name.to_string()))
113    }
114
115    /// Select specific columns
116    pub fn select(&self, column_names: &[&str]) -> DataFrameResult<DataFrame> {
117        let mut new_columns = IndexMap::new();
118
119        for name in column_names {
120            let series = self.column(name)?.clone();
121            new_columns.insert(name.to_string(), series);
122        }
123
124        Self::from_columns(new_columns)
125    }
126
127    /// Filter rows based on a predicate function
128    pub fn filter<F>(&self, predicate: F) -> DataFrameResult<DataFrame>
129    where
130        F: Fn(usize, &HashMap<String, &XdlValue>) -> bool,
131    {
132        let mut selected_rows = Vec::new();
133
134        // Find which rows satisfy the predicate
135        for row_idx in 0..self.nrows {
136            let mut row_map = HashMap::new();
137            for (col_name, series) in &self.columns {
138                if let Ok(value) = series.get(row_idx) {
139                    row_map.insert(col_name.clone(), value);
140                }
141            }
142
143            if predicate(row_idx, &row_map) {
144                selected_rows.push(row_idx);
145            }
146        }
147
148        // Create new DataFrame with selected rows
149        let mut new_columns = IndexMap::new();
150        for (col_name, series) in &self.columns {
151            let filtered_values: Vec<XdlValue> = selected_rows
152                .iter()
153                .filter_map(|&idx| series.get(idx).ok().cloned())
154                .collect();
155            new_columns.insert(col_name.clone(), Series::from_vec(filtered_values)?);
156        }
157
158        Self::from_columns(new_columns)
159    }
160
161    /// Get a row as a HashMap
162    pub fn row(&self, index: usize) -> DataFrameResult<HashMap<String, XdlValue>> {
163        if index >= self.nrows {
164            return Err(DataFrameError::IndexOutOfBounds(index, self.nrows));
165        }
166
167        let mut row = HashMap::new();
168        for (col_name, series) in &self.columns {
169            row.insert(col_name.clone(), series.get(index)?.clone());
170        }
171
172        Ok(row)
173    }
174
175    /// Get shape as (nrows, ncols)
176    pub fn shape(&self) -> (usize, usize) {
177        (self.nrows, self.ncols())
178    }
179
180    /// Get DataFrame info summary
181    pub fn info(&self) -> String {
182        let mut info = String::new();
183        info.push_str(&format!(
184            "DataFrame: {} rows × {} columns\n",
185            self.nrows,
186            self.ncols()
187        ));
188        info.push_str("\nColumns:\n");
189        for (name, series) in &self.columns {
190            info.push_str(&format!("  {} ({})\n", name, series.dtype()));
191        }
192        info
193    }
194
195    /// Head - get first n rows
196    pub fn head(&self, n: usize) -> DataFrameResult<DataFrame> {
197        let n = n.min(self.nrows);
198        let mut new_columns = IndexMap::new();
199
200        for (col_name, series) in &self.columns {
201            new_columns.insert(col_name.clone(), series.head(n)?);
202        }
203
204        Self::from_columns(new_columns)
205    }
206
207    /// Tail - get last n rows
208    pub fn tail(&self, n: usize) -> DataFrameResult<DataFrame> {
209        let n = n.min(self.nrows);
210        let mut new_columns = IndexMap::new();
211
212        for (col_name, series) in &self.columns {
213            new_columns.insert(col_name.clone(), series.tail(n)?);
214        }
215
216        Self::from_columns(new_columns)
217    }
218
219    /// Describe - get statistical summary
220    pub fn describe(&self) -> DataFrameResult<HashMap<String, HashMap<String, f64>>> {
221        let mut stats = HashMap::new();
222
223        for (col_name, series) in &self.columns {
224            if let Ok(col_stats) = series.describe() {
225                stats.insert(col_name.clone(), col_stats);
226            }
227        }
228
229        Ok(stats)
230    }
231
232    /// Convert to JSON representation
233    pub fn to_json(&self) -> Vec<JsonValue> {
234        let mut rows = Vec::new();
235
236        for row_idx in 0..self.nrows {
237            let mut row_obj = serde_json::Map::new();
238            for (col_name, series) in &self.columns {
239                if let Ok(value) = series.get(row_idx) {
240                    row_obj.insert(col_name.clone(), xdl_value_to_json(value));
241                }
242            }
243            rows.push(JsonValue::Object(row_obj));
244        }
245
246        rows
247    }
248
249    /// Convert to XdlValue (nested array)
250    pub fn to_xdl_value(&self) -> XdlResult<XdlValue> {
251        let mut rows = Vec::new();
252
253        for row_idx in 0..self.nrows {
254            let mut row_values = Vec::new();
255            for series in self.columns.values() {
256                if let Ok(value) = series.get(row_idx) {
257                    row_values.push(value.clone());
258                }
259            }
260            rows.push(XdlValue::NestedArray(row_values));
261        }
262
263        Ok(XdlValue::NestedArray(rows))
264    }
265
266    /// Sort by column(s)
267    pub fn sort_by(&self, column_names: &[&str], ascending: bool) -> DataFrameResult<DataFrame> {
268        if column_names.is_empty() {
269            return Ok(self.clone());
270        }
271
272        // Create index vector
273        let mut indices: Vec<usize> = (0..self.nrows).collect();
274
275        // Sort indices based on column values
276        indices.sort_by(|&a, &b| {
277            for &col_name in column_names {
278                if let Ok(series) = self.column(col_name) {
279                    if let (Ok(val_a), Ok(val_b)) = (series.get(a), series.get(b)) {
280                        let cmp = compare_xdl_values(val_a, val_b);
281                        if cmp != std::cmp::Ordering::Equal {
282                            return if ascending { cmp } else { cmp.reverse() };
283                        }
284                    }
285                }
286            }
287            std::cmp::Ordering::Equal
288        });
289
290        // Create new DataFrame with sorted rows
291        let mut new_columns = IndexMap::new();
292        for (col_name, series) in &self.columns {
293            let sorted_values: Vec<XdlValue> = indices
294                .iter()
295                .filter_map(|&idx| series.get(idx).ok().cloned())
296                .collect();
297            new_columns.insert(col_name.clone(), Series::from_vec(sorted_values)?);
298        }
299
300        Self::from_columns(new_columns)
301    }
302
303    /// Group by column(s) - returns grouped data for aggregation
304    pub fn groupby(&self, column_names: &[&str]) -> DataFrameResult<GroupBy> {
305        GroupBy::new(
306            self.clone(),
307            column_names.iter().map(|s| s.to_string()).collect(),
308        )
309    }
310}
311
312impl Default for DataFrame {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318/// GroupBy structure for aggregations
319#[derive(Debug, Clone)]
320pub struct GroupBy {
321    dataframe: DataFrame,
322    group_columns: Vec<String>,
323    groups: HashMap<Vec<String>, Vec<usize>>, // group keys -> row indices
324}
325
326impl GroupBy {
327    fn new(dataframe: DataFrame, group_columns: Vec<String>) -> DataFrameResult<Self> {
328        let mut groups: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
329
330        // Build groups
331        for row_idx in 0..dataframe.nrows() {
332            let mut key = Vec::new();
333            for col_name in &group_columns {
334                if let Ok(value) = dataframe.column(col_name)?.get(row_idx) {
335                    key.push(value.to_string_repr());
336                }
337            }
338
339            groups.entry(key).or_default().push(row_idx);
340        }
341
342        Ok(Self {
343            dataframe,
344            group_columns,
345            groups,
346        })
347    }
348
349    /// Count rows in each group
350    pub fn count(&self) -> DataFrameResult<DataFrame> {
351        let mut columns = IndexMap::new();
352
353        // Add group key columns
354        let mut group_keys: Vec<_> = self.groups.keys().collect();
355        group_keys.sort();
356
357        for (i, col_name) in self.group_columns.iter().enumerate() {
358            let values: Vec<XdlValue> = group_keys
359                .iter()
360                .map(|key| XdlValue::String(key[i].clone()))
361                .collect();
362            columns.insert(col_name.clone(), Series::from_vec(values)?);
363        }
364
365        // Add count column
366        let counts: Vec<XdlValue> = group_keys
367            .iter()
368            .map(|key| XdlValue::Long(self.groups[*key].len() as i32))
369            .collect();
370        columns.insert("count".to_string(), Series::from_vec(counts)?);
371
372        DataFrame::from_columns(columns)
373    }
374
375    /// Compute mean for numeric columns in each group
376    pub fn mean(&self) -> DataFrameResult<DataFrame> {
377        self.aggregate("mean", |values| {
378            let nums: Vec<f64> = values.iter().filter_map(|v| v.to_double().ok()).collect();
379            if nums.is_empty() {
380                XdlValue::Undefined
381            } else {
382                XdlValue::Double(nums.iter().sum::<f64>() / nums.len() as f64)
383            }
384        })
385    }
386
387    /// Compute sum for numeric columns in each group
388    pub fn sum(&self) -> DataFrameResult<DataFrame> {
389        self.aggregate("sum", |values| {
390            let sum: f64 = values.iter().filter_map(|v| v.to_double().ok()).sum();
391            XdlValue::Double(sum)
392        })
393    }
394
395    /// Generic aggregation function
396    fn aggregate<F>(&self, _agg_name: &str, agg_fn: F) -> DataFrameResult<DataFrame>
397    where
398        F: Fn(&[XdlValue]) -> XdlValue,
399    {
400        let mut columns = IndexMap::new();
401        let mut group_keys: Vec<_> = self.groups.keys().collect();
402        group_keys.sort();
403
404        // Add group key columns
405        for (i, col_name) in self.group_columns.iter().enumerate() {
406            let values: Vec<XdlValue> = group_keys
407                .iter()
408                .map(|key| XdlValue::String(key[i].clone()))
409                .collect();
410            columns.insert(col_name.clone(), Series::from_vec(values)?);
411        }
412
413        // Aggregate value columns
414        for (col_name, _series) in &self.dataframe.columns {
415            if self.group_columns.contains(col_name) {
416                continue;
417            }
418
419            let values: Vec<XdlValue> = group_keys
420                .iter()
421                .map(|key| {
422                    let indices = &self.groups[*key];
423                    let col_values: Vec<XdlValue> = indices
424                        .iter()
425                        .filter_map(|&idx| {
426                            self.dataframe.column(col_name).ok()?.get(idx).ok().cloned()
427                        })
428                        .collect();
429                    agg_fn(&col_values)
430                })
431                .collect();
432
433            columns.insert(col_name.clone(), Series::from_vec(values)?);
434        }
435
436        DataFrame::from_columns(columns)
437    }
438}
439
440/// Helper function to convert XdlValue to JsonValue
441fn xdl_value_to_json(value: &XdlValue) -> JsonValue {
442    match value {
443        XdlValue::Undefined => JsonValue::Null,
444        XdlValue::Int(i) => JsonValue::from(*i),
445        XdlValue::Long(l) => JsonValue::from(*l),
446        XdlValue::Long64(l) => JsonValue::from(*l),
447        XdlValue::Float(f) => JsonValue::from(*f),
448        XdlValue::Double(d) => JsonValue::from(*d),
449        XdlValue::String(s) => JsonValue::from(s.clone()),
450        XdlValue::NestedArray(arr) => JsonValue::Array(arr.iter().map(xdl_value_to_json).collect()),
451        _ => JsonValue::String(value.to_string_repr()),
452    }
453}
454
455/// Helper function to compare XdlValues for sorting
456fn compare_xdl_values(a: &XdlValue, b: &XdlValue) -> std::cmp::Ordering {
457    use std::cmp::Ordering;
458
459    match (a, b) {
460        (XdlValue::Int(a), XdlValue::Int(b)) => a.cmp(b),
461        (XdlValue::Long(a), XdlValue::Long(b)) => a.cmp(b),
462        (XdlValue::Long64(a), XdlValue::Long64(b)) => a.cmp(b),
463        (XdlValue::Float(a), XdlValue::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
464        (XdlValue::Double(a), XdlValue::Double(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
465        (XdlValue::String(a), XdlValue::String(b)) => a.cmp(b),
466        _ => {
467            // Try to compare as doubles
468            if let (Ok(a_f), Ok(b_f)) = (a.to_double(), b.to_double()) {
469                a_f.partial_cmp(&b_f).unwrap_or(Ordering::Equal)
470            } else {
471                a.to_string_repr().cmp(&b.to_string_repr())
472            }
473        }
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_empty_dataframe() {
483        let df = DataFrame::new();
484        assert_eq!(df.nrows(), 0);
485        assert_eq!(df.ncols(), 0);
486    }
487
488    #[test]
489    fn test_from_map() {
490        let mut data = HashMap::new();
491        data.insert(
492            "col1".to_string(),
493            vec![XdlValue::Long(1), XdlValue::Long(2), XdlValue::Long(3)],
494        );
495        data.insert(
496            "col2".to_string(),
497            vec![
498                XdlValue::String("a".to_string()),
499                XdlValue::String("b".to_string()),
500                XdlValue::String("c".to_string()),
501            ],
502        );
503
504        let df = DataFrame::from_map(data).unwrap();
505        assert_eq!(df.nrows(), 3);
506        assert_eq!(df.ncols(), 2);
507    }
508
509    #[test]
510    fn test_select() {
511        let mut data = HashMap::new();
512        data.insert("col1".to_string(), vec![XdlValue::Long(1)]);
513        data.insert("col2".to_string(), vec![XdlValue::Long(2)]);
514        data.insert("col3".to_string(), vec![XdlValue::Long(3)]);
515
516        let df = DataFrame::from_map(data).unwrap();
517        let selected = df.select(&["col1", "col3"]).unwrap();
518
519        assert_eq!(selected.ncols(), 2);
520        assert!(selected.column("col1").is_ok());
521        assert!(selected.column("col3").is_ok());
522        assert!(selected.column("col2").is_err());
523    }
524}