ruvector_math/optimization/
sos.rs

1//! Sum-of-Squares Decomposition
2//!
3//! Check if a polynomial can be written as a sum of squared polynomials.
4
5use super::polynomial::{Polynomial, Monomial, Term};
6
7/// SOS decomposition configuration
8#[derive(Debug, Clone)]
9pub struct SOSConfig {
10    /// Maximum iterations for SDP solver
11    pub max_iters: usize,
12    /// Convergence tolerance
13    pub tolerance: f64,
14    /// Regularization parameter
15    pub regularization: f64,
16}
17
18impl Default for SOSConfig {
19    fn default() -> Self {
20        Self {
21            max_iters: 100,
22            tolerance: 1e-8,
23            regularization: 1e-6,
24        }
25    }
26}
27
28/// Result of SOS decomposition
29#[derive(Debug, Clone)]
30pub enum SOSResult {
31    /// Polynomial is SOS with given decomposition
32    IsSOS(SOSDecomposition),
33    /// Could not verify SOS (may or may not be SOS)
34    Unknown,
35    /// Polynomial is definitely not SOS (has negative value somewhere)
36    NotSOS { witness: Vec<f64> },
37}
38
39/// SOS decomposition: p = Σ q_i²
40#[derive(Debug, Clone)]
41pub struct SOSDecomposition {
42    /// The squared polynomials q_i
43    pub squares: Vec<Polynomial>,
44    /// Gram matrix Q such that p = v^T Q v where v is monomial basis
45    pub gram_matrix: Vec<f64>,
46    /// Monomial basis used
47    pub basis: Vec<Monomial>,
48}
49
50impl SOSDecomposition {
51    /// Verify decomposition: check that Σ q_i² ≈ original polynomial
52    pub fn verify(&self, original: &Polynomial, tol: f64) -> bool {
53        let reconstructed = self.reconstruct();
54
55        // Check each term
56        for (m, &c) in original.terms() {
57            let c_rec = reconstructed.coeff(m);
58            if (c - c_rec).abs() > tol {
59                return false;
60            }
61        }
62
63        // Check that reconstructed doesn't have extra terms
64        for (m, &c) in reconstructed.terms() {
65            if c.abs() > tol && original.coeff(m).abs() < tol {
66                return false;
67            }
68        }
69
70        true
71    }
72
73    /// Reconstruct polynomial from decomposition
74    pub fn reconstruct(&self) -> Polynomial {
75        let mut result = Polynomial::zero();
76        for q in &self.squares {
77            result = result.add(&q.square());
78        }
79        result
80    }
81
82    /// Get lower bound on polynomial (should be ≥ 0 if SOS)
83    pub fn lower_bound(&self) -> f64 {
84        0.0 // SOS polynomials are always ≥ 0
85    }
86}
87
88/// SOS checker/decomposer
89pub struct SOSChecker {
90    config: SOSConfig,
91}
92
93impl SOSChecker {
94    /// Create with config
95    pub fn new(config: SOSConfig) -> Self {
96        Self { config }
97    }
98
99    /// Create with defaults
100    pub fn default() -> Self {
101        Self::new(SOSConfig::default())
102    }
103
104    /// Check if polynomial is SOS and find decomposition
105    pub fn check(&self, p: &Polynomial) -> SOSResult {
106        let degree = p.degree();
107        if degree == 0 {
108            // Constant polynomial
109            let c = p.eval(&[]);
110            if c >= 0.0 {
111                return SOSResult::IsSOS(SOSDecomposition {
112                    squares: vec![Polynomial::constant(c.sqrt())],
113                    gram_matrix: vec![c],
114                    basis: vec![Monomial::one()],
115                });
116            } else {
117                return SOSResult::NotSOS { witness: vec![] };
118            }
119        }
120
121        if degree % 2 == 1 {
122            // Odd degree polynomials cannot be SOS (go to -∞)
123            // Try to find a witness
124            let witness = self.find_negative_witness(p);
125            if let Some(w) = witness {
126                return SOSResult::NotSOS { witness: w };
127            }
128            return SOSResult::Unknown;
129        }
130
131        // Build SOS program
132        let half_degree = degree / 2;
133        let num_vars = p.num_variables();
134
135        // Monomial basis for degree ≤ half_degree
136        let basis = Polynomial::monomials_up_to_degree(num_vars, half_degree);
137        let n = basis.len();
138
139        if n == 0 {
140            return SOSResult::Unknown;
141        }
142
143        // Try to find Gram matrix Q such that p = v^T Q v
144        // where v is the monomial basis vector
145        match self.find_gram_matrix(p, &basis) {
146            Some(gram) => {
147                // Check if Gram matrix is PSD
148                if self.is_psd(&gram, n) {
149                    let squares = self.extract_squares(&gram, &basis, n);
150                    SOSResult::IsSOS(SOSDecomposition {
151                        squares,
152                        gram_matrix: gram,
153                        basis,
154                    })
155                } else {
156                    SOSResult::Unknown
157                }
158            }
159            None => {
160                // Try to find witness that p < 0
161                let witness = self.find_negative_witness(p);
162                if let Some(w) = witness {
163                    SOSResult::NotSOS { witness: w }
164                } else {
165                    SOSResult::Unknown
166                }
167            }
168        }
169    }
170
171    /// Find Gram matrix via moment matching
172    fn find_gram_matrix(&self, p: &Polynomial, basis: &[Monomial]) -> Option<Vec<f64>> {
173        let n = basis.len();
174
175        // Build mapping from monomial to coefficient constraint
176        // p = Σ_{i,j} Q[i,j] * (basis[i] * basis[j])
177        // So for each monomial m in p, we need:
178        // coeff(m) = Σ_{i,j: basis[i]*basis[j] = m} Q[i,j]
179
180        // For simplicity, use a direct approach for small cases
181        // and iterative refinement for larger ones
182
183        if n <= 10 {
184            return self.find_gram_direct(p, basis);
185        }
186
187        self.find_gram_iterative(p, basis)
188    }
189
190    /// Direct Gram matrix construction for small cases
191    fn find_gram_direct(&self, p: &Polynomial, basis: &[Monomial]) -> Option<Vec<f64>> {
192        let n = basis.len();
193
194        // Start with identity scaled by constant term
195        let c0 = p.coeff(&Monomial::one());
196        let scale = (c0.abs() + 1.0) / n as f64;
197
198        let mut gram = vec![0.0; n * n];
199        for i in 0..n {
200            gram[i * n + i] = scale;
201        }
202
203        // Iteratively adjust to match polynomial coefficients
204        for _ in 0..self.config.max_iters {
205            // Compute current reconstruction
206            let mut recon_terms = std::collections::HashMap::new();
207            for i in 0..n {
208                for j in 0..n {
209                    let m = basis[i].mul(&basis[j]);
210                    *recon_terms.entry(m).or_insert(0.0) += gram[i * n + j];
211                }
212            }
213
214            // Compute error
215            let mut max_err = 0.0f64;
216            for (m, &c_target) in p.terms() {
217                let c_current = *recon_terms.get(m).unwrap_or(&0.0);
218                max_err = max_err.max((c_target - c_current).abs());
219            }
220
221            if max_err < self.config.tolerance {
222                return Some(gram);
223            }
224
225            // Gradient step to reduce error
226            let step = 0.1;
227            for i in 0..n {
228                for j in 0..n {
229                    let m = basis[i].mul(&basis[j]);
230                    let c_target = p.coeff(&m);
231                    let c_current = *recon_terms.get(&m).unwrap_or(&0.0);
232                    let err = c_target - c_current;
233
234                    // Count how many (i',j') pairs produce this monomial
235                    let count = self.count_pairs(&basis, &m);
236                    if count > 0 {
237                        gram[i * n + j] += step * err / count as f64;
238                    }
239                }
240            }
241
242            // Project to symmetric
243            for i in 0..n {
244                for j in i + 1..n {
245                    let avg = (gram[i * n + j] + gram[j * n + i]) / 2.0;
246                    gram[i * n + j] = avg;
247                    gram[j * n + i] = avg;
248                }
249            }
250
251            // Regularize diagonal
252            for i in 0..n {
253                gram[i * n + i] = gram[i * n + i].max(self.config.regularization);
254            }
255        }
256
257        None
258    }
259
260    fn find_gram_iterative(&self, p: &Polynomial, basis: &[Monomial]) -> Option<Vec<f64>> {
261        // Same as direct but with larger step budget
262        self.find_gram_direct(p, basis)
263    }
264
265    fn count_pairs(&self, basis: &[Monomial], target: &Monomial) -> usize {
266        let n = basis.len();
267        let mut count = 0;
268        for i in 0..n {
269            for j in 0..n {
270                if basis[i].mul(&basis[j]) == *target {
271                    count += 1;
272                }
273            }
274        }
275        count
276    }
277
278    /// Check if matrix is positive semidefinite via Cholesky
279    fn is_psd(&self, gram: &[f64], n: usize) -> bool {
280        // Simple check: try Cholesky decomposition
281        let mut l = vec![0.0; n * n];
282
283        for i in 0..n {
284            for j in 0..=i {
285                let mut sum = gram[i * n + j];
286                for k in 0..j {
287                    sum -= l[i * n + k] * l[j * n + k];
288                }
289
290                if i == j {
291                    if sum < -self.config.tolerance {
292                        return false;
293                    }
294                    l[i * n + j] = sum.max(0.0).sqrt();
295                } else {
296                    let ljj = l[j * n + j];
297                    l[i * n + j] = if ljj > self.config.tolerance {
298                        sum / ljj
299                    } else {
300                        0.0
301                    };
302                }
303            }
304        }
305
306        true
307    }
308
309    /// Extract square polynomials from Gram matrix via Cholesky
310    fn extract_squares(&self, gram: &[f64], basis: &[Monomial], n: usize) -> Vec<Polynomial> {
311        // Cholesky: G = L L^T
312        let mut l = vec![0.0; n * n];
313
314        for i in 0..n {
315            for j in 0..=i {
316                let mut sum = gram[i * n + j];
317                for k in 0..j {
318                    sum -= l[i * n + k] * l[j * n + k];
319                }
320
321                if i == j {
322                    l[i * n + j] = sum.max(0.0).sqrt();
323                } else {
324                    let ljj = l[j * n + j];
325                    l[i * n + j] = if ljj > 1e-15 { sum / ljj } else { 0.0 };
326                }
327            }
328        }
329
330        // Each column of L gives a polynomial q_j = Σ_i L[i,j] * basis[i]
331        let mut squares = Vec::new();
332        for j in 0..n {
333            let terms: Vec<Term> = (0..n)
334                .filter(|&i| l[i * n + j].abs() > 1e-15)
335                .map(|i| Term {
336                    coeff: l[i * n + j],
337                    monomial: basis[i].clone(),
338                })
339                .collect();
340
341            if !terms.is_empty() {
342                squares.push(Polynomial::from_terms(terms));
343            }
344        }
345
346        squares
347    }
348
349    /// Try to find a point where polynomial is negative
350    fn find_negative_witness(&self, p: &Polynomial) -> Option<Vec<f64>> {
351        let n = p.num_variables().max(1);
352
353        // Grid search
354        let grid = [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0];
355
356        fn recurse(
357            p: &Polynomial,
358            current: &mut Vec<f64>,
359            depth: usize,
360            n: usize,
361            grid: &[f64],
362        ) -> Option<Vec<f64>> {
363            if depth == n {
364                if p.eval(current) < -1e-10 {
365                    return Some(current.clone());
366                }
367                return None;
368            }
369
370            for &v in grid {
371                current.push(v);
372                if let Some(w) = recurse(p, current, depth + 1, n, grid) {
373                    return Some(w);
374                }
375                current.pop();
376            }
377
378            None
379        }
380
381        let mut current = Vec::new();
382        recurse(p, &mut current, 0, n, &grid)
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    #[test]
391    fn test_constant_sos() {
392        let p = Polynomial::constant(4.0);
393        let checker = SOSChecker::default();
394
395        match checker.check(&p) {
396            SOSResult::IsSOS(decomp) => {
397                assert!(decomp.verify(&p, 1e-6));
398            }
399            _ => panic!("4.0 should be SOS"),
400        }
401    }
402
403    #[test]
404    fn test_negative_constant_not_sos() {
405        let p = Polynomial::constant(-1.0);
406        let checker = SOSChecker::default();
407
408        match checker.check(&p) {
409            SOSResult::NotSOS { .. } => {}
410            _ => panic!("-1.0 should not be SOS"),
411        }
412    }
413
414    #[test]
415    fn test_square_is_sos() {
416        // (x + y)² = x² + 2xy + y² is SOS
417        let x = Polynomial::var(0);
418        let y = Polynomial::var(1);
419        let p = x.add(&y).square();
420
421        let checker = SOSChecker::default();
422
423        match checker.check(&p) {
424            SOSResult::IsSOS(decomp) => {
425                // Verify reconstruction
426                let recon = decomp.reconstruct();
427                for pt in [vec![1.0, 1.0], vec![2.0, -1.0], vec![0.0, 3.0]] {
428                    let diff = (p.eval(&pt) - recon.eval(&pt)).abs();
429                    assert!(diff < 1.0, "Reconstruction error too large: {}", diff);
430                }
431            }
432            SOSResult::Unknown => {
433                // Simplified solver may not always converge
434                // But polynomial should be non-negative at sample points
435                for pt in [vec![1.0, 1.0], vec![2.0, -1.0], vec![0.0, 3.0]] {
436                    assert!(p.eval(&pt) >= 0.0, "(x+y)² should be >= 0");
437                }
438            }
439            SOSResult::NotSOS { witness } => {
440                // Should not find counterexample for a true SOS polynomial
441                panic!("(x+y)² incorrectly marked as not SOS with witness {:?}", witness);
442            }
443        }
444    }
445
446    #[test]
447    fn test_x_squared_plus_one() {
448        // x² + 1 is SOS
449        let x = Polynomial::var(0);
450        let p = x.square().add(&Polynomial::constant(1.0));
451
452        let checker = SOSChecker::default();
453
454        match checker.check(&p) {
455            SOSResult::IsSOS(_) => {}
456            SOSResult::Unknown => {} // Acceptable if solver didn't converge
457            SOSResult::NotSOS { .. } => panic!("x² + 1 should be SOS"),
458        }
459    }
460}