1use ndarray::{Array1, Array2, ArrayView1};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use super::formula::Formula;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Series {
19 name: String,
20 data: Array1<f64>,
21 levels: Option<Vec<String>>,
23}
24
25impl Series {
26 pub fn new(name: impl Into<String>, data: Array1<f64>) -> Self {
28 Self {
29 name: name.into(),
30 data,
31 levels: None,
32 }
33 }
34
35 pub fn factor(name: impl Into<String>, levels: Vec<String>, indices: Array1<usize>) -> Self {
37 Self {
38 name: name.into(),
39 data: indices.mapv(|i| i as f64),
40 levels: Some(levels),
41 }
42 }
43
44 pub fn name(&self) -> &str {
46 &self.name
47 }
48
49 pub fn data(&self) -> ArrayView1<'_, f64> {
51 self.data.view()
52 }
53
54 pub fn len(&self) -> usize {
56 self.data.len()
57 }
58
59 pub fn is_empty(&self) -> bool {
61 self.data.is_empty()
62 }
63
64 pub fn mean(&self) -> Option<f64> {
66 if self.is_empty() {
67 None
68 } else {
69 Some(self.data.mean().unwrap_or(f64::NAN))
70 }
71 }
72
73 pub fn var(&self, ddof: f64) -> Option<f64> {
75 if self.len() <= 1 {
76 None
77 } else {
78 Some(self.data.var(ddof))
79 }
80 }
81
82 pub fn std(&self, ddof: f64) -> Option<f64> {
84 self.var(ddof).map(|v| v.sqrt())
85 }
86
87 pub fn min(&self) -> Option<f64> {
89 self.data.fold(f64::INFINITY, |a, &b| a.min(b)).into()
90 }
91
92 pub fn max(&self) -> Option<f64> {
94 self.data.fold(-f64::INFINITY, |a, &b| a.max(b)).into()
95 }
96
97 pub fn quantile(&self, q: f64) -> Option<f64> {
99 if self.is_empty() || !(0.0..=1.0).contains(&q) {
100 return None;
101 }
102
103 let mut sorted = self.data.to_vec();
104 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
105
106 let n = sorted.len();
107 let index = (n - 1) as f64 * q;
108 let lower = index.floor() as usize;
109 let upper = index.ceil() as usize;
110
111 if lower == upper {
112 Some(sorted[lower])
113 } else {
114 let weight = index - lower as f64;
115 Some((1.0 - weight) * sorted[lower] + weight * sorted[upper])
116 }
117 }
118
119 pub fn map(&self, f: impl Fn(f64) -> f64) -> Self {
121 Self {
122 name: self.name.clone(),
123 data: self.data.mapv(f),
124 levels: self.levels.clone(),
125 }
126 }
127
128 pub fn standardize(&self) -> Option<Self> {
130 let mean = self.mean()?;
131 let std = self.std(1.0)?;
132
133 if std == 0.0 {
134 return None;
135 }
136
137 Some(self.map(|x| (x - mean) / std))
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct DataFrame {
148 columns: HashMap<String, Series>,
149 n_rows: usize,
150}
151
152impl DataFrame {
153 pub fn new() -> Self {
155 Self {
156 columns: HashMap::new(),
157 n_rows: 0,
158 }
159 }
160
161 pub fn from_series(columns: HashMap<String, Series>) -> Result<Self, super::error::Error> {
163 let mut n_rows = 0;
164 for (name, series) in &columns {
165 if n_rows == 0 {
166 n_rows = series.len();
167 } else if series.len() != n_rows {
168 return Err(super::error::Error::DimensionMismatch(format!(
169 "Column '{}' has length {}, expected {}",
170 name,
171 series.len(),
172 n_rows
173 )));
174 }
175 }
176
177 Ok(Self { columns, n_rows })
178 }
179
180 pub fn n_rows(&self) -> usize {
182 self.n_rows
183 }
184
185 pub fn n_cols(&self) -> usize {
187 self.columns.len()
188 }
189
190 pub fn column_names(&self) -> Vec<String> {
192 self.columns.keys().cloned().collect()
193 }
194
195 pub fn column(&self, name: &str) -> Option<&Series> {
197 self.columns.get(name)
198 }
199
200 pub fn column_mut(&mut self, name: &str) -> Option<&mut Series> {
202 self.columns.get_mut(name)
203 }
204
205 pub fn with_column(mut self, series: Series) -> Result<Self, super::error::Error> {
207 let name = series.name().to_string();
208
209 if self.n_rows == 0 {
210 self.n_rows = series.len();
211 } else if series.len() != self.n_rows {
212 return Err(super::error::Error::DimensionMismatch(format!(
213 "Column '{}' has length {}, expected {}",
214 name,
215 series.len(),
216 self.n_rows
217 )));
218 }
219
220 self.columns.insert(name, series);
221 Ok(self)
222 }
223
224 pub fn drop_column(mut self, name: &str) -> Self {
226 self.columns.remove(name);
227 self
228 }
229
230 pub fn select(&self, col_names: &[&str]) -> Result<Self, super::error::Error> {
232 let mut new_columns = HashMap::new();
233 for &name in col_names {
234 if let Some(series) = self.columns.get(name) {
235 new_columns.insert(name.to_string(), series.clone());
236 } else {
237 return Err(super::error::Error::Message(format!(
238 "Column '{}' not found",
239 name
240 )));
241 }
242 }
243 Self::from_series(new_columns)
244 }
245
246 pub fn filter(&self, mask: &[bool]) -> Result<Self, super::error::Error> {
248 if mask.len() != self.n_rows {
249 return Err(super::error::Error::DimensionMismatch(format!(
250 "Mask length {} doesn't match DataFrame rows {}",
251 mask.len(),
252 self.n_rows
253 )));
254 }
255
256 let mut new_columns = HashMap::new();
257 for (name, series) in &self.columns {
258 let mut filtered_data = Vec::new();
259 for (i, &value) in series.data.iter().enumerate() {
260 if mask[i] {
261 filtered_data.push(value);
262 }
263 }
264 new_columns.insert(
265 name.clone(),
266 Series::new(name.clone(), Array1::from_vec(filtered_data)),
267 );
268 }
269
270 Self::from_series(new_columns)
271 }
272
273 pub fn design_matrix(&self, formula: &Formula) -> Result<Array2<f64>, super::error::Error> {
275 formula.build_matrix(self)
276 }
277}
278
279impl Default for DataFrame {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285pub trait FromData<T> {
291 fn from_data(data: T) -> Result<Self, super::error::Error>
292 where
293 Self: Sized;
294}
295
296pub trait ToData<T> {
298 fn to_data(&self) -> Result<T, super::error::Error>;
299}
300
301impl FromData<Vec<Vec<f64>>> for DataFrame {
302 fn from_data(data: Vec<Vec<f64>>) -> Result<Self, super::error::Error> {
303 if data.is_empty() {
304 return Ok(Self::new());
305 }
306
307 let n_rows = data[0].len();
308 let mut columns = HashMap::new();
309
310 for (i, column_data) in data.iter().enumerate() {
311 if column_data.len() != n_rows {
312 return Err(super::error::Error::DimensionMismatch(format!(
313 "Column {} has length {}, expected {}",
314 i,
315 column_data.len(),
316 n_rows
317 )));
318 }
319
320 columns.insert(
321 format!("x{}", i),
322 Series::new(format!("x{}", i), Array1::from_vec(column_data.clone())),
323 );
324 }
325
326 Self::from_series(columns)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use ndarray::arr1;
334
335 #[test]
336 fn test_series_basic() {
337 let data = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
338 let series = Series::new("test", data);
339
340 assert_eq!(series.name(), "test");
341 assert_eq!(series.len(), 5);
342 assert_eq!(series.mean(), Some(3.0));
343 assert_eq!(series.std(1.0).unwrap(), 1.5811388300841898);
344 assert_eq!(series.min(), Some(1.0));
345 assert_eq!(series.max(), Some(5.0));
346 }
347
348 #[test]
349 fn test_dataframe_basic() {
350 let mut columns = HashMap::new();
351 columns.insert("x".to_string(), Series::new("x", arr1(&[1.0, 2.0, 3.0])));
352 columns.insert("y".to_string(), Series::new("y", arr1(&[4.0, 5.0, 6.0])));
353
354 let df = DataFrame::from_series(columns).unwrap();
355 assert_eq!(df.n_rows(), 3);
356 assert_eq!(df.n_cols(), 2);
357 assert!(df.column("x").is_some());
358 assert!(df.column("z").is_none());
359 }
360}