sklears_semi_supervised/
local_global_consistency.rs

1//! Local and Global Consistency algorithm implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Local and Global Consistency classifier
11///
12/// Local and Global Consistency is a graph-based semi-supervised learning method that
13/// finds a function which is smooth with respect to the intrinsic structure (local consistency)
14/// and close to the given labels (global consistency). It minimizes:
15/// min f^T * L * f + α * ||f - y||²
16///
17/// # Parameters
18///
19/// * `kernel` - Kernel function ('knn' or 'rbf')
20/// * `gamma` - Parameter for RBF kernel
21/// * `n_neighbors` - Number of neighbors for KNN kernel
22/// * `alpha` - Regularization parameter balancing smoothness and label fitting
23/// * `max_iter` - Maximum number of iterations for iterative solution
24/// * `tol` - Convergence tolerance
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::array;
30/// use sklears_semi_supervised::LocalGlobalConsistency;
31/// use sklears_core::traits::{Predict, Fit};
32///
33///
34/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
35/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
36///
37/// let lgc = LocalGlobalConsistency::new()
38///     .kernel("rbf".to_string())
39///     .gamma(20.0)
40///     .alpha(0.99);
41/// let fitted = lgc.fit(&X.view(), &y.view()).unwrap();
42/// let predictions = fitted.predict(&X.view()).unwrap();
43/// ```
44#[derive(Debug, Clone)]
45pub struct LocalGlobalConsistency<S = Untrained> {
46    state: S,
47    kernel: String,
48    gamma: f64,
49    n_neighbors: usize,
50    alpha: f64,
51    max_iter: usize,
52    tol: f64,
53}
54
55impl LocalGlobalConsistency<Untrained> {
56    /// Create a new LocalGlobalConsistency instance
57    pub fn new() -> Self {
58        Self {
59            state: Untrained,
60            kernel: "rbf".to_string(),
61            gamma: 20.0,
62            n_neighbors: 7,
63            alpha: 0.99,
64            max_iter: 1000,
65            tol: 1e-6,
66        }
67    }
68
69    /// Set the kernel function
70    pub fn kernel(mut self, kernel: String) -> Self {
71        self.kernel = kernel;
72        self
73    }
74
75    /// Set the gamma parameter for RBF kernel
76    pub fn gamma(mut self, gamma: f64) -> Self {
77        self.gamma = gamma;
78        self
79    }
80
81    /// Set the number of neighbors for KNN kernel
82    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
83        self.n_neighbors = n_neighbors;
84        self
85    }
86
87    /// Set the regularization parameter
88    pub fn alpha(mut self, alpha: f64) -> Self {
89        self.alpha = alpha;
90        self
91    }
92
93    /// Set the maximum number of iterations
94    pub fn max_iter(mut self, max_iter: usize) -> Self {
95        self.max_iter = max_iter;
96        self
97    }
98
99    /// Set the convergence tolerance
100    pub fn tol(mut self, tol: f64) -> Self {
101        self.tol = tol;
102        self
103    }
104
105    fn build_affinity_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
106        let n_samples = X.nrows();
107        let mut W = Array2::zeros((n_samples, n_samples));
108
109        match self.kernel.as_str() {
110            "rbf" => {
111                for i in 0..n_samples {
112                    for j in 0..n_samples {
113                        if i != j {
114                            let diff = &X.row(i) - &X.row(j);
115                            let dist_sq = diff.mapv(|x| x * x).sum();
116                            W[[i, j]] = (-self.gamma * dist_sq).exp();
117                        }
118                    }
119                }
120            }
121            "knn" => {
122                for i in 0..n_samples {
123                    let mut distances: Vec<(usize, f64)> = Vec::new();
124                    for j in 0..n_samples {
125                        if i != j {
126                            let diff = &X.row(i) - &X.row(j);
127                            let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
128                            distances.push((j, dist));
129                        }
130                    }
131
132                    distances
133                        .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
134
135                    for &(j, _) in distances.iter().take(self.n_neighbors) {
136                        W[[i, j]] = 1.0;
137                        W[[j, i]] = 1.0; // Make symmetric
138                    }
139                }
140            }
141            _ => {
142                return Err(SklearsError::InvalidInput(format!(
143                    "Unknown kernel: {}",
144                    self.kernel
145                )));
146            }
147        }
148
149        Ok(W)
150    }
151
152    #[allow(non_snake_case)]
153    fn solve_lgc_function(
154        &self,
155        W: &Array2<f64>,
156        labeled_indices: &[usize],
157        y: &Array1<i32>,
158        classes: &[i32],
159    ) -> SklResult<Array2<f64>> {
160        let n_samples = W.nrows();
161        let n_classes = classes.len();
162
163        // Compute degree matrix and normalized Laplacian
164        let D = W.sum_axis(Axis(1));
165        let mut D_sqrt_inv = Array2::zeros((n_samples, n_samples));
166        for i in 0..n_samples {
167            if D[i] > 0.0 {
168                D_sqrt_inv[[i, i]] = 1.0 / D[i].sqrt();
169            }
170        }
171
172        // Symmetric normalized Laplacian: L_sym = I - D^(-1/2) * W * D^(-1/2)
173        let S = D_sqrt_inv.dot(W).dot(&D_sqrt_inv);
174
175        // Initialize label matrix Y
176        let mut Y = Array2::zeros((n_samples, n_classes));
177        for &idx in labeled_indices {
178            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
179                Y[[idx, class_idx]] = 1.0;
180            }
181        }
182
183        let Y_static = Y.clone();
184
185        // Iteratively solve: F = α * S * F + (1 - α) * Y
186        let mut prev_F = Y.clone();
187        for _iter in 0..self.max_iter {
188            Y = self.alpha * S.dot(&Y) + (1.0 - self.alpha) * &Y_static;
189
190            // Check convergence
191            let diff = (&Y - &prev_F).mapv(|x| x.abs()).sum();
192            if diff < self.tol {
193                break;
194            }
195            prev_F = Y.clone();
196        }
197
198        Ok(Y)
199    }
200}
201
202impl Default for LocalGlobalConsistency<Untrained> {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl Estimator for LocalGlobalConsistency<Untrained> {
209    type Config = ();
210    type Error = SklearsError;
211    type Float = Float;
212
213    fn config(&self) -> &Self::Config {
214        &()
215    }
216}
217
218impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for LocalGlobalConsistency<Untrained> {
219    type Fitted = LocalGlobalConsistency<LocalGlobalConsistencyTrained>;
220
221    #[allow(non_snake_case)]
222    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
223        let X = X.to_owned();
224        let y = y.to_owned();
225
226        let (n_samples, _n_features) = X.dim();
227
228        // Identify labeled and unlabeled samples
229        let mut labeled_indices = Vec::new();
230        let mut classes = std::collections::HashSet::new();
231
232        for (i, &label) in y.iter().enumerate() {
233            if label != -1 {
234                labeled_indices.push(i);
235                classes.insert(label);
236            }
237        }
238
239        if labeled_indices.is_empty() {
240            return Err(SklearsError::InvalidInput(
241                "No labeled samples provided".to_string(),
242            ));
243        }
244
245        let classes: Vec<i32> = classes.into_iter().collect();
246
247        // Build affinity matrix
248        let W = self.build_affinity_matrix(&X)?;
249
250        // Solve local global consistency function
251        let F = self.solve_lgc_function(&W, &labeled_indices, &y, &classes)?;
252
253        Ok(LocalGlobalConsistency {
254            state: LocalGlobalConsistencyTrained {
255                X_train: X.clone(),
256                y_train: y,
257                classes: Array1::from(classes),
258                label_distributions: F,
259                affinity_matrix: W,
260            },
261            kernel: self.kernel,
262            gamma: self.gamma,
263            n_neighbors: self.n_neighbors,
264            alpha: self.alpha,
265            max_iter: self.max_iter,
266            tol: self.tol,
267        })
268    }
269}
270
271impl Predict<ArrayView2<'_, Float>, Array1<i32>>
272    for LocalGlobalConsistency<LocalGlobalConsistencyTrained>
273{
274    #[allow(non_snake_case)]
275    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
276        let X = X.to_owned();
277        let n_test = X.nrows();
278        let mut predictions = Array1::zeros(n_test);
279
280        for i in 0..n_test {
281            // Find most similar training sample
282            let mut min_dist = f64::INFINITY;
283            let mut best_idx = 0;
284
285            for j in 0..self.state.X_train.nrows() {
286                let diff = &X.row(i) - &self.state.X_train.row(j);
287                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
288                if dist < min_dist {
289                    min_dist = dist;
290                    best_idx = j;
291                }
292            }
293
294            // Use the label distribution of the most similar sample
295            let distributions = self.state.label_distributions.row(best_idx);
296            let max_idx = distributions
297                .iter()
298                .enumerate()
299                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
300                .unwrap()
301                .0;
302
303            predictions[i] = self.state.classes[max_idx];
304        }
305
306        Ok(predictions)
307    }
308}
309
310impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
311    for LocalGlobalConsistency<LocalGlobalConsistencyTrained>
312{
313    #[allow(non_snake_case)]
314    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
315        let X = X.to_owned();
316        let n_test = X.nrows();
317        let n_classes = self.state.classes.len();
318        let mut probas = Array2::zeros((n_test, n_classes));
319
320        for i in 0..n_test {
321            // Find most similar training sample
322            let mut min_dist = f64::INFINITY;
323            let mut best_idx = 0;
324
325            for j in 0..self.state.X_train.nrows() {
326                let diff = &X.row(i) - &self.state.X_train.row(j);
327                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
328                if dist < min_dist {
329                    min_dist = dist;
330                    best_idx = j;
331                }
332            }
333
334            // Copy the label distribution
335            for k in 0..n_classes {
336                probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
337            }
338        }
339
340        Ok(probas)
341    }
342}
343
344/// Trained state for LocalGlobalConsistency
345#[derive(Debug, Clone)]
346pub struct LocalGlobalConsistencyTrained {
347    /// X_train
348    pub X_train: Array2<f64>,
349    /// y_train
350    pub y_train: Array1<i32>,
351    /// classes
352    pub classes: Array1<i32>,
353    /// label_distributions
354    pub label_distributions: Array2<f64>,
355    /// affinity_matrix
356    pub affinity_matrix: Array2<f64>,
357}