Skip to main content

tensorlogic_compiler/symbolic_diff/
api.rs

1//! Top-level public entry points (differentiate / jacobian).
2
3use tensorlogic_ir::TLExpr;
4
5use super::diff_core::diff_expr;
6use super::helpers::simplify_derivative;
7use super::types::{DiffConfig, DiffContext, DiffError, DiffResult};
8
9/// Symbolically differentiate `expr` with respect to the variable named `var`.
10///
11/// # Differentiation rules
12///
13/// - `d(c)/dx = 0` for any constant `c`
14/// - `d(x)/dx = 1`
15/// - `d(y)/dx = 0` for `y ≠ x`
16/// - Sum rule: `d(a + b)/dx = d(a)/dx + d(b)/dx`
17/// - Product rule: `d(a * b)/dx = a * d(b)/dx + b * d(a)/dx`
18/// - Quotient rule: `d(a / b)/dx = (d(a)/dx * b − a * d(b)/dx) / b²`
19/// - Power rule: `d(a^n)/dx = n * a^(n−1) * d(a)/dx` (when exponent is a constant)
20/// - Chain rule applies to transcendental unary functions
21/// - Logical AND: `d(AND(a,b))/dx = AND(d(a)/dx, b) OR AND(a, d(b)/dx)`
22/// - Logical OR: `d(OR(a,b))/dx = OR(d(a)/dx, d(b)/dx)`
23/// - Logical NOT: `d(NOT(a))/dx = NOT(d(a)/dx)`
24/// - Implication: expanded as `NOT(a) OR b` before differentiating
25/// - Quantifiers: bound variable shadowed; derivative of body is returned
26/// - Let-binding: full chain-rule expansion via d(body)/d(bound) * d(value)/dx
27///
28/// # Errors
29///
30/// Returns `DiffError::MaxDepthExceeded` if the expression tree exceeds
31/// `config.max_expr_depth`. Returns `DiffError::ExprTooComplex` if an
32/// unsupported node is encountered and `config.error_on_unsupported` is true.
33pub fn differentiate(
34    expr: &TLExpr,
35    var: &str,
36    config: &DiffConfig,
37) -> Result<DiffResult, DiffError> {
38    let mut ctx = DiffContext {
39        var: var.to_string(),
40        config,
41        depth: 0,
42        unsupported_nodes: Vec::new(),
43    };
44    let derivative = diff_expr(expr, &mut ctx)?;
45    let (final_expr, simplified) = if config.simplify_result {
46        (simplify_derivative(derivative), true)
47    } else {
48        (derivative, false)
49    };
50    Ok(DiffResult {
51        derivative: final_expr,
52        simplified,
53        unsupported_nodes: ctx.unsupported_nodes,
54    })
55}
56
57/// Compute the Jacobian: differentiate `expr` with respect to each variable in `vars`.
58///
59/// Returns a vector of `(variable_name, DiffResult)` pairs in the same order as `vars`.
60pub fn jacobian(
61    expr: &TLExpr,
62    vars: &[&str],
63    config: &DiffConfig,
64) -> Result<Vec<(String, DiffResult)>, DiffError> {
65    let mut results = Vec::with_capacity(vars.len());
66    for &v in vars {
67        let result = differentiate(expr, v, config)?;
68        results.push((v.to_string(), result));
69    }
70    Ok(results)
71}