Skip to main content

so_core/
data.rs

1//! Data structures for statistical computing
2//!
3//! This module provides columnar data structures optimized for
4//! statistical operations, with interoperability with numpy and pandas.
5
6use ndarray::{Array1, Array2, ArrayView1};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::formula::Formula;
11
12// ============================================================================
13// Series - Vector with metadata
14// ============================================================================
15
16/// A Series represents a single column of data with a name and dtype
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Series {
19    name: String,
20    data: Array1<f64>,
21    /// Optional factor/categorical encoding
22    levels: Option<Vec<String>>,
23}
24
25impl Series {
26    /// Create a new numeric series
27    pub fn new(name: impl Into<String>, data: Array1<f64>) -> Self {
28        Self {
29            name: name.into(),
30            data,
31            levels: None,
32        }
33    }
34
35    /// Create a factor/categorical series
36    pub fn factor(name: impl Into<String>, levels: Vec<String>, indices: Array1<usize>) -> Self {
37        Self {
38            name: name.into(),
39            data: indices.mapv(|i| i as f64),
40            levels: Some(levels),
41        }
42    }
43
44    /// Get series name
45    pub fn name(&self) -> &str {
46        &self.name
47    }
48
49    /// Get data as array view
50    pub fn data(&self) -> ArrayView1<'_, f64> {
51        self.data.view()
52    }
53
54    /// Get length of series
55    pub fn len(&self) -> usize {
56        self.data.len()
57    }
58
59    /// Check if series is empty
60    pub fn is_empty(&self) -> bool {
61        self.data.is_empty()
62    }
63
64    /// Compute mean
65    pub fn mean(&self) -> Option<f64> {
66        if self.is_empty() {
67            None
68        } else {
69            Some(self.data.mean().unwrap_or(f64::NAN))
70        }
71    }
72
73    /// Compute variance
74    pub fn var(&self, ddof: f64) -> Option<f64> {
75        if self.len() <= 1 {
76            None
77        } else {
78            Some(self.data.var(ddof))
79        }
80    }
81
82    /// Compute standard deviation
83    pub fn std(&self, ddof: f64) -> Option<f64> {
84        self.var(ddof).map(|v| v.sqrt())
85    }
86
87    /// Compute minimum value
88    pub fn min(&self) -> Option<f64> {
89        self.data.fold(f64::INFINITY, |a, &b| a.min(b)).into()
90    }
91
92    /// Compute maximum value
93    pub fn max(&self) -> Option<f64> {
94        self.data.fold(-f64::INFINITY, |a, &b| a.max(b)).into()
95    }
96
97    /// Compute quantile using R's type 7 method
98    pub fn quantile(&self, q: f64) -> Option<f64> {
99        if self.is_empty() || !(0.0..=1.0).contains(&q) {
100            return None;
101        }
102
103        let mut sorted = self.data.to_vec();
104        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
105
106        let n = sorted.len();
107        let index = (n - 1) as f64 * q;
108        let lower = index.floor() as usize;
109        let upper = index.ceil() as usize;
110
111        if lower == upper {
112            Some(sorted[lower])
113        } else {
114            let weight = index - lower as f64;
115            Some((1.0 - weight) * sorted[lower] + weight * sorted[upper])
116        }
117    }
118
119    /// Apply a function element-wise
120    pub fn map(&self, f: impl Fn(f64) -> f64) -> Self {
121        Self {
122            name: self.name.clone(),
123            data: self.data.mapv(f),
124            levels: self.levels.clone(),
125        }
126    }
127
128    /// Standardize (z-score normalization)
129    pub fn standardize(&self) -> Option<Self> {
130        let mean = self.mean()?;
131        let std = self.std(1.0)?;
132
133        if std == 0.0 {
134            return None;
135        }
136
137        Some(self.map(|x| (x - mean) / std))
138    }
139}
140
141// ============================================================================
142// DataFrame - Collection of Series
143// ============================================================================
144
145/// A DataFrame represents a collection of named Series (columns)
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DataFrame {
148    columns: HashMap<String, Series>,
149    n_rows: usize,
150}
151
152impl DataFrame {
153    /// Create a new empty DataFrame
154    pub fn new() -> Self {
155        Self {
156            columns: HashMap::new(),
157            n_rows: 0,
158        }
159    }
160
161    /// Create a DataFrame from a map of column names to Series
162    pub fn from_series(columns: HashMap<String, Series>) -> Result<Self, super::error::Error> {
163        let mut n_rows = 0;
164        for (name, series) in &columns {
165            if n_rows == 0 {
166                n_rows = series.len();
167            } else if series.len() != n_rows {
168                return Err(super::error::Error::DimensionMismatch(format!(
169                    "Column '{}' has length {}, expected {}",
170                    name,
171                    series.len(),
172                    n_rows
173                )));
174            }
175        }
176
177        Ok(Self { columns, n_rows })
178    }
179
180    /// Get number of rows
181    pub fn n_rows(&self) -> usize {
182        self.n_rows
183    }
184
185    /// Get number of columns
186    pub fn n_cols(&self) -> usize {
187        self.columns.len()
188    }
189
190    /// Get column names
191    pub fn column_names(&self) -> Vec<String> {
192        self.columns.keys().cloned().collect()
193    }
194
195    /// Get a column by name
196    pub fn column(&self, name: &str) -> Option<&Series> {
197        self.columns.get(name)
198    }
199
200    /// Get a mutable reference to a column
201    pub fn column_mut(&mut self, name: &str) -> Option<&mut Series> {
202        self.columns.get_mut(name)
203    }
204
205    /// Add a column to the DataFrame
206    pub fn with_column(mut self, series: Series) -> Result<Self, super::error::Error> {
207        let name = series.name().to_string();
208
209        if self.n_rows == 0 {
210            self.n_rows = series.len();
211        } else if series.len() != self.n_rows {
212            return Err(super::error::Error::DimensionMismatch(format!(
213                "Column '{}' has length {}, expected {}",
214                name,
215                series.len(),
216                self.n_rows
217            )));
218        }
219
220        self.columns.insert(name, series);
221        Ok(self)
222    }
223
224    /// Remove a column
225    pub fn drop_column(mut self, name: &str) -> Self {
226        self.columns.remove(name);
227        self
228    }
229
230    /// Select specific columns
231    pub fn select(&self, col_names: &[&str]) -> Result<Self, super::error::Error> {
232        let mut new_columns = HashMap::new();
233        for &name in col_names {
234            if let Some(series) = self.columns.get(name) {
235                new_columns.insert(name.to_string(), series.clone());
236            } else {
237                return Err(super::error::Error::Message(format!(
238                    "Column '{}' not found",
239                    name
240                )));
241            }
242        }
243        Self::from_series(new_columns)
244    }
245
246    /// Filter rows based on a boolean mask
247    pub fn filter(&self, mask: &[bool]) -> Result<Self, super::error::Error> {
248        if mask.len() != self.n_rows {
249            return Err(super::error::Error::DimensionMismatch(format!(
250                "Mask length {} doesn't match DataFrame rows {}",
251                mask.len(),
252                self.n_rows
253            )));
254        }
255
256        let mut new_columns = HashMap::new();
257        for (name, series) in &self.columns {
258            let mut filtered_data = Vec::new();
259            for (i, &value) in series.data.iter().enumerate() {
260                if mask[i] {
261                    filtered_data.push(value);
262                }
263            }
264            new_columns.insert(
265                name.clone(),
266                Series::new(name.clone(), Array1::from_vec(filtered_data)),
267            );
268        }
269
270        Self::from_series(new_columns)
271    }
272
273    /// Get design matrix for regression
274    pub fn design_matrix(&self, formula: &Formula) -> Result<Array2<f64>, super::error::Error> {
275        formula.build_matrix(self)
276    }
277}
278
279impl Default for DataFrame {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285// ============================================================================
286// Data loading and conversion utilities
287// ============================================================================
288
289/// Trait for converting external data formats to StatOxide structures
290pub trait FromData<T> {
291    fn from_data(data: T) -> Result<Self, super::error::Error>
292    where
293        Self: Sized;
294}
295
296/// Trait for exporting StatOxide structures to external formats
297pub trait ToData<T> {
298    fn to_data(&self) -> Result<T, super::error::Error>;
299}
300
301impl FromData<Vec<Vec<f64>>> for DataFrame {
302    fn from_data(data: Vec<Vec<f64>>) -> Result<Self, super::error::Error> {
303        if data.is_empty() {
304            return Ok(Self::new());
305        }
306
307        let n_rows = data[0].len();
308        let mut columns = HashMap::new();
309
310        for (i, column_data) in data.iter().enumerate() {
311            if column_data.len() != n_rows {
312                return Err(super::error::Error::DimensionMismatch(format!(
313                    "Column {} has length {}, expected {}",
314                    i,
315                    column_data.len(),
316                    n_rows
317                )));
318            }
319
320            columns.insert(
321                format!("x{}", i),
322                Series::new(format!("x{}", i), Array1::from_vec(column_data.clone())),
323            );
324        }
325
326        Self::from_series(columns)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use ndarray::arr1;
334
335    #[test]
336    fn test_series_basic() {
337        let data = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
338        let series = Series::new("test", data);
339
340        assert_eq!(series.name(), "test");
341        assert_eq!(series.len(), 5);
342        assert_eq!(series.mean(), Some(3.0));
343        assert_eq!(series.std(1.0).unwrap(), 1.5811388300841898);
344        assert_eq!(series.min(), Some(1.0));
345        assert_eq!(series.max(), Some(5.0));
346    }
347
348    #[test]
349    fn test_dataframe_basic() {
350        let mut columns = HashMap::new();
351        columns.insert("x".to_string(), Series::new("x", arr1(&[1.0, 2.0, 3.0])));
352        columns.insert("y".to_string(), Series::new("y", arr1(&[4.0, 5.0, 6.0])));
353
354        let df = DataFrame::from_series(columns).unwrap();
355        assert_eq!(df.n_rows(), 3);
356        assert_eq!(df.n_cols(), 2);
357        assert!(df.column("x").is_some());
358        assert!(df.column("z").is_none());
359    }
360}