ruvector_math/optimization/
sos.rs1use super::polynomial::{Polynomial, Monomial, Term};
6
7#[derive(Debug, Clone)]
9pub struct SOSConfig {
10 pub max_iters: usize,
12 pub tolerance: f64,
14 pub regularization: f64,
16}
17
18impl Default for SOSConfig {
19 fn default() -> Self {
20 Self {
21 max_iters: 100,
22 tolerance: 1e-8,
23 regularization: 1e-6,
24 }
25 }
26}
27
28#[derive(Debug, Clone)]
30pub enum SOSResult {
31 IsSOS(SOSDecomposition),
33 Unknown,
35 NotSOS { witness: Vec<f64> },
37}
38
39#[derive(Debug, Clone)]
41pub struct SOSDecomposition {
42 pub squares: Vec<Polynomial>,
44 pub gram_matrix: Vec<f64>,
46 pub basis: Vec<Monomial>,
48}
49
50impl SOSDecomposition {
51 pub fn verify(&self, original: &Polynomial, tol: f64) -> bool {
53 let reconstructed = self.reconstruct();
54
55 for (m, &c) in original.terms() {
57 let c_rec = reconstructed.coeff(m);
58 if (c - c_rec).abs() > tol {
59 return false;
60 }
61 }
62
63 for (m, &c) in reconstructed.terms() {
65 if c.abs() > tol && original.coeff(m).abs() < tol {
66 return false;
67 }
68 }
69
70 true
71 }
72
73 pub fn reconstruct(&self) -> Polynomial {
75 let mut result = Polynomial::zero();
76 for q in &self.squares {
77 result = result.add(&q.square());
78 }
79 result
80 }
81
82 pub fn lower_bound(&self) -> f64 {
84 0.0 }
86}
87
88pub struct SOSChecker {
90 config: SOSConfig,
91}
92
93impl SOSChecker {
94 pub fn new(config: SOSConfig) -> Self {
96 Self { config }
97 }
98
99 pub fn default() -> Self {
101 Self::new(SOSConfig::default())
102 }
103
104 pub fn check(&self, p: &Polynomial) -> SOSResult {
106 let degree = p.degree();
107 if degree == 0 {
108 let c = p.eval(&[]);
110 if c >= 0.0 {
111 return SOSResult::IsSOS(SOSDecomposition {
112 squares: vec![Polynomial::constant(c.sqrt())],
113 gram_matrix: vec![c],
114 basis: vec![Monomial::one()],
115 });
116 } else {
117 return SOSResult::NotSOS { witness: vec![] };
118 }
119 }
120
121 if degree % 2 == 1 {
122 let witness = self.find_negative_witness(p);
125 if let Some(w) = witness {
126 return SOSResult::NotSOS { witness: w };
127 }
128 return SOSResult::Unknown;
129 }
130
131 let half_degree = degree / 2;
133 let num_vars = p.num_variables();
134
135 let basis = Polynomial::monomials_up_to_degree(num_vars, half_degree);
137 let n = basis.len();
138
139 if n == 0 {
140 return SOSResult::Unknown;
141 }
142
143 match self.find_gram_matrix(p, &basis) {
146 Some(gram) => {
147 if self.is_psd(&gram, n) {
149 let squares = self.extract_squares(&gram, &basis, n);
150 SOSResult::IsSOS(SOSDecomposition {
151 squares,
152 gram_matrix: gram,
153 basis,
154 })
155 } else {
156 SOSResult::Unknown
157 }
158 }
159 None => {
160 let witness = self.find_negative_witness(p);
162 if let Some(w) = witness {
163 SOSResult::NotSOS { witness: w }
164 } else {
165 SOSResult::Unknown
166 }
167 }
168 }
169 }
170
171 fn find_gram_matrix(&self, p: &Polynomial, basis: &[Monomial]) -> Option<Vec<f64>> {
173 let n = basis.len();
174
175 if n <= 10 {
184 return self.find_gram_direct(p, basis);
185 }
186
187 self.find_gram_iterative(p, basis)
188 }
189
190 fn find_gram_direct(&self, p: &Polynomial, basis: &[Monomial]) -> Option<Vec<f64>> {
192 let n = basis.len();
193
194 let c0 = p.coeff(&Monomial::one());
196 let scale = (c0.abs() + 1.0) / n as f64;
197
198 let mut gram = vec![0.0; n * n];
199 for i in 0..n {
200 gram[i * n + i] = scale;
201 }
202
203 for _ in 0..self.config.max_iters {
205 let mut recon_terms = std::collections::HashMap::new();
207 for i in 0..n {
208 for j in 0..n {
209 let m = basis[i].mul(&basis[j]);
210 *recon_terms.entry(m).or_insert(0.0) += gram[i * n + j];
211 }
212 }
213
214 let mut max_err = 0.0f64;
216 for (m, &c_target) in p.terms() {
217 let c_current = *recon_terms.get(m).unwrap_or(&0.0);
218 max_err = max_err.max((c_target - c_current).abs());
219 }
220
221 if max_err < self.config.tolerance {
222 return Some(gram);
223 }
224
225 let step = 0.1;
227 for i in 0..n {
228 for j in 0..n {
229 let m = basis[i].mul(&basis[j]);
230 let c_target = p.coeff(&m);
231 let c_current = *recon_terms.get(&m).unwrap_or(&0.0);
232 let err = c_target - c_current;
233
234 let count = self.count_pairs(&basis, &m);
236 if count > 0 {
237 gram[i * n + j] += step * err / count as f64;
238 }
239 }
240 }
241
242 for i in 0..n {
244 for j in i + 1..n {
245 let avg = (gram[i * n + j] + gram[j * n + i]) / 2.0;
246 gram[i * n + j] = avg;
247 gram[j * n + i] = avg;
248 }
249 }
250
251 for i in 0..n {
253 gram[i * n + i] = gram[i * n + i].max(self.config.regularization);
254 }
255 }
256
257 None
258 }
259
260 fn find_gram_iterative(&self, p: &Polynomial, basis: &[Monomial]) -> Option<Vec<f64>> {
261 self.find_gram_direct(p, basis)
263 }
264
265 fn count_pairs(&self, basis: &[Monomial], target: &Monomial) -> usize {
266 let n = basis.len();
267 let mut count = 0;
268 for i in 0..n {
269 for j in 0..n {
270 if basis[i].mul(&basis[j]) == *target {
271 count += 1;
272 }
273 }
274 }
275 count
276 }
277
278 fn is_psd(&self, gram: &[f64], n: usize) -> bool {
280 let mut l = vec![0.0; n * n];
282
283 for i in 0..n {
284 for j in 0..=i {
285 let mut sum = gram[i * n + j];
286 for k in 0..j {
287 sum -= l[i * n + k] * l[j * n + k];
288 }
289
290 if i == j {
291 if sum < -self.config.tolerance {
292 return false;
293 }
294 l[i * n + j] = sum.max(0.0).sqrt();
295 } else {
296 let ljj = l[j * n + j];
297 l[i * n + j] = if ljj > self.config.tolerance {
298 sum / ljj
299 } else {
300 0.0
301 };
302 }
303 }
304 }
305
306 true
307 }
308
309 fn extract_squares(&self, gram: &[f64], basis: &[Monomial], n: usize) -> Vec<Polynomial> {
311 let mut l = vec![0.0; n * n];
313
314 for i in 0..n {
315 for j in 0..=i {
316 let mut sum = gram[i * n + j];
317 for k in 0..j {
318 sum -= l[i * n + k] * l[j * n + k];
319 }
320
321 if i == j {
322 l[i * n + j] = sum.max(0.0).sqrt();
323 } else {
324 let ljj = l[j * n + j];
325 l[i * n + j] = if ljj > 1e-15 { sum / ljj } else { 0.0 };
326 }
327 }
328 }
329
330 let mut squares = Vec::new();
332 for j in 0..n {
333 let terms: Vec<Term> = (0..n)
334 .filter(|&i| l[i * n + j].abs() > 1e-15)
335 .map(|i| Term {
336 coeff: l[i * n + j],
337 monomial: basis[i].clone(),
338 })
339 .collect();
340
341 if !terms.is_empty() {
342 squares.push(Polynomial::from_terms(terms));
343 }
344 }
345
346 squares
347 }
348
349 fn find_negative_witness(&self, p: &Polynomial) -> Option<Vec<f64>> {
351 let n = p.num_variables().max(1);
352
353 let grid = [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0];
355
356 fn recurse(
357 p: &Polynomial,
358 current: &mut Vec<f64>,
359 depth: usize,
360 n: usize,
361 grid: &[f64],
362 ) -> Option<Vec<f64>> {
363 if depth == n {
364 if p.eval(current) < -1e-10 {
365 return Some(current.clone());
366 }
367 return None;
368 }
369
370 for &v in grid {
371 current.push(v);
372 if let Some(w) = recurse(p, current, depth + 1, n, grid) {
373 return Some(w);
374 }
375 current.pop();
376 }
377
378 None
379 }
380
381 let mut current = Vec::new();
382 recurse(p, &mut current, 0, n, &grid)
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_constant_sos() {
392 let p = Polynomial::constant(4.0);
393 let checker = SOSChecker::default();
394
395 match checker.check(&p) {
396 SOSResult::IsSOS(decomp) => {
397 assert!(decomp.verify(&p, 1e-6));
398 }
399 _ => panic!("4.0 should be SOS"),
400 }
401 }
402
403 #[test]
404 fn test_negative_constant_not_sos() {
405 let p = Polynomial::constant(-1.0);
406 let checker = SOSChecker::default();
407
408 match checker.check(&p) {
409 SOSResult::NotSOS { .. } => {}
410 _ => panic!("-1.0 should not be SOS"),
411 }
412 }
413
414 #[test]
415 fn test_square_is_sos() {
416 let x = Polynomial::var(0);
418 let y = Polynomial::var(1);
419 let p = x.add(&y).square();
420
421 let checker = SOSChecker::default();
422
423 match checker.check(&p) {
424 SOSResult::IsSOS(decomp) => {
425 let recon = decomp.reconstruct();
427 for pt in [vec![1.0, 1.0], vec![2.0, -1.0], vec![0.0, 3.0]] {
428 let diff = (p.eval(&pt) - recon.eval(&pt)).abs();
429 assert!(diff < 1.0, "Reconstruction error too large: {}", diff);
430 }
431 }
432 SOSResult::Unknown => {
433 for pt in [vec![1.0, 1.0], vec![2.0, -1.0], vec![0.0, 3.0]] {
436 assert!(p.eval(&pt) >= 0.0, "(x+y)² should be >= 0");
437 }
438 }
439 SOSResult::NotSOS { witness } => {
440 panic!("(x+y)² incorrectly marked as not SOS with witness {:?}", witness);
442 }
443 }
444 }
445
446 #[test]
447 fn test_x_squared_plus_one() {
448 let x = Polynomial::var(0);
450 let p = x.square().add(&Polynomial::constant(1.0));
451
452 let checker = SOSChecker::default();
453
454 match checker.check(&p) {
455 SOSResult::IsSOS(_) => {}
456 SOSResult::Unknown => {} SOSResult::NotSOS { .. } => panic!("x² + 1 should be SOS"),
458 }
459 }
460}