scirs2_graph/alignment/
types.rs1use scirs2_core::ndarray::Array2;
8
9use crate::error::{GraphError, Result};
10
11#[derive(Debug, Clone)]
16pub struct AlignmentConfig {
17 pub alpha: f64,
20 pub max_iter: usize,
22 pub tolerance: f64,
24 pub greedy_candidates: usize,
26 pub local_search_depth: usize,
28}
29
30impl Default for AlignmentConfig {
31 fn default() -> Self {
32 Self {
33 alpha: 0.6,
34 max_iter: 100,
35 tolerance: 1e-8,
36 greedy_candidates: 5,
37 local_search_depth: 50,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
46pub struct AlignmentResult {
47 pub mapping: Vec<(usize, usize)>,
49 pub score: f64,
51 pub edge_conservation: f64,
53 pub converged: bool,
55 pub iterations: usize,
57}
58
59#[derive(Debug, Clone)]
64pub struct SimilarityMatrix {
65 data: Array2<f64>,
66 n1: usize,
67 n2: usize,
68}
69
70impl SimilarityMatrix {
71 pub fn new(n1: usize, n2: usize) -> Result<Self> {
77 if n1 == 0 || n2 == 0 {
78 return Err(GraphError::InvalidParameter {
79 param: "dimensions".to_string(),
80 value: format!("({}, {})", n1, n2),
81 expected: "both dimensions > 0".to_string(),
82 context: "SimilarityMatrix::new".to_string(),
83 });
84 }
85 let val = 1.0 / (n1 as f64 * n2 as f64);
86 let data = Array2::from_elem((n1, n2), val);
87 Ok(Self { data, n1, n2 })
88 }
89
90 pub fn from_prior(prior: Array2<f64>) -> Result<Self> {
98 let shape = prior.shape();
99 let n1 = shape[0];
100 let n2 = shape[1];
101 if n1 == 0 || n2 == 0 {
102 return Err(GraphError::InvalidParameter {
103 param: "prior dimensions".to_string(),
104 value: format!("({}, {})", n1, n2),
105 expected: "both dimensions > 0".to_string(),
106 context: "SimilarityMatrix::from_prior".to_string(),
107 });
108 }
109 let mut sm = Self {
110 data: prior,
111 n1,
112 n2,
113 };
114 sm.normalize();
115 Ok(sm)
116 }
117
118 pub fn get(&self, i: usize, j: usize) -> f64 {
122 if i < self.n1 && j < self.n2 {
123 self.data[[i, j]]
124 } else {
125 0.0
126 }
127 }
128
129 pub fn set(&mut self, i: usize, j: usize, value: f64) {
133 if i < self.n1 && j < self.n2 {
134 self.data[[i, j]] = value;
135 }
136 }
137
138 pub fn normalize(&mut self) {
142 let sum: f64 = self.data.iter().sum();
143 if sum.abs() < f64::EPSILON {
144 let val = 1.0 / (self.n1 as f64 * self.n2 as f64);
145 self.data.fill(val);
146 } else {
147 self.data /= sum;
148 }
149 }
150
151 pub fn as_array(&self) -> &Array2<f64> {
153 &self.data
154 }
155
156 pub fn n1(&self) -> usize {
158 self.n1
159 }
160
161 pub fn n2(&self) -> usize {
163 self.n2
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use scirs2_core::ndarray::array;
171
172 #[test]
173 fn test_alignment_config_default() {
174 let cfg = AlignmentConfig::default();
175 assert!((cfg.alpha - 0.6).abs() < f64::EPSILON);
176 assert_eq!(cfg.max_iter, 100);
177 assert!((cfg.tolerance - 1e-8).abs() < f64::EPSILON);
178 assert_eq!(cfg.greedy_candidates, 5);
179 assert_eq!(cfg.local_search_depth, 50);
180 }
181
182 #[test]
183 fn test_similarity_matrix_new() {
184 let sm = SimilarityMatrix::new(3, 4).expect("should create matrix");
185 let expected = 1.0 / 12.0;
186 for i in 0..3 {
187 for j in 0..4 {
188 assert!((sm.get(i, j) - expected).abs() < f64::EPSILON);
189 }
190 }
191 }
192
193 #[test]
194 fn test_similarity_matrix_zero_dim() {
195 assert!(SimilarityMatrix::new(0, 5).is_err());
196 assert!(SimilarityMatrix::new(5, 0).is_err());
197 }
198
199 #[test]
200 fn test_similarity_matrix_from_prior() {
201 let prior = array![[1.0, 2.0], [3.0, 4.0]];
202 let sm = SimilarityMatrix::from_prior(prior).expect("should create from prior");
203 let sum: f64 = sm.as_array().iter().sum();
204 assert!((sum - 1.0).abs() < 1e-12);
205 }
206
207 #[test]
208 fn test_similarity_matrix_set_get() {
209 let mut sm = SimilarityMatrix::new(2, 2).expect("should create matrix");
210 sm.set(0, 1, 0.99);
211 assert!((sm.get(0, 1) - 0.99).abs() < f64::EPSILON);
212 assert!((sm.get(10, 10)).abs() < f64::EPSILON);
214 }
215
216 #[test]
217 fn test_similarity_matrix_normalize_zero() {
218 let prior = Array2::zeros((3, 3));
219 let sm = SimilarityMatrix::from_prior(prior).expect("should create from zero prior");
220 let expected = 1.0 / 9.0;
221 assert!((sm.get(0, 0) - expected).abs() < 1e-12);
222 }
223}