1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName};
2
3use crate::types::{MatA, VecN, MIN_DIAG_CLAMP};
4
5#[allow(clippy::needless_range_loop)] fn gamma_estimator<const NV: usize>(a2: &[[f32; NV]; NV], cond_target: f32) -> (f32, f32) {
7 let mut max_sig: f32 = 0.0;
8 for i in 0..NV {
9 let mut r: f32 = 0.0;
10 for j in 0..NV {
11 if j != i {
12 r += libm::fabsf(a2[i][j]);
13 }
14 }
15 let disk = a2[i][i] + r;
16 if max_sig < disk {
17 max_sig = disk;
18 }
19 }
20 (libm::sqrtf(max_sig / cond_target), max_sig)
21}
22
23#[allow(clippy::needless_range_loop)] pub fn setup_a<const NU: usize, const NV: usize, const NC: usize>(
29 b_mat: &MatA<NV, NU>,
30 wv: &VecN<NV>,
31 wu: &mut VecN<NU>,
32 theta: f32,
33 cond_bound: f32,
34) -> (MatA<NC, NU>, f32)
35where
36 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
37 Const<NU>: DimName,
38 Const<NV>: DimName,
39 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
40 + Allocator<Const<NC>, Const<NC>>
41 + Allocator<Const<NU>, Const<NU>>
42 + Allocator<Const<NC>>
43 + Allocator<Const<NU>>
44 + Allocator<Const<NV>>,
45{
46 debug_assert_eq!(NC, NU + NV);
47
48 let mut a2 = [[0.0f32; NV]; NV];
50 for i in 0..NV {
51 for j in i..NV {
52 let mut sum = 0.0f32;
53 for k in 0..NU {
54 sum += b_mat[(i, k)] * b_mat[(j, k)];
55 }
56 a2[i][j] = sum * wv[i] * wv[i];
57 if i != j {
58 a2[j][i] = a2[i][j];
59 }
60 }
61 }
62
63 let mut min_diag: f32 = f32::INFINITY;
65 let mut max_diag: f32 = 0.0;
66 for i in 0..NU {
67 if wu[i] < min_diag {
68 min_diag = wu[i];
69 }
70 if wu[i] > max_diag {
71 max_diag = wu[i];
72 }
73 }
74 if min_diag < MIN_DIAG_CLAMP {
75 min_diag = MIN_DIAG_CLAMP;
76 }
77 let inv = 1.0 / min_diag;
78 for i in 0..NU {
79 wu[i] *= inv;
80 }
81 max_diag *= inv;
82
83 let gamma = if cond_bound > 0.0 {
85 let (ge, ms) = gamma_estimator(&a2, cond_bound);
86 let gt = libm::sqrtf(ms) * theta / max_diag;
87 if ge > gt {
88 ge
89 } else {
90 gt
91 }
92 } else {
93 let (_, ms) = gamma_estimator(&a2, 1.0);
94 libm::sqrtf(ms) * theta / max_diag
95 };
96
97 let mut a: MatA<NC, NU> = MatA::zeros();
99 for j in 0..NU {
100 for i in 0..NV {
101 a[(i, j)] = wv[i] * b_mat[(i, j)];
102 }
103 a[(NV + j, j)] = gamma * wu[j];
104 }
105
106 (a, gamma)
107}
108
109pub fn setup_a_unreg<const NU: usize, const NV: usize>(
117 b_mat: &MatA<NV, NU>,
118 wv: &VecN<NV>,
119) -> MatA<NV, NU>
120where
121 Const<NV>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
122 Const<NU>: DimName,
123 DefaultAllocator: Allocator<Const<NV>, Const<NU>>
124 + Allocator<Const<NV>, Const<NV>>
125 + Allocator<Const<NU>, Const<NU>>
126 + Allocator<Const<NV>>
127 + Allocator<Const<NU>>,
128{
129 let mut a: MatA<NV, NU> = MatA::zeros();
130 for j in 0..NU {
131 for i in 0..NV {
132 a[(i, j)] = wv[i] * b_mat[(i, j)];
133 }
134 }
135 a
136}
137
138pub fn setup_b_unreg<const NV: usize>(v: &VecN<NV>, wv: &VecN<NV>) -> VecN<NV>
144where
145 Const<NV>: DimName,
146 DefaultAllocator: Allocator<Const<NV>>,
147{
148 let mut b: VecN<NV> = VecN::zeros();
149 for i in 0..NV {
150 b[i] = wv[i] * v[i];
151 }
152 b
153}
154
155pub fn setup_b<const NU: usize, const NV: usize, const NC: usize>(
157 v: &VecN<NV>,
158 ud: &VecN<NU>,
159 wv: &VecN<NV>,
160 wu_norm: &VecN<NU>,
161 gamma: f32,
162) -> VecN<NC>
163where
164 Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
165 Const<NU>: DimName,
166 Const<NV>: DimName,
167 DefaultAllocator: Allocator<Const<NC>, Const<NU>>
168 + Allocator<Const<NC>, Const<NC>>
169 + Allocator<Const<NU>, Const<NU>>
170 + Allocator<Const<NC>>
171 + Allocator<Const<NU>>
172 + Allocator<Const<NV>>,
173{
174 debug_assert_eq!(NC, NU + NV);
175 let mut b: VecN<NC> = VecN::zeros();
176 for i in 0..NV {
177 b[i] = wv[i] * v[i];
178 }
179 for i in 0..NU {
180 b[NV + i] = gamma * wu_norm[i] * ud[i];
181 }
182 b
183}