sklears_semi_supervised/few_shot/
prototypical_networks.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
53pub struct PrototypicalNetworks<S = Untrained> {
54 state: S,
55 embedding_dim: usize,
56 hidden_layers: Vec<usize>,
57 distance_metric: String,
58 learning_rate: f64,
59 n_episodes: usize,
60 n_way: usize,
61 n_shot: usize,
62 n_query: usize,
63 temperature: f64,
64}
65
66impl PrototypicalNetworks<Untrained> {
67 pub fn new() -> Self {
69 Self {
70 state: Untrained,
71 embedding_dim: 64,
72 hidden_layers: vec![128, 64],
73 distance_metric: "euclidean".to_string(),
74 learning_rate: 0.001,
75 n_episodes: 100,
76 n_way: 5,
77 n_shot: 1,
78 n_query: 15,
79 temperature: 1.0,
80 }
81 }
82
83 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
85 self.embedding_dim = embedding_dim;
86 self
87 }
88
89 pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Self {
91 self.hidden_layers = hidden_layers;
92 self
93 }
94
95 pub fn distance_metric(mut self, metric: String) -> Self {
97 self.distance_metric = metric;
98 self
99 }
100
101 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
103 self.learning_rate = learning_rate;
104 self
105 }
106
107 pub fn n_episodes(mut self, n_episodes: usize) -> Self {
109 self.n_episodes = n_episodes;
110 self
111 }
112
113 pub fn n_way(mut self, n_way: usize) -> Self {
115 self.n_way = n_way;
116 self
117 }
118
119 pub fn n_shot(mut self, n_shot: usize) -> Self {
121 self.n_shot = n_shot;
122 self
123 }
124
125 pub fn n_query(mut self, n_query: usize) -> Self {
127 self.n_query = n_query;
128 self
129 }
130
131 pub fn temperature(mut self, temperature: f64) -> Self {
133 self.temperature = temperature;
134 self
135 }
136
137 fn compute_embedding(
139 &self,
140 X: &Array2<f64>,
141 weights: &[Array2<f64>],
142 biases: &[Array1<f64>],
143 ) -> Array2<f64> {
144 let mut current = X.clone();
145
146 for (i, (w, b)) in weights.iter().zip(biases.iter()).enumerate() {
147 current = current.dot(w);
148
149 for mut row in current.axis_iter_mut(Axis(0)) {
151 for (j, &bias_val) in b.iter().enumerate() {
152 row[j] += bias_val;
153 }
154 }
155
156 if i < weights.len() - 1 {
158 current.mapv_inplace(|x| x.max(0.0));
159 }
160 }
161
162 current
163 }
164
165 fn compute_distance(&self, a: &Array1<f64>, b: &Array1<f64>) -> f64 {
167 match self.distance_metric.as_str() {
168 "euclidean" => {
169 let diff = a - b;
170 diff.mapv(|x| x * x).sum().sqrt()
171 }
172 "cosine" => {
173 let dot_product = a.dot(b);
174 let norm_a = a.mapv(|x| x * x).sum().sqrt();
175 let norm_b = b.mapv(|x| x * x).sum().sqrt();
176 1.0 - (dot_product / (norm_a * norm_b))
177 }
178 "manhattan" => {
179 let diff = a - b;
180 diff.mapv(|x| x.abs()).sum()
181 }
182 _ => {
183 let diff = a - b;
185 diff.mapv(|x| x * x).sum().sqrt()
186 }
187 }
188 }
189
190 fn compute_prototypes(
192 &self,
193 support_embeddings: &Array2<f64>,
194 support_labels: &Array1<i32>,
195 classes: &[i32],
196 ) -> Array2<f64> {
197 let n_classes = classes.len();
198 let embedding_dim = support_embeddings.ncols();
199 let mut prototypes = Array2::zeros((n_classes, embedding_dim));
200
201 for (class_idx, &class_label) in classes.iter().enumerate() {
202 let mut class_embeddings = Vec::new();
203
204 for (sample_idx, &label) in support_labels.iter().enumerate() {
205 if label == class_label {
206 class_embeddings.push(support_embeddings.row(sample_idx).to_owned());
207 }
208 }
209
210 if !class_embeddings.is_empty() {
211 for dim in 0..embedding_dim {
213 let mean_val: f64 = class_embeddings.iter().map(|emb| emb[dim]).sum::<f64>()
214 / class_embeddings.len() as f64;
215 prototypes[[class_idx, dim]] = mean_val;
216 }
217 }
218 }
219
220 prototypes
221 }
222
223 fn softmax_distances(&self, distances: &Array1<f64>) -> Array1<f64> {
225 let scaled_distances = distances.mapv(|d| -d / self.temperature);
226 let max_dist = scaled_distances
227 .iter()
228 .fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
229
230 let exp_distances = scaled_distances.mapv(|d| (d - max_dist).exp());
231 let sum_exp = exp_distances.sum();
232
233 exp_distances.mapv(|x| x / sum_exp)
234 }
235
236 #[allow(clippy::type_complexity)]
238 fn sample_episode(
239 &self,
240 X: &Array2<f64>,
241 y: &Array1<i32>,
242 classes: &[i32],
243 ) -> SklResult<(Array2<f64>, Array1<i32>, Array2<f64>, Array1<i32>)> {
244 let n_samples = X.nrows();
245 let n_features = X.ncols();
246
247 let mut class_samples: HashMap<i32, Vec<usize>> = HashMap::new();
249 for (i, &label) in y.iter().enumerate() {
250 class_samples.entry(label).or_default().push(i);
251 }
252
253 for &class_label in classes {
255 if let Some(samples) = class_samples.get(&class_label) {
256 if samples.len() < self.n_shot + self.n_query {
257 return Err(SklearsError::InvalidInput(format!(
258 "Not enough samples for class {}: need {}, have {}",
259 class_label,
260 self.n_shot + self.n_query,
261 samples.len()
262 )));
263 }
264 } else {
265 return Err(SklearsError::InvalidInput(format!(
266 "Class {} not found in data",
267 class_label
268 )));
269 }
270 }
271
272 let total_support = self.n_way * self.n_shot;
274 let total_query = self.n_way * self.n_query;
275
276 let mut support_X = Array2::zeros((total_support, n_features));
277 let mut support_y = Array1::zeros(total_support);
278 let mut query_X = Array2::zeros((total_query, n_features));
279 let mut query_y = Array1::zeros(total_query);
280
281 let mut support_idx = 0;
282 let mut query_idx = 0;
283
284 for (class_idx, &class_label) in classes.iter().take(self.n_way).enumerate() {
285 if let Some(samples) = class_samples.get(&class_label) {
286 let selected_samples: Vec<usize> = samples
288 .iter()
289 .take(self.n_shot + self.n_query)
290 .cloned()
291 .collect();
292
293 #[allow(clippy::needless_range_loop)]
294 for i in 0..self.n_shot {
296 let sample_idx = selected_samples[i];
297 support_X.row_mut(support_idx).assign(&X.row(sample_idx));
298 support_y[support_idx] = class_idx as i32; support_idx += 1;
300 }
301
302 #[allow(clippy::needless_range_loop)]
303 for i in self.n_shot..self.n_shot + self.n_query {
305 let sample_idx = selected_samples[i];
306 query_X.row_mut(query_idx).assign(&X.row(sample_idx));
307 query_y[query_idx] = class_idx as i32; query_idx += 1;
309 }
310 }
311 }
312
313 Ok((support_X, support_y, query_X, query_y))
314 }
315}
316
317impl Default for PrototypicalNetworks<Untrained> {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323impl Estimator for PrototypicalNetworks<Untrained> {
324 type Config = ();
325 type Error = SklearsError;
326 type Float = Float;
327
328 fn config(&self) -> &Self::Config {
329 &()
330 }
331}
332
333impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for PrototypicalNetworks<Untrained> {
334 type Fitted = PrototypicalNetworks<PrototypicalNetworksTrained>;
335
336 #[allow(non_snake_case)]
337 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
338 let X = X.to_owned();
339 let y = y.to_owned();
340
341 let (n_samples, n_features) = X.dim();
342
343 let mut classes = std::collections::HashSet::new();
345 for &label in y.iter() {
346 if label != -1 {
347 classes.insert(label);
348 }
349 }
350 let classes: Vec<i32> = classes.into_iter().collect();
351
352 if classes.len() < self.n_way {
353 return Err(SklearsError::InvalidInput(format!(
354 "Need at least {} classes for {}-way classification, found {}",
355 self.n_way,
356 self.n_way,
357 classes.len()
358 )));
359 }
360
361 let mut layer_sizes = vec![n_features];
363 layer_sizes.extend(&self.hidden_layers);
364 layer_sizes.push(self.embedding_dim);
365
366 let mut weights = Vec::new();
367 let mut biases = Vec::new();
368
369 for i in 0..layer_sizes.len() - 1 {
370 let in_size = layer_sizes[i];
371 let out_size = layer_sizes[i + 1];
372
373 let scale = (2.0 / (in_size + out_size) as f64).sqrt();
375 let mut w = Array2::zeros((in_size, out_size));
376 let b = Array1::zeros(out_size);
377
378 for i in 0..in_size {
380 for j in 0..out_size {
381 w[[i, j]] = scale * ((i + j) as f64 * 0.1).sin();
382 }
383 }
384
385 weights.push(w);
386 biases.push(b);
387 }
388
389 for episode in 0..self.n_episodes {
391 let episode_classes: Vec<i32> = classes.iter().take(self.n_way).cloned().collect();
393
394 let (support_X, support_y, query_X, query_y) =
395 self.sample_episode(&X, &y, &episode_classes)?;
396
397 let support_embeddings = self.compute_embedding(&support_X, &weights, &biases);
399 let query_embeddings = self.compute_embedding(&query_X, &weights, &biases);
400
401 let episode_class_indices: Vec<i32> = (0..self.n_way as i32).collect();
403 let prototypes =
404 self.compute_prototypes(&support_embeddings, &support_y, &episode_class_indices);
405
406 let n_query_samples = query_embeddings.nrows();
408 let mut total_loss = 0.0;
409
410 for query_idx in 0..n_query_samples {
411 let query_embedding = query_embeddings.row(query_idx);
412 let true_class = query_y[query_idx] as usize;
413
414 if true_class >= self.n_way {
416 continue;
417 }
418
419 let mut distances = Array1::zeros(self.n_way);
421 for class_idx in 0..self.n_way {
422 let prototype = prototypes.row(class_idx);
423 distances[class_idx] =
424 self.compute_distance(&query_embedding.to_owned(), &prototype.to_owned());
425 }
426
427 let probabilities = self.softmax_distances(&distances);
429
430 let prob = probabilities[true_class].max(1e-10);
432 total_loss -= prob.ln();
433
434 let lr = self.learning_rate / (episode + 1) as f64;
436
437 if let (Some(last_w), Some(last_b)) = (weights.last_mut(), biases.last_mut()) {
439 let max_features = query_X.ncols().min(last_w.nrows());
440 for i in 0..max_features {
441 for j in 0..last_w.ncols() {
442 let grad_w =
443 (probabilities[true_class] - 1.0) * query_X[[query_idx, i]];
444 last_w[[i, j]] -= lr * grad_w;
445 }
446 }
447
448 for j in 0..last_b.len() {
449 let grad_b = probabilities[true_class] - 1.0;
450 last_b[j] -= lr * grad_b;
451 }
452 }
453 }
454
455 if episode % 20 == 0 {
457 let avg_loss = total_loss / n_query_samples as f64;
458 let _ = avg_loss; }
461 }
462
463 Ok(PrototypicalNetworks {
464 state: PrototypicalNetworksTrained {
465 weights,
466 biases,
467 classes: Array1::from(classes),
468 prototypes: Array2::zeros((1, 1)), },
470 embedding_dim: self.embedding_dim,
471 hidden_layers: self.hidden_layers,
472 distance_metric: self.distance_metric,
473 learning_rate: self.learning_rate,
474 n_episodes: self.n_episodes,
475 n_way: self.n_way,
476 n_shot: self.n_shot,
477 n_query: self.n_query,
478 temperature: self.temperature,
479 })
480 }
481}
482
483impl Predict<ArrayView2<'_, Float>, Array1<i32>>
484 for PrototypicalNetworks<PrototypicalNetworksTrained>
485{
486 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
487 let probabilities = self.predict_proba(X)?;
488 let n_test = X.nrows();
489 let mut predictions = Array1::zeros(n_test);
490
491 for i in 0..n_test {
492 let max_idx = probabilities
493 .row(i)
494 .iter()
495 .enumerate()
496 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
497 .unwrap()
498 .0;
499 predictions[i] = self.state.classes[max_idx];
500 }
501
502 Ok(predictions)
503 }
504}
505
506impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
507 for PrototypicalNetworks<PrototypicalNetworksTrained>
508{
509 #[allow(non_snake_case)]
510 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
511 let X = X.to_owned();
512 let n_test = X.nrows();
513 let n_classes = self.state.classes.len();
514
515 let mut probabilities = Array2::zeros((n_test, n_classes));
521 for i in 0..n_test {
522 for j in 0..n_classes {
523 probabilities[[i, j]] = 1.0 / n_classes as f64;
524 }
525 }
526
527 Ok(probabilities)
528 }
529}
530
531#[derive(Debug, Clone)]
533pub struct PrototypicalNetworksTrained {
534 pub weights: Vec<Array2<f64>>,
536 pub biases: Vec<Array1<f64>>,
538 pub classes: Array1<i32>,
540 pub prototypes: Array2<f64>,
542}