1use crate::ast::*;
2use crate::calc::{calc_function_call, CalcError, Env};
3
4use thiserror::Error;
5
6#[derive(Debug, PartialEq)]
9struct NormForm {
10    a1: Number,
11    a0: Number,
12}
13
14#[derive(Debug, PartialEq, Eq, Error)]
15pub enum SolverError {
16    #[error("Unknown variable `{0}` in `solve ... for ...`")]
17    UnknownVariable(String),
18    #[error("Unsupported `^2` of variable to solve for in `solve ... for ...`")]
19    UnsupportedXSquare,
20    #[error("Unsupported variable in denominator in `solve ... for ...`")]
21    UnsupportedXDenominator,
22    #[error("Unsupported % with solve for variable in `solve ... for ...`")]
23    UnsupportedRemainder,
24    #[error("Unsupported power in `solve ... for ...`")]
25    UnsupportedPower,
26    #[error("`solve ... for ...` contains no variable (after simplification)")]
27    NoVariable,
28    #[error(transparent)]
29    FunctionCallError(#[from] CalcError),
30}
31
32fn normalize_term(term: &Term, sym: &str, env: &dyn Env) -> Result<NormForm, SolverError> {
33    let lhs = normalize(&term.lhs, sym, env)?;
34    let rhs = normalize(&term.rhs, sym, env)?;
35    match term.op {
36        Operation::Add => Ok({
37            let factor = lhs.a1 + rhs.a1;
38            let summand = lhs.a0 + rhs.a0;
39            NormForm {
40                a1: factor,
41                a0: summand,
42            }
43        }),
44        Operation::Sub => Ok({
45            let factor = lhs.a1 - rhs.a1;
46            let summand = lhs.a0 - rhs.a0;
47            NormForm {
48                a1: factor,
49                a0: summand,
50            }
51        }),
52        Operation::Mul => {
53            let a2 = lhs.a1 * rhs.a1;
54            let a1 = lhs.a1 * rhs.a0 + rhs.a1 * lhs.a0;
55            let a0 = lhs.a0 * rhs.a0;
56            if a2 != 0.0 {
57                Err(SolverError::UnsupportedXSquare)
58            } else {
59                Ok(NormForm { a1, a0 })
60            }
61        }
62        Operation::Div => {
63            if rhs.a1 != 0.0 {
64                Err(SolverError::UnsupportedXDenominator)
65            } else {
66                let a1 = lhs.a1 / rhs.a0;
67                let a0 = lhs.a0 / rhs.a0;
68                Ok(NormForm { a1, a0 })
69            }
70        }
71        Operation::Rem => {
72            if (lhs.a1 != 0.0) || (rhs.a1 != 0.0) {
73                Err(SolverError::UnsupportedRemainder)
74            } else {
75                Ok(NormForm {
76                    a1: 0.0,
77                    a0: (lhs.a0 % rhs.a0),
78                })
79            }
80        }
81        Operation::Pow => {
82            if (lhs.a1 != 0.0) || (rhs.a1 != 0.0) {
83                Err(SolverError::UnsupportedPower)
84            } else {
85                Ok(NormForm {
86                    a1: 0.0,
87                    a0: (lhs.a0.powf(rhs.a0)),
88                })
89            }
90        }
91    }
92}
93
94fn normalize(op: &Operand, sym: &str, env: &dyn Env) -> Result<NormForm, SolverError> {
95    match op {
96        Operand::Number(num) => Ok(NormForm { a1: 0.0, a0: *num }),
97        Operand::Symbol(s) => {
98            if op.is_symbol(sym) {
99                Ok(NormForm { a1: 1.0, a0: 0.0 })
100            } else {
101                let num = env
102                    .get(s)
103                    .ok_or_else(|| SolverError::UnknownVariable(s.clone()))?;
104                Ok(NormForm { a1: 0.0, a0: *num })
105            }
106        }
107        Operand::Term(term) => normalize_term(&*term, sym, env),
108        Operand::FunCall(fun_call) => {
109            let num = calc_function_call(fun_call, env)?;
110            Ok(NormForm { a1: 0.0, a0: num })
111        }
112    }
113}
114
115pub fn solve_for(
116    lhs: &Operand,
117    rhs: &Operand,
118    sym: &str,
119    env: &dyn Env,
120) -> Result<Number, SolverError> {
121    let norm_form_lhs = normalize(lhs, sym, env)?;
122    let norm_form_rhs = normalize(rhs, sym, env)?;
123    let denominator = norm_form_lhs.a1 - norm_form_rhs.a1;
124    if 0.0 == denominator {
125        Err(SolverError::NoVariable)
126    } else {
127        let nominator = norm_form_rhs.a0 - norm_form_lhs.a0;
128        Ok(nominator / denominator)
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    mod helpers {
135        use crate::ast::{Operand, Statement};
136        use crate::parser::parse;
137
138        pub fn parse_expression(s: &str) -> Operand {
139            let statement = parse(s).unwrap();
140            if let Statement::Expression { op } = statement {
141                op
142            } else {
143                panic!("string is not a valid expression")
144            }
145        }
146
147        #[test]
148        fn parse_expression_success() {
149            assert_eq!(Operand::Number(1.0), parse_expression("1"));
150        }
151
152        #[test]
153        #[should_panic(expected = "string is not a valid expression")]
154        fn parse_expression_failed_assignment() {
155            parse_expression("x:=1");
156        }
157
158        #[test]
159        #[should_panic(expected = "InvalidExpression")]
160        fn parse_expression_failed_equation() {
161            parse_expression("1 @");
162        }
163    }
164    use self::helpers::parse_expression;
165    use super::*;
166    use crate::ast::CustomFunction;
167    use crate::calc::TopLevelEnv;
168    use crate::parse;
169
170    #[test]
171    fn normalize_operand_number() {
172        let exp = NormForm { a1: 0f64, a0: 1.2 };
173        assert_eq!(
174            exp,
175            normalize(&parse_expression("1.2"), "x", &TopLevelEnv::default()).unwrap()
176        );
177    }
178
179    #[test]
180    fn normalize_operand_symbol_x() {
181        let exp = NormForm { a1: 1f64, a0: 0f64 };
182        assert_eq!(
183            exp,
184            normalize(&parse_expression("x"), "x", &TopLevelEnv::default()).unwrap()
185        );
186    }
187
188    #[test]
189    fn normalize_operand_symbol_y_unknown() {
190        let act = normalize(&parse_expression("y"), "x", &TopLevelEnv::default());
191        assert!(matches!(act, Err(SolverError::UnknownVariable(s)) if s == "y"));
192    }
193
194    #[test]
195    fn normalize_operand_symbol_y() {
196        let mut env = TopLevelEnv::default();
197        env.put("y".to_string(), 12.0).unwrap();
198        let act = normalize(&parse_expression("y"), "x", &env);
199        assert_eq!(Ok(NormForm { a1: 0.0, a0: 12.0 }), act);
200    }
201
202    #[test]
203    fn normalize_operand_simple_add() {
204        let exp = NormForm { a1: 1f64, a0: 1f64 };
205        assert_eq!(
206            exp,
207            normalize(&parse_expression("x + 1"), "x", &TopLevelEnv::default()).unwrap()
208        );
209    }
210
211    #[test]
212    fn normalize_operand_simple_sub() {
213        let exp = NormForm {
214            a1: 1f64,
215            a0: -12f64,
216        };
217        assert_eq!(
218            exp,
219            normalize(&parse_expression("x - 12"), "x", &TopLevelEnv::default()).unwrap()
220        );
221    }
222
223    #[test]
224    fn normalize_operand_simple_mul() {
225        let exp = NormForm { a1: 2f64, a0: 0f64 };
226        assert_eq!(
227            exp,
228            normalize(&parse_expression("x * 2"), "x", &TopLevelEnv::default()).unwrap()
229        );
230    }
231
232    #[test]
233    fn normalize_operand_simple_rem() {
234        let exp = NormForm { a1: 0f64, a0: 1f64 };
235        assert_eq!(
236            exp,
237            normalize(&parse_expression("7 % 3"), "x", &TopLevelEnv::default()).unwrap()
238        );
239    }
240
241    #[test]
242    fn normalize_operand_simple_pow() {
243        let exp = NormForm {
244            a1: 0f64,
245            a0: 27f64,
246        };
247        assert_eq!(
248            exp,
249            normalize(&parse_expression("3 ^ 3"), "x", &TopLevelEnv::default()).unwrap()
250        );
251    }
252
253    #[test]
254    fn normalize_operand_simple_norm_form() {
255        let exp = NormForm { a1: 3f64, a0: 2f64 };
256        assert_eq!(
257            exp,
258            normalize(&parse_expression("3 * x + 2"), "x", &TopLevelEnv::default()).unwrap()
259        );
260    }
261
262    #[test]
263    fn normalize_operand_simple_norm_sub() {
264        let exp = NormForm {
265            a1: 3f64,
266            a0: -2f64,
267        };
268        assert_eq!(
269            exp,
270            normalize(&parse_expression("3 * x - 2"), "x", &TopLevelEnv::default()).unwrap()
271        );
272    }
273
274    #[test]
275    fn normalize_operand_div() {
276        let exp = NormForm {
277            a1: 4f64,
278            a0: -5f64,
279        };
280        assert_eq!(
281            exp,
282            normalize(
283                &parse_expression("(12 * x - 15) / 3"),
284                "x",
285                &TopLevelEnv::default()
286            )
287            .unwrap()
288        );
289    }
290
291    #[test]
292    fn solve_for_simple() {
293        assert!(
294            if let Statement::SolveFor { lhs, rhs, sym } = parse("solve x = 10 for x").unwrap() {
295                assert_eq!(
296                    Ok(10.0),
297                    solve_for(&lhs, &rhs, &sym, &TopLevelEnv::default())
298                );
299                true
300            } else {
301                false
302            }
303        );
304    }
305
306    #[test]
307    fn solve_for_complex() {
308        assert!(if let Statement::SolveFor { lhs, rhs, sym } =
309            parse("solve 5 + 2 * x + 12 = 22 - 6 * x + 7 for x").unwrap()
310        {
311            assert_eq!(
312                Ok(1.5),
313                solve_for(&lhs, &rhs, &sym, &TopLevelEnv::default())
314            );
315            true
316        } else {
317            false
318        });
319    }
320
321    #[test]
322    fn solve_for_with_function_call() {
323        let mut env = TopLevelEnv::default();
324        env.put_fun(
325            "add".to_string(),
326            Function::Custom(CustomFunction {
327                args: vec!["x".to_string(), "y".to_string()],
328                body: Operand::Term(Box::new(Term {
329                    lhs: Operand::Symbol("x".to_string()),
330                    rhs: Operand::Symbol("y".to_string()),
331                    op: Operation::Add,
332                })),
333            }),
334        );
335        assert!(if let Statement::SolveFor { lhs, rhs, sym } =
336            parse("solve 2 * x + add(5, 12) = 22 - 6 * x + 7 for x").unwrap()
337        {
338            assert_eq!(Ok(1.5), solve_for(&lhs, &rhs, &sym, &env));
339            true
340        } else {
341            false
342        });
343    }
344}