sklears_semi_supervised/
label_spreading.rs

1//! Label Spreading algorithm implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9
10/// Label Spreading classifier
11///
12/// Label spreading is similar to label propagation but uses a different
13/// normalization of the graph Laplacian and includes a regularization parameter.
14///
15/// # Parameters
16///
17/// * `kernel` - Kernel function ('knn' or 'rbf')
18/// * `gamma` - Parameter for RBF kernel
19/// * `n_neighbors` - Number of neighbors for KNN kernel
20/// * `alpha` - Clamping factor
21/// * `max_iter` - Maximum number of iterations
22/// * `tol` - Convergence tolerance
23///
24/// # Examples
25///
26/// ```
27/// use scirs2_core::array;
28/// use sklears_semi_supervised::LabelSpreading;
29/// use sklears_core::traits::{Predict, Fit};
30///
31///
32/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
33/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
34///
35/// let ls = LabelSpreading::new()
36///     .kernel("rbf".to_string())
37///     .gamma(20.0)
38///     .alpha(0.2);
39/// let fitted = ls.fit(&X.view(), &y.view()).unwrap();
40/// let predictions = fitted.predict(&X.view()).unwrap();
41/// ```
42#[derive(Debug, Clone)]
43pub struct LabelSpreading<S = Untrained> {
44    state: S,
45    kernel: String,
46    gamma: f64,
47    n_neighbors: usize,
48    alpha: f64,
49    max_iter: usize,
50    tol: f64,
51}
52
53impl LabelSpreading<Untrained> {
54    /// Create a new LabelSpreading instance
55    pub fn new() -> Self {
56        Self {
57            state: Untrained,
58            kernel: "rbf".to_string(),
59            gamma: 20.0,
60            n_neighbors: 7,
61            alpha: 0.2,
62            max_iter: 30,
63            tol: 1e-3,
64        }
65    }
66
67    /// Set the kernel function
68    pub fn kernel(mut self, kernel: String) -> Self {
69        self.kernel = kernel;
70        self
71    }
72
73    /// Set the gamma parameter for RBF kernel
74    pub fn gamma(mut self, gamma: f64) -> Self {
75        self.gamma = gamma;
76        self
77    }
78
79    /// Set the number of neighbors for KNN kernel
80    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
81        self.n_neighbors = n_neighbors;
82        self
83    }
84
85    /// Set the clamping factor
86    pub fn alpha(mut self, alpha: f64) -> Self {
87        self.alpha = alpha;
88        self
89    }
90
91    /// Set the maximum number of iterations
92    pub fn max_iter(mut self, max_iter: usize) -> Self {
93        self.max_iter = max_iter;
94        self
95    }
96
97    /// Set the convergence tolerance
98    pub fn tol(mut self, tol: f64) -> Self {
99        self.tol = tol;
100        self
101    }
102}
103
104impl Default for LabelSpreading<Untrained> {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl Estimator for LabelSpreading<Untrained> {
111    type Config = ();
112    type Error = SklearsError;
113    type Float = Float;
114
115    fn config(&self) -> &Self::Config {
116        &()
117    }
118}
119
120impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for LabelSpreading<Untrained> {
121    type Fitted = LabelSpreading<LabelSpreadingTrained>;
122
123    #[allow(non_snake_case)]
124    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
125        let X = X.to_owned();
126        let y = y.to_owned();
127
128        let (n_samples, _n_features) = X.dim();
129
130        // Identify labeled and unlabeled samples
131        let mut labeled_indices = Vec::new();
132        let mut unlabeled_indices = Vec::new();
133        let mut classes = std::collections::HashSet::new();
134
135        for (i, &label) in y.iter().enumerate() {
136            if label == -1 {
137                unlabeled_indices.push(i);
138            } else {
139                labeled_indices.push(i);
140                classes.insert(label);
141            }
142        }
143
144        if labeled_indices.is_empty() {
145            return Err(SklearsError::InvalidInput(
146                "No labeled samples provided".to_string(),
147            ));
148        }
149
150        let classes: Vec<i32> = classes.into_iter().collect();
151        let n_classes = classes.len();
152
153        // Build affinity matrix
154        let W = self.build_affinity_matrix(&X)?;
155
156        // Compute normalized Laplacian
157        let D = W.sum_axis(Axis(1));
158        let mut D_sqrt_inv = Array2::zeros((n_samples, n_samples));
159        for i in 0..n_samples {
160            if D[i] > 0.0 {
161                D_sqrt_inv[[i, i]] = 1.0 / D[i].sqrt();
162            }
163        }
164
165        let S = D_sqrt_inv.dot(&W).dot(&D_sqrt_inv);
166
167        // Initialize label matrix
168        let mut Y = Array2::zeros((n_samples, n_classes));
169        for &idx in &labeled_indices {
170            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
171                Y[[idx, class_idx]] = 1.0;
172            }
173        }
174
175        let Y_static = Y.clone();
176
177        // Label spreading iterations
178        let mut prev_Y = Y.clone();
179        for _iter in 0..self.max_iter {
180            // Update: Y = alpha * S * Y + (1 - alpha) * Y_static
181            Y = self.alpha * S.dot(&Y) + (1.0 - self.alpha) * &Y_static;
182
183            // Check convergence
184            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
185            if diff < self.tol {
186                break;
187            }
188            prev_Y = Y.clone();
189        }
190
191        Ok(LabelSpreading {
192            state: LabelSpreadingTrained {
193                X_train: X.clone(),
194                y_train: y,
195                classes: Array1::from(classes),
196                label_distributions: Y,
197                affinity_matrix: W,
198            },
199            kernel: self.kernel,
200            gamma: self.gamma,
201            n_neighbors: self.n_neighbors,
202            alpha: self.alpha,
203            max_iter: self.max_iter,
204            tol: self.tol,
205        })
206    }
207}
208
209impl LabelSpreading<Untrained> {
210    fn build_affinity_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
211        let n_samples = X.nrows();
212        let mut W = Array2::zeros((n_samples, n_samples));
213
214        match self.kernel.as_str() {
215            "rbf" => {
216                for i in 0..n_samples {
217                    for j in 0..n_samples {
218                        if i != j {
219                            let diff = &X.row(i) - &X.row(j);
220                            let dist_sq = diff.mapv(|x| x * x).sum();
221                            W[[i, j]] = (-self.gamma * dist_sq).exp();
222                        }
223                    }
224                }
225            }
226            "knn" => {
227                for i in 0..n_samples {
228                    let mut distances: Vec<(usize, f64)> = Vec::new();
229                    for j in 0..n_samples {
230                        if i != j {
231                            let diff = &X.row(i) - &X.row(j);
232                            let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
233                            distances.push((j, dist));
234                        }
235                    }
236
237                    distances
238                        .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
239
240                    for &(j, _) in distances.iter().take(self.n_neighbors) {
241                        W[[i, j]] = 1.0;
242                        W[[j, i]] = 1.0; // Make symmetric
243                    }
244                }
245            }
246            _ => {
247                return Err(SklearsError::InvalidInput(format!(
248                    "Unknown kernel: {}",
249                    self.kernel
250                )));
251            }
252        }
253
254        Ok(W)
255    }
256}
257
258impl Predict<ArrayView2<'_, Float>, Array1<i32>> for LabelSpreading<LabelSpreadingTrained> {
259    #[allow(non_snake_case)]
260    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
261        let X = X.to_owned();
262        let n_test = X.nrows();
263        let mut predictions = Array1::zeros(n_test);
264
265        for i in 0..n_test {
266            // Find most similar training sample
267            let mut min_dist = f64::INFINITY;
268            let mut best_idx = 0;
269
270            for j in 0..self.state.X_train.nrows() {
271                let diff = &X.row(i) - &self.state.X_train.row(j);
272                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
273                if dist < min_dist {
274                    min_dist = dist;
275                    best_idx = j;
276                }
277            }
278
279            // Use the label distribution of the most similar sample
280            let distributions = self.state.label_distributions.row(best_idx);
281            let max_idx = distributions
282                .iter()
283                .enumerate()
284                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
285                .unwrap()
286                .0;
287
288            predictions[i] = self.state.classes[max_idx];
289        }
290
291        Ok(predictions)
292    }
293}
294
295impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for LabelSpreading<LabelSpreadingTrained> {
296    #[allow(non_snake_case)]
297    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
298        let X = X.to_owned();
299        let n_test = X.nrows();
300        let n_classes = self.state.classes.len();
301        let mut probas = Array2::zeros((n_test, n_classes));
302
303        for i in 0..n_test {
304            // Find most similar training sample
305            let mut min_dist = f64::INFINITY;
306            let mut best_idx = 0;
307
308            for j in 0..self.state.X_train.nrows() {
309                let diff = &X.row(i) - &self.state.X_train.row(j);
310                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
311                if dist < min_dist {
312                    min_dist = dist;
313                    best_idx = j;
314                }
315            }
316
317            // Copy the label distribution
318            for k in 0..n_classes {
319                probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
320            }
321        }
322
323        Ok(probas)
324    }
325}
326
327/// Trained state for LabelSpreading
328#[derive(Debug, Clone)]
329pub struct LabelSpreadingTrained {
330    /// X_train
331    pub X_train: Array2<f64>,
332    /// y_train
333    pub y_train: Array1<i32>,
334    /// classes
335    pub classes: Array1<i32>,
336    /// label_distributions
337    pub label_distributions: Array2<f64>,
338    /// affinity_matrix
339    pub affinity_matrix: Array2<f64>,
340}