1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
45pub struct LocalGlobalConsistency<S = Untrained> {
46 state: S,
47 kernel: String,
48 gamma: f64,
49 n_neighbors: usize,
50 alpha: f64,
51 max_iter: usize,
52 tol: f64,
53}
54
55impl LocalGlobalConsistency<Untrained> {
56 pub fn new() -> Self {
58 Self {
59 state: Untrained,
60 kernel: "rbf".to_string(),
61 gamma: 20.0,
62 n_neighbors: 7,
63 alpha: 0.99,
64 max_iter: 1000,
65 tol: 1e-6,
66 }
67 }
68
69 pub fn kernel(mut self, kernel: String) -> Self {
71 self.kernel = kernel;
72 self
73 }
74
75 pub fn gamma(mut self, gamma: f64) -> Self {
77 self.gamma = gamma;
78 self
79 }
80
81 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
83 self.n_neighbors = n_neighbors;
84 self
85 }
86
87 pub fn alpha(mut self, alpha: f64) -> Self {
89 self.alpha = alpha;
90 self
91 }
92
93 pub fn max_iter(mut self, max_iter: usize) -> Self {
95 self.max_iter = max_iter;
96 self
97 }
98
99 pub fn tol(mut self, tol: f64) -> Self {
101 self.tol = tol;
102 self
103 }
104
105 fn build_affinity_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
106 let n_samples = X.nrows();
107 let mut W = Array2::zeros((n_samples, n_samples));
108
109 match self.kernel.as_str() {
110 "rbf" => {
111 for i in 0..n_samples {
112 for j in 0..n_samples {
113 if i != j {
114 let diff = &X.row(i) - &X.row(j);
115 let dist_sq = diff.mapv(|x| x * x).sum();
116 W[[i, j]] = (-self.gamma * dist_sq).exp();
117 }
118 }
119 }
120 }
121 "knn" => {
122 for i in 0..n_samples {
123 let mut distances: Vec<(usize, f64)> = Vec::new();
124 for j in 0..n_samples {
125 if i != j {
126 let diff = &X.row(i) - &X.row(j);
127 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
128 distances.push((j, dist));
129 }
130 }
131
132 distances
133 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
134
135 for &(j, _) in distances.iter().take(self.n_neighbors) {
136 W[[i, j]] = 1.0;
137 W[[j, i]] = 1.0; }
139 }
140 }
141 _ => {
142 return Err(SklearsError::InvalidInput(format!(
143 "Unknown kernel: {}",
144 self.kernel
145 )));
146 }
147 }
148
149 Ok(W)
150 }
151
152 #[allow(non_snake_case)]
153 fn solve_lgc_function(
154 &self,
155 W: &Array2<f64>,
156 labeled_indices: &[usize],
157 y: &Array1<i32>,
158 classes: &[i32],
159 ) -> SklResult<Array2<f64>> {
160 let n_samples = W.nrows();
161 let n_classes = classes.len();
162
163 let D = W.sum_axis(Axis(1));
165 let mut D_sqrt_inv = Array2::zeros((n_samples, n_samples));
166 for i in 0..n_samples {
167 if D[i] > 0.0 {
168 D_sqrt_inv[[i, i]] = 1.0 / D[i].sqrt();
169 }
170 }
171
172 let S = D_sqrt_inv.dot(W).dot(&D_sqrt_inv);
174
175 let mut Y = Array2::zeros((n_samples, n_classes));
177 for &idx in labeled_indices {
178 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
179 Y[[idx, class_idx]] = 1.0;
180 }
181 }
182
183 let Y_static = Y.clone();
184
185 let mut prev_F = Y.clone();
187 for _iter in 0..self.max_iter {
188 Y = self.alpha * S.dot(&Y) + (1.0 - self.alpha) * &Y_static;
189
190 let diff = (&Y - &prev_F).mapv(|x| x.abs()).sum();
192 if diff < self.tol {
193 break;
194 }
195 prev_F = Y.clone();
196 }
197
198 Ok(Y)
199 }
200}
201
202impl Default for LocalGlobalConsistency<Untrained> {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl Estimator for LocalGlobalConsistency<Untrained> {
209 type Config = ();
210 type Error = SklearsError;
211 type Float = Float;
212
213 fn config(&self) -> &Self::Config {
214 &()
215 }
216}
217
218impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for LocalGlobalConsistency<Untrained> {
219 type Fitted = LocalGlobalConsistency<LocalGlobalConsistencyTrained>;
220
221 #[allow(non_snake_case)]
222 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
223 let X = X.to_owned();
224 let y = y.to_owned();
225
226 let (n_samples, _n_features) = X.dim();
227
228 let mut labeled_indices = Vec::new();
230 let mut classes = std::collections::HashSet::new();
231
232 for (i, &label) in y.iter().enumerate() {
233 if label != -1 {
234 labeled_indices.push(i);
235 classes.insert(label);
236 }
237 }
238
239 if labeled_indices.is_empty() {
240 return Err(SklearsError::InvalidInput(
241 "No labeled samples provided".to_string(),
242 ));
243 }
244
245 let classes: Vec<i32> = classes.into_iter().collect();
246
247 let W = self.build_affinity_matrix(&X)?;
249
250 let F = self.solve_lgc_function(&W, &labeled_indices, &y, &classes)?;
252
253 Ok(LocalGlobalConsistency {
254 state: LocalGlobalConsistencyTrained {
255 X_train: X.clone(),
256 y_train: y,
257 classes: Array1::from(classes),
258 label_distributions: F,
259 affinity_matrix: W,
260 },
261 kernel: self.kernel,
262 gamma: self.gamma,
263 n_neighbors: self.n_neighbors,
264 alpha: self.alpha,
265 max_iter: self.max_iter,
266 tol: self.tol,
267 })
268 }
269}
270
271impl Predict<ArrayView2<'_, Float>, Array1<i32>>
272 for LocalGlobalConsistency<LocalGlobalConsistencyTrained>
273{
274 #[allow(non_snake_case)]
275 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
276 let X = X.to_owned();
277 let n_test = X.nrows();
278 let mut predictions = Array1::zeros(n_test);
279
280 for i in 0..n_test {
281 let mut min_dist = f64::INFINITY;
283 let mut best_idx = 0;
284
285 for j in 0..self.state.X_train.nrows() {
286 let diff = &X.row(i) - &self.state.X_train.row(j);
287 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
288 if dist < min_dist {
289 min_dist = dist;
290 best_idx = j;
291 }
292 }
293
294 let distributions = self.state.label_distributions.row(best_idx);
296 let max_idx = distributions
297 .iter()
298 .enumerate()
299 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
300 .unwrap()
301 .0;
302
303 predictions[i] = self.state.classes[max_idx];
304 }
305
306 Ok(predictions)
307 }
308}
309
310impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
311 for LocalGlobalConsistency<LocalGlobalConsistencyTrained>
312{
313 #[allow(non_snake_case)]
314 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
315 let X = X.to_owned();
316 let n_test = X.nrows();
317 let n_classes = self.state.classes.len();
318 let mut probas = Array2::zeros((n_test, n_classes));
319
320 for i in 0..n_test {
321 let mut min_dist = f64::INFINITY;
323 let mut best_idx = 0;
324
325 for j in 0..self.state.X_train.nrows() {
326 let diff = &X.row(i) - &self.state.X_train.row(j);
327 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
328 if dist < min_dist {
329 min_dist = dist;
330 best_idx = j;
331 }
332 }
333
334 for k in 0..n_classes {
336 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
337 }
338 }
339
340 Ok(probas)
341 }
342}
343
344#[derive(Debug, Clone)]
346pub struct LocalGlobalConsistencyTrained {
347 pub X_train: Array2<f64>,
349 pub y_train: Array1<i32>,
351 pub classes: Array1<i32>,
353 pub label_distributions: Array2<f64>,
355 pub affinity_matrix: Array2<f64>,
357}