1use super::semiring::{Tropical, TropicalMin};
10
11#[derive(Debug, Clone)]
13pub struct TropicalMatrix {
14 rows: usize,
15 cols: usize,
16 data: Vec<f64>,
17}
18
19impl TropicalMatrix {
20 pub fn zeros(rows: usize, cols: usize) -> Self {
22 Self {
23 rows,
24 cols,
25 data: vec![f64::NEG_INFINITY; rows * cols],
26 }
27 }
28
29 pub fn identity(n: usize) -> Self {
31 let mut m = Self::zeros(n, n);
32 for i in 0..n {
33 m.set(i, i, 0.0);
34 }
35 m
36 }
37
38 pub fn from_rows(data: Vec<Vec<f64>>) -> Self {
40 let rows = data.len();
41 let cols = if rows > 0 { data[0].len() } else { 0 };
42 let flat: Vec<f64> = data.into_iter().flatten().collect();
43 Self { rows, cols, data: flat }
44 }
45
46 #[inline]
48 pub fn get(&self, i: usize, j: usize) -> f64 {
49 if i >= self.rows || j >= self.cols {
50 return f64::NEG_INFINITY;
51 }
52 self.data[i * self.cols + j]
53 }
54
55 #[inline]
57 pub fn set(&mut self, i: usize, j: usize, val: f64) {
58 if i >= self.rows || j >= self.cols {
59 return;
60 }
61 self.data[i * self.cols + j] = val;
62 }
63
64 pub fn dims(&self) -> (usize, usize) {
66 (self.rows, self.cols)
67 }
68
69 pub fn mul(&self, other: &Self) -> Self {
71 assert_eq!(self.cols, other.rows, "Dimension mismatch");
72
73 let mut result = Self::zeros(self.rows, other.cols);
74
75 for i in 0..self.rows {
76 for k in 0..other.cols {
77 let mut max_val = f64::NEG_INFINITY;
78 for j in 0..self.cols {
79 let a = self.get(i, j);
80 let b = other.get(j, k);
81
82 if a != f64::NEG_INFINITY && b != f64::NEG_INFINITY {
83 max_val = max_val.max(a + b);
84 }
85 }
86 result.set(i, k, max_val);
87 }
88 }
89
90 result
91 }
92
93 pub fn pow(&self, n: usize) -> Self {
95 assert_eq!(self.rows, self.cols, "Must be square");
96
97 if n == 0 {
98 return Self::identity(self.rows);
99 }
100
101 let mut result = self.clone();
102 for _ in 1..n {
103 result = result.mul(self);
104 }
105 result
106 }
107
108 pub fn closure(&self) -> Self {
111 assert_eq!(self.rows, self.cols, "Must be square");
112 let n = self.rows;
113
114 let mut result = Self::identity(n);
115 let mut power = self.clone();
116
117 for _ in 0..n {
118 for i in 0..n {
120 for j in 0..n {
121 let old = result.get(i, j);
122 let new = power.get(i, j);
123 result.set(i, j, old.max(new));
124 }
125 }
126 power = power.mul(self);
127 }
128
129 result
130 }
131
132 pub fn max_cycle_mean(&self) -> f64 {
135 assert_eq!(self.rows, self.cols, "Must be square");
136 let n = self.rows;
137
138 let mut d = vec![vec![f64::NEG_INFINITY; n + 1]; n];
140
141 for i in 0..n {
143 d[i][0] = 0.0;
144 }
145
146 for k in 1..=n {
148 for i in 0..n {
149 for j in 0..n {
150 let w = self.get(i, j);
151 if w != f64::NEG_INFINITY && d[j][k - 1] != f64::NEG_INFINITY {
152 d[i][k] = d[i][k].max(w + d[j][k - 1]);
153 }
154 }
155 }
156 }
157
158 let mut lambda = f64::NEG_INFINITY;
160 for i in 0..n {
161 if d[i][n] != f64::NEG_INFINITY {
162 let mut min_ratio = f64::INFINITY;
163 for k in 0..n {
164 if k < n && d[i][k] != f64::NEG_INFINITY {
166 let divisor = (n - k) as f64;
167 if divisor > 0.0 {
168 let ratio = (d[i][n] - d[i][k]) / divisor;
169 min_ratio = min_ratio.min(ratio);
170 }
171 }
172 }
173 lambda = lambda.max(min_ratio);
174 }
175 }
176
177 lambda
178 }
179}
180
181#[derive(Debug, Clone)]
183pub struct TropicalEigen {
184 pub eigenvalue: f64,
186 pub eigenvector: Vec<f64>,
188}
189
190impl TropicalEigen {
191 pub fn power_iteration(matrix: &TropicalMatrix, max_iters: usize) -> Option<Self> {
194 let n = matrix.rows;
195 if n == 0 {
196 return None;
197 }
198
199 let mut v: Vec<f64> = vec![0.0; n];
201 let mut eigenvalue = 0.0f64;
202
203 for _ in 0..max_iters {
204 let mut av = vec![f64::NEG_INFINITY; n];
206 for i in 0..n {
207 for j in 0..n {
208 let aij = matrix.get(i, j);
209 if aij != f64::NEG_INFINITY && v[j] != f64::NEG_INFINITY {
210 av[i] = av[i].max(aij + v[j]);
211 }
212 }
213 }
214
215 let max_av = av.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
217 if max_av == f64::NEG_INFINITY {
218 return None;
219 }
220
221 let new_eigenvalue = max_av - v.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
223
224 for i in 0..n {
226 v[i] = av[i] - max_av;
227 }
228
229 if (new_eigenvalue - eigenvalue).abs() < 1e-10 {
231 return Some(TropicalEigen {
232 eigenvalue: new_eigenvalue,
233 eigenvector: v,
234 });
235 }
236
237 eigenvalue = new_eigenvalue;
238 }
239
240 Some(TropicalEigen { eigenvalue, eigenvector: v })
241 }
242}
243
244#[derive(Debug, Clone)]
246pub struct MinPlusMatrix {
247 rows: usize,
248 cols: usize,
249 data: Vec<f64>,
250}
251
252impl MinPlusMatrix {
253 pub fn from_adjacency(adj: Vec<Vec<f64>>) -> Self {
255 let rows = adj.len();
256 let cols = if rows > 0 { adj[0].len() } else { 0 };
257 let data: Vec<f64> = adj.into_iter().flatten().collect();
258 Self { rows, cols, data }
259 }
260
261 #[inline]
263 pub fn get(&self, i: usize, j: usize) -> f64 {
264 if i >= self.rows || j >= self.cols {
265 return f64::INFINITY;
266 }
267 self.data[i * self.cols + j]
268 }
269
270 #[inline]
272 pub fn set(&mut self, i: usize, j: usize, val: f64) {
273 if i >= self.rows || j >= self.cols {
274 return;
275 }
276 self.data[i * self.cols + j] = val;
277 }
278
279 pub fn all_pairs_shortest_paths(&self) -> Self {
281 let n = self.rows;
282 let mut dist = self.clone();
283
284 for k in 0..n {
285 for i in 0..n {
286 for j in 0..n {
287 let via_k = dist.get(i, k) + dist.get(k, j);
288 if via_k < dist.get(i, j) {
289 dist.set(i, j, via_k);
290 }
291 }
292 }
293 }
294
295 dist
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_tropical_matrix_mul() {
305 let a = TropicalMatrix::from_rows(vec![
307 vec![0.0, 1.0],
308 vec![f64::NEG_INFINITY, 2.0],
309 ]);
310
311 let a2 = a.mul(&a);
313
314 assert!((a2.get(0, 1) - 3.0).abs() < 1e-10); }
316
317 #[test]
318 fn test_tropical_identity() {
319 let i = TropicalMatrix::identity(3);
320 let a = TropicalMatrix::from_rows(vec![
321 vec![1.0, 2.0, 3.0],
322 vec![4.0, 5.0, 6.0],
323 vec![7.0, 8.0, 9.0],
324 ]);
325
326 let ia = i.mul(&a);
327 for row in 0..3 {
328 for col in 0..3 {
329 assert!((ia.get(row, col) - a.get(row, col)).abs() < 1e-10);
330 }
331 }
332 }
333
334 #[test]
335 fn test_max_cycle_mean() {
336 let a = TropicalMatrix::from_rows(vec![
339 vec![f64::NEG_INFINITY, 3.0],
340 vec![1.0, f64::NEG_INFINITY],
341 ]);
342
343 let mcm = a.max_cycle_mean();
344 assert!((mcm - 2.0).abs() < 1e-10);
345 }
346
347 #[test]
348 fn test_floyd_warshall() {
349 let adj = MinPlusMatrix::from_adjacency(vec![
351 vec![0.0, 1.0, 5.0],
352 vec![f64::INFINITY, 0.0, 2.0],
353 vec![f64::INFINITY, f64::INFINITY, 0.0],
354 ]);
355
356 let dist = adj.all_pairs_shortest_paths();
357
358 assert!((dist.get(0, 2) - 3.0).abs() < 1e-10);
360 }
361}