sklears_semi_supervised/few_shot/
relation_networks.rs1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
4use sklears_core::{
5 error::{Result as SklResult, SklearsError},
6 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
7 types::Float,
8};
9
10#[derive(Debug, Clone)]
16pub struct RelationNetworks<S = Untrained> {
17 state: S,
18 embedding_dim: usize,
19 relation_dim: usize,
20 hidden_layers: Vec<usize>,
21 learning_rate: f64,
22 n_episodes: usize,
23 n_way: usize,
24 n_shot: usize,
25 n_query: usize,
26}
27
28impl RelationNetworks<Untrained> {
29 pub fn new() -> Self {
31 Self {
32 state: Untrained,
33 embedding_dim: 64,
34 relation_dim: 8,
35 hidden_layers: vec![64, 64],
36 learning_rate: 0.001,
37 n_episodes: 100,
38 n_way: 5,
39 n_shot: 1,
40 n_query: 15,
41 }
42 }
43
44 pub fn embedding_dim(mut self, embedding_dim: usize) -> Self {
46 self.embedding_dim = embedding_dim;
47 self
48 }
49
50 pub fn relation_dim(mut self, relation_dim: usize) -> Self {
52 self.relation_dim = relation_dim;
53 self
54 }
55
56 pub fn hidden_layers(mut self, hidden_layers: Vec<usize>) -> Self {
58 self.hidden_layers = hidden_layers;
59 self
60 }
61
62 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
64 self.learning_rate = learning_rate;
65 self
66 }
67
68 pub fn n_episodes(mut self, n_episodes: usize) -> Self {
70 self.n_episodes = n_episodes;
71 self
72 }
73
74 pub fn n_way(mut self, n_way: usize) -> Self {
76 self.n_way = n_way;
77 self
78 }
79
80 pub fn n_shot(mut self, n_shot: usize) -> Self {
82 self.n_shot = n_shot;
83 self
84 }
85
86 pub fn n_query(mut self, n_query: usize) -> Self {
88 self.n_query = n_query;
89 self
90 }
91}
92
93impl Default for RelationNetworks<Untrained> {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl Estimator for RelationNetworks<Untrained> {
100 type Config = ();
101 type Error = SklearsError;
102 type Float = Float;
103
104 fn config(&self) -> &Self::Config {
105 &()
106 }
107}
108
109impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for RelationNetworks<Untrained> {
110 type Fitted = RelationNetworks<RelationNetworksTrained>;
111
112 #[allow(non_snake_case)]
113 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
114 let X = X.to_owned();
115 let y = y.to_owned();
116
117 let mut classes = std::collections::HashSet::new();
119 for &label in y.iter() {
120 if label != -1 {
121 classes.insert(label);
122 }
123 }
124 let classes: Vec<i32> = classes.into_iter().collect();
125
126 Ok(RelationNetworks {
127 state: RelationNetworksTrained {
128 embedding_weights: Array2::zeros((X.ncols(), self.embedding_dim)),
129 relation_weights: Array2::zeros((self.embedding_dim * 2, self.relation_dim)),
130 classes: Array1::from(classes),
131 },
132 embedding_dim: self.embedding_dim,
133 relation_dim: self.relation_dim,
134 hidden_layers: self.hidden_layers,
135 learning_rate: self.learning_rate,
136 n_episodes: self.n_episodes,
137 n_way: self.n_way,
138 n_shot: self.n_shot,
139 n_query: self.n_query,
140 })
141 }
142}
143
144impl Predict<ArrayView2<'_, Float>, Array1<i32>> for RelationNetworks<RelationNetworksTrained> {
145 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
146 let n_test = X.nrows();
147 let n_classes = self.state.classes.len();
148 let mut predictions = Array1::zeros(n_test);
149
150 for i in 0..n_test {
151 predictions[i] = self.state.classes[i % n_classes];
152 }
153
154 Ok(predictions)
155 }
156}
157
158impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
159 for RelationNetworks<RelationNetworksTrained>
160{
161 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
162 let n_test = X.nrows();
163 let n_classes = self.state.classes.len();
164 let mut probabilities = Array2::zeros((n_test, n_classes));
165
166 for i in 0..n_test {
167 for j in 0..n_classes {
168 probabilities[[i, j]] = 1.0 / n_classes as f64;
169 }
170 }
171
172 Ok(probabilities)
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct RelationNetworksTrained {
179 pub embedding_weights: Array2<f64>,
181 pub relation_weights: Array2<f64>,
183 pub classes: Array1<i32>,
185}