scirs2_integrate/symbolic/
conversion.rs1use super::expression::{SymbolicExpression, Variable};
8use crate::common::IntegrateFloat;
9use crate::error::{IntegrateError, IntegrateResult};
10use scirs2_core::ndarray::{Array1, ArrayView1};
11use std::collections::HashMap;
12use SymbolicExpression::{Add, Constant, Cos, Div, Exp, Ln, Mul, Neg, Pow, Sin, Sqrt, Sub, Var};
13
14pub struct HigherOrderODE<F: IntegrateFloat> {
16 pub order: usize,
18 pub dependent_var: String,
20 pub independent_var: String,
22 pub expression: SymbolicExpression<F>,
25}
26
27impl<F: IntegrateFloat> HigherOrderODE<F> {
28 pub fn new(
30 order: usize,
31 dependent_var: impl Into<String>,
32 independent_var: impl Into<String>,
33 expression: SymbolicExpression<F>,
34 ) -> IntegrateResult<Self> {
35 if order == 0 {
36 return Err(IntegrateError::ValueError(
37 "ODE order must be at least 1".to_string(),
38 ));
39 }
40
41 Ok(HigherOrderODE {
42 order,
43 dependent_var: dependent_var.into(),
44 independent_var: independent_var.into(),
45 expression,
46 })
47 }
48
49 pub fn state_variables(&self) -> Vec<Variable> {
51 (0..self.order)
52 .map(|i| Variable::indexed(&self.dependent_var, i))
53 .collect()
54 }
55}
56
57pub struct FirstOrderSystem<F: IntegrateFloat> {
59 pub state_vars: Vec<Variable>,
61 pub expressions: Vec<SymbolicExpression<F>>,
63 pub variable_map: HashMap<String, Variable>,
65}
66
67impl<F: IntegrateFloat> FirstOrderSystem<F> {
68 pub fn to_function(&self) -> impl Fn(F, ArrayView1<F>) -> IntegrateResult<Array1<F>> {
70 let expressions = self.expressions.clone();
71 let state_vars = self.state_vars.clone();
72
73 move |t: F, y: ArrayView1<F>| {
74 if y.len() != state_vars.len() {
75 return Err(IntegrateError::DimensionMismatch(format!(
76 "Expected {} states, got {}",
77 state_vars.len(),
78 y.len()
79 )));
80 }
81
82 let mut values = HashMap::new();
84 for (i, var) in state_vars.iter().enumerate() {
85 values.insert(var.clone(), y[i]);
86 }
87 values.insert(Variable::new("t"), t);
88
89 let mut result = Array1::zeros(expressions.len());
91 for (i, expr) in expressions.iter().enumerate() {
92 result[i] = expr.evaluate(&values)?;
93 }
94
95 Ok(result)
96 }
97 }
98}
99
100#[allow(dead_code)]
116pub fn higher_order_to_first_order<F: IntegrateFloat>(
117 ode: &HigherOrderODE<F>,
118) -> IntegrateResult<FirstOrderSystem<F>> {
119 use SymbolicExpression::*;
120
121 let mut state_vars = Vec::new();
122 let mut expressions = Vec::new();
123 let mut variable_map = HashMap::new();
124
125 for i in 0..ode.order {
127 let var = Variable::indexed(&ode.dependent_var, i);
128 state_vars.push(var.clone());
129
130 let deriv_notation = match i {
132 0 => ode.dependent_var.clone(),
133 1 => format!("{}'", ode.dependent_var),
134 n => format!("{}^({})", ode.dependent_var, n),
135 };
136 variable_map.insert(deriv_notation, var);
137 }
138
139 for i in 0..ode.order - 1 {
142 expressions.push(Var(state_vars[i + 1].clone()));
143 }
144
145 let mut highest_deriv_expr = ode.expression.clone();
147 highest_deriv_expr = substitute_derivatives(&highest_deriv_expr, &variable_map);
148 expressions.push(highest_deriv_expr);
149
150 Ok(FirstOrderSystem {
151 state_vars,
152 expressions,
153 variable_map,
154 })
155}
156
157#[allow(dead_code)]
159fn substitute_derivatives<F: IntegrateFloat>(
160 expr: &SymbolicExpression<F>,
161 variable_map: &HashMap<String, Variable>,
162) -> SymbolicExpression<F> {
163 match expr {
164 Var(v) => {
165 if let Some(state_var) = variable_map.get(&v.name) {
167 Var(state_var.clone())
168 } else {
169 expr.clone()
170 }
171 }
172 Add(a, b) => Add(
173 Box::new(substitute_derivatives(a, variable_map)),
174 Box::new(substitute_derivatives(b, variable_map)),
175 ),
176 Sub(a, b) => Sub(
177 Box::new(substitute_derivatives(a, variable_map)),
178 Box::new(substitute_derivatives(b, variable_map)),
179 ),
180 Mul(a, b) => Mul(
181 Box::new(substitute_derivatives(a, variable_map)),
182 Box::new(substitute_derivatives(b, variable_map)),
183 ),
184 Div(a, b) => Div(
185 Box::new(substitute_derivatives(a, variable_map)),
186 Box::new(substitute_derivatives(b, variable_map)),
187 ),
188 Pow(a, b) => Pow(
189 Box::new(substitute_derivatives(a, variable_map)),
190 Box::new(substitute_derivatives(b, variable_map)),
191 ),
192 Neg(a) => Neg(Box::new(substitute_derivatives(a, variable_map))),
193 Sin(a) => Sin(Box::new(substitute_derivatives(a, variable_map))),
194 Cos(a) => Cos(Box::new(substitute_derivatives(a, variable_map))),
195 Exp(a) => Exp(Box::new(substitute_derivatives(a, variable_map))),
196 Ln(a) => Ln(Box::new(substitute_derivatives(a, variable_map))),
197 Sqrt(a) => Sqrt(Box::new(substitute_derivatives(a, variable_map))),
198 _ => expr.clone(),
199 }
200}
201
202#[allow(dead_code)]
204pub fn example_damped_oscillator<F: IntegrateFloat>(
205 omega: F,
206 damping: F,
207) -> IntegrateResult<FirstOrderSystem<F>> {
208 let x = Var(Variable::new("x"));
212 let x_prime = Var(Variable::new("x'"));
213
214 let expression = Neg(Box::new(Add(
215 Box::new(Mul(
216 Box::new(Mul(
217 Box::new(Constant(
218 F::from(2.0).expect("Failed to convert constant to float"),
219 )),
220 Box::new(Constant(damping)),
221 )),
222 Box::new(x_prime),
223 )),
224 Box::new(Mul(
225 Box::new(Pow(
226 Box::new(Constant(omega)),
227 Box::new(Constant(
228 F::from(2.0).expect("Failed to convert constant to float"),
229 )),
230 )),
231 Box::new(x),
232 )),
233 )));
234
235 let ode = HigherOrderODE::new(2, "x", "t", expression)?;
236 higher_order_to_first_order(&ode)
237}
238
239#[allow(dead_code)]
241pub fn example_driven_pendulum<F: IntegrateFloat>(
242 g: F, l: F, gamma: F, a: F, omega: F, ) -> IntegrateResult<FirstOrderSystem<F>> {
248 let theta = SymbolicExpression::var("θ");
252 let theta_prime = SymbolicExpression::var("θ'");
253 let t = SymbolicExpression::var("t");
254
255 let g_over_l = SymbolicExpression::constant(g / l);
256 let gamma_const = SymbolicExpression::constant(gamma);
257 let a_const = SymbolicExpression::constant(a);
258 let omega_const = SymbolicExpression::constant(omega);
259
260 let damping_term = -gamma_const * theta_prime;
262 let gravity_term = -g_over_l * SymbolicExpression::Sin(Box::new(theta));
263 let driving_term = a_const * SymbolicExpression::Cos(Box::new(omega_const * t));
264
265 let expression = damping_term + gravity_term + driving_term;
266
267 let ode = HigherOrderODE::new(2, "θ", "t", expression)?;
268 higher_order_to_first_order(&ode)
269}
270
271#[allow(dead_code)]
273pub fn example_euler_bernoulli_beam<F: IntegrateFloat>(
274 ei: F, _rho_a: F, f: F, ) -> IntegrateResult<FirstOrderSystem<F>> {
278 let f_over_ei = SymbolicExpression::constant(f / ei);
283
284 let ode = HigherOrderODE::new(4, "w", "x", f_over_ei)?;
285 higher_order_to_first_order(&ode)
286}
287
288pub struct SystemConverter<F: IntegrateFloat> {
290 odes: Vec<HigherOrderODE<F>>,
291 total_states: usize,
292}
293
294impl<F: IntegrateFloat> SystemConverter<F> {
295 pub fn new() -> Self {
297 SystemConverter {
298 odes: Vec::new(),
299 total_states: 0,
300 }
301 }
302
303 pub fn add_ode(&mut self, ode: HigherOrderODE<F>) -> &mut Self {
305 self.total_states += ode.order;
306 self.odes.push(ode);
307 self
308 }
309
310 pub fn convert(&self) -> IntegrateResult<FirstOrderSystem<F>> {
312 let mut all_state_vars = Vec::new();
313 let mut all_expressions = Vec::new();
314 let mut all_variable_map = HashMap::new();
315
316 for ode in &self.odes {
317 let system = higher_order_to_first_order(ode)?;
318 all_state_vars.extend(system.state_vars);
319 all_expressions.extend(system.expressions);
320 all_variable_map.extend(system.variable_map);
321 }
322
323 Ok(FirstOrderSystem {
324 state_vars: all_state_vars,
325 expressions: all_expressions,
326 variable_map: all_variable_map,
327 })
328 }
329}
330
331impl<F: IntegrateFloat> Default for SystemConverter<F> {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::{
341 higher_order_to_first_order, HigherOrderODE, SymbolicExpression,
342 SymbolicExpression::{Neg, Var},
343 Variable,
344 };
345
346 #[test]
347 fn test_second_order_conversion() {
348 let x: SymbolicExpression<f64> = Var(Variable::new("x"));
350 let expr = Neg(Box::new(x));
351
352 let ode = HigherOrderODE::new(2, "x", "t", expr).expect("Operation failed");
353 let system = higher_order_to_first_order(&ode).expect("Operation failed");
354
355 assert_eq!(system.state_vars.len(), 2);
356 assert_eq!(system.expressions.len(), 2);
357
358 if let Var(v) = &system.expressions[0] {
360 assert_eq!(v.name, "x");
361 assert_eq!(v.index, Some(1));
362 } else {
363 panic!(
364 "Expected variable expression, got {:?}",
365 system.expressions[0]
366 );
367 }
368 }
369}