sklears_semi_supervised/
label_propagation.rs

1//! Label Propagation algorithm implementation
2
3use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5    error::{Result as SklResult, SklearsError},
6    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7    types::Float,
8};
9use std::collections::BTreeSet;
10
11/// Label Propagation classifier
12///
13/// Label propagation is a graph-based semi-supervised learning algorithm that
14/// propagates labels through a neighborhood graph. It uses the graph Laplacian
15/// to propagate labels from labeled to unlabeled points.
16///
17/// # Parameters
18///
19/// * `kernel` - Kernel function ('knn' or 'rbf')
20/// * `gamma` - Parameter for RBF kernel
21/// * `n_neighbors` - Number of neighbors for KNN kernel
22/// * `alpha` - Clamping factor (0.0 = hard clamping, 1.0 = soft clamping)
23/// * `max_iter` - Maximum number of iterations
24/// * `tol` - Convergence tolerance
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::array;
30/// use sklears_semi_supervised::LabelPropagation;
31/// use sklears_core::traits::{Predict, Fit};
32///
33///
34/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
35/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
36///
37/// let lp = LabelPropagation::new()
38///     .kernel("rbf".to_string())
39///     .gamma(20.0);
40/// let fitted = lp.fit(&X.view(), &y.view()).unwrap();
41/// let predictions = fitted.predict(&X.view()).unwrap();
42/// ```
43#[derive(Debug, Clone)]
44pub struct LabelPropagation<S = Untrained> {
45    state: S,
46    kernel: String,
47    gamma: f64,
48    n_neighbors: usize,
49    alpha: f64,
50    max_iter: usize,
51    tol: f64,
52}
53
54impl LabelPropagation<Untrained> {
55    /// Create a new LabelPropagation instance
56    pub fn new() -> Self {
57        Self {
58            state: Untrained,
59            kernel: "rbf".to_string(),
60            gamma: 20.0,
61            n_neighbors: 7,
62            alpha: 1.0,
63            max_iter: 1000,
64            tol: 1e-3,
65        }
66    }
67
68    /// Set the kernel function
69    pub fn kernel(mut self, kernel: String) -> Self {
70        self.kernel = kernel;
71        self
72    }
73
74    /// Set the gamma parameter for RBF kernel
75    pub fn gamma(mut self, gamma: f64) -> Self {
76        self.gamma = gamma;
77        self
78    }
79
80    /// Set the number of neighbors for KNN kernel
81    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
82        self.n_neighbors = n_neighbors;
83        self
84    }
85
86    /// Set the clamping factor
87    pub fn alpha(mut self, alpha: f64) -> Self {
88        self.alpha = alpha;
89        self
90    }
91
92    /// Set the maximum number of iterations
93    pub fn max_iter(mut self, max_iter: usize) -> Self {
94        self.max_iter = max_iter;
95        self
96    }
97
98    /// Set the convergence tolerance
99    pub fn tol(mut self, tol: f64) -> Self {
100        self.tol = tol;
101        self
102    }
103}
104
105impl Default for LabelPropagation<Untrained> {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl Estimator for LabelPropagation<Untrained> {
112    type Config = ();
113    type Error = SklearsError;
114    type Float = Float;
115
116    fn config(&self) -> &Self::Config {
117        &()
118    }
119}
120
121impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for LabelPropagation<Untrained> {
122    type Fitted = LabelPropagation<LabelPropagationTrained>;
123
124    #[allow(non_snake_case)]
125    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
126        let X = X.to_owned();
127        let y = y.to_owned();
128
129        let (n_samples, _n_features) = X.dim();
130
131        // Identify labeled and unlabeled samples
132        let mut labeled_indices = Vec::new();
133        let mut unlabeled_indices = Vec::new();
134        let mut classes = BTreeSet::new();
135
136        for (i, &label) in y.iter().enumerate() {
137            if label == -1 {
138                unlabeled_indices.push(i);
139            } else {
140                labeled_indices.push(i);
141                classes.insert(label);
142            }
143        }
144
145        if labeled_indices.is_empty() {
146            return Err(SklearsError::InvalidInput(
147                "No labeled samples provided".to_string(),
148            ));
149        }
150
151        let classes: Vec<i32> = classes.into_iter().collect();
152        let n_classes = classes.len();
153
154        // Build affinity matrix
155        let W = self.build_affinity_matrix(&X)?;
156
157        // Normalize the matrix (row-wise normalization)
158        let D = W.sum_axis(Axis(1));
159        let mut P = Array2::zeros((n_samples, n_samples));
160        for i in 0..n_samples {
161            if D[i] > 0.0 {
162                for j in 0..n_samples {
163                    P[[i, j]] = W[[i, j]] / D[i];
164                }
165            }
166        }
167
168        // Initialize label distributions
169        let mut Y = Array2::zeros((n_samples, n_classes));
170        for &idx in &labeled_indices {
171            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
172                Y[[idx, class_idx]] = 1.0;
173            }
174        }
175
176        // Label propagation iterations
177        let mut prev_Y = Y.clone();
178        for _iter in 0..self.max_iter {
179            // Propagate: Y = P * Y
180            Y = P.dot(&Y);
181
182            // Clamp labeled nodes
183            for &idx in &labeled_indices {
184                if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
185                    for k in 0..n_classes {
186                        Y[[idx, k]] = if k == class_idx { 1.0 } else { 0.0 };
187                    }
188                }
189            }
190
191            // Check convergence
192            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
193            if diff < self.tol {
194                break;
195            }
196            prev_Y = Y.clone();
197        }
198
199        Ok(LabelPropagation {
200            state: LabelPropagationTrained {
201                X_train: X.clone(),
202                y_train: y,
203                classes: Array1::from(classes),
204                label_distributions: Y,
205                affinity_matrix: W,
206            },
207            kernel: self.kernel,
208            gamma: self.gamma,
209            n_neighbors: self.n_neighbors,
210            alpha: self.alpha,
211            max_iter: self.max_iter,
212            tol: self.tol,
213        })
214    }
215}
216
217impl LabelPropagation<Untrained> {
218    fn build_affinity_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
219        let n_samples = X.nrows();
220        let mut W = Array2::zeros((n_samples, n_samples));
221
222        match self.kernel.as_str() {
223            "rbf" => {
224                for i in 0..n_samples {
225                    for j in 0..n_samples {
226                        if i != j {
227                            let diff = &X.row(i) - &X.row(j);
228                            let dist_sq = diff.mapv(|x| x * x).sum();
229                            W[[i, j]] = (-self.gamma * dist_sq).exp();
230                        }
231                    }
232                }
233            }
234            "knn" => {
235                for i in 0..n_samples {
236                    let mut distances: Vec<(usize, f64)> = Vec::new();
237                    for j in 0..n_samples {
238                        if i != j {
239                            let diff = &X.row(i) - &X.row(j);
240                            let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
241                            distances.push((j, dist));
242                        }
243                    }
244
245                    distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
246
247                    for &(j, _) in distances.iter().take(self.n_neighbors) {
248                        W[[i, j]] = 1.0;
249                        W[[j, i]] = 1.0; // Make symmetric
250                    }
251                }
252            }
253            _ => {
254                return Err(SklearsError::InvalidInput(format!(
255                    "Unknown kernel: {}",
256                    self.kernel
257                )));
258            }
259        }
260
261        Ok(W)
262    }
263}
264
265impl Predict<ArrayView2<'_, Float>, Array1<i32>> for LabelPropagation<LabelPropagationTrained> {
266    #[allow(non_snake_case)]
267    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
268        let X = X.to_owned();
269        let n_test = X.nrows();
270        let mut predictions = Array1::zeros(n_test);
271
272        for i in 0..n_test {
273            // Find most similar training sample
274            let mut min_dist = f64::INFINITY;
275            let mut best_idx = 0;
276
277            for j in 0..self.state.X_train.nrows() {
278                let diff = &X.row(i) - &self.state.X_train.row(j);
279                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
280                if dist < min_dist {
281                    min_dist = dist;
282                    best_idx = j;
283                }
284            }
285
286            // Use the label distribution of the most similar sample
287            let distributions = self.state.label_distributions.row(best_idx);
288            let max_idx = distributions
289                .iter()
290                .enumerate()
291                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
292                .unwrap()
293                .0;
294
295            predictions[i] = self.state.classes[max_idx];
296        }
297
298        Ok(predictions)
299    }
300}
301
302impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
303    for LabelPropagation<LabelPropagationTrained>
304{
305    #[allow(non_snake_case)]
306    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
307        let X = X.to_owned();
308        let n_test = X.nrows();
309        let n_classes = self.state.classes.len();
310        let mut probas = Array2::zeros((n_test, n_classes));
311
312        for i in 0..n_test {
313            // Find most similar training sample
314            let mut min_dist = f64::INFINITY;
315            let mut best_idx = 0;
316
317            for j in 0..self.state.X_train.nrows() {
318                let diff = &X.row(i) - &self.state.X_train.row(j);
319                let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
320                if dist < min_dist {
321                    min_dist = dist;
322                    best_idx = j;
323                }
324            }
325
326            // Copy the label distribution
327            for k in 0..n_classes {
328                probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
329            }
330        }
331
332        Ok(probas)
333    }
334}
335
336/// Trained state for LabelPropagation
337#[derive(Debug, Clone)]
338pub struct LabelPropagationTrained {
339    /// X_train
340    pub X_train: Array2<f64>,
341    /// y_train
342    pub y_train: Array1<i32>,
343    /// classes
344    pub classes: Array1<i32>,
345    /// label_distributions
346    pub label_distributions: Array2<f64>,
347    /// affinity_matrix
348    pub affinity_matrix: Array2<f64>,
349}