1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9use std::collections::BTreeSet;
10
11#[derive(Debug, Clone)]
44pub struct LabelPropagation<S = Untrained> {
45 state: S,
46 kernel: String,
47 gamma: f64,
48 n_neighbors: usize,
49 alpha: f64,
50 max_iter: usize,
51 tol: f64,
52}
53
54impl LabelPropagation<Untrained> {
55 pub fn new() -> Self {
57 Self {
58 state: Untrained,
59 kernel: "rbf".to_string(),
60 gamma: 20.0,
61 n_neighbors: 7,
62 alpha: 1.0,
63 max_iter: 1000,
64 tol: 1e-3,
65 }
66 }
67
68 pub fn kernel(mut self, kernel: String) -> Self {
70 self.kernel = kernel;
71 self
72 }
73
74 pub fn gamma(mut self, gamma: f64) -> Self {
76 self.gamma = gamma;
77 self
78 }
79
80 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
82 self.n_neighbors = n_neighbors;
83 self
84 }
85
86 pub fn alpha(mut self, alpha: f64) -> Self {
88 self.alpha = alpha;
89 self
90 }
91
92 pub fn max_iter(mut self, max_iter: usize) -> Self {
94 self.max_iter = max_iter;
95 self
96 }
97
98 pub fn tol(mut self, tol: f64) -> Self {
100 self.tol = tol;
101 self
102 }
103}
104
105impl Default for LabelPropagation<Untrained> {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111impl Estimator for LabelPropagation<Untrained> {
112 type Config = ();
113 type Error = SklearsError;
114 type Float = Float;
115
116 fn config(&self) -> &Self::Config {
117 &()
118 }
119}
120
121impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for LabelPropagation<Untrained> {
122 type Fitted = LabelPropagation<LabelPropagationTrained>;
123
124 #[allow(non_snake_case)]
125 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
126 let X = X.to_owned();
127 let y = y.to_owned();
128
129 let (n_samples, _n_features) = X.dim();
130
131 let mut labeled_indices = Vec::new();
133 let mut unlabeled_indices = Vec::new();
134 let mut classes = BTreeSet::new();
135
136 for (i, &label) in y.iter().enumerate() {
137 if label == -1 {
138 unlabeled_indices.push(i);
139 } else {
140 labeled_indices.push(i);
141 classes.insert(label);
142 }
143 }
144
145 if labeled_indices.is_empty() {
146 return Err(SklearsError::InvalidInput(
147 "No labeled samples provided".to_string(),
148 ));
149 }
150
151 let classes: Vec<i32> = classes.into_iter().collect();
152 let n_classes = classes.len();
153
154 let W = self.build_affinity_matrix(&X)?;
156
157 let D = W.sum_axis(Axis(1));
159 let mut P = Array2::zeros((n_samples, n_samples));
160 for i in 0..n_samples {
161 if D[i] > 0.0 {
162 for j in 0..n_samples {
163 P[[i, j]] = W[[i, j]] / D[i];
164 }
165 }
166 }
167
168 let mut Y = Array2::zeros((n_samples, n_classes));
170 for &idx in &labeled_indices {
171 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
172 Y[[idx, class_idx]] = 1.0;
173 }
174 }
175
176 let mut prev_Y = Y.clone();
178 for _iter in 0..self.max_iter {
179 Y = P.dot(&Y);
181
182 for &idx in &labeled_indices {
184 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
185 for k in 0..n_classes {
186 Y[[idx, k]] = if k == class_idx { 1.0 } else { 0.0 };
187 }
188 }
189 }
190
191 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
193 if diff < self.tol {
194 break;
195 }
196 prev_Y = Y.clone();
197 }
198
199 Ok(LabelPropagation {
200 state: LabelPropagationTrained {
201 X_train: X.clone(),
202 y_train: y,
203 classes: Array1::from(classes),
204 label_distributions: Y,
205 affinity_matrix: W,
206 },
207 kernel: self.kernel,
208 gamma: self.gamma,
209 n_neighbors: self.n_neighbors,
210 alpha: self.alpha,
211 max_iter: self.max_iter,
212 tol: self.tol,
213 })
214 }
215}
216
217impl LabelPropagation<Untrained> {
218 fn build_affinity_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
219 let n_samples = X.nrows();
220 let mut W = Array2::zeros((n_samples, n_samples));
221
222 match self.kernel.as_str() {
223 "rbf" => {
224 for i in 0..n_samples {
225 for j in 0..n_samples {
226 if i != j {
227 let diff = &X.row(i) - &X.row(j);
228 let dist_sq = diff.mapv(|x| x * x).sum();
229 W[[i, j]] = (-self.gamma * dist_sq).exp();
230 }
231 }
232 }
233 }
234 "knn" => {
235 for i in 0..n_samples {
236 let mut distances: Vec<(usize, f64)> = Vec::new();
237 for j in 0..n_samples {
238 if i != j {
239 let diff = &X.row(i) - &X.row(j);
240 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
241 distances.push((j, dist));
242 }
243 }
244
245 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
246
247 for &(j, _) in distances.iter().take(self.n_neighbors) {
248 W[[i, j]] = 1.0;
249 W[[j, i]] = 1.0; }
251 }
252 }
253 _ => {
254 return Err(SklearsError::InvalidInput(format!(
255 "Unknown kernel: {}",
256 self.kernel
257 )));
258 }
259 }
260
261 Ok(W)
262 }
263}
264
265impl Predict<ArrayView2<'_, Float>, Array1<i32>> for LabelPropagation<LabelPropagationTrained> {
266 #[allow(non_snake_case)]
267 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
268 let X = X.to_owned();
269 let n_test = X.nrows();
270 let mut predictions = Array1::zeros(n_test);
271
272 for i in 0..n_test {
273 let mut min_dist = f64::INFINITY;
275 let mut best_idx = 0;
276
277 for j in 0..self.state.X_train.nrows() {
278 let diff = &X.row(i) - &self.state.X_train.row(j);
279 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
280 if dist < min_dist {
281 min_dist = dist;
282 best_idx = j;
283 }
284 }
285
286 let distributions = self.state.label_distributions.row(best_idx);
288 let max_idx = distributions
289 .iter()
290 .enumerate()
291 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
292 .unwrap()
293 .0;
294
295 predictions[i] = self.state.classes[max_idx];
296 }
297
298 Ok(predictions)
299 }
300}
301
302impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
303 for LabelPropagation<LabelPropagationTrained>
304{
305 #[allow(non_snake_case)]
306 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
307 let X = X.to_owned();
308 let n_test = X.nrows();
309 let n_classes = self.state.classes.len();
310 let mut probas = Array2::zeros((n_test, n_classes));
311
312 for i in 0..n_test {
313 let mut min_dist = f64::INFINITY;
315 let mut best_idx = 0;
316
317 for j in 0..self.state.X_train.nrows() {
318 let diff = &X.row(i) - &self.state.X_train.row(j);
319 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
320 if dist < min_dist {
321 min_dist = dist;
322 best_idx = j;
323 }
324 }
325
326 for k in 0..n_classes {
328 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
329 }
330 }
331
332 Ok(probas)
333 }
334}
335
336#[derive(Debug, Clone)]
338pub struct LabelPropagationTrained {
339 pub X_train: Array2<f64>,
341 pub y_train: Array1<i32>,
343 pub classes: Array1<i32>,
345 pub label_distributions: Array2<f64>,
347 pub affinity_matrix: Array2<f64>,
349}