Skip to main content

least_squares/
least_squares.rs

1use trellis_runner::{
2    CancellationGuard, GenerateBuilder, MaxIterationPolicy, Procedure, Progress,
3    ProgressDiagnostics, StagnationPolicy, UserState,
4};
5
6#[derive(Clone)]
7pub struct LinearRegressionProblem {
8    pub data: Vec<(f64, f64)>,
9}
10
11#[derive(Clone, Debug)]
12pub struct LSState {
13    a: f64,
14    b: f64,
15    loss: f64,
16}
17
18impl Default for LSState {
19    fn default() -> Self {
20        Self {
21            a: 0.0,
22            b: 0.0,
23            loss: f64::INFINITY,
24        }
25    }
26}
27
28impl UserState for LSState {
29    type Float = f64;
30
31    fn progress(&self) -> Progress<Self::Float> {
32        Progress::Report {
33            measure: self.loss,
34            diagnostics: ProgressDiagnostics {
35                gradient_norm: Some((self.a.powi(2) + self.b.powi(2)).sqrt()),
36                step_size: Some(0.01),
37                ..Default::default()
38            },
39        }
40    }
41}
42
43/// Simple linear regression via gradient descent
44pub struct LeastSquares;
45
46impl Procedure<LinearRegressionProblem> for LeastSquares {
47    type State = LSState;
48    type Output = (f64, f64);
49
50    const NAME: &'static str = "Least Squares Optimisation";
51
52    fn step(
53        &self,
54        problem: &mut LinearRegressionProblem,
55        state: &mut Self::State,
56        _guard: CancellationGuard<'_>,
57    ) {
58        let lr = 0.01;
59        let mut da = 0.0;
60        let mut db = 0.0;
61        let mut loss = 0.0;
62
63        for (x, y) in &problem.data {
64            let pred = state.a * *x + state.b;
65            let err = pred - *y;
66
67            loss += err * err;
68            da += err * *x;
69            db += err;
70        }
71
72        state.a -= lr * da;
73        state.b -= lr * db;
74        state.loss = loss;
75    }
76
77    fn finalise(&self, _: &mut LinearRegressionProblem, state: &Self::State) -> Self::Output {
78        (state.a, state.b)
79    }
80}
81
82fn main() {
83    let problem = LinearRegressionProblem {
84        data: vec![(1.0, 2.0), (2.0, 4.0), (3.0, 6.0)],
85    };
86
87    let result = LeastSquares
88        .build_for(problem)
89        .with_initial_state(LSState::default())
90        .and_policy(MaxIterationPolicy::new(3000))
91        .and_policy(StagnationPolicy::new(10))
92        .finalise()
93        .run();
94
95    println!("fit: {:?}", result);
96}