ruut_functions/
derivation.rs

1use crate::{FType, Func, F1D, F2D, F3D, FND};
2
3impl F1D {
4    /// Computes the nth-derivative
5    /// ```
6    /// use ruut_functions::{f1d, F1D};
7    ///
8    /// let f = f1d!("x^log(x)");
9    /// let df = f.derive(1);
10    /// assert_eq!(df, f1d!("2ln(x)x^ln(x)/x"));
11    /// assert_eq!(df.derive(1), f.derive(2));
12    /// ```
13    pub fn derive(&self, order: usize) -> Self {
14        F1D(self.0.derive_nth('x', order))
15    }
16}
17impl F2D {
18    /// Computes the nth-derivative
19    /// ```
20    /// use ruut_functions::{f2d,F2D};
21    /// assert_eq!(f2d!("x+y^2").derive('y', 2), f2d!("2"));
22    /// ```
23    pub fn derive(&self, var: char, order: usize) -> Self {
24        F2D(self.0.derive_nth(var, order))
25    }
26    /// Computes the gradient
27    /// ```
28    /// use ruut_functions::{f2d, F2D};
29    /// assert_eq!(f2d!("x+y^2").gradient(), vec![f2d!("1"), f2d!("2y")]);
30    /// ```
31    pub fn gradient(&self) -> Vec<Self> {
32        vec![self.derive('x', 1), self.derive('y', 1)]
33    }
34
35    /// Computes the hessian matrix
36    /// ```
37    /// use ruut_functions::{f2d, F2D};
38    /// let f = f2d!("x^3+y^2");
39    /// let hessian = f.hessian();
40    /// assert_eq!(hessian, vec![vec![f2d!("6x"), f2d!("0")],
41    ///                          vec![f2d!("0"), f2d!("2")]]);
42    /// ```
43    pub fn hessian(&self) -> Vec<Vec<Self>> {
44        vec![
45            vec![self.derive('x', 2), self.derive('x', 1).derive('y', 1)],
46            vec![self.derive('y', 1).derive('x', 1), self.derive('y', 2)],
47        ]
48    }
49}
50
51impl F3D {
52    /// Computes the nth-derivative
53    /// ```
54    /// use ruut_functions::{f3d, F3D};
55    /// assert_eq!(f3d!("x+zy^2").derive('y', 2), f3d!("2z"));
56    /// ```
57    pub fn derive(&self, var: char, order: usize) -> Self {
58        F3D(self.0.derive_nth(var, order))
59    }
60    /// Computes the gradient
61    /// ```
62    /// use ruut_functions::{f3d, F3D};
63    /// assert_eq!(f3d!("x+zy^2").gradient(), vec![f3d!("1"), f3d!("2yz"), f3d!("y^2")]);
64    /// ```
65    pub fn gradient(&self) -> Vec<Self> {
66        vec![
67            self.derive('x', 1),
68            self.derive('y', 1),
69            self.derive('z', 1),
70        ]
71    }
72    /// Computes the hessian
73    /// ```
74    /// use ruut_functions::{f3d, F3D};
75    /// assert_eq!(f3d!("x^3+zy^2").hessian(), vec![vec![f3d!("6x"), f3d!("0"), f3d!("0")],
76    ///                                             vec![f3d!("0"), f3d!("2z"), f3d!("2y")],
77    ///                                             vec![f3d!("0"), f3d!("2y"), f3d!("0")]]);
78    /// ```
79    pub fn hessian(&self) -> Vec<Vec<Self>> {
80        let dx = self.derive('x', 1);
81        let dy = self.derive('y', 1);
82        let dz = self.derive('z', 1);
83        vec![
84            vec![dx.derive('x', 1), dx.derive('y', 1), dx.derive('z', 1)],
85            vec![dy.derive('x', 1), dy.derive('y', 1), dy.derive('z', 1)],
86            vec![dz.derive('x', 1), dz.derive('y', 1), dz.derive('z', 1)],
87        ]
88    }
89}
90
91impl FND {
92    /// Computes the nth-derivative wrt a variable
93    /// ```
94    /// use ruut_functions::{fnd,FND};
95    /// let vars = ['f', 'z'];
96    /// assert_eq!(fnd!("f^2+z", &vars).derive('f', 1), fnd!("2f", &vars));
97    /// ```
98    pub fn derive(&self, var: char, order: usize) -> Self {
99        FND {
100            vars: self.vars.clone(),
101            func: self.func.derive_nth(var, order),
102        }
103    }
104
105    /// Computes the gradient
106    /// ```
107    /// use ruut_functions::{fnd,FND};
108    /// let vars = ['f', 'z'];
109    /// assert_eq!(fnd!("f^2+z", &vars).gradient(), vec![fnd!("2f", &vars), fnd!("1", &vars)]);
110    /// ```
111    pub fn gradient(&self) -> Vec<Self> {
112        let mut result = Vec::with_capacity(self.vars.len());
113        for var in &self.vars {
114            result.push(self.derive(*var, 1));
115        }
116        result
117    }
118    /// Computes the hessian
119    /// ```
120    /// use ruut_functions::{fnd,FND};
121    /// let vars = ['f', 'z'];
122    /// assert_eq!(fnd!("f^3+zf", &vars).hessian(), vec![vec![fnd!("6f", &vars), fnd!("1", &vars)],
123    ///                                                  vec![fnd!("1", &vars), fnd!("0", &vars)]]);
124    /// ```
125    pub fn hessian(&self) -> Vec<Vec<Self>> {
126        let mut result = Vec::new();
127
128        // first derivative
129        let mut first_deriv = Vec::with_capacity(self.vars.len());
130        for var in &self.vars {
131            first_deriv.push(self.derive(*var, 1));
132        }
133
134        for el in first_deriv {
135            let mut gradient = Vec::with_capacity(self.vars.len());
136
137            for var in &self.vars {
138                gradient.push(el.derive(*var, 1))
139            }
140            result.push(gradient);
141        }
142
143        result
144    }
145}
146
147impl Func {
148    fn derive_nth(&self, var: char, order: usize) -> Self {
149        let mut result = self.clone();
150        for _ in 1..=order {
151            result = result.derive(var);
152        }
153        result
154    }
155    fn derive(&self, var: char) -> Self {
156        let res = match self {
157            Self::Var(char) => {
158                if *char == var {
159                    Self::Num(1)
160                } else {
161                    Self::Num(0)
162                }
163            }
164            Self::Num(_) | Self::Param(..) => Self::Num(0),
165            Self::E | Self::PI => Self::Num(0),
166            Self::Add(add) => add.iter().map(|term| term.derive(var)).sum::<Self>(),
167            Self::Mul(mul) => {
168                let mut result = Func::Num(0);
169                for (i, term) in mul.iter().enumerate() {
170                    let mut multipliers = term.derive(var);
171                    for (j, other) in mul.iter().enumerate() {
172                        if i != j {
173                            multipliers *= other.clone()
174                        }
175                    }
176                    result += multipliers;
177                }
178
179                result
180            }
181            Self::Pow(base, exp) => {
182                if let Func::E = **base {
183                    return exp.derive(var) * self.clone();
184                }
185                if let Func::Num(exp_val) = **exp {
186                    return exp_val * base.derive(var) * base.clone().powi(exp_val - 1);
187                }
188                (Func::E.pow(*exp.clone() * Self::S(FType::Ln, base.clone()))).derive(var)
189            }
190            Self::S(kind, argument) => {
191                let argument = Box::new(*argument.clone());
192                let arg = argument.derive(var);
193
194                match kind {
195                    FType::Ln => arg / *argument,
196                    FType::Sin => arg * Func::S(FType::Cos, argument),
197                    FType::Cos => -1 * arg * Func::S(FType::Sin, argument),
198                    FType::Tan => arg * Func::S(FType::Sec, argument).powi(2),
199                    FType::Cot => -1 * arg * (Func::S(FType::Csc, argument)).powi(2),
200                    FType::Sec => {
201                        arg * Func::S(FType::Sec, argument.clone()) * Func::S(FType::Tan, argument)
202                    }
203                    FType::Csc => {
204                        -1 * arg
205                            * Func::S(FType::Cot, argument.clone())
206                            * Func::S(FType::Csc, argument)
207                    }
208                    FType::ASin => arg / (1 - argument.powi(2)).pow(Func::Num(1) / Func::Num(2)),
209                    FType::ACos => {
210                        -1 * arg / (1 - argument.powi(2)).pow(Func::Num(1) / Func::Num(2))
211                    }
212                    FType::ATan => arg / (1 + argument.powi(2)),
213                    FType::Sinh => arg * Func::S(FType::Cosh, argument),
214                    FType::Cosh => arg * Func::S(FType::Sinh, argument),
215                    FType::Tanh => arg * Func::S(FType::Sech, argument).powi(2),
216                    FType::Coth => -1 * arg * Func::S(FType::Csch, argument).powi(2),
217                    FType::Sech => {
218                        -1 * arg
219                            * Func::S(FType::Sech, argument.clone())
220                            * Func::S(FType::Tanh, argument)
221                    }
222                    FType::Csch => {
223                        -1 * arg
224                            * Func::S(FType::Csch, argument.clone())
225                            * Func::S(FType::Coth, argument)
226                    }
227                    FType::ASinh => arg / (1 + argument.powi(2)).pow(Func::Num(1) / Func::Num(2)),
228                    FType::ACosh => arg / (argument.powi(2) - 1).pow(Func::Num(1) / Func::Num(2)),
229                    FType::ATanh => arg / (1 - argument.powi(2)),
230                    FType::Abs => arg * *argument.clone() / Func::S(FType::Abs, argument),
231                }
232            }
233        };
234        res
235    }
236}
237
238#[test]
239fn test_derive() {
240    use crate::{f1d, f2d, f3d, fnd};
241
242    assert_eq!(
243        f1d!("x+ln(x)+x^2+sin(2x)").derive(1),
244        f1d!("1+1/x+2cos(2x)+2x")
245    );
246
247    assert_eq!(f1d!("3x+7+e").derive(1), f1d!("3"));
248    assert_eq!(f1d!("xsin(x)").derive(1), f1d!("sin(x)+xcos(x)"));
249    assert_eq!(f1d!("tan(x^2)").derive(1), f1d!("2xsec(x^2)^2"));
250    assert_eq!(f1d!("x^x").derive(1), f1d!("(ln(x)+1)e^(xln(x))"));
251    assert_eq!(
252        f3d!("xyz^2").gradient(),
253        vec![f3d!("yz^2"), f3d!("xz^2"), f3d!("2xyz")]
254    );
255    // assert_eq!(f1d!("x/(x+1)").derive(1), f1d!("1/(x+1)^2"));
256    assert_eq!(f1d!("1/(3e*x^2)").derive(1), f1d!("-2/(3e*x^3)"));
257    assert_eq!(f1d!("cos(x)").derive(1), f1d!("-sin(x)"));
258    assert_eq!(f1d!("sin(x)").derive(1), f1d!("cos(x)"));
259    assert_eq!(f1d!("cot(x)").derive(1), f1d!("-csc(x)^2"));
260    assert_eq!(f1d!("sec(x)").derive(1), f1d!("sec(x)tan(x)"));
261    assert_eq!(f1d!("csc(x)").derive(1), f1d!("-csc(x)cot(x)"));
262    assert_eq!(f1d!("asin(x)").derive(1), f1d!("1/(1-x^2)^(1/2)"));
263    assert_eq!(f1d!("acos(x)").derive(1), f1d!("-1/(1-x^2)^(1/2)"));
264    assert_eq!(f1d!("atan(x)").derive(1), f1d!("1/(1+x^2)"));
265    assert_eq!(f1d!("sinh(x)").derive(1), f1d!("cosh(x)"));
266    assert_eq!(f1d!("cosh(x)").derive(1), f1d!("sinh(x)"));
267    assert_eq!(f1d!("tanh(x)").derive(1), f1d!("sech(x)^2"));
268    assert_eq!(f1d!("coth(x)").derive(1), f1d!("-csch(x)^2"));
269    assert_eq!(f1d!("sech(x)").derive(1), f1d!("-tanh(x)sech(x)"));
270    assert_eq!(f1d!("csch(x)").derive(1), f1d!("-csch(x)coth(x)"));
271    assert_eq!(f1d!("asinh(x)").derive(1), f1d!("1/(1+x^2)^(1/2)"));
272    assert_eq!(f1d!("acosh(x)").derive(1), f1d!("1/(x^2-1)^(1/2)"));
273    assert_eq!(f1d!("atanh(x)").derive(1), f1d!("1/(1-x^2)"));
274    assert_eq!(f1d!("abs(x)").derive(1), f1d!("x/abs(x)"));
275
276    // F2D
277    assert_eq!(f2d!("xy+y^2").gradient(), vec![f2d!("y"), f2d!("x+2y")]);
278    assert_eq!(
279        f2d!("xy+y^2").hessian(),
280        vec![vec![f2d!("0"), f2d!("1")], vec![f2d!("1"), f2d!("2")]]
281    );
282
283    // F3D
284    assert_eq!(
285        f3d!("xy+y^2+1/z").gradient(),
286        vec![f3d!("y"), f3d!("x+2y"), f3d!("-1/z^2")]
287    );
288    assert_eq!(
289        f3d!("xy+y^2+1/z").hessian(),
290        vec![
291            vec![f3d!("0"), f3d!("1"), f3d!("0")],
292            vec![f3d!("1"), f3d!("2"), f3d!("0")],
293            vec![f3d!("0"), f3d!("0"), f3d!("2/z^3")]
294        ]
295    );
296
297    // FND
298    let v = ['w', 'f'];
299    assert_eq!(
300        fnd!("w+f^2", &v).gradient(),
301        vec![fnd!("1", &v), fnd!("2f", &v)]
302    );
303
304    assert_eq!(
305        fnd!("w+f^2", &v).hessian(),
306        vec![
307            vec![fnd!("0", &v), fnd!("0", &v)],
308            vec![fnd!("0", &v), fnd!("2", &v)]
309        ]
310    )
311}