Skip to main content

wls_alloc/
setup.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName};
2
3use crate::types::{MatA, VecN, MIN_DIAG_CLAMP};
4
5#[allow(clippy::needless_range_loop)] // 2D symmetric matrix access a2[i][j]
6fn gamma_estimator<const NV: usize>(a2: &[[f32; NV]; NV], cond_target: f32) -> (f32, f32) {
7    let mut max_sig: f32 = 0.0;
8    for i in 0..NV {
9        let mut r: f32 = 0.0;
10        for j in 0..NV {
11            if j != i {
12                r += libm::fabsf(a2[i][j]);
13            }
14        }
15        let disk = a2[i][i] + r;
16        if max_sig < disk {
17            max_sig = disk;
18        }
19    }
20    (libm::sqrtf(max_sig / cond_target), max_sig)
21}
22
23/// Convert WLS control allocation to a least-squares problem `min ||Au - b||`.
24///
25/// `wu` is **normalized in-place** by its minimum value (matching the C code).
26/// Returns `(A, gamma)`.
27#[allow(clippy::needless_range_loop)] // symmetric matrix fill uses a2[i][j] and a2[j][i]
28pub fn setup_a<const NU: usize, const NV: usize, const NC: usize>(
29    b_mat: &MatA<NV, NU>,
30    wv: &VecN<NV>,
31    wu: &mut VecN<NU>,
32    theta: f32,
33    cond_bound: f32,
34) -> (MatA<NC, NU>, f32)
35where
36    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
37    Const<NU>: DimName,
38    Const<NV>: DimName,
39    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
40        + Allocator<Const<NC>, Const<NC>>
41        + Allocator<Const<NU>, Const<NU>>
42        + Allocator<Const<NC>>
43        + Allocator<Const<NU>>
44        + Allocator<Const<NV>>,
45{
46    debug_assert_eq!(NC, NU + NV);
47
48    // Compute A2[i][j] — symmetric NV×NV Gershgorin scratch
49    let mut a2 = [[0.0f32; NV]; NV];
50    for i in 0..NV {
51        for j in i..NV {
52            let mut sum = 0.0f32;
53            for k in 0..NU {
54                sum += b_mat[(i, k)] * b_mat[(j, k)];
55            }
56            a2[i][j] = sum * wv[i] * wv[i];
57            if i != j {
58                a2[j][i] = a2[i][j];
59            }
60        }
61    }
62
63    // Normalise Wu
64    let mut min_diag: f32 = f32::INFINITY;
65    let mut max_diag: f32 = 0.0;
66    for i in 0..NU {
67        if wu[i] < min_diag {
68            min_diag = wu[i];
69        }
70        if wu[i] > max_diag {
71            max_diag = wu[i];
72        }
73    }
74    if min_diag < MIN_DIAG_CLAMP {
75        min_diag = MIN_DIAG_CLAMP;
76    }
77    let inv = 1.0 / min_diag;
78    for i in 0..NU {
79        wu[i] *= inv;
80    }
81    max_diag *= inv;
82
83    // Compute gamma
84    let gamma = if cond_bound > 0.0 {
85        let (ge, ms) = gamma_estimator(&a2, cond_bound);
86        let gt = libm::sqrtf(ms) * theta / max_diag;
87        if ge > gt {
88            ge
89        } else {
90            gt
91        }
92    } else {
93        let (_, ms) = gamma_estimator(&a2, 1.0);
94        libm::sqrtf(ms) * theta / max_diag
95    };
96
97    // Build A via nalgebra
98    let mut a: MatA<NC, NU> = MatA::zeros();
99    for j in 0..NU {
100        for i in 0..NV {
101            a[(i, j)] = wv[i] * b_mat[(i, j)];
102        }
103        a[(NV + j, j)] = gamma * wu[j];
104    }
105
106    (a, gamma)
107}
108
109/// Build the unregularised coefficient matrix `A = Wv · G`.
110///
111/// Returns an `NV × NU` matrix suitable for [`solve_cls`] with `NC = NV`.
112/// No regularisation rows are appended, so this is appropriate when the
113/// regularisation term `γ ‖Wu (u − u_pref)‖²` is not desired.
114///
115/// [`solve_cls`]: crate::solve_cls
116pub fn setup_a_unreg<const NU: usize, const NV: usize>(
117    b_mat: &MatA<NV, NU>,
118    wv: &VecN<NV>,
119) -> MatA<NV, NU>
120where
121    Const<NV>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
122    Const<NU>: DimName,
123    DefaultAllocator: Allocator<Const<NV>, Const<NU>>
124        + Allocator<Const<NV>, Const<NV>>
125        + Allocator<Const<NU>, Const<NU>>
126        + Allocator<Const<NV>>
127        + Allocator<Const<NU>>,
128{
129    let mut a: MatA<NV, NU> = MatA::zeros();
130    for j in 0..NU {
131        for i in 0..NV {
132            a[(i, j)] = wv[i] * b_mat[(i, j)];
133        }
134    }
135    a
136}
137
138/// Build the unregularised right-hand side `b = Wv · v`.
139///
140/// Returns an `NV`-element vector suitable for [`solve_cls`] with `NC = NV`.
141///
142/// [`solve_cls`]: crate::solve_cls
143pub fn setup_b_unreg<const NV: usize>(v: &VecN<NV>, wv: &VecN<NV>) -> VecN<NV>
144where
145    Const<NV>: DimName,
146    DefaultAllocator: Allocator<Const<NV>>,
147{
148    let mut b: VecN<NV> = VecN::zeros();
149    for i in 0..NV {
150        b[i] = wv[i] * v[i];
151    }
152    b
153}
154
155/// Compute the right-hand side `b` for the LS problem.
156pub fn setup_b<const NU: usize, const NV: usize, const NC: usize>(
157    v: &VecN<NV>,
158    ud: &VecN<NU>,
159    wv: &VecN<NV>,
160    wu_norm: &VecN<NU>,
161    gamma: f32,
162) -> VecN<NC>
163where
164    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
165    Const<NU>: DimName,
166    Const<NV>: DimName,
167    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
168        + Allocator<Const<NC>, Const<NC>>
169        + Allocator<Const<NU>, Const<NU>>
170        + Allocator<Const<NC>>
171        + Allocator<Const<NU>>
172        + Allocator<Const<NV>>,
173{
174    debug_assert_eq!(NC, NU + NV);
175    let mut b: VecN<NC> = VecN::zeros();
176    for i in 0..NV {
177        b[i] = wv[i] * v[i];
178    }
179    for i in 0..NU {
180        b[NV + i] = gamma * wu_norm[i] * ud[i];
181    }
182    b
183}