tensorlogic_infer/low_rank/
approximation.rs1use tensorlogic_ir::EinsumGraph;
4
5use super::config::LowRankConfig;
6use super::error::LowRankError;
7use super::svd::{SvdResult, TruncatedSvd};
8
9pub struct LowRankApproximation {
15 config: LowRankConfig,
16 svd: TruncatedSvd,
17}
18
19impl LowRankApproximation {
20 pub fn new(config: LowRankConfig) -> Self {
22 let svd = TruncatedSvd::new(config.clone());
23 LowRankApproximation { config, svd }
24 }
25
26 pub fn approximate_matrix(
28 &self,
29 data: &[f64],
30 rows: usize,
31 cols: usize,
32 ) -> Result<SvdResult, LowRankError> {
33 self.svd.decompose(data, rows, cols)
34 }
35
36 pub fn approximate_matmul(
42 &self,
43 a: &[f64],
44 a_rows: usize,
45 a_cols: usize,
46 b: &[f64],
47 b_rows: usize,
48 b_cols: usize,
49 ) -> Result<Vec<f64>, LowRankError> {
50 if a_cols != b_rows {
51 return Err(LowRankError::InvalidInput(format!(
52 "inner dimensions mismatch: a_cols={} != b_rows={}",
53 a_cols, b_rows
54 )));
55 }
56
57 let svd_result = self.svd.decompose(a, a_rows, a_cols)?;
58 let rank = svd_result.rank_used;
59
60 let mut m = vec![0.0_f64; rank * b_cols];
62 for k in 0..rank {
63 for j in 0..b_cols {
64 let mut val = 0.0_f64;
65 for l in 0..b_rows {
66 val += svd_result.vt[k * svd_result.vt_cols + l] * b[l * b_cols + j];
68 }
69 m[k * b_cols + j] = val;
70 }
71 }
72
73 let mut c = vec![0.0_f64; a_rows * b_cols];
75 for i in 0..a_rows {
76 for j in 0..b_cols {
77 let mut val = 0.0_f64;
78 for k in 0..rank {
79 let u_ik = svd_result.u[i * svd_result.u_cols + k];
81 val += u_ik * svd_result.singular_values[k] * m[k * b_cols + j];
82 }
83 c[i * b_cols + j] = val;
84 }
85 }
86
87 Ok(c)
88 }
89
90 pub fn is_candidate(&self, rows: usize, cols: usize) -> bool {
93 rows >= self.config.min_dimension && cols >= self.config.min_dimension
94 }
95
96 pub fn optimal_rank(singular_values: &[f64], energy_threshold: f64) -> usize {
102 if singular_values.is_empty() {
103 return 0;
104 }
105 let total: f64 = singular_values.iter().map(|s| s * s).sum();
106 if total == 0.0 {
107 return 1;
108 }
109 let mut cumulative = 0.0_f64;
110 for (k, &sv) in singular_values.iter().enumerate() {
111 cumulative += sv * sv;
112 if cumulative / total >= energy_threshold {
113 return k + 1;
114 }
115 }
116 singular_values.len()
117 }
118}
119
120#[derive(Debug, Clone)]
126pub struct LowRankCandidate {
127 pub node_index: usize,
129 pub reason: String,
131 pub estimated_savings_ratio: f64,
133}
134
135#[derive(Debug, Clone, Default)]
137pub struct LowRankPassStats {
138 pub candidates_found: usize,
139 pub nodes_inspected: usize,
140 pub estimated_total_flop_reduction: f64,
141}
142
143#[derive(Debug)]
149pub struct LowRankInferencePass {
150 config: LowRankConfig,
151}
152
153impl LowRankInferencePass {
154 pub fn new(config: LowRankConfig) -> Self {
156 LowRankInferencePass { config }
157 }
158
159 pub fn find_candidates(&self, graph: &EinsumGraph) -> Vec<LowRankCandidate> {
161 let mut candidates = Vec::new();
162
163 for (idx, node) in graph.nodes.iter().enumerate() {
164 if let tensorlogic_ir::OpType::Einsum { spec } = &node.op {
165 if node.inputs.len() >= 2 && self.is_matmul_like(spec) {
169 let savings = self.estimate_savings(spec);
170 candidates.push(LowRankCandidate {
171 node_index: idx,
172 reason: format!(
173 "Einsum '{}' has {} inputs and matmul-like contraction",
174 spec,
175 node.inputs.len()
176 ),
177 estimated_savings_ratio: savings,
178 });
179 }
180 }
181 }
182
183 candidates
184 }
185
186 pub fn apply_annotations(&self, graph: &EinsumGraph) -> LowRankPassStats {
192 let candidates = self.find_candidates(graph);
193 let estimated_total_flop_reduction: f64 =
194 candidates.iter().map(|c| c.estimated_savings_ratio).sum();
195 LowRankPassStats {
196 candidates_found: candidates.len(),
197 nodes_inspected: graph.nodes.len(),
198 estimated_total_flop_reduction,
199 }
200 }
201
202 fn is_matmul_like(&self, spec: &str) -> bool {
209 if let Some(arrow_pos) = spec.find("->") {
210 let inputs_part = &spec[..arrow_pos];
211 let operands: Vec<&str> = inputs_part.split(',').collect();
212 if operands.len() < 2 {
213 return false;
214 }
215 let a_chars: std::collections::HashSet<char> =
217 operands[0].chars().filter(|c| c.is_alphabetic()).collect();
218 let b_chars: std::collections::HashSet<char> =
219 operands[1].chars().filter(|c| c.is_alphabetic()).collect();
220 let output_chars: std::collections::HashSet<char> = spec[arrow_pos + 2..]
221 .chars()
222 .filter(|c| c.is_alphabetic())
223 .collect();
224 let contracted: std::collections::HashSet<char> = a_chars
226 .intersection(&b_chars)
227 .copied()
228 .filter(|c| !output_chars.contains(c))
229 .collect();
230 return contracted.len() >= 1
231 && self.config.rank < self.min_contracted_dim_estimate(spec);
232 }
233 false
234 }
235
236 fn min_contracted_dim_estimate(&self, spec: &str) -> usize {
239 let contracted = spec.chars().filter(|c| c.is_alphabetic()).count();
242 contracted.max(1)
244 }
245
246 fn estimate_savings(&self, spec: &str) -> f64 {
248 let contracted_dims = self.min_contracted_dim_estimate(spec).max(1) as f64;
250 let rank = self.config.rank as f64;
251 (1.0 - (2.0 * rank) / contracted_dims).clamp(0.0, 1.0)
252 }
253}
254
255#[cfg(test)]
260mod tests {
261 use super::*;
262
263 fn make_config(rank: usize) -> LowRankConfig {
264 LowRankConfig::new(rank)
265 .with_tolerance(1e-8)
266 .with_max_iterations(300)
267 .with_min_dimension(8)
268 }
269
270 #[test]
275 fn test_approximation_4x4_matrix() {
276 let m: Vec<f64> = (1..=16).map(|x| x as f64).collect();
278 let cfg = make_config(2);
279 let approx = LowRankApproximation::new(cfg);
280 let svd = approx
281 .approximate_matrix(&m, 4, 4)
282 .expect("approximation should succeed for a valid 4x4 matrix");
283 assert!(svd.rank_used >= 1);
284 assert!(svd.frobenius_error.is_finite());
286 }
287
288 #[test]
289 fn test_is_candidate_small_matrix() {
290 let cfg = LowRankConfig::new(2).with_min_dimension(32);
291 let approx = LowRankApproximation::new(cfg);
292 assert!(!approx.is_candidate(4, 4));
294 }
295
296 #[test]
297 fn test_is_candidate_large_matrix() {
298 let cfg = LowRankConfig::new(4).with_min_dimension(32);
299 let approx = LowRankApproximation::new(cfg);
300 assert!(approx.is_candidate(64, 64));
302 }
303
304 #[test]
305 fn test_optimal_rank_energy_threshold() {
306 let svs = vec![10.0_f64, 5.0, 2.0, 1.0];
309 let r = LowRankApproximation::optimal_rank(&svs, 0.90);
310 assert_eq!(r, 2, "optimal rank for 90% energy should be 2, got {r}");
311
312 let r2 = LowRankApproximation::optimal_rank(&svs, 0.99);
314 assert_eq!(r2, 3, "optimal rank for 99% energy should be 3, got {r2}");
315 }
316
317 #[test]
322 fn test_inference_pass_empty_graph() {
323 let graph = EinsumGraph::new();
324 let pass = LowRankInferencePass::new(LowRankConfig::default());
325 let candidates = pass.find_candidates(&graph);
326 assert!(
327 candidates.is_empty(),
328 "empty graph should yield no candidates"
329 );
330 }
331
332 #[test]
333 fn test_inference_pass_stats() {
334 let mut graph = EinsumGraph::new();
335 let t0 = graph.add_tensor("A");
336 let t1 = graph.add_tensor("B");
337 let t2 = graph.add_tensor("C");
338 let node = tensorlogic_ir::EinsumNode::einsum("ij,jk->ik", vec![t0, t1], vec![t2]);
339 graph.add_node(node).expect("add_node ok");
340
341 let pass = LowRankInferencePass::new(LowRankConfig::new(2));
342 let stats = pass.apply_annotations(&graph);
343 assert_eq!(stats.nodes_inspected, 1);
344 assert!(stats.nodes_inspected >= 1);
347 assert!(stats.estimated_total_flop_reduction >= 0.0);
348 }
349}