1use crate::graph::{graph_laplacian, knn_graph};
8use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14use std::collections::HashSet;
15
16#[derive(Debug, Clone)]
63pub struct ManifoldRegularization<S = Untrained> {
64 state: S,
65 lambda_a: f64,
66 lambda_i: f64,
67 kernel: String,
68 gamma: f64,
69 degree: usize,
70 graph_kernel: String,
71 n_neighbors: usize,
72 max_iter: usize,
73 tol: f64,
74}
75
76impl ManifoldRegularization<Untrained> {
77 pub fn new() -> Self {
79 Self {
80 state: Untrained,
81 lambda_a: 0.01,
82 lambda_i: 0.1,
83 kernel: "rbf".to_string(),
84 gamma: 1.0,
85 degree: 3,
86 graph_kernel: "knn".to_string(),
87 n_neighbors: 7,
88 max_iter: 1000,
89 tol: 1e-6,
90 }
91 }
92
93 pub fn lambda_a(mut self, lambda_a: f64) -> Self {
95 self.lambda_a = lambda_a;
96 self
97 }
98
99 pub fn lambda_i(mut self, lambda_i: f64) -> Self {
101 self.lambda_i = lambda_i;
102 self
103 }
104
105 pub fn kernel(mut self, kernel: String) -> Self {
107 self.kernel = kernel;
108 self
109 }
110
111 pub fn gamma(mut self, gamma: f64) -> Self {
113 self.gamma = gamma;
114 self
115 }
116
117 pub fn degree(mut self, degree: usize) -> Self {
119 self.degree = degree;
120 self
121 }
122
123 pub fn graph_kernel(mut self, graph_kernel: String) -> Self {
125 self.graph_kernel = graph_kernel;
126 self
127 }
128
129 pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
131 self.n_neighbors = n_neighbors;
132 self
133 }
134
135 pub fn max_iter(mut self, max_iter: usize) -> Self {
137 self.max_iter = max_iter;
138 self
139 }
140
141 pub fn tol(mut self, tol: f64) -> Self {
143 self.tol = tol;
144 self
145 }
146
147 fn compute_kernel_matrix(&self, X1: &Array2<f64>, X2: &Array2<f64>) -> SklResult<Array2<f64>> {
148 let n1 = X1.nrows();
149 let n2 = X2.nrows();
150 let mut K = Array2::zeros((n1, n2));
151
152 match self.kernel.as_str() {
153 "rbf" => {
154 for i in 0..n1 {
155 for j in 0..n2 {
156 let diff = &X1.row(i) - &X2.row(j);
157 let dist_sq = diff.mapv(|x| x * x).sum();
158 K[[i, j]] = (-self.gamma * dist_sq).exp();
159 }
160 }
161 }
162 "linear" => {
163 K = X1.dot(&X2.t());
164 }
165 "polynomial" => {
166 let linear_kernel = X1.dot(&X2.t());
167 for i in 0..n1 {
168 for j in 0..n2 {
169 K[[i, j]] =
170 (self.gamma * linear_kernel[[i, j]] + 1.0).powi(self.degree as i32);
171 }
172 }
173 }
174 _ => {
175 return Err(SklearsError::InvalidInput(format!(
176 "Unknown kernel: {}",
177 self.kernel
178 )));
179 }
180 }
181
182 Ok(K)
183 }
184
185 fn build_manifold_graph(&self, X: &Array2<f64>) -> SklResult<Array2<f64>> {
186 match self.graph_kernel.as_str() {
187 "knn" => knn_graph(X, self.n_neighbors, "connectivity"),
188 "rbf" => {
189 let n_samples = X.nrows();
190 let mut W = Array2::zeros((n_samples, n_samples));
191 for i in 0..n_samples {
192 for j in 0..n_samples {
193 if i != j {
194 let diff = &X.row(i) - &X.row(j);
195 let dist_sq = diff.mapv(|x| x * x).sum();
196 W[[i, j]] = (-self.gamma * dist_sq).exp();
197 }
198 }
199 }
200 Ok(W)
201 }
202 _ => Err(SklearsError::InvalidInput(format!(
203 "Unknown graph kernel: {}",
204 self.graph_kernel
205 ))),
206 }
207 }
208
209 fn solve_manifold_regularized_problem(
210 &self,
211 K: &Array2<f64>,
212 L: &Array2<f64>,
213 labeled_indices: &[usize],
214 y_labeled: &Array1<i32>,
215 classes: &[i32],
216 ) -> SklResult<Array2<f64>> {
217 let n_samples = K.nrows();
218 let n_classes = classes.len();
219 let n_labeled = labeled_indices.len();
220
221 let mut Y_l = Array2::zeros((n_labeled, n_classes));
223 for (i, &idx) in labeled_indices.iter().enumerate() {
224 if let Some(class_idx) = classes.iter().position(|&c| c == y_labeled[i]) {
225 Y_l[[i, class_idx]] = 1.0;
226 }
227 }
228
229 let mut K_ll = Array2::zeros((n_labeled, n_labeled));
231 for (i, &idx_i) in labeled_indices.iter().enumerate() {
232 for (j, &idx_j) in labeled_indices.iter().enumerate() {
233 K_ll[[i, j]] = K[[idx_i, idx_j]];
234 }
235 }
236
237 let mut L_ll = Array2::zeros((n_labeled, n_labeled));
240 for (i, &idx_i) in labeled_indices.iter().enumerate() {
241 for (j, &idx_j) in labeled_indices.iter().enumerate() {
242 L_ll[[i, j]] = L[[idx_i, idx_j]];
243 }
244 }
245
246 let mut A = K_ll.clone();
248
249 for i in 0..n_labeled {
251 A[[i, i]] += self.lambda_a;
252 }
253
254 A = A + self.lambda_i * &L_ll;
256
257 let mut alpha = Array2::zeros((n_labeled, n_classes));
259
260 for _iter in 0..self.max_iter {
262 let mut new_alpha = Array2::zeros((n_labeled, n_classes));
263
264 for i in 0..n_labeled {
265 for k in 0..n_classes {
266 let mut sum = Y_l[[i, k]];
267 for j in 0..n_labeled {
268 if i != j {
269 sum -= A[[i, j]] * alpha[[j, k]];
270 }
271 }
272 new_alpha[[i, k]] = sum / A[[i, i]];
273 }
274 }
275
276 let diff = (&new_alpha - &alpha).mapv(|x| x.abs()).sum();
278 alpha = new_alpha;
279
280 if diff < self.tol {
281 break;
282 }
283 }
284
285 let mut F = Array2::zeros((n_samples, n_classes));
287
288 for i in 0..n_samples {
289 for k in 0..n_classes {
290 let mut sum = 0.0;
291 for (j, &labeled_idx) in labeled_indices.iter().enumerate() {
292 sum += K[[i, labeled_idx]] * alpha[[j, k]];
293 }
294 F[[i, k]] = sum;
295 }
296 }
297
298 Ok(F)
299 }
300}
301
302impl Default for ManifoldRegularization<Untrained> {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308impl Estimator for ManifoldRegularization<Untrained> {
309 type Config = ();
310 type Error = SklearsError;
311 type Float = Float;
312
313 fn config(&self) -> &Self::Config {
314 &()
315 }
316}
317
318impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for ManifoldRegularization<Untrained> {
319 type Fitted = ManifoldRegularization<ManifoldRegularizationTrained>;
320
321 #[allow(non_snake_case)]
322 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
323 let X = X.to_owned();
324 let y = y.to_owned();
325
326 let mut labeled_indices = Vec::new();
328 let mut y_labeled = Vec::new();
329 let mut classes = HashSet::new();
330
331 for (i, &label) in y.iter().enumerate() {
332 if label != -1 {
333 labeled_indices.push(i);
334 y_labeled.push(label);
335 classes.insert(label);
336 }
337 }
338
339 if labeled_indices.is_empty() {
340 return Err(SklearsError::InvalidInput(
341 "No labeled samples provided".to_string(),
342 ));
343 }
344
345 let classes: Vec<i32> = classes.into_iter().collect();
346 let y_labeled = Array1::from(y_labeled);
347
348 let K = self.compute_kernel_matrix(&X, &X)?;
350
351 let W = self.build_manifold_graph(&X)?;
353 let L = graph_laplacian(&W, false)?;
354
355 let F = self.solve_manifold_regularized_problem(
357 &K,
358 &L,
359 &labeled_indices,
360 &y_labeled,
361 &classes,
362 )?;
363
364 Ok(ManifoldRegularization {
365 state: ManifoldRegularizationTrained {
366 X_train: X.clone(),
367 y_train: y,
368 classes: Array1::from(classes),
369 alpha: F,
370 kernel_matrix: K,
371 manifold_laplacian: L,
372 labeled_indices,
373 },
374 lambda_a: self.lambda_a,
375 lambda_i: self.lambda_i,
376 kernel: self.kernel,
377 gamma: self.gamma,
378 degree: self.degree,
379 graph_kernel: self.graph_kernel,
380 n_neighbors: self.n_neighbors,
381 max_iter: self.max_iter,
382 tol: self.tol,
383 })
384 }
385}
386
387impl Predict<ArrayView2<'_, Float>, Array1<i32>>
388 for ManifoldRegularization<ManifoldRegularizationTrained>
389{
390 #[allow(non_snake_case)]
391 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
392 let X = X.to_owned();
393 let n_test = X.nrows();
394 let mut predictions = Array1::zeros(n_test);
395
396 let K_test = self.compute_kernel_matrix(&X, &self.state.X_train)?;
398
399 for i in 0..n_test {
400 let mut max_value = f64::NEG_INFINITY;
402 let mut best_class_idx = 0;
403
404 for k in 0..self.state.classes.len() {
405 let mut value = 0.0;
406 for j in 0..self.state.X_train.nrows() {
407 value += K_test[[i, j]] * self.state.alpha[[j, k]];
408 }
409
410 if value > max_value {
411 max_value = value;
412 best_class_idx = k;
413 }
414 }
415
416 predictions[i] = self.state.classes[best_class_idx];
417 }
418
419 Ok(predictions)
420 }
421}
422
423impl ManifoldRegularization<ManifoldRegularizationTrained> {
424 fn compute_kernel_matrix(&self, X1: &Array2<f64>, X2: &Array2<f64>) -> SklResult<Array2<f64>> {
425 let n1 = X1.nrows();
426 let n2 = X2.nrows();
427 let mut K = Array2::zeros((n1, n2));
428
429 match self.kernel.as_str() {
430 "rbf" => {
431 for i in 0..n1 {
432 for j in 0..n2 {
433 let diff = &X1.row(i) - &X2.row(j);
434 let dist_sq = diff.mapv(|x| x * x).sum();
435 K[[i, j]] = (-self.gamma * dist_sq).exp();
436 }
437 }
438 }
439 "linear" => {
440 K = X1.dot(&X2.t());
441 }
442 "polynomial" => {
443 let linear_kernel = X1.dot(&X2.t());
444 for i in 0..n1 {
445 for j in 0..n2 {
446 K[[i, j]] =
447 (self.gamma * linear_kernel[[i, j]] + 1.0).powi(self.degree as i32);
448 }
449 }
450 }
451 _ => {
452 return Err(SklearsError::InvalidInput(format!(
453 "Unknown kernel: {}",
454 self.kernel
455 )));
456 }
457 }
458
459 Ok(K)
460 }
461}
462
463#[derive(Debug, Clone)]
465pub struct ManifoldRegularizationTrained {
466 pub X_train: Array2<f64>,
468 pub y_train: Array1<i32>,
470 pub classes: Array1<i32>,
472 pub alpha: Array2<f64>,
474 pub kernel_matrix: Array2<f64>,
476 pub manifold_laplacian: Array2<f64>,
478 pub labeled_indices: Vec<usize>,
480}