Skip to main content

so_core/
formula.rs

1//! R-style formula parser and design matrix builder
2//!
3//! This module provides formula parsing and evaluation similar to R's
4//! formula interface, supporting:
5//!
6//! - Basic terms: `y ~ x1 + x2`
7//! - Interaction terms: `y ~ x1:x2`, `y ~ x1*x2`
8//! - Polynomial terms: `y ~ x1^2`
9//! - Factor expansion: `y ~ factor(x1)`
10//! - Special functions: `y ~ log(x1) + sqrt(x2)`
11//! - Random effects (for mixed models): `y ~ (1 | group)`
12
13use super::data::DataFrame;
14use ndarray::{Array1, Array2};
15use nom::{
16    IResult,
17    branch::alt,
18    bytes::complete::tag,
19    character::complete::{alpha1, alphanumeric1, char, digit1, space0},
20    combinator::{map, recognize},
21    multi::{many0, separated_list1},
22    sequence::{delimited, pair, tuple},
23};
24use std::collections::HashSet;
25
26// ============================================================================
27// Formula AST
28// ============================================================================
29
30/// A term in a formula (variable, function, or interaction)
31#[derive(Debug, Clone, PartialEq)]
32pub enum Term {
33    /// Simple variable: `x`
34    Variable(String),
35    /// Function call: `log(x)`
36    Function(String, Box<Term>),
37    /// Interaction: `x:y`
38    Interaction(Box<Term>, Box<Term>),
39    /// Polynomial: `x^2`
40    Polynomial(Box<Term>, u32),
41}
42
43/// A formula expression (response ~ predictors)
44#[derive(Debug, Clone)]
45pub struct Formula {
46    /// Response variable (left side of ~)
47    pub response: Option<Term>,
48    /// Predictor terms (right side of ~)
49    pub predictors: Vec<Term>,
50    /// Whether to include intercept (default: true)
51    pub intercept: bool,
52}
53
54impl Formula {
55    /// Create a new formula from a string
56    pub fn parse(input: &str) -> Result<Self, super::error::Error> {
57        parse_formula(input)
58            .map_err(|e| super::error::Error::FormulaError(format!("Parse error: {:?}", e)))
59    }
60
61    /// Create a formula with no intercept
62    pub fn no_intercept(mut self) -> Self {
63        self.intercept = false;
64        self
65    }
66
67    /// Get all variable names in the formula
68    pub fn variables(&self) -> HashSet<String> {
69        let mut vars = HashSet::new();
70
71        if let Some(ref resp) = self.response {
72            collect_variables(resp, &mut vars);
73        }
74
75        for pred in &self.predictors {
76            collect_variables(pred, &mut vars);
77        }
78
79        vars
80    }
81
82    /// Build design matrix from DataFrame
83    pub fn build_matrix(&self, df: &DataFrame) -> Result<Array2<f64>, super::error::Error> {
84        let n_rows = df.n_rows();
85        let vars = self.variables();
86
87        // Validate all variables exist in DataFrame
88        for var in &vars {
89            if !df.column_names().contains(var) {
90                return Err(super::error::Error::FormulaError(format!(
91                    "Variable '{}' not found in DataFrame",
92                    var
93                )));
94            }
95        }
96
97        // Start with intercept if requested
98        let mut columns = Vec::new();
99        if self.intercept {
100            columns.push(vec![1.0; n_rows]);
101        }
102
103        // Process each predictor term
104        for term in &self.predictors {
105            let term_cols = build_term_matrix(term, df)?;
106            columns.extend(term_cols);
107        }
108
109        // Convert to Array2
110        let n_cols = columns.len();
111        let mut matrix = Array2::zeros((n_rows, n_cols));
112
113        for (j, col_data) in columns.into_iter().enumerate() {
114            for (i, &val) in col_data.iter().enumerate() {
115                matrix[(i, j)] = val;
116            }
117        }
118
119        Ok(matrix)
120    }
121
122    /// Get response variable as array (if specified)
123    pub fn response_vector(
124        &self,
125        df: &DataFrame,
126    ) -> Result<Option<Array1<f64>>, super::error::Error> {
127        if let Some(ref resp) = self.response {
128            let resp_name = match resp {
129                Term::Variable(name) => name,
130                _ => {
131                    return Err(super::error::Error::FormulaError(
132                        "Complex response terms not yet supported".to_string(),
133                    ));
134                }
135            };
136
137            let series = df.column(resp_name).ok_or_else(|| {
138                super::error::Error::FormulaError(format!(
139                    "Response variable '{}' not found",
140                    resp_name
141                ))
142            })?;
143
144            Ok(Some(series.data().to_owned()))
145        } else {
146            Ok(None)
147        }
148    }
149}
150
151// ============================================================================
152// Formula Parser (using nom)
153// ============================================================================
154
155fn parse_formula(input: &str) -> Result<Formula, String> {
156    let (rest, (response, predictors)) =
157        formula_parser(input).map_err(|e| format!("Parse error: {:?}", e))?;
158
159    if !rest.trim().is_empty() {
160        return Err(format!("Unexpected input after formula: '{}'", rest));
161    }
162
163    Ok(Formula {
164        response,
165        predictors,
166        intercept: true,
167    })
168}
169
170fn formula_parser(input: &str) -> IResult<&str, (Option<Term>, Vec<Term>)> {
171    let (input, _) = space0(input)?;
172
173    // Parse response ~ predictors or just predictors
174    let (input, result) = alt((
175        // With response: y ~ x1 + x2
176        map(
177            tuple((term_parser, space0, tag("~"), space0, predictors_parser)),
178            |(resp, _, _, _, preds)| (Some(resp), preds),
179        ),
180        // Without response: ~ x1 + x2
181        map(
182            tuple((tag("~"), space0, predictors_parser)),
183            |(_, _, preds)| (None, preds),
184        ),
185        // Just predictors (implied ~)
186        map(predictors_parser, |preds| (None, preds)),
187    ))(input)?;
188
189    Ok((input, result))
190}
191
192fn predictors_parser(input: &str) -> IResult<&str, Vec<Term>> {
193    separated_list1(delimited(space0, tag("+"), space0), term_parser)(input)
194}
195
196fn term_parser(input: &str) -> IResult<&str, Term> {
197    let (input, term) = alt((
198        // Function call: log(x)
199        map(
200            tuple((alpha1, char('('), term_parser, char(')'))),
201            |(func, _, arg, _)| Term::Function(func.to_string(), Box::new(arg)),
202        ),
203        // Interaction: x:y or x*y
204        interaction_parser,
205        // Base term (variable or number)
206        base_term_parser,
207    ))(input)?;
208
209    // Handle polynomial: x^2
210    let (input, term) = many0(map(tuple((char('^'), digit1)), |(_, exp): (_, &str)| {
211        exp.parse::<u32>().unwrap_or(1)
212    }))(input)
213    .map(|(rest, exponents)| {
214        let mut current = term;
215        for exp in exponents {
216            current = Term::Polynomial(Box::new(current), exp);
217        }
218        (rest, current)
219    })?;
220
221    Ok((input, term))
222}
223
224fn interaction_parser(input: &str) -> IResult<&str, Term> {
225    let (input, left) = base_term_parser(input)?;
226    let (input, _) = space0(input)?;
227    let (input, op) = alt((tag(":"), tag("*")))(input)?;
228    let (input, _) = space0(input)?;
229    let (input, right) = term_parser(input)?;
230
231    let term = if op == "*" {
232        // x*y expands to x + y + x:y
233        // We'll handle expansion later
234        Term::Interaction(Box::new(left), Box::new(right))
235    } else {
236        Term::Interaction(Box::new(left), Box::new(right))
237    };
238
239    Ok((input, term))
240}
241
242fn base_term_parser(input: &str) -> IResult<&str, Term> {
243    map(
244        recognize(pair(
245            alt((alpha1, tag("_"))),
246            many0(alt((alphanumeric1, tag("_"), tag(".")))),
247        )),
248        |name: &str| Term::Variable(name.to_string()),
249    )(input)
250}
251
252// ============================================================================
253// Formula Evaluation
254// ============================================================================
255
256fn collect_variables(term: &Term, vars: &mut HashSet<String>) {
257    match term {
258        Term::Variable(name) => {
259            vars.insert(name.clone());
260        }
261        Term::Function(_, arg) => {
262            collect_variables(arg, vars);
263        }
264        Term::Interaction(left, right) => {
265            collect_variables(left, vars);
266            collect_variables(right, vars);
267        }
268        Term::Polynomial(base, _) => {
269            collect_variables(base, vars);
270        }
271    }
272}
273
274fn build_term_matrix(term: &Term, df: &DataFrame) -> Result<Vec<Vec<f64>>, super::error::Error> {
275    match term {
276        Term::Variable(name) => {
277            let series = df.column(name).ok_or_else(|| {
278                super::error::Error::FormulaError(format!("Variable '{}' not found", name))
279            })?;
280            Ok(vec![series.data().to_vec()])
281        }
282        Term::Function(func, arg) => {
283            let base_cols = build_term_matrix(arg, df)?;
284            if base_cols.len() != 1 {
285                return Err(super::error::Error::FormulaError(
286                    "Functions can only be applied to single variables".to_string(),
287                ));
288            }
289
290            let base_data = &base_cols[0];
291            let transformed: Vec<f64> = match func.as_str() {
292                "log" => base_data.iter().map(|&x| x.ln()).collect(),
293                "log10" => base_data.iter().map(|&x| x.log10()).collect(),
294                "log2" => base_data.iter().map(|&x| x.log2()).collect(),
295                "sqrt" => base_data.iter().map(|&x| x.sqrt()).collect(),
296                "exp" => base_data.iter().map(|&x| x.exp()).collect(),
297                "abs" => base_data.iter().map(|&x| x.abs()).collect(),
298                "sin" => base_data.iter().map(|&x| x.sin()).collect(),
299                "cos" => base_data.iter().map(|&x| x.cos()).collect(),
300                "tan" => base_data.iter().map(|&x| x.tan()).collect(),
301                _ => {
302                    return Err(super::error::Error::FormulaError(format!(
303                        "Unsupported function: {}",
304                        func
305                    )));
306                }
307            };
308
309            Ok(vec![transformed])
310        }
311        Term::Interaction(left, right) => {
312            let left_cols = build_term_matrix(left, df)?;
313            let right_cols = build_term_matrix(right, df)?;
314
315            // Simple interaction: multiply corresponding columns
316            let mut result = Vec::new();
317            for lcol in &left_cols {
318                for rcol in &right_cols {
319                    let interacted: Vec<f64> =
320                        lcol.iter().zip(rcol.iter()).map(|(&l, &r)| l * r).collect();
321                    result.push(interacted);
322                }
323            }
324
325            Ok(result)
326        }
327        Term::Polynomial(base, power) => {
328            let base_cols = build_term_matrix(base, df)?;
329            if base_cols.len() != 1 {
330                return Err(super::error::Error::FormulaError(
331                    "Polynomial can only be applied to single variables".to_string(),
332                ));
333            }
334
335            let base_data = &base_cols[0];
336            let powered: Vec<f64> = base_data.iter().map(|&x| x.powi(*power as i32)).collect();
337
338            Ok(vec![powered])
339        }
340    }
341}
342
343// ============================================================================
344// Tests
345// ============================================================================
346
347#[cfg(test)]
348mod tests {
349    use super::super::data::DataFrame;
350    use super::*;
351    use ndarray::arr1;
352    use std::collections::HashMap;
353
354    #[test]
355    fn test_formula_parsing() {
356        let cases = vec![
357            ("y ~ x1 + x2", true, 2),
358            ("y ~ x1 * x2", true, 1), // TODO: Should expand to x1 + x2 + x1:x2 (currently just x1:x2)
359            ("y ~ log(x1) + sqrt(x2)", true, 2),
360            ("~ x1 + x2", false, 2),
361            ("x1 + x2", false, 2),
362        ];
363
364        for (input, has_response, pred_count) in cases {
365            let formula = Formula::parse(input).unwrap();
366            assert_eq!(formula.response.is_some(), has_response);
367            assert_eq!(formula.predictors.len(), pred_count);
368        }
369    }
370
371    #[test]
372    fn test_variable_extraction() {
373        let formula = Formula::parse("y ~ x1 + log(x2) + x3:x4").unwrap();
374        let vars = formula.variables();
375
376        assert!(vars.contains("y"));
377        assert!(vars.contains("x1"));
378        assert!(vars.contains("x2"));
379        assert!(vars.contains("x3"));
380        assert!(vars.contains("x4"));
381        assert_eq!(vars.len(), 5);
382    }
383
384    #[test]
385    fn test_design_matrix() {
386        let mut columns = HashMap::new();
387        columns.insert("y".to_string(), Series::new("y", arr1(&[1.0, 2.0, 3.0])));
388        columns.insert("x1".to_string(), Series::new("x1", arr1(&[1.0, 2.0, 3.0])));
389        columns.insert("x2".to_string(), Series::new("x2", arr1(&[4.0, 5.0, 6.0])));
390
391        let df = DataFrame::from_series(columns).unwrap();
392        let formula = Formula::parse("y ~ x1 + x2").unwrap();
393
394        let matrix = formula.build_matrix(&df).unwrap();
395        assert_eq!(matrix.shape(), &[3, 3]); // intercept + x1 + x2
396
397        // Check intercept column
398        assert_eq!(matrix.column(0).to_vec(), vec![1.0, 1.0, 1.0]);
399        // Check x1 column
400        assert_eq!(matrix.column(1).to_vec(), vec![1.0, 2.0, 3.0]);
401        // Check x2 column
402        assert_eq!(matrix.column(2).to_vec(), vec![4.0, 5.0, 6.0]);
403    }
404}