1use super::data::DataFrame;
14use ndarray::{Array1, Array2};
15use nom::{
16 IResult,
17 branch::alt,
18 bytes::complete::tag,
19 character::complete::{alpha1, alphanumeric1, char, digit1, space0},
20 combinator::{map, recognize},
21 multi::{many0, separated_list1},
22 sequence::{delimited, pair, tuple},
23};
24use std::collections::HashSet;
25
26#[derive(Debug, Clone, PartialEq)]
32pub enum Term {
33 Variable(String),
35 Function(String, Box<Term>),
37 Interaction(Box<Term>, Box<Term>),
39 Polynomial(Box<Term>, u32),
41}
42
43#[derive(Debug, Clone)]
45pub struct Formula {
46 pub response: Option<Term>,
48 pub predictors: Vec<Term>,
50 pub intercept: bool,
52}
53
54impl Formula {
55 pub fn parse(input: &str) -> Result<Self, super::error::Error> {
57 parse_formula(input)
58 .map_err(|e| super::error::Error::FormulaError(format!("Parse error: {:?}", e)))
59 }
60
61 pub fn no_intercept(mut self) -> Self {
63 self.intercept = false;
64 self
65 }
66
67 pub fn variables(&self) -> HashSet<String> {
69 let mut vars = HashSet::new();
70
71 if let Some(ref resp) = self.response {
72 collect_variables(resp, &mut vars);
73 }
74
75 for pred in &self.predictors {
76 collect_variables(pred, &mut vars);
77 }
78
79 vars
80 }
81
82 pub fn build_matrix(&self, df: &DataFrame) -> Result<Array2<f64>, super::error::Error> {
84 let n_rows = df.n_rows();
85 let vars = self.variables();
86
87 for var in &vars {
89 if !df.column_names().contains(var) {
90 return Err(super::error::Error::FormulaError(format!(
91 "Variable '{}' not found in DataFrame",
92 var
93 )));
94 }
95 }
96
97 let mut columns = Vec::new();
99 if self.intercept {
100 columns.push(vec![1.0; n_rows]);
101 }
102
103 for term in &self.predictors {
105 let term_cols = build_term_matrix(term, df)?;
106 columns.extend(term_cols);
107 }
108
109 let n_cols = columns.len();
111 let mut matrix = Array2::zeros((n_rows, n_cols));
112
113 for (j, col_data) in columns.into_iter().enumerate() {
114 for (i, &val) in col_data.iter().enumerate() {
115 matrix[(i, j)] = val;
116 }
117 }
118
119 Ok(matrix)
120 }
121
122 pub fn response_vector(
124 &self,
125 df: &DataFrame,
126 ) -> Result<Option<Array1<f64>>, super::error::Error> {
127 if let Some(ref resp) = self.response {
128 let resp_name = match resp {
129 Term::Variable(name) => name,
130 _ => {
131 return Err(super::error::Error::FormulaError(
132 "Complex response terms not yet supported".to_string(),
133 ));
134 }
135 };
136
137 let series = df.column(resp_name).ok_or_else(|| {
138 super::error::Error::FormulaError(format!(
139 "Response variable '{}' not found",
140 resp_name
141 ))
142 })?;
143
144 Ok(Some(series.data().to_owned()))
145 } else {
146 Ok(None)
147 }
148 }
149}
150
151fn parse_formula(input: &str) -> Result<Formula, String> {
156 let (rest, (response, predictors)) =
157 formula_parser(input).map_err(|e| format!("Parse error: {:?}", e))?;
158
159 if !rest.trim().is_empty() {
160 return Err(format!("Unexpected input after formula: '{}'", rest));
161 }
162
163 Ok(Formula {
164 response,
165 predictors,
166 intercept: true,
167 })
168}
169
170fn formula_parser(input: &str) -> IResult<&str, (Option<Term>, Vec<Term>)> {
171 let (input, _) = space0(input)?;
172
173 let (input, result) = alt((
175 map(
177 tuple((term_parser, space0, tag("~"), space0, predictors_parser)),
178 |(resp, _, _, _, preds)| (Some(resp), preds),
179 ),
180 map(
182 tuple((tag("~"), space0, predictors_parser)),
183 |(_, _, preds)| (None, preds),
184 ),
185 map(predictors_parser, |preds| (None, preds)),
187 ))(input)?;
188
189 Ok((input, result))
190}
191
192fn predictors_parser(input: &str) -> IResult<&str, Vec<Term>> {
193 separated_list1(delimited(space0, tag("+"), space0), term_parser)(input)
194}
195
196fn term_parser(input: &str) -> IResult<&str, Term> {
197 let (input, term) = alt((
198 map(
200 tuple((alpha1, char('('), term_parser, char(')'))),
201 |(func, _, arg, _)| Term::Function(func.to_string(), Box::new(arg)),
202 ),
203 interaction_parser,
205 base_term_parser,
207 ))(input)?;
208
209 let (input, term) = many0(map(tuple((char('^'), digit1)), |(_, exp): (_, &str)| {
211 exp.parse::<u32>().unwrap_or(1)
212 }))(input)
213 .map(|(rest, exponents)| {
214 let mut current = term;
215 for exp in exponents {
216 current = Term::Polynomial(Box::new(current), exp);
217 }
218 (rest, current)
219 })?;
220
221 Ok((input, term))
222}
223
224fn interaction_parser(input: &str) -> IResult<&str, Term> {
225 let (input, left) = base_term_parser(input)?;
226 let (input, _) = space0(input)?;
227 let (input, op) = alt((tag(":"), tag("*")))(input)?;
228 let (input, _) = space0(input)?;
229 let (input, right) = term_parser(input)?;
230
231 let term = if op == "*" {
232 Term::Interaction(Box::new(left), Box::new(right))
235 } else {
236 Term::Interaction(Box::new(left), Box::new(right))
237 };
238
239 Ok((input, term))
240}
241
242fn base_term_parser(input: &str) -> IResult<&str, Term> {
243 map(
244 recognize(pair(
245 alt((alpha1, tag("_"))),
246 many0(alt((alphanumeric1, tag("_"), tag(".")))),
247 )),
248 |name: &str| Term::Variable(name.to_string()),
249 )(input)
250}
251
252fn collect_variables(term: &Term, vars: &mut HashSet<String>) {
257 match term {
258 Term::Variable(name) => {
259 vars.insert(name.clone());
260 }
261 Term::Function(_, arg) => {
262 collect_variables(arg, vars);
263 }
264 Term::Interaction(left, right) => {
265 collect_variables(left, vars);
266 collect_variables(right, vars);
267 }
268 Term::Polynomial(base, _) => {
269 collect_variables(base, vars);
270 }
271 }
272}
273
274fn build_term_matrix(term: &Term, df: &DataFrame) -> Result<Vec<Vec<f64>>, super::error::Error> {
275 match term {
276 Term::Variable(name) => {
277 let series = df.column(name).ok_or_else(|| {
278 super::error::Error::FormulaError(format!("Variable '{}' not found", name))
279 })?;
280 Ok(vec![series.data().to_vec()])
281 }
282 Term::Function(func, arg) => {
283 let base_cols = build_term_matrix(arg, df)?;
284 if base_cols.len() != 1 {
285 return Err(super::error::Error::FormulaError(
286 "Functions can only be applied to single variables".to_string(),
287 ));
288 }
289
290 let base_data = &base_cols[0];
291 let transformed: Vec<f64> = match func.as_str() {
292 "log" => base_data.iter().map(|&x| x.ln()).collect(),
293 "log10" => base_data.iter().map(|&x| x.log10()).collect(),
294 "log2" => base_data.iter().map(|&x| x.log2()).collect(),
295 "sqrt" => base_data.iter().map(|&x| x.sqrt()).collect(),
296 "exp" => base_data.iter().map(|&x| x.exp()).collect(),
297 "abs" => base_data.iter().map(|&x| x.abs()).collect(),
298 "sin" => base_data.iter().map(|&x| x.sin()).collect(),
299 "cos" => base_data.iter().map(|&x| x.cos()).collect(),
300 "tan" => base_data.iter().map(|&x| x.tan()).collect(),
301 _ => {
302 return Err(super::error::Error::FormulaError(format!(
303 "Unsupported function: {}",
304 func
305 )));
306 }
307 };
308
309 Ok(vec![transformed])
310 }
311 Term::Interaction(left, right) => {
312 let left_cols = build_term_matrix(left, df)?;
313 let right_cols = build_term_matrix(right, df)?;
314
315 let mut result = Vec::new();
317 for lcol in &left_cols {
318 for rcol in &right_cols {
319 let interacted: Vec<f64> =
320 lcol.iter().zip(rcol.iter()).map(|(&l, &r)| l * r).collect();
321 result.push(interacted);
322 }
323 }
324
325 Ok(result)
326 }
327 Term::Polynomial(base, power) => {
328 let base_cols = build_term_matrix(base, df)?;
329 if base_cols.len() != 1 {
330 return Err(super::error::Error::FormulaError(
331 "Polynomial can only be applied to single variables".to_string(),
332 ));
333 }
334
335 let base_data = &base_cols[0];
336 let powered: Vec<f64> = base_data.iter().map(|&x| x.powi(*power as i32)).collect();
337
338 Ok(vec![powered])
339 }
340 }
341}
342
343#[cfg(test)]
348mod tests {
349 use super::super::data::DataFrame;
350 use super::*;
351 use ndarray::arr1;
352 use std::collections::HashMap;
353
354 #[test]
355 fn test_formula_parsing() {
356 let cases = vec![
357 ("y ~ x1 + x2", true, 2),
358 ("y ~ x1 * x2", true, 1), ("y ~ log(x1) + sqrt(x2)", true, 2),
360 ("~ x1 + x2", false, 2),
361 ("x1 + x2", false, 2),
362 ];
363
364 for (input, has_response, pred_count) in cases {
365 let formula = Formula::parse(input).unwrap();
366 assert_eq!(formula.response.is_some(), has_response);
367 assert_eq!(formula.predictors.len(), pred_count);
368 }
369 }
370
371 #[test]
372 fn test_variable_extraction() {
373 let formula = Formula::parse("y ~ x1 + log(x2) + x3:x4").unwrap();
374 let vars = formula.variables();
375
376 assert!(vars.contains("y"));
377 assert!(vars.contains("x1"));
378 assert!(vars.contains("x2"));
379 assert!(vars.contains("x3"));
380 assert!(vars.contains("x4"));
381 assert_eq!(vars.len(), 5);
382 }
383
384 #[test]
385 fn test_design_matrix() {
386 let mut columns = HashMap::new();
387 columns.insert("y".to_string(), Series::new("y", arr1(&[1.0, 2.0, 3.0])));
388 columns.insert("x1".to_string(), Series::new("x1", arr1(&[1.0, 2.0, 3.0])));
389 columns.insert("x2".to_string(), Series::new("x2", arr1(&[4.0, 5.0, 6.0])));
390
391 let df = DataFrame::from_series(columns).unwrap();
392 let formula = Formula::parse("y ~ x1 + x2").unwrap();
393
394 let matrix = formula.build_matrix(&df).unwrap();
395 assert_eq!(matrix.shape(), &[3, 3]); assert_eq!(matrix.column(0).to_vec(), vec![1.0, 1.0, 1.0]);
399 assert_eq!(matrix.column(1).to_vec(), vec![1.0, 2.0, 3.0]);
401 assert_eq!(matrix.column(2).to_vec(), vec![4.0, 5.0, 6.0]);
403 }
404}