1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
43pub struct LabelSpreading<S = Untrained> {
44 state: S,
45 kernel: String,
46 gamma: f64,
47 n_neighbors: usize,
48 alpha: f64,
49 max_iter: usize,
50 tol: f64,
51}
52
53impl LabelSpreading<Untrained> {
54 pub fn new() -> Self {
56 Self {
57 state: Untrained,
58 kernel: "rbf".to_string(),
59 gamma: 20.0,
60 n_neighbors: 7,
61 alpha: 0.2,
62 max_iter: 30,
63 tol: 1e-3,
64 }
65 }
66
67 pub fn kernel(mut self, kernel: String) -> Self {
69 self.kernel = kernel;
70 self
71 }
72
73 pub fn gamma(mut self, gamma: f64) -> Self {
75 self.gamma = gamma;
76 self
77 }
78
79 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
81 self.n_neighbors = n_neighbors;
82 self
83 }
84
85 pub fn alpha(mut self, alpha: f64) -> Self {
87 self.alpha = alpha;
88 self
89 }
90
91 pub fn max_iter(mut self, max_iter: usize) -> Self {
93 self.max_iter = max_iter;
94 self
95 }
96
97 pub fn tol(mut self, tol: f64) -> Self {
99 self.tol = tol;
100 self
101 }
102}
103
104impl Default for LabelSpreading<Untrained> {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110impl Estimator for LabelSpreading<Untrained> {
111 type Config = ();
112 type Error = SklearsError;
113 type Float = Float;
114
115 fn config(&self) -> &Self::Config {
116 &()
117 }
118}
119
120impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for LabelSpreading<Untrained> {
121 type Fitted = LabelSpreading<LabelSpreadingTrained>;
122
123 #[allow(non_snake_case)]
124 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
125 let X = X.to_owned();
126 let y = y.to_owned();
127
128 let (n_samples, _n_features) = X.dim();
129
130 let mut labeled_indices = Vec::new();
132 let mut unlabeled_indices = Vec::new();
133 let mut classes = std::collections::HashSet::new();
134
135 for (i, &label) in y.iter().enumerate() {
136 if label == -1 {
137 unlabeled_indices.push(i);
138 } else {
139 labeled_indices.push(i);
140 classes.insert(label);
141 }
142 }
143
144 if labeled_indices.is_empty() {
145 return Err(SklearsError::InvalidInput(
146 "No labeled samples provided".to_string(),
147 ));
148 }
149
150 let classes: Vec<i32> = classes.into_iter().collect();
151 let n_classes = classes.len();
152
153 let W = self.build_affinity_matrix(&X)?;
155
156 let D = W.sum_axis(Axis(1));
158 let mut D_sqrt_inv = Array2::zeros((n_samples, n_samples));
159 for i in 0..n_samples {
160 if D[i] > 0.0 {
161 D_sqrt_inv[[i, i]] = 1.0 / D[i].sqrt();
162 }
163 }
164
165 let S = D_sqrt_inv.dot(&W).dot(&D_sqrt_inv);
166
167 let mut Y = Array2::zeros((n_samples, n_classes));
169 for &idx in &labeled_indices {
170 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
171 Y[[idx, class_idx]] = 1.0;
172 }
173 }
174
175 let Y_static = Y.clone();
176
177 let mut prev_Y = Y.clone();
179 for _iter in 0..self.max_iter {
180 Y = self.alpha * S.dot(&Y) + (1.0 - self.alpha) * &Y_static;
182
183 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
185 if diff < self.tol {
186 break;
187 }
188 prev_Y = Y.clone();
189 }
190
191 Ok(LabelSpreading {
192 state: LabelSpreadingTrained {
193 X_train: X.clone(),
194 y_train: y,
195 classes: Array1::from(classes),
196 label_distributions: Y,
197 affinity_matrix: W,
198 },
199 kernel: self.kernel,
200 gamma: self.gamma,
201 n_neighbors: self.n_neighbors,
202 alpha: self.alpha,
203 max_iter: self.max_iter,
204 tol: self.tol,
205 })
206 }
207}
208
209impl LabelSpreading<Untrained> {
210 fn build_affinity_matrix(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
211 let n_samples = X.nrows();
212 let mut W = Array2::zeros((n_samples, n_samples));
213
214 match self.kernel.as_str() {
215 "rbf" => {
216 for i in 0..n_samples {
217 for j in 0..n_samples {
218 if i != j {
219 let diff = &X.row(i) - &X.row(j);
220 let dist_sq = diff.mapv(|x| x * x).sum();
221 W[[i, j]] = (-self.gamma * dist_sq).exp();
222 }
223 }
224 }
225 }
226 "knn" => {
227 for i in 0..n_samples {
228 let mut distances: Vec<(usize, f64)> = Vec::new();
229 for j in 0..n_samples {
230 if i != j {
231 let diff = &X.row(i) - &X.row(j);
232 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
233 distances.push((j, dist));
234 }
235 }
236
237 distances
238 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
239
240 for &(j, _) in distances.iter().take(self.n_neighbors) {
241 W[[i, j]] = 1.0;
242 W[[j, i]] = 1.0; }
244 }
245 }
246 _ => {
247 return Err(SklearsError::InvalidInput(format!(
248 "Unknown kernel: {}",
249 self.kernel
250 )));
251 }
252 }
253
254 Ok(W)
255 }
256}
257
258impl Predict<ArrayView2<'_, Float>, Array1<i32>> for LabelSpreading<LabelSpreadingTrained> {
259 #[allow(non_snake_case)]
260 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
261 let X = X.to_owned();
262 let n_test = X.nrows();
263 let mut predictions = Array1::zeros(n_test);
264
265 for i in 0..n_test {
266 let mut min_dist = f64::INFINITY;
268 let mut best_idx = 0;
269
270 for j in 0..self.state.X_train.nrows() {
271 let diff = &X.row(i) - &self.state.X_train.row(j);
272 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
273 if dist < min_dist {
274 min_dist = dist;
275 best_idx = j;
276 }
277 }
278
279 let distributions = self.state.label_distributions.row(best_idx);
281 let max_idx = distributions
282 .iter()
283 .enumerate()
284 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
285 .unwrap()
286 .0;
287
288 predictions[i] = self.state.classes[max_idx];
289 }
290
291 Ok(predictions)
292 }
293}
294
295impl PredictProba<ArrayView2<'_, Float>, Array2<f64>> for LabelSpreading<LabelSpreadingTrained> {
296 #[allow(non_snake_case)]
297 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
298 let X = X.to_owned();
299 let n_test = X.nrows();
300 let n_classes = self.state.classes.len();
301 let mut probas = Array2::zeros((n_test, n_classes));
302
303 for i in 0..n_test {
304 let mut min_dist = f64::INFINITY;
306 let mut best_idx = 0;
307
308 for j in 0..self.state.X_train.nrows() {
309 let diff = &X.row(i) - &self.state.X_train.row(j);
310 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
311 if dist < min_dist {
312 min_dist = dist;
313 best_idx = j;
314 }
315 }
316
317 for k in 0..n_classes {
319 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
320 }
321 }
322
323 Ok(probas)
324 }
325}
326
327#[derive(Debug, Clone)]
329pub struct LabelSpreadingTrained {
330 pub X_train: Array2<f64>,
332 pub y_train: Array1<i32>,
334 pub classes: Array1<i32>,
336 pub label_distributions: Array2<f64>,
338 pub affinity_matrix: Array2<f64>,
340}