1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName};
2
3use crate::types::{MatA, VecN, MIN_DIAG_CLAMP};
4
5#[allow(clippy::needless_range_loop)] fn gamma_estimator<const NV: usize>(a2: &[[f32; NV]; NV], cond_target: f32) -> (f32, f32) {
7 let mut max_sig: f32 = 0.0;
8 for i in 0..NV {
9 let mut r: f32 = 0.0;
10 for j in 0..NV {
11 if j != i {
12 r += libm::fabsf(a2[i][j]);
13 }
14 }
15 let disk = a2[i][i] + r;
16 if max_sig < disk {
17 max_sig = disk;
18 }
19 }
20 (libm::sqrtf(max_sig / cond_target), max_sig)
21}
22
23#[allow(clippy::needless_range_loop)] pub fn setup_a<const NU: usize, const NV: usize, const NC: usize>(
29 b_mat: &MatA<NV, NU>,
30 wv: &VecN<NV>,
31 wu: &mut VecN<NU>,
32 theta: f32,
33 cond_bound: f32,
34) -> (MatA<NC, NU>, f32)
35where
36 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
37 Const<NU>: DimName,
38 Const<NV>: DimName,
39 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
40 + Allocator<Const<NC>, Const<NC>>
41 + Allocator<Const<NU>, Const<NU>>
42 + Allocator<Const<NC>>
43 + Allocator<Const<NU>>
44 + Allocator<Const<NV>>,
45{
46 debug_assert_eq!(NC, NU + NV);
47
48 let mut a2 = [[0.0f32; NV]; NV];
50 for i in 0..NV {
51 for j in i..NV {
52 let mut sum = 0.0f32;
53 for k in 0..NU {
54 sum += b_mat[(i, k)] * b_mat[(j, k)];
55 }
56 a2[i][j] = sum * wv[i] * wv[i];
57 if i != j {
58 a2[j][i] = a2[i][j];
59 }
60 }
61 }
62
63 let mut min_diag: f32 = f32::INFINITY;
65 let mut max_diag: f32 = 0.0;
66 for i in 0..NU {
67 if wu[i] < min_diag {
68 min_diag = wu[i];
69 }
70 if wu[i] > max_diag {
71 max_diag = wu[i];
72 }
73 }
74 if min_diag < MIN_DIAG_CLAMP {
75 min_diag = MIN_DIAG_CLAMP;
76 }
77 let inv = 1.0 / min_diag;
78 for i in 0..NU {
79 wu[i] *= inv;
80 }
81 max_diag *= inv;
82
83 let gamma = if cond_bound > 0.0 {
85 let (ge, ms) = gamma_estimator(&a2, cond_bound);
86 let gt = libm::sqrtf(ms) * theta / max_diag;
87 if ge > gt {
88 ge
89 } else {
90 gt
91 }
92 } else {
93 let (_, ms) = gamma_estimator(&a2, 1.0);
94 libm::sqrtf(ms) * theta / max_diag
95 };
96
97 let mut a: MatA<NC, NU> = MatA::zeros();
99 for j in 0..NU {
100 for i in 0..NV {
101 a[(i, j)] = wv[i] * b_mat[(i, j)];
102 }
103 a[(NV + j, j)] = gamma * wu[j];
104 }
105
106 (a, gamma)
107}
108
109pub fn setup_b<const NU: usize, const NV: usize, const NC: usize>(
111 v: &VecN<NV>,
112 ud: &VecN<NU>,
113 wv: &VecN<NV>,
114 wu_norm: &VecN<NU>,
115 gamma: f32,
116) -> VecN<NC>
117where
118 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
119 Const<NU>: DimName,
120 Const<NV>: DimName,
121 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
122 + Allocator<Const<NC>, Const<NC>>
123 + Allocator<Const<NU>, Const<NU>>
124 + Allocator<Const<NC>>
125 + Allocator<Const<NU>>
126 + Allocator<Const<NV>>,
127{
128 debug_assert_eq!(NC, NU + NV);
129 let mut b: VecN<NC> = VecN::zeros();
130 for i in 0..NV {
131 b[i] = wv[i] * v[i];
132 }
133 for i in 0..NU {
134 b[NV + i] = gamma * wu_norm[i] * ud[i];
135 }
136 b
137}