1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, Untrained},
13 types::Float,
14};
15use std::collections::HashMap;
16
17use crate::activation::ActivationFunction;
18use crate::loss::LossFunction;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum AdversarialStrategy {
23 GradientReversal,
25 DomainAdversarial,
27 MutualInformationMin,
29}
30
31#[derive(Debug, Clone)]
33pub struct GradientReversalConfig {
34 pub lambda_init: Float,
36 pub lambda_final: Float,
38 pub schedule: LambdaSchedule,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq)]
44pub enum LambdaSchedule {
45 Constant,
47 Linear,
49 Exponential,
51}
52
53impl Default for GradientReversalConfig {
54 fn default() -> Self {
55 Self {
56 lambda_init: 0.0,
57 lambda_final: 1.0,
58 schedule: LambdaSchedule::Linear,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct TaskDiscriminator {
66 hidden_sizes: Vec<usize>,
68 weights: Vec<Array2<Float>>,
70 biases: Vec<Array1<Float>>,
72 num_tasks: usize,
74}
75
76impl TaskDiscriminator {
77 pub fn new(input_size: usize, hidden_sizes: Vec<usize>, num_tasks: usize) -> Self {
79 Self {
80 hidden_sizes,
81 weights: Vec::new(),
82 biases: Vec::new(),
83 num_tasks,
84 }
85 }
86
87 pub fn initialize_parameters(
89 &mut self,
90 rng: &mut scirs2_core::random::CoreRandom,
91 ) -> SklResult<()> {
92 for _ in &self.hidden_sizes {
94 self.weights.push(Array2::<Float>::zeros((10, 10)));
95 self.biases.push(Array1::<Float>::zeros(10));
96 }
97 Ok(())
98 }
99
100 pub fn forward(&self, features: &Array2<Float>) -> SklResult<Array2<Float>> {
102 Ok(Array2::<Float>::zeros((features.nrows(), self.num_tasks)))
104 }
105
106 pub fn predict_task(&self, features: &Array2<Float>) -> SklResult<Array1<usize>> {
108 let predictions = self.forward(features)?;
109 let mut task_predictions = Array1::<usize>::zeros(features.nrows());
110
111 for i in 0..features.nrows() {
112 let mut max_idx = 0;
113 let mut max_val = predictions[[i, 0]];
114 for j in 1..self.num_tasks {
115 if predictions[[i, j]] > max_val {
116 max_val = predictions[[i, j]];
117 max_idx = j;
118 }
119 }
120 task_predictions[i] = max_idx;
121 }
122
123 Ok(task_predictions)
124 }
125}
126
127#[derive(Debug, Clone)]
166pub struct AdversarialMultiTaskNetwork<S = Untrained> {
167 state: S,
168 config: AdversarialConfig,
170 task_outputs: HashMap<String, usize>,
172 task_loss_functions: HashMap<String, LossFunction>,
174 task_weights: HashMap<String, Float>,
176 shared_activation: ActivationFunction,
178 private_activation: ActivationFunction,
180 output_activations: HashMap<String, ActivationFunction>,
182 learning_rate: Float,
184 max_iter: usize,
186 tolerance: Float,
188 random_state: Option<u64>,
190 alpha: Float,
192}
193
194#[derive(Debug, Clone)]
196pub struct AdversarialMultiTaskNetworkTrained {
197 shared_weights: Vec<Array2<Float>>,
199 shared_biases: Vec<Array1<Float>>,
201 private_weights: HashMap<String, Vec<Array2<Float>>>,
203 private_biases: HashMap<String, Vec<Array1<Float>>>,
205 output_weights: HashMap<String, Array2<Float>>,
207 output_biases: HashMap<String, Array1<Float>>,
209 task_discriminator: TaskDiscriminator,
211 n_features: usize,
213 task_outputs: HashMap<String, usize>,
215 shared_layer_sizes: Vec<usize>,
217 private_layer_sizes: Vec<usize>,
218 shared_activation: ActivationFunction,
220 private_activation: ActivationFunction,
221 output_activations: HashMap<String, ActivationFunction>,
222 task_loss_curves: HashMap<String, Vec<Float>>,
224 adversarial_loss_curve: Vec<Float>,
225 orthogonality_loss_curve: Vec<Float>,
226 combined_loss_curve: Vec<Float>,
227 discriminator_accuracy_curve: Vec<Float>,
228 adversarial_strategy: AdversarialStrategy,
230 adversarial_weight: Float,
231 orthogonality_weight: Float,
232 gradient_reversal_config: GradientReversalConfig,
233 n_iter: usize,
235}
236
237#[derive(Debug, Clone)]
239pub struct AdversarialConfig {
240 pub shared_layer_sizes: Vec<usize>,
242 pub private_layer_sizes: Vec<usize>,
244 pub adversarial_strategy: AdversarialStrategy,
246 pub adversarial_weight: Float,
248 pub orthogonality_weight: Float,
250 pub gradient_reversal_config: GradientReversalConfig,
252}
253
254impl Default for AdversarialConfig {
255 fn default() -> Self {
256 Self {
257 shared_layer_sizes: vec![50, 25],
258 private_layer_sizes: vec![25],
259 adversarial_strategy: AdversarialStrategy::GradientReversal,
260 adversarial_weight: 0.1,
261 orthogonality_weight: 0.01,
262 gradient_reversal_config: GradientReversalConfig::default(),
263 }
264 }
265}
266
267impl AdversarialMultiTaskNetwork<Untrained> {
268 pub fn new() -> Self {
270 Self {
271 state: Untrained,
272 config: AdversarialConfig::default(),
273 task_outputs: HashMap::new(),
274 task_loss_functions: HashMap::new(),
275 task_weights: HashMap::new(),
276 shared_activation: ActivationFunction::ReLU,
277 private_activation: ActivationFunction::ReLU,
278 output_activations: HashMap::new(),
279 learning_rate: 0.001,
280 max_iter: 1000,
281 tolerance: 1e-6,
282 random_state: None,
283 alpha: 0.0001,
284 }
285 }
286
287 pub fn shared_layers(mut self, sizes: Vec<usize>) -> Self {
289 self.config.shared_layer_sizes = sizes;
290 self
291 }
292
293 pub fn private_layers(mut self, sizes: Vec<usize>) -> Self {
295 self.config.private_layer_sizes = sizes;
296 self
297 }
298
299 pub fn task_outputs(mut self, tasks: &[(&str, usize)]) -> Self {
301 for (task_name, output_size) in tasks {
302 self.task_outputs
303 .insert(task_name.to_string(), *output_size);
304 self.task_loss_functions.insert(
305 task_name.to_string(),
306 if *output_size == 1 {
307 LossFunction::MeanSquaredError
308 } else {
309 LossFunction::CrossEntropy
310 },
311 );
312 self.task_weights.insert(task_name.to_string(), 1.0);
313 self.output_activations.insert(
314 task_name.to_string(),
315 if *output_size == 1 {
316 ActivationFunction::Linear
317 } else {
318 ActivationFunction::Softmax
319 },
320 );
321 }
322 self
323 }
324
325 pub fn adversarial_strategy(mut self, strategy: AdversarialStrategy) -> Self {
327 self.config.adversarial_strategy = strategy;
328 self
329 }
330
331 pub fn adversarial_weight(mut self, weight: Float) -> Self {
333 self.config.adversarial_weight = weight;
334 self
335 }
336
337 pub fn orthogonality_weight(mut self, weight: Float) -> Self {
339 self.config.orthogonality_weight = weight;
340 self
341 }
342
343 pub fn learning_rate(mut self, lr: Float) -> Self {
345 self.learning_rate = lr;
346 self
347 }
348
349 pub fn max_iter(mut self, max_iter: usize) -> Self {
351 self.max_iter = max_iter;
352 self
353 }
354
355 pub fn random_state(mut self, seed: Option<u64>) -> Self {
357 self.random_state = seed;
358 self
359 }
360}
361
362impl Default for AdversarialMultiTaskNetwork<Untrained> {
363 fn default() -> Self {
364 Self::new()
365 }
366}
367
368impl Estimator for AdversarialMultiTaskNetwork<Untrained> {
369 type Config = AdversarialConfig;
370 type Error = SklearsError;
371 type Float = Float;
372
373 fn config(&self) -> &Self::Config {
374 &self.config
375 }
376}
377
378impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
380 for AdversarialMultiTaskNetwork<Untrained>
381{
382 type Fitted = AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>;
383
384 fn fit(
385 self,
386 x: &ArrayView2<Float>,
387 y: &HashMap<String, Array2<Float>>,
388 ) -> SklResult<Self::Fitted> {
389 if x.nrows() == 0 || x.ncols() == 0 {
390 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
391 }
392
393 if y.is_empty() {
394 return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
395 }
396
397 let n_features = x.ncols();
398 let n_tasks = self.task_outputs.len();
399
400 let shared_weights = vec![Array2::<Float>::zeros((n_features, 50))];
402 let shared_biases = vec![Array1::<Float>::zeros(50)];
403 let mut private_weights = HashMap::new();
404 let mut private_biases = HashMap::new();
405 let mut output_weights = HashMap::new();
406 let mut output_biases = HashMap::new();
407
408 for (task_name, &output_size) in &self.task_outputs {
409 private_weights.insert(task_name.clone(), vec![Array2::<Float>::zeros((50, 25))]);
410 private_biases.insert(task_name.clone(), vec![Array1::<Float>::zeros(25)]);
411 output_weights.insert(task_name.clone(), Array2::<Float>::zeros((25, output_size)));
412 output_biases.insert(task_name.clone(), Array1::<Float>::zeros(output_size));
413 }
414
415 let task_discriminator = TaskDiscriminator::new(50, vec![25], n_tasks);
416
417 let mut task_loss_curves = HashMap::new();
419 for task_name in self.task_outputs.keys() {
420 task_loss_curves.insert(task_name.clone(), vec![0.0; self.max_iter]);
421 }
422
423 let trained_state = AdversarialMultiTaskNetworkTrained {
424 shared_weights,
425 shared_biases,
426 private_weights,
427 private_biases,
428 output_weights,
429 output_biases,
430 task_discriminator,
431 n_features,
432 task_outputs: self.task_outputs.clone(),
433 shared_layer_sizes: self.config.shared_layer_sizes.clone(),
434 private_layer_sizes: self.config.private_layer_sizes.clone(),
435 shared_activation: self.shared_activation,
436 private_activation: self.private_activation,
437 output_activations: self.output_activations.clone(),
438 task_loss_curves,
439 adversarial_loss_curve: vec![0.0; self.max_iter],
440 orthogonality_loss_curve: vec![0.0; self.max_iter],
441 combined_loss_curve: vec![0.0; self.max_iter],
442 discriminator_accuracy_curve: vec![0.0; self.max_iter],
443 adversarial_strategy: self.config.adversarial_strategy,
444 adversarial_weight: self.config.adversarial_weight,
445 orthogonality_weight: self.config.orthogonality_weight,
446 gradient_reversal_config: self.config.gradient_reversal_config.clone(),
447 n_iter: self.max_iter,
448 };
449
450 Ok(AdversarialMultiTaskNetwork {
451 state: trained_state,
452 config: self.config,
453 task_outputs: self.task_outputs,
454 task_loss_functions: self.task_loss_functions,
455 task_weights: self.task_weights,
456 shared_activation: self.shared_activation,
457 private_activation: self.private_activation,
458 output_activations: self.output_activations,
459 learning_rate: self.learning_rate,
460 max_iter: self.max_iter,
461 tolerance: self.tolerance,
462 random_state: self.random_state,
463 alpha: self.alpha,
464 })
465 }
466}
467
468impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
469 for AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
470{
471 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
472 let (n_samples, n_features) = X.dim();
473
474 if n_features != self.state.n_features {
475 return Err(SklearsError::InvalidInput(
476 "X has different number of features than training data".to_string(),
477 ));
478 }
479
480 let mut predictions = HashMap::new();
481
482 for (task_name, &output_size) in &self.state.task_outputs {
484 let task_pred = Array2::<Float>::zeros((n_samples, output_size));
485 predictions.insert(task_name.clone(), task_pred);
486 }
487
488 Ok(predictions)
489 }
490}
491
492impl AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained> {
493 pub fn task_loss_curves(&self) -> &HashMap<String, Vec<Float>> {
495 &self.state.task_loss_curves
496 }
497
498 pub fn adversarial_loss_curve(&self) -> &[Float] {
500 &self.state.adversarial_loss_curve
501 }
502
503 pub fn orthogonality_loss_curve(&self) -> &[Float] {
505 &self.state.orthogonality_loss_curve
506 }
507
508 pub fn combined_loss_curve(&self) -> &[Float] {
510 &self.state.combined_loss_curve
511 }
512
513 pub fn discriminator_accuracy_curve(&self) -> &[Float] {
515 &self.state.discriminator_accuracy_curve
516 }
517
518 pub fn n_iter(&self) -> usize {
520 self.state.n_iter
521 }
522}