sparse_ir/sve/
strategy.rs

1//! SVE computation strategies
2
3use crate::gauss::{Rule, legendre_generic};
4use crate::kernel::{AbstractKernel, CentrosymmKernel, KernelProperties, SVEHints, SymmetryType};
5use crate::kernelmatrix::{matrix_from_gauss_noncentrosymmetric, matrix_from_gauss_with_segments};
6use crate::numeric::CustomNumeric;
7use crate::poly::PiecewiseLegendrePolyVector;
8use mdarray::DTensor;
9
10use super::result::SVEResult;
11use super::utils::{extend_to_full_domain, merge_results, remove_weights, svd_to_polynomials};
12
13/// Trait for SVE computation strategies
14pub trait SVEStrategy<T: CustomNumeric> {
15    /// Compute the discretized matrices for SVD
16    fn matrices(&self) -> Vec<DTensor<T, 2>>;
17
18    /// Post-process SVD results to create SVEResult
19    fn postprocess(
20        &self,
21        u_list: Vec<DTensor<T, 2>>,
22        s_list: Vec<Vec<T>>,
23        v_list: Vec<DTensor<T, 2>>,
24    ) -> SVEResult;
25}
26
27/// Sampling-based SVE computation
28///
29/// This is the general SVE computation strategy that works with discretized kernels.
30/// It does NOT know about symmetry - it just processes a given discretized kernel matrix.
31///
32/// # Responsibility
33///
34/// - Remove weights from SVD results
35/// - Convert to polynomials on the domain specified by segments
36/// - Domain extension is caller's responsibility
37pub struct SamplingSVE<T>
38where
39    T: CustomNumeric + Send + Sync + 'static,
40{
41    segments_x: Vec<T>,
42    segments_y: Vec<T>,
43    gauss_x: Rule<T>,
44    gauss_y: Rule<T>,
45    #[allow(dead_code)]
46    epsilon: f64,
47    n_gauss: usize,
48}
49
50impl<T> SamplingSVE<T>
51where
52    T: CustomNumeric + Send + Sync + 'static,
53{
54    /// Create a new SamplingSVE
55    ///
56    /// This takes only the geometric information needed for polynomial conversion,
57    /// not the kernel itself.
58    pub fn new(
59        segments_x: Vec<T>,
60        segments_y: Vec<T>,
61        gauss_x: Rule<T>,
62        gauss_y: Rule<T>,
63        epsilon: f64,
64        n_gauss: usize,
65    ) -> Self {
66        Self {
67            segments_x,
68            segments_y,
69            gauss_x,
70            gauss_y,
71            epsilon,
72            n_gauss,
73        }
74    }
75
76    /// Post-process a single SVD result to create polynomials
77    ///
78    /// This converts SVD results to piecewise Legendre polynomials
79    /// on the domain specified by segments (e.g., [0, xmax] for reduced kernels).
80    pub fn postprocess_single(
81        &self,
82        u: &DTensor<T, 2>,
83        s: &[T],
84        v: &DTensor<T, 2>,
85    ) -> (
86        PiecewiseLegendrePolyVector,
87        Vec<f64>,
88        PiecewiseLegendrePolyVector,
89    ) {
90        // 1. Remove weights
91        // Both U and V have rows corresponding to Gauss points, so is_row=true for both
92        let u_unweighted = remove_weights(u, self.gauss_x.w.as_slice(), true);
93        let v_unweighted = remove_weights(v, self.gauss_y.w.as_slice(), true);
94
95        // 2. Convert to polynomials
96        let gauss_rule_f64 = legendre_generic::<f64>(self.n_gauss);
97        let u_polys = svd_to_polynomials(
98            &u_unweighted,
99            &self.segments_x,
100            &gauss_rule_f64,
101            self.n_gauss,
102        );
103        let v_polys = svd_to_polynomials(
104            &v_unweighted,
105            &self.segments_y,
106            &gauss_rule_f64,
107            self.n_gauss,
108        );
109
110        // Note: No domain extension here - that's the caller's responsibility
111        (
112            PiecewiseLegendrePolyVector::new(u_polys),
113            s.iter().map(|&x| x.to_f64()).collect(),
114            PiecewiseLegendrePolyVector::new(v_polys),
115        )
116    }
117}
118
119/// Centrosymmetric SVE computation
120///
121/// Exploits even/odd symmetry for efficient computation.
122/// This manages the symmetry: creating reduced kernels, extending to full domain, and merging.
123pub struct CentrosymmSVE<T, K>
124where
125    T: CustomNumeric + Send + Sync + 'static,
126    K: CentrosymmKernel + KernelProperties,
127{
128    kernel: K,
129    epsilon: f64,
130    hints: K::SVEHintsType<T>,
131    #[allow(dead_code)]
132    n_gauss: usize,
133
134    // Geometric information (positive domain [0, xmax])
135    #[allow(dead_code)]
136    segments_x: Vec<T>,
137    #[allow(dead_code)]
138    segments_y: Vec<T>,
139    gauss_x: Rule<T>,
140    gauss_y: Rule<T>,
141
142    // The general SVE processor (no symmetry knowledge)
143    sampling_sve: SamplingSVE<T>,
144}
145
146impl<T, K> CentrosymmSVE<T, K>
147where
148    T: CustomNumeric + Send + Sync + Clone + 'static,
149    K: CentrosymmKernel + KernelProperties + Clone,
150    K::SVEHintsType<T>: SVEHints<T> + Clone,
151{
152    /// Create a new CentrosymmSVE
153    pub fn new(kernel: K, epsilon: f64) -> Self {
154        let hints = kernel.sve_hints::<T>(epsilon);
155
156        // Get segments for positive domain [0, xmax]
157        let segments_x = hints.segments_x();
158        let segments_y = hints.segments_y();
159        let n_gauss = hints.ngauss();
160
161        // Create composite Gauss rules
162        let rule = legendre_generic::<T>(n_gauss);
163        let gauss_x = rule.piecewise(&segments_x);
164        let gauss_y = rule.piecewise(&segments_y);
165
166        // Create the general SVE processor
167        let sampling_sve = SamplingSVE::new(
168            segments_x.clone(),
169            segments_y.clone(),
170            gauss_x.clone(),
171            gauss_y.clone(),
172            epsilon,
173            n_gauss,
174        );
175
176        Self {
177            kernel,
178            epsilon,
179            hints,
180            n_gauss,
181            segments_x,
182            segments_y,
183            gauss_x,
184            gauss_y,
185            sampling_sve,
186        }
187    }
188
189    /// Compute reduced kernel matrix for given symmetry
190    fn compute_reduced_matrix(&self, symmetry: SymmetryType) -> DTensor<T, 2> {
191        // Compute K_red(x, y) = K(x, y) + sign * K(x, -y)
192        // where x, y are in [0, xmax] and [0, ymax]
193        let discretized = matrix_from_gauss_with_segments(
194            &self.kernel,
195            &self.gauss_x,
196            &self.gauss_y,
197            symmetry,
198            &self.hints,
199        );
200
201        // Apply weights for SVE
202
203        discretized.apply_weights_for_sve()
204    }
205
206    /// Extend polynomials from [0, xmax] to [-xmax, xmax]
207    fn extend_result_to_full_domain(
208        &self,
209        result: (
210            PiecewiseLegendrePolyVector,
211            Vec<f64>,
212            PiecewiseLegendrePolyVector,
213        ),
214        symmetry: SymmetryType,
215    ) -> (
216        PiecewiseLegendrePolyVector,
217        Vec<f64>,
218        PiecewiseLegendrePolyVector,
219    ) {
220        let (u, s, v) = result;
221
222        // Extend u and v from [0, xmax] to [-xmax, xmax]
223        let u_full = extend_to_full_domain(u.get_polys().to_vec(), symmetry, self.kernel.xmax());
224        let v_full = extend_to_full_domain(v.get_polys().to_vec(), symmetry, self.kernel.ymax());
225
226        (
227            PiecewiseLegendrePolyVector::new(u_full),
228            s,
229            PiecewiseLegendrePolyVector::new(v_full),
230        )
231    }
232}
233
234impl<T, K> SVEStrategy<T> for CentrosymmSVE<T, K>
235where
236    T: CustomNumeric + Send + Sync + Clone + 'static,
237    K: CentrosymmKernel + KernelProperties + Clone,
238    K::SVEHintsType<T>: SVEHints<T> + Clone,
239{
240    fn matrices(&self) -> Vec<DTensor<T, 2>> {
241        // Compute reduced kernels for even and odd symmetries
242        let even_matrix = self.compute_reduced_matrix(SymmetryType::Even);
243        let odd_matrix = self.compute_reduced_matrix(SymmetryType::Odd);
244
245        vec![even_matrix, odd_matrix]
246    }
247
248    fn postprocess(
249        &self,
250        u_list: Vec<DTensor<T, 2>>,
251        s_list: Vec<Vec<T>>,
252        v_list: Vec<DTensor<T, 2>>,
253    ) -> SVEResult {
254        // Process even and odd results using SamplingSVE (which doesn't know about symmetry)
255        let result_even = self
256            .sampling_sve
257            .postprocess_single(&u_list[0], &s_list[0], &v_list[0]);
258        let result_odd = self
259            .sampling_sve
260            .postprocess_single(&u_list[1], &s_list[1], &v_list[1]);
261
262        // Now extend to full domain (this is where symmetry comes in)
263        let result_even_full = self.extend_result_to_full_domain(result_even, SymmetryType::Even);
264        let result_odd_full = self.extend_result_to_full_domain(result_odd, SymmetryType::Odd);
265
266        // Merge the results
267        merge_results(result_even_full, result_odd_full, self.epsilon)
268    }
269}
270
271/// Non-centrosymmetric SVE computation
272///
273/// This strategy works with non-centrosymmetric kernels by directly computing
274/// the kernel matrix over the full domain [-xmax, xmax] × [-ymax, ymax].
275/// No symmetry exploitation is performed.
276#[allow(dead_code)]
277pub struct NonCentrosymmSVE<T, K>
278where
279    T: CustomNumeric + Send + Sync + 'static,
280    K: AbstractKernel + KernelProperties,
281{
282    kernel: K,
283    epsilon: f64,
284    hints: K::SVEHintsType<T>,
285    n_gauss: usize,
286
287    // Geometric information (full domain [-xmax, xmax])
288    segments_x: Vec<T>,
289    segments_y: Vec<T>,
290    gauss_x: Rule<T>,
291    gauss_y: Rule<T>,
292
293    // The general SVE processor
294    sampling_sve: SamplingSVE<T>,
295}
296
297impl<T, K> NonCentrosymmSVE<T, K>
298where
299    T: CustomNumeric + Send + Sync + Clone + 'static,
300    K: AbstractKernel + KernelProperties + Clone,
301    K::SVEHintsType<T>: SVEHints<T> + Clone,
302{
303    /// Create a new NonCentrosymmSVE
304    pub fn new(kernel: K, epsilon: f64) -> Self {
305        let hints = kernel.sve_hints::<T>(epsilon);
306
307        // Get segments for full domain [-xmax, xmax]
308        let segments_x = hints.segments_x();
309        let segments_y = hints.segments_y();
310        let n_gauss = hints.ngauss();
311
312        // Create composite Gauss rules for full domain
313        let rule = legendre_generic::<T>(n_gauss);
314        let gauss_x = rule.piecewise(&segments_x);
315        let gauss_y = rule.piecewise(&segments_y);
316
317        // Create the general SVE processor
318        let sampling_sve = SamplingSVE::new(
319            segments_x.clone(),
320            segments_y.clone(),
321            gauss_x.clone(),
322            gauss_y.clone(),
323            epsilon,
324            n_gauss,
325        );
326
327        Self {
328            kernel,
329            epsilon,
330            hints,
331            n_gauss,
332            segments_x,
333            segments_y,
334            gauss_x,
335            gauss_y,
336            sampling_sve,
337        }
338    }
339
340    /// Compute kernel matrix for non-centrosymmetric kernel
341    fn compute_matrix(&self) -> DTensor<T, 2> {
342        // Compute K(x, y) directly over full domain
343        let discretized = matrix_from_gauss_noncentrosymmetric(
344            &self.kernel,
345            &self.gauss_x,
346            &self.gauss_y,
347            &self.hints,
348        );
349
350        // Apply weights for SVE
351        discretized.apply_weights_for_sve()
352    }
353}
354
355impl<T, K> SVEStrategy<T> for NonCentrosymmSVE<T, K>
356where
357    T: CustomNumeric + Send + Sync + Clone + 'static,
358    K: AbstractKernel + KernelProperties + Clone,
359    K::SVEHintsType<T>: SVEHints<T> + Clone,
360{
361    fn matrices(&self) -> Vec<DTensor<T, 2>> {
362        // Single matrix for non-centrosymmetric kernel
363        vec![self.compute_matrix()]
364    }
365
366    fn postprocess(
367        &self,
368        u_list: Vec<DTensor<T, 2>>,
369        s_list: Vec<Vec<T>>,
370        v_list: Vec<DTensor<T, 2>>,
371    ) -> SVEResult {
372        // Process single result using SamplingSVE
373        let (u_polys, s, v_polys) = self
374            .sampling_sve
375            .postprocess_single(&u_list[0], &s_list[0], &v_list[0]);
376
377        // No domain extension needed - already on full domain
378        SVEResult::new(u_polys, s, v_polys, self.epsilon)
379    }
380}