1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use scirs2_core::random::RandNormal;
10use scirs2_core::random::Rng;
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 traits::{Estimator, Fit, Predict, Untrained},
14 types::Float,
15};
16
17use crate::activation::ActivationFunction;
18use crate::loss::LossFunction;
19
20#[derive(Debug, Clone)]
51pub struct MultiOutputMLP<S = Untrained> {
52 state: S,
53 hidden_layer_sizes: Vec<usize>,
54 activation: ActivationFunction,
55 output_activation: ActivationFunction,
56 loss_function: LossFunction,
57 learning_rate: Float,
58 max_iter: usize,
59 tolerance: Float,
60 random_state: Option<u64>,
61 alpha: Float, batch_size: Option<usize>,
63 early_stopping: bool,
64 validation_fraction: Float,
65}
66
67#[derive(Debug, Clone)]
69pub struct MultiOutputMLPTrained {
70 weights: Vec<Array2<Float>>,
72 biases: Vec<Array1<Float>>,
74 n_features: usize,
76 n_outputs: usize,
78 hidden_layer_sizes: Vec<usize>,
80 activation: ActivationFunction,
81 output_activation: ActivationFunction,
82 loss_curve: Vec<Float>,
84 n_iter: usize,
86}
87
88impl MultiOutputMLP<Untrained> {
89 pub fn new() -> Self {
91 Self {
92 state: Untrained,
93 hidden_layer_sizes: vec![100],
94 activation: ActivationFunction::ReLU,
95 output_activation: ActivationFunction::Linear,
96 loss_function: LossFunction::MeanSquaredError,
97 learning_rate: 0.001,
98 max_iter: 200,
99 tolerance: 1e-4,
100 random_state: None,
101 alpha: 0.0001,
102 batch_size: None,
103 early_stopping: false,
104 validation_fraction: 0.1,
105 }
106 }
107
108 pub fn hidden_layer_sizes(mut self, sizes: Vec<usize>) -> Self {
110 self.hidden_layer_sizes = sizes;
111 self
112 }
113
114 pub fn activation(mut self, activation: ActivationFunction) -> Self {
116 self.activation = activation;
117 self
118 }
119
120 pub fn output_activation(mut self, activation: ActivationFunction) -> Self {
122 self.output_activation = activation;
123 self
124 }
125
126 pub fn loss_function(mut self, loss_function: LossFunction) -> Self {
128 self.loss_function = loss_function;
129 self
130 }
131
132 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
134 self.learning_rate = learning_rate;
135 self
136 }
137
138 pub fn max_iter(mut self, max_iter: usize) -> Self {
140 self.max_iter = max_iter;
141 self
142 }
143
144 pub fn tolerance(mut self, tolerance: Float) -> Self {
146 self.tolerance = tolerance;
147 self
148 }
149
150 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
152 self.random_state = random_state;
153 self
154 }
155
156 pub fn alpha(mut self, alpha: Float) -> Self {
158 self.alpha = alpha;
159 self
160 }
161
162 pub fn batch_size(mut self, batch_size: Option<usize>) -> Self {
164 self.batch_size = batch_size;
165 self
166 }
167
168 pub fn early_stopping(mut self, early_stopping: bool) -> Self {
170 self.early_stopping = early_stopping;
171 self
172 }
173
174 pub fn validation_fraction(mut self, validation_fraction: Float) -> Self {
176 self.validation_fraction = validation_fraction;
177 self
178 }
179}
180
181impl Default for MultiOutputMLP<Untrained> {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl Estimator for MultiOutputMLP<Untrained> {
188 type Config = ();
189 type Error = SklearsError;
190 type Float = Float;
191
192 fn config(&self) -> &Self::Config {
193 &()
194 }
195}
196
197impl Fit<ArrayView2<'_, Float>, Array2<Float>> for MultiOutputMLP<Untrained> {
198 type Fitted = MultiOutputMLP<MultiOutputMLPTrained>;
199
200 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
201 let (n_samples, n_features) = X.dim();
202 let (n_samples_y, n_outputs) = y.dim();
203
204 if n_samples != n_samples_y {
205 return Err(SklearsError::InvalidInput(
206 "X and y must have the same number of samples".to_string(),
207 ));
208 }
209
210 if n_samples == 0 {
211 return Err(SklearsError::InvalidInput(
212 "Cannot fit with zero samples".to_string(),
213 ));
214 }
215
216 let mut rng = match self.random_state {
218 Some(seed) => scirs2_core::random::seeded_rng(seed),
219 None => scirs2_core::random::seeded_rng(42),
220 };
221
222 let mut layer_sizes = vec![n_features];
224 layer_sizes.extend(&self.hidden_layer_sizes);
225 layer_sizes.push(n_outputs);
226
227 let mut weights = Vec::new();
229 let mut biases = Vec::new();
230
231 for i in 0..layer_sizes.len() - 1 {
232 let input_size = layer_sizes[i];
233 let output_size = layer_sizes[i + 1];
234
235 let scale = (2.0 / (input_size + output_size) as Float).sqrt();
237 let normal_dist = RandNormal::new(0.0, scale).unwrap();
238 let mut weight_matrix = Array2::<Float>::zeros((output_size, input_size));
239 for i in 0..output_size {
240 for j in 0..input_size {
241 weight_matrix[[i, j]] = rng.sample(normal_dist);
242 }
243 }
244 let bias_vector = Array1::<Float>::zeros(output_size);
245
246 weights.push(weight_matrix);
247 biases.push(bias_vector);
248 }
249
250 let mut loss_curve = Vec::new();
252 let X_owned = X.to_owned();
253 let y_owned = y.to_owned();
254
255 for epoch in 0..self.max_iter {
256 let (activations, _) = self.forward_pass(&X_owned, &weights, &biases)?;
258 let predictions = activations.last().unwrap();
259
260 let loss = self.loss_function.compute_loss(predictions, &y_owned);
262 loss_curve.push(loss);
263
264 if epoch > 0 && (loss_curve[epoch - 1] - loss).abs() < self.tolerance {
266 break;
267 }
268
269 self.backward_pass(&X_owned, &y_owned, &mut weights, &mut biases)?;
271 }
272
273 let trained_state = MultiOutputMLPTrained {
274 weights,
275 biases,
276 n_features,
277 n_outputs,
278 hidden_layer_sizes: self.hidden_layer_sizes.clone(),
279 activation: self.activation,
280 output_activation: self.output_activation,
281 loss_curve,
282 n_iter: self.max_iter,
283 };
284
285 Ok(MultiOutputMLP {
286 state: trained_state,
287 hidden_layer_sizes: self.hidden_layer_sizes,
288 activation: self.activation,
289 output_activation: self.output_activation,
290 loss_function: self.loss_function,
291 learning_rate: self.learning_rate,
292 max_iter: self.max_iter,
293 tolerance: self.tolerance,
294 random_state: self.random_state,
295 alpha: self.alpha,
296 batch_size: self.batch_size,
297 early_stopping: self.early_stopping,
298 validation_fraction: self.validation_fraction,
299 })
300 }
301}
302
303impl MultiOutputMLP<Untrained> {
304 #[allow(clippy::type_complexity)]
306 fn forward_pass(
307 &self,
308 X: &Array2<Float>,
309 weights: &[Array2<Float>],
310 biases: &[Array1<Float>],
311 ) -> SklResult<(Vec<Array2<Float>>, Vec<Array2<Float>>)> {
312 let mut activations = vec![X.clone()];
313 let mut z_values = Vec::new();
314
315 for (i, (weight, bias)) in weights.iter().zip(biases.iter()).enumerate() {
316 let current_input = activations.last().unwrap();
317
318 let z = current_input.dot(&weight.t()) + bias.view().insert_axis(Axis(0));
320 z_values.push(z.clone());
321
322 let activation_fn = if i == weights.len() - 1 {
324 self.output_activation
325 } else {
326 self.activation
327 };
328
329 let activated = activation_fn.apply_2d(&z);
330 activations.push(activated);
331 }
332
333 Ok((activations, z_values))
334 }
335
336 fn backward_pass(
338 &self,
339 X: &Array2<Float>,
340 y: &Array2<Float>,
341 weights: &mut [Array2<Float>],
342 biases: &mut [Array1<Float>],
343 ) -> SklResult<()> {
344 let (activations, z_values) = self.forward_pass(X, weights, biases)?;
345 let n_samples = X.nrows() as Float;
346
347 let output_predictions = activations.last().unwrap();
349 let mut delta = output_predictions - y;
350
351 for i in (0..weights.len()).rev() {
353 let current_activation = &activations[i];
354
355 let weight_gradient = delta.t().dot(current_activation) / n_samples;
357 let bias_gradient = delta.mean_axis(Axis(0)).unwrap();
358
359 let regularized_weight_gradient = weight_gradient + self.alpha * &weights[i];
361
362 weights[i] = &weights[i] - self.learning_rate * regularized_weight_gradient;
364 biases[i] = &biases[i] - self.learning_rate * bias_gradient;
365
366 if i > 0 {
368 let activation_fn = if i == weights.len() - 1 {
369 self.output_activation
370 } else {
371 self.activation
372 };
373
374 let derivative_approx = match activation_fn {
376 ActivationFunction::ReLU => {
377 z_values[i - 1].map(|&val| if val > 0.0 { 1.0 } else { 0.0 })
378 }
379 ActivationFunction::Sigmoid => {
380 let sigmoid_vals = &activations[i];
381 sigmoid_vals.map(|&val| val * (1.0 - val))
382 }
383 ActivationFunction::Tanh => {
384 let tanh_vals = &activations[i];
385 tanh_vals.map(|&val| 1.0 - val * val)
386 }
387 _ => Array2::ones(z_values[i - 1].dim()),
388 };
389
390 delta = delta.dot(&weights[i]) * derivative_approx;
391 }
392 }
393
394 Ok(())
395 }
396}
397
398impl Predict<ArrayView2<'_, Float>, Array2<Float>> for MultiOutputMLP<MultiOutputMLPTrained> {
399 #[allow(non_snake_case)]
400 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
401 let (n_samples, n_features) = X.dim();
402
403 if n_features != self.state.n_features {
404 return Err(SklearsError::InvalidInput(
405 "X has different number of features than training data".to_string(),
406 ));
407 }
408
409 let X_owned = X.to_owned();
410 let (activations, _) = self.forward_pass_trained(&X_owned)?;
411 let predictions = activations.last().unwrap().clone();
412
413 Ok(predictions)
414 }
415}
416
417impl MultiOutputMLP<MultiOutputMLPTrained> {
418 #[allow(clippy::type_complexity)]
420 fn forward_pass_trained(
421 &self,
422 X: &Array2<Float>,
423 ) -> SklResult<(Vec<Array2<Float>>, Vec<Array2<Float>>)> {
424 let mut activations = vec![X.clone()];
425 let mut z_values = Vec::new();
426
427 for (i, (weight, bias)) in self
428 .state
429 .weights
430 .iter()
431 .zip(self.state.biases.iter())
432 .enumerate()
433 {
434 let current_input = activations.last().unwrap();
435
436 let z = current_input.dot(&weight.t()) + bias.view().insert_axis(Axis(0));
438 z_values.push(z.clone());
439
440 let activation_fn = if i == self.state.weights.len() - 1 {
442 self.state.output_activation
443 } else {
444 self.state.activation
445 };
446
447 let activated = activation_fn.apply_2d(&z);
448 activations.push(activated);
449 }
450
451 Ok((activations, z_values))
452 }
453
454 pub fn loss_curve(&self) -> &[Float] {
456 &self.state.loss_curve
457 }
458
459 pub fn n_iter(&self) -> usize {
461 self.state.n_iter
462 }
463
464 pub fn weights(&self) -> &[Array2<Float>] {
466 &self.state.weights
467 }
468
469 pub fn biases(&self) -> &[Array1<Float>] {
471 &self.state.biases
472 }
473}
474
475pub type MultiOutputMLPClassifier<S = Untrained> = MultiOutputMLP<S>;
480
481impl MultiOutputMLPClassifier<Untrained> {
482 pub fn new_classifier() -> Self {
484 Self::new()
485 .output_activation(ActivationFunction::Sigmoid)
486 .loss_function(LossFunction::BinaryCrossEntropy)
487 }
488}
489
490pub type MultiOutputMLPRegressor<S = Untrained> = MultiOutputMLP<S>;
495
496impl MultiOutputMLPRegressor<Untrained> {
497 pub fn new_regressor() -> Self {
499 Self::new()
500 .output_activation(ActivationFunction::Linear)
501 .loss_function(LossFunction::MeanSquaredError)
502 }
503}