1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use scirs2_core::random::Random;
5use sklears_core::{
6 error::{Result as SklResult, SklearsError},
7 traits::{Estimator, Fit, Predict, Untrained},
8 types::Float,
9};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone)]
41pub struct TriTraining<S = Untrained> {
42 state: S,
43 max_iter: usize,
44 verbose: bool,
45 theta: f64,
46}
47
48impl TriTraining<Untrained> {
49 pub fn new() -> Self {
51 Self {
52 state: Untrained,
53 max_iter: 30,
54 verbose: false,
55 theta: 0.1,
56 }
57 }
58
59 pub fn max_iter(mut self, max_iter: usize) -> Self {
61 self.max_iter = max_iter;
62 self
63 }
64
65 pub fn verbose(mut self, verbose: bool) -> Self {
67 self.verbose = verbose;
68 self
69 }
70
71 pub fn theta(mut self, theta: f64) -> Self {
73 self.theta = theta;
74 self
75 }
76
77 fn bootstrap_sample(
78 &self,
79 X: &Array2<f64>,
80 y: &Array1<i32>,
81 labeled_indices: &[usize],
82 ) -> (Array2<f64>, Array1<i32>) {
83 let n_labeled = labeled_indices.len();
84 let mut bootstrap_X = Array2::zeros((n_labeled, X.ncols()));
85 let mut bootstrap_y = Array1::zeros(n_labeled);
86
87 let mut rng = Random::seed(42);
89 for i in 0..n_labeled {
90 let random_idx = rng.gen_range(0..n_labeled);
91 let idx = labeled_indices[random_idx];
92 bootstrap_X.row_mut(i).assign(&X.row(idx));
93 bootstrap_y[i] = y[idx];
94 }
95
96 (bootstrap_X, bootstrap_y)
97 }
98
99 fn simple_classifier_fit_predict(
100 &self,
101 X_train: &Array2<f64>,
102 y_train: &Array1<i32>,
103 X_test: &Array2<f64>,
104 ) -> Array1<i32> {
105 let n_test = X_test.nrows();
106 let mut predictions = Array1::zeros(n_test);
107
108 for i in 0..n_test {
109 let mut distances: Vec<(f64, i32)> = Vec::new();
111 for j in 0..X_train.nrows() {
112 let diff = &X_test.row(i) - &X_train.row(j);
113 let dist = diff.mapv(|x| x * x).sum().sqrt();
114 distances.push((dist, y_train[j]));
115 }
116
117 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
118
119 let k = distances.len().clamp(1, 5);
121 let mut class_votes: HashMap<i32, usize> = HashMap::new();
122
123 for &(_, label) in distances.iter().take(k) {
124 *class_votes.entry(label).or_insert(0) += 1;
125 }
126
127 let best_class = class_votes
128 .iter()
129 .max_by_key(|(_, &count)| count)
130 .map(|(&class, _)| class)
131 .unwrap_or(y_train[0]);
132
133 predictions[i] = best_class;
134 }
135
136 predictions
137 }
138
139 fn estimate_error_rate(
140 &self,
141 classifier_i: &Array2<f64>,
142 y_i: &Array1<i32>,
143 classifier_j: &Array2<f64>,
144 y_j: &Array1<i32>,
145 X_labeled: &Array2<f64>,
146 y_labeled: &Array1<i32>,
147 ) -> f64 {
148 let n_labeled = X_labeled.nrows();
149 let mut errors = 0;
150 let mut total = 0;
151
152 for k in 0..n_labeled {
153 let test_sample = X_labeled
154 .row(k)
155 .to_owned()
156 .insert_axis(scirs2_core::ndarray::Axis(0));
157 let pred_i = self.simple_classifier_fit_predict(classifier_i, y_i, &test_sample);
158 let pred_j = self.simple_classifier_fit_predict(classifier_j, y_j, &test_sample);
159
160 if pred_i[0] == pred_j[0] {
161 total += 1;
162 if pred_i[0] != y_labeled[k] {
163 errors += 1;
164 }
165 }
166 }
167
168 if total > 0 {
169 errors as f64 / total as f64
170 } else {
171 1.0 }
173 }
174}
175
176impl Default for TriTraining<Untrained> {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl Estimator for TriTraining<Untrained> {
183 type Config = ();
184 type Error = SklearsError;
185 type Float = Float;
186
187 fn config(&self) -> &Self::Config {
188 &()
189 }
190}
191
192impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for TriTraining<Untrained> {
193 type Fitted = TriTraining<TriTrainingTrained>;
194
195 #[allow(non_snake_case)]
196 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
197 let X = X.to_owned();
198 let mut y = y.to_owned();
199
200 let mut labeled_indices: Vec<usize> = y
202 .iter()
203 .enumerate()
204 .filter(|(_, &label)| label != -1)
205 .map(|(i, _)| i)
206 .collect();
207
208 let mut unlabeled_indices: Vec<usize> = y
209 .iter()
210 .enumerate()
211 .filter(|(_, &label)| label == -1)
212 .map(|(i, _)| i)
213 .collect();
214
215 if labeled_indices.is_empty() {
216 return Err(SklearsError::InvalidInput(
217 "No labeled samples provided".to_string(),
218 ));
219 }
220
221 let mut classes = HashSet::new();
222 for &idx in &labeled_indices {
223 classes.insert(y[idx]);
224 }
225 let classes: Vec<i32> = classes.into_iter().collect();
226
227 let mut classifiers: Vec<(Array2<f64>, Array1<i32>)> = Vec::new();
229 for _ in 0..3 {
230 let (bootstrap_X, bootstrap_y) = self.bootstrap_sample(&X, &y, &labeled_indices);
231 classifiers.push((bootstrap_X, bootstrap_y));
232 }
233
234 let mut e_prime = [0.5; 3]; let mut l_prime = [0; 3]; for iter in 0..self.max_iter {
239 let mut any_changes = false;
240
241 for i in 0..3 {
242 let j = (i + 1) % 3;
243 let k = (i + 2) % 3;
244
245 let X_labeled_i: Vec<Vec<f64>> = labeled_indices
247 .iter()
248 .map(|&idx| X.row(idx).to_vec())
249 .collect();
250 let y_labeled_i: Vec<i32> = labeled_indices.iter().map(|&idx| y[idx]).collect();
251
252 let X_labeled_array = Array2::from_shape_vec(
253 (X_labeled_i.len(), X.ncols()),
254 X_labeled_i.into_iter().flatten().collect(),
255 )
256 .map_err(|_| {
257 SklearsError::InvalidInput("Failed to create labeled training data".to_string())
258 })?;
259
260 let y_labeled_array = Array1::from(y_labeled_i);
261
262 let e_jk = self.estimate_error_rate(
264 &classifiers[j].0,
265 &classifiers[j].1,
266 &classifiers[k].0,
267 &classifiers[k].1,
268 &X_labeled_array,
269 &y_labeled_array,
270 );
271
272 if e_jk < e_prime[i] && e_jk < self.theta {
273 if !unlabeled_indices.is_empty() {
275 let X_unlabeled: Vec<Vec<f64>> = unlabeled_indices
276 .iter()
277 .map(|&idx| X.row(idx).to_vec())
278 .collect();
279
280 let X_unlabeled_array = Array2::from_shape_vec(
281 (X_unlabeled.len(), X.ncols()),
282 X_unlabeled.into_iter().flatten().collect(),
283 )
284 .map_err(|_| {
285 SklearsError::InvalidInput(
286 "Failed to create unlabeled data".to_string(),
287 )
288 })?;
289
290 let pred_j = self.simple_classifier_fit_predict(
291 &classifiers[j].0,
292 &classifiers[j].1,
293 &X_unlabeled_array,
294 );
295 let pred_k = self.simple_classifier_fit_predict(
296 &classifiers[k].0,
297 &classifiers[k].1,
298 &X_unlabeled_array,
299 );
300
301 let mut new_labeled_for_i = Vec::new();
303 for (idx, (&p_j, &p_k)) in pred_j.iter().zip(pred_k.iter()).enumerate() {
304 if p_j == p_k {
305 let original_idx = unlabeled_indices[idx];
306 new_labeled_for_i.push((original_idx, p_j));
307 }
308 }
309
310 if !new_labeled_for_i.is_empty() {
311 for (idx, label) in new_labeled_for_i {
313 y[idx] = label;
314 labeled_indices.push(idx);
315 any_changes = true;
316 }
317
318 unlabeled_indices.retain(|&idx| y[idx] == -1);
320
321 let (new_bootstrap_X, new_bootstrap_y) =
323 self.bootstrap_sample(&X, &y, &labeled_indices);
324 classifiers[i] = (new_bootstrap_X, new_bootstrap_y);
325
326 e_prime[i] = e_jk;
327 l_prime[i] = unlabeled_indices.len();
328 }
329 }
330 }
331 }
332
333 if !any_changes {
334 if self.verbose {
335 println!("Iteration {}: No changes, stopping", iter + 1);
336 }
337 break;
338 }
339
340 if self.verbose {
341 let n_labeled = labeled_indices.len();
342 let n_unlabeled = unlabeled_indices.len();
343 println!(
344 "Iteration {}: {} labeled, {} unlabeled",
345 iter + 1,
346 n_labeled,
347 n_unlabeled
348 );
349 }
350
351 if unlabeled_indices.is_empty() {
352 if self.verbose {
353 println!("All samples labeled, stopping");
354 }
355 break;
356 }
357 }
358
359 Ok(TriTraining {
360 state: TriTrainingTrained {
361 X_train: X.clone(),
362 y_train: y,
363 classes: Array1::from(classes),
364 classifiers,
365 },
366 max_iter: self.max_iter,
367 verbose: self.verbose,
368 theta: self.theta,
369 })
370 }
371}
372
373impl Predict<ArrayView2<'_, Float>, Array1<i32>> for TriTraining<TriTrainingTrained> {
374 #[allow(non_snake_case)]
375 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
376 let X = X.to_owned();
377 let n_test = X.nrows();
378 let mut predictions = Array1::zeros(n_test);
379
380 for i in 0..n_test {
382 let test_sample = X
383 .row(i)
384 .to_owned()
385 .insert_axis(scirs2_core::ndarray::Axis(0));
386 let mut votes: HashMap<i32, usize> = HashMap::new();
387
388 for (classifier_X, classifier_y) in &self.state.classifiers {
390 let pred = TriTraining::<TriTrainingTrained>::simple_classifier_fit_predict_static(
391 classifier_X,
392 classifier_y,
393 &test_sample,
394 );
395 *votes.entry(pred[0]).or_insert(0) += 1;
396 }
397
398 let best_class = votes
400 .iter()
401 .max_by_key(|(_, &count)| count)
402 .map(|(&class, _)| class)
403 .unwrap_or(self.state.classes[0]);
404
405 predictions[i] = best_class;
406 }
407
408 Ok(predictions)
409 }
410}
411
412impl TriTraining<TriTrainingTrained> {
413 fn simple_classifier_fit_predict_static(
415 X_train: &Array2<f64>,
416 y_train: &Array1<i32>,
417 X_test: &Array2<f64>,
418 ) -> Array1<i32> {
419 let n_test = X_test.nrows();
420 let mut predictions = Array1::zeros(n_test);
421
422 for i in 0..n_test {
423 let mut distances: Vec<(f64, i32)> = Vec::new();
425 for j in 0..X_train.nrows() {
426 let diff = &X_test.row(i) - &X_train.row(j);
427 let dist = diff.mapv(|x| x * x).sum().sqrt();
428 distances.push((dist, y_train[j]));
429 }
430
431 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
432
433 let k = distances.len().clamp(1, 5);
435 let mut class_votes: HashMap<i32, usize> = HashMap::new();
436
437 for &(_, label) in distances.iter().take(k) {
438 *class_votes.entry(label).or_insert(0) += 1;
439 }
440
441 let best_class = class_votes
442 .iter()
443 .max_by_key(|(_, &count)| count)
444 .map(|(&class, _)| class)
445 .unwrap_or(y_train[0]);
446
447 predictions[i] = best_class;
448 }
449
450 predictions
451 }
452}
453
454#[derive(Debug, Clone)]
456pub struct TriTrainingTrained {
457 pub X_train: Array2<f64>,
459 pub y_train: Array1<i32>,
461 pub classes: Array1<i32>,
463 pub classifiers: Vec<(Array2<f64>, Array1<i32>)>,
465}