Skip to main content

tensorlogic_infer/low_rank/
approximation.rs

1//! High-level low-rank approximation API and inference pass.
2
3use tensorlogic_ir::EinsumGraph;
4
5use super::config::LowRankConfig;
6use super::error::LowRankError;
7use super::svd::{SvdResult, TruncatedSvd};
8
9// ---------------------------------------------------------------------------
10// LowRankApproximation
11// ---------------------------------------------------------------------------
12
13/// High-level API for low-rank matrix approximation.
14pub struct LowRankApproximation {
15    config: LowRankConfig,
16    svd: TruncatedSvd,
17}
18
19impl LowRankApproximation {
20    /// Create a new approximation engine with the given configuration.
21    pub fn new(config: LowRankConfig) -> Self {
22        let svd = TruncatedSvd::new(config.clone());
23        LowRankApproximation { config, svd }
24    }
25
26    /// Approximate a 2-D matrix stored in row-major order.
27    pub fn approximate_matrix(
28        &self,
29        data: &[f64],
30        rows: usize,
31        cols: usize,
32    ) -> Result<SvdResult, LowRankError> {
33        self.svd.decompose(data, rows, cols)
34    }
35
36    /// Approximate the matrix product `A @ B` using low-rank factors of `A`.
37    ///
38    /// Rather than computing the full product `C = A B` (which is O(a_rows · a_cols · b_cols)),
39    /// we approximate `A ≈ U Σ Vᵀ` (rank-k) and then compute `C ≈ (U Σ) (Vᵀ B)`.
40    /// This can be cheaper when `rank << min(a_rows, a_cols)`.
41    pub fn approximate_matmul(
42        &self,
43        a: &[f64],
44        a_rows: usize,
45        a_cols: usize,
46        b: &[f64],
47        b_rows: usize,
48        b_cols: usize,
49    ) -> Result<Vec<f64>, LowRankError> {
50        if a_cols != b_rows {
51            return Err(LowRankError::InvalidInput(format!(
52                "inner dimensions mismatch: a_cols={} != b_rows={}",
53                a_cols, b_rows
54            )));
55        }
56
57        let svd_result = self.svd.decompose(a, a_rows, a_cols)?;
58        let rank = svd_result.rank_used;
59
60        // Compute intermediate: M = Vᵀ B   [rank × b_cols]
61        let mut m = vec![0.0_f64; rank * b_cols];
62        for k in 0..rank {
63            for j in 0..b_cols {
64                let mut val = 0.0_f64;
65                for l in 0..b_rows {
66                    // vt[k, l] * b[l, j]
67                    val += svd_result.vt[k * svd_result.vt_cols + l] * b[l * b_cols + j];
68                }
69                m[k * b_cols + j] = val;
70            }
71        }
72
73        // Compute result: C = (U Σ) M   [a_rows × b_cols]
74        let mut c = vec![0.0_f64; a_rows * b_cols];
75        for i in 0..a_rows {
76            for j in 0..b_cols {
77                let mut val = 0.0_f64;
78                for k in 0..rank {
79                    // u[i, k] * sigma[k] * m[k, j]
80                    let u_ik = svd_result.u[i * svd_result.u_cols + k];
81                    val += u_ik * svd_result.singular_values[k] * m[k * b_cols + j];
82                }
83                c[i * b_cols + j] = val;
84            }
85        }
86
87        Ok(c)
88    }
89
90    /// Return `true` if this matrix is large enough to be a candidate for
91    /// low-rank approximation (based on `min_dimension` in the config).
92    pub fn is_candidate(&self, rows: usize, cols: usize) -> bool {
93        rows >= self.config.min_dimension && cols >= self.config.min_dimension
94    }
95
96    /// Compute the smallest rank `k` such that the top-`k` singular values
97    /// capture at least `energy_threshold` fraction of the total singular energy.
98    ///
99    /// `energy_threshold` should be in `[0, 1]`.  If the slice is empty the
100    /// function returns `0`.
101    pub fn optimal_rank(singular_values: &[f64], energy_threshold: f64) -> usize {
102        if singular_values.is_empty() {
103            return 0;
104        }
105        let total: f64 = singular_values.iter().map(|s| s * s).sum();
106        if total == 0.0 {
107            return 1;
108        }
109        let mut cumulative = 0.0_f64;
110        for (k, &sv) in singular_values.iter().enumerate() {
111            cumulative += sv * sv;
112            if cumulative / total >= energy_threshold {
113                return k + 1;
114            }
115        }
116        singular_values.len()
117    }
118}
119
120// ---------------------------------------------------------------------------
121// LowRankInferencePass
122// ---------------------------------------------------------------------------
123
124/// A low-rank approximation candidate identified in an `EinsumGraph`.
125#[derive(Debug, Clone)]
126pub struct LowRankCandidate {
127    /// Index of the node in `EinsumGraph::nodes`
128    pub node_index: usize,
129    /// Human-readable reason this node was flagged
130    pub reason: String,
131    /// Rough estimated savings as a ratio (0–1); higher is better.
132    pub estimated_savings_ratio: f64,
133}
134
135/// Aggregated statistics from a single pass over an `EinsumGraph`.
136#[derive(Debug, Clone, Default)]
137pub struct LowRankPassStats {
138    pub candidates_found: usize,
139    pub nodes_inspected: usize,
140    pub estimated_total_flop_reduction: f64,
141}
142
143/// Optimization pass that scans an `EinsumGraph` and annotates Einsum nodes
144/// with low-rank approximation candidates.
145///
146/// Currently uses a heuristic based on einsum spec complexity (number of
147/// unique contracted indices) to identify potential candidates.
148#[derive(Debug)]
149pub struct LowRankInferencePass {
150    config: LowRankConfig,
151}
152
153impl LowRankInferencePass {
154    /// Create a new pass with the given configuration.
155    pub fn new(config: LowRankConfig) -> Self {
156        LowRankInferencePass { config }
157    }
158
159    /// Scan the graph and return a list of low-rank candidates.
160    pub fn find_candidates(&self, graph: &EinsumGraph) -> Vec<LowRankCandidate> {
161        let mut candidates = Vec::new();
162
163        for (idx, node) in graph.nodes.iter().enumerate() {
164            if let tensorlogic_ir::OpType::Einsum { spec } = &node.op {
165                // Heuristic: if the einsum spec suggests a matmul-like pattern
166                // (two inputs, contracted indices) and would benefit from low-rank
167                // approximation, flag it.
168                if node.inputs.len() >= 2 && self.is_matmul_like(spec) {
169                    let savings = self.estimate_savings(spec);
170                    candidates.push(LowRankCandidate {
171                        node_index: idx,
172                        reason: format!(
173                            "Einsum '{}' has {} inputs and matmul-like contraction",
174                            spec,
175                            node.inputs.len()
176                        ),
177                        estimated_savings_ratio: savings,
178                    });
179                }
180            }
181        }
182
183        candidates
184    }
185
186    /// Apply annotations and return aggregate stats.
187    ///
188    /// In this implementation the "annotation" is a dry-run analysis only —
189    /// the graph is not mutated (annotation requires mutable access and is
190    /// outside the scope of a read-only pass).
191    pub fn apply_annotations(&self, graph: &EinsumGraph) -> LowRankPassStats {
192        let candidates = self.find_candidates(graph);
193        let estimated_total_flop_reduction: f64 =
194            candidates.iter().map(|c| c.estimated_savings_ratio).sum();
195        LowRankPassStats {
196            candidates_found: candidates.len(),
197            nodes_inspected: graph.nodes.len(),
198            estimated_total_flop_reduction,
199        }
200    }
201
202    // ------------------------------------------------------------------
203    // Internal helpers
204    // ------------------------------------------------------------------
205
206    /// Very light heuristic: treat any spec with `->` and two operands that
207    /// share at least one contracted index as matmul-like.
208    fn is_matmul_like(&self, spec: &str) -> bool {
209        if let Some(arrow_pos) = spec.find("->") {
210            let inputs_part = &spec[..arrow_pos];
211            let operands: Vec<&str> = inputs_part.split(',').collect();
212            if operands.len() < 2 {
213                return false;
214            }
215            // Check for shared characters (contracted indices)
216            let a_chars: std::collections::HashSet<char> =
217                operands[0].chars().filter(|c| c.is_alphabetic()).collect();
218            let b_chars: std::collections::HashSet<char> =
219                operands[1].chars().filter(|c| c.is_alphabetic()).collect();
220            let output_chars: std::collections::HashSet<char> = spec[arrow_pos + 2..]
221                .chars()
222                .filter(|c| c.is_alphabetic())
223                .collect();
224            // A contracted index appears in inputs but not in output
225            let contracted: std::collections::HashSet<char> = a_chars
226                .intersection(&b_chars)
227                .copied()
228                .filter(|c| !output_chars.contains(c))
229                .collect();
230            return contracted.len() >= 1
231                && self.config.rank < self.min_contracted_dim_estimate(spec);
232        }
233        false
234    }
235
236    /// Estimate the number of contracted dimensions from the spec string.
237    /// Used as a rough proxy for matrix size.
238    fn min_contracted_dim_estimate(&self, spec: &str) -> usize {
239        // Use rank as a stand-in; if rank < estimated contracted dims → candidate
240        // Here we just count contracted chars as a size proxy
241        let contracted = spec.chars().filter(|c| c.is_alphabetic()).count();
242        // If the spec has at least 4 unique index chars assume "large enough"
243        contracted.max(1)
244    }
245
246    /// Estimate FLOP savings ratio for a candidate node.
247    fn estimate_savings(&self, spec: &str) -> f64 {
248        // Heuristic: savings = 1 - (2*rank) / (contracted dims)
249        let contracted_dims = self.min_contracted_dim_estimate(spec).max(1) as f64;
250        let rank = self.config.rank as f64;
251        (1.0 - (2.0 * rank) / contracted_dims).clamp(0.0, 1.0)
252    }
253}
254
255// ---------------------------------------------------------------------------
256// Tests
257// ---------------------------------------------------------------------------
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn make_config(rank: usize) -> LowRankConfig {
264        LowRankConfig::new(rank)
265            .with_tolerance(1e-8)
266            .with_max_iterations(300)
267            .with_min_dimension(8)
268    }
269
270    // -----------------------------------------------------------------------
271    // LowRankApproximation tests
272    // -----------------------------------------------------------------------
273
274    #[test]
275    fn test_approximation_4x4_matrix() {
276        // Use a rank-2 approximation on a 4×4 matrix
277        let m: Vec<f64> = (1..=16).map(|x| x as f64).collect();
278        let cfg = make_config(2);
279        let approx = LowRankApproximation::new(cfg);
280        let svd = approx
281            .approximate_matrix(&m, 4, 4)
282            .expect("approximation should succeed for a valid 4x4 matrix");
283        assert!(svd.rank_used >= 1);
284        // Frobenius error should be a valid number
285        assert!(svd.frobenius_error.is_finite());
286    }
287
288    #[test]
289    fn test_is_candidate_small_matrix() {
290        let cfg = LowRankConfig::new(2).with_min_dimension(32);
291        let approx = LowRankApproximation::new(cfg);
292        // 4×4 is below min_dimension=32
293        assert!(!approx.is_candidate(4, 4));
294    }
295
296    #[test]
297    fn test_is_candidate_large_matrix() {
298        let cfg = LowRankConfig::new(4).with_min_dimension(32);
299        let approx = LowRankApproximation::new(cfg);
300        // 64×64 is above min_dimension=32
301        assert!(approx.is_candidate(64, 64));
302    }
303
304    #[test]
305    fn test_optimal_rank_energy_threshold() {
306        // Singular values: [10, 5, 2, 1]  → energies squared: [100, 25, 4, 1]  total=130
307        // 0.9 threshold → need cumulative >= 117 → first two give 125 >= 117 → rank=2
308        let svs = vec![10.0_f64, 5.0, 2.0, 1.0];
309        let r = LowRankApproximation::optimal_rank(&svs, 0.90);
310        assert_eq!(r, 2, "optimal rank for 90% energy should be 2, got {r}");
311
312        // 0.99 threshold → 100+25+4=129 >= 0.99*130=128.7 → rank=3
313        let r2 = LowRankApproximation::optimal_rank(&svs, 0.99);
314        assert_eq!(r2, 3, "optimal rank for 99% energy should be 3, got {r2}");
315    }
316
317    // -----------------------------------------------------------------------
318    // LowRankInferencePass tests
319    // -----------------------------------------------------------------------
320
321    #[test]
322    fn test_inference_pass_empty_graph() {
323        let graph = EinsumGraph::new();
324        let pass = LowRankInferencePass::new(LowRankConfig::default());
325        let candidates = pass.find_candidates(&graph);
326        assert!(
327            candidates.is_empty(),
328            "empty graph should yield no candidates"
329        );
330    }
331
332    #[test]
333    fn test_inference_pass_stats() {
334        let mut graph = EinsumGraph::new();
335        let t0 = graph.add_tensor("A");
336        let t1 = graph.add_tensor("B");
337        let t2 = graph.add_tensor("C");
338        let node = tensorlogic_ir::EinsumNode::einsum("ij,jk->ik", vec![t0, t1], vec![t2]);
339        graph.add_node(node).expect("add_node ok");
340
341        let pass = LowRankInferencePass::new(LowRankConfig::new(2));
342        let stats = pass.apply_annotations(&graph);
343        assert_eq!(stats.nodes_inspected, 1);
344        // This is a matmul-like spec with contracted index j
345        // candidates_found could be 0 or 1 depending on heuristic
346        assert!(stats.nodes_inspected >= 1);
347        assert!(stats.estimated_total_flop_reduction >= 0.0);
348    }
349}