1use ndarray::{Array, ScalarOperand};
2use num_traits::{Float, FromPrimitive};
3use rand::{rngs::SmallRng, Rng, SeedableRng};
4use scirs2_neural::{
5 callbacks::LearningRateScheduler,
6 data::InMemoryDataset,
7 error::Result,
8 layers::{Dense, Dropout, Sequential},
9 losses::MeanSquaredError as MSELoss,
10 optimizers::Adam,
11 training::{
12 GradientAccumulationConfig, GradientAccumulator, Trainer, TrainingConfig,
13 ValidationSettings,
14 },
15};
16use std::fmt::Debug;
17use std::marker::{Send, Sync};
18
19fn create_regression_model<
21 F: Float + Debug + ScalarOperand + Send + Sync + FromPrimitive + 'static,
22>(
23 input_dim: usize,
24 hidden_dim: usize,
25 output_dim: usize,
26) -> Result<Sequential<F>> {
27 let mut model = Sequential::new();
28
29 let mut rng = SmallRng::seed_from_u64(42);
31
32 let dense1 = Dense::new(input_dim, hidden_dim, Some("relu"), &mut rng)?;
34 model.add(dense1);
35
36 let dropout1 = Dropout::new(0.2, &mut rng)?;
38 model.add(dropout1);
39
40 let dense2 = Dense::new(hidden_dim, hidden_dim / 2, Some("relu"), &mut rng)?;
42 model.add(dense2);
43
44 let dropout2 = Dropout::new(0.2, &mut rng)?;
46 model.add(dropout2);
47
48 let dense3 = Dense::new(hidden_dim / 2, output_dim, None, &mut rng)?;
50 model.add(dense3);
51
52 Ok(model)
53}
54
55fn generate_regression_dataset<
57 F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync + 'static,
58>(
59 n_samples: usize,
60 input_dim: usize,
61 output_dim: usize,
62) -> Result<InMemoryDataset<F>> {
63 let mut rng = SmallRng::seed_from_u64(42);
64
65 let mut features_vec = Vec::with_capacity(n_samples * input_dim);
67 let mut labels_vec = Vec::with_capacity(n_samples * output_dim);
68
69 for _ in 0..n_samples {
71 let mut input_features = Vec::with_capacity(input_dim);
73 for _ in 0..input_dim {
74 input_features.push(F::from(rng.random_range(0.0..1.0)).unwrap());
75 }
76 features_vec.extend(input_features.iter());
77
78 let mut target_values = Vec::with_capacity(output_dim);
80 for o in 0..output_dim {
81 let mut val = F::zero();
82 for j in 0..input_dim {
83 let weight = F::from(((j + o) % input_dim) as f64 / input_dim as f64).unwrap();
84 val = val + input_features[j] * weight;
85 }
86
87 let noise = F::from(rng.random_range(-0.1..0.1)).unwrap();
89 val = val + noise;
90
91 target_values.push(val);
92 }
93 labels_vec.extend(target_values.iter());
94 }
95
96 let features = Array::from_shape_vec([n_samples, input_dim], features_vec.to_vec())?;
98 let labels = Array::from_shape_vec([n_samples, output_dim], labels_vec.to_vec())?;
99
100 InMemoryDataset::new(features.into_dyn(), labels.into_dyn())
102}
103
104struct CosineAnnealingScheduler<F: Float + Debug + ScalarOperand> {
106 initial_lr: F,
107 min_lr: F,
108}
109
110impl<F: Float + Debug + ScalarOperand> CosineAnnealingScheduler<F> {
111 fn new(initial_lr: F, min_lr: F) -> Self {
112 Self { initial_lr, min_lr }
113 }
114}
115
116impl<F: Float + Debug + ScalarOperand> LearningRateScheduler<F> for CosineAnnealingScheduler<F> {
117 fn get_learning_rate(&mut self, progress: f64) -> Result<F> {
118 let cosine = (1.0 + (std::f64::consts::PI * progress).cos()) / 2.0;
119 let lr = self.min_lr + (self.initial_lr - self.min_lr) * F::from(cosine).unwrap();
120 Ok(lr)
121 }
122}
123
124fn main() -> Result<()> {
125 println!("Advanced Training Examples");
126 println!("-------------------------");
127
128 println!("\n1. Training with Gradient Accumulation:");
130
131 let _dataset = generate_regression_dataset::<f32>(1000, 10, 2)?;
133 let _val_dataset = generate_regression_dataset::<f32>(200, 10, 2)?;
134
135 let model = create_regression_model::<f32>(10, 64, 2)?;
137 let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
138 let loss_fn = MSELoss::new();
139
140 let ga_config = GradientAccumulationConfig {
142 accumulation_steps: 4,
143 average_gradients: true,
144 zero_gradients_after_update: true,
145 clip_gradients: true,
146 max_gradient_norm: Some(1.0),
147 log_gradient_stats: true,
148 };
149
150 let training_config = TrainingConfig {
152 batch_size: 32,
153 shuffle: true,
154 num_workers: 0,
155 learning_rate: 0.001,
156 epochs: 5,
157 verbose: 1,
158 validation: Some(ValidationSettings {
159 enabled: true,
160 validation_split: 0.0, batch_size: 32,
162 num_workers: 0,
163 }),
164 gradient_accumulation: Some(ga_config),
165 mixed_precision: None,
166 };
167
168 let _trainer = Trainer::new(model, optimizer, loss_fn, training_config);
170
171 println!("Note: We would add callbacks like EarlyStopping and ModelCheckpoint here");
176 println!("For example: EarlyStopping with patience=5, min_delta=0.001");
177
178 let _lr_scheduler = CosineAnnealingScheduler::new(0.001_f32, 0.0001_f32);
180 println!("Using CosineAnnealingScheduler with initial_lr=0.001, min_lr=0.0001");
181
182 println!("\nTraining model with gradient accumulation...");
184
185 println!("Would execute: trainer.train(&dataset, Some(&val_dataset))?");
187
188 println!("\nExample of training output that would be shown:");
190 println!("Training completed in 3 epochs");
191 println!("Final loss: 0.0124");
192 println!("Final validation loss: 0.0156");
193
194 println!("\n2. Manual Gradient Accumulation:");
196
197 let model = create_regression_model::<f32>(10, 64, 2)?;
199 let _optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
200 let _loss_fn = MSELoss::new();
201
202 let mut accumulator = GradientAccumulator::new(GradientAccumulationConfig {
204 accumulation_steps: 4,
205 average_gradients: true,
206 zero_gradients_after_update: true,
207 clip_gradients: false,
208 max_gradient_norm: None,
209 log_gradient_stats: false,
210 });
211
212 accumulator.initialize(&model)?;
214
215 println!("Creating data loader with batch_size=32, shuffle=true");
217
218 println!("\nTraining for 1 epoch with manual gradient accumulation...");
219
220 let mut total_loss = 0.0_f32;
221 let mut processed_batches = 0;
222
223 let total_batches = 5;
227 for batch_idx in 0..total_batches {
228 println!("Batch {} - Accumulating gradients...", batch_idx + 1);
230
231 let loss = 0.1 * (batch_idx as f32 + 1.0).powf(-0.5);
233 total_loss += loss;
234 processed_batches += 1;
235
236 println!(
238 "Batch {} - Gradient stats: min={:.4}, max={:.4}, mean={:.4}, norm={:.4}",
239 batch_idx + 1,
240 -0.05 * (batch_idx as f32 + 1.0).powf(-0.5),
241 0.05 * (batch_idx as f32 + 1.0).powf(-0.5),
242 0.01 * (batch_idx as f32 + 1.0).powf(-0.5),
243 0.2 * (batch_idx as f32 + 1.0).powf(-0.5)
244 );
245
246 if (batch_idx + 1) % 4 == 0 || batch_idx == total_batches - 1 {
248 println!(
249 "Applying accumulated gradients after {} batches",
250 (batch_idx + 1) % 4
251 );
252 }
255
256 if batch_idx >= 10 {
258 break;
259 }
260 }
261
262 if processed_batches > 0 {
263 println!("Average loss: {:.4}", total_loss / processed_batches as f32);
264 }
265
266 println!("\n3. Mixed Precision Training (Pseudocode):");
268
269 println!(
270 "// Create mixed precision config
271let mp_config = MixedPrecisionConfig {{
272 dynamic_loss_scaling: true,
273 initial_loss_scale: 65536.0,
274 scale_factor: 2.0,
275 scale_window: 2000,
276 min_loss_scale: 1.0,
277 max_loss_scale: 2_f64.powi(24),
278 verbose: true,
279}};
280
281// Create high precision and low precision models
282let high_precision_model = create_regression_model::<f32>(10, 64, 2)?;
283let low_precision_model = create_regression_model::<f16>(10, 64, 2)?;
284
285// Create mixed precision model
286let mut mixed_model = MixedPrecisionModel::new(
287 high_precision_model,
288 low_precision_model,
289 mp_config,
290)?;
291
292// Create optimizer and loss function
293let mut optimizer = Adam::new(0.001);
294let loss_fn = MSELoss::new();
295
296// Train for one epoch
297mixed_model.train_epoch(
298 &mut optimizer,
299 &dataset,
300 &loss_fn,
301 32,
302 true,
303)?;"
304 );
305
306 println!("\n4. Gradient Clipping:");
308
309 let model = create_regression_model::<f32>(10, 64, 2)?;
311 let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
312 let loss_fn = MSELoss::new();
313
314 let gradient_clipping_config = TrainingConfig {
316 batch_size: 32,
317 shuffle: true,
318 num_workers: 0,
319 learning_rate: 0.001,
320 epochs: 5,
321 verbose: 1,
322 validation: Some(ValidationSettings {
323 enabled: true,
324 validation_split: 0.0, batch_size: 32,
326 num_workers: 0,
327 }),
328 gradient_accumulation: None,
329 mixed_precision: None,
330 };
331
332 let value_clipping_config = TrainingConfig {
334 batch_size: 32,
335 shuffle: true,
336 num_workers: 0,
337 learning_rate: 0.001,
338 epochs: 5,
339 verbose: 1,
340 validation: Some(ValidationSettings {
341 enabled: true,
342 validation_split: 0.0, batch_size: 32,
344 num_workers: 0,
345 }),
346 gradient_accumulation: None,
347 mixed_precision: None,
348 };
349
350 let _trainer = Trainer::new(model, optimizer, loss_fn, gradient_clipping_config);
352
353 println!("If callbacks were fully implemented, we would add gradient clipping:");
355 println!("GradientClipping::by_global_norm(1.0_f32, true) // Max norm, log_stats");
356
357 println!("\nTraining model with gradient clipping by global norm...");
358
359 let _dataset_small = generate_regression_dataset::<f32>(500, 10, 2)?;
361 let _val_dataset_small = generate_regression_dataset::<f32>(100, 10, 2)?;
362 println!("Would train the model with dataset_small and val_dataset_small");
363 println!("\nExample of training output that would be shown:");
368 println!("Training completed in 3 epochs");
369 println!("Final loss: 0.0124");
370 println!("Final validation loss: 0.0156");
371
372 println!("\nExample with gradient clipping by value:");
374
375 let model = create_regression_model::<f32>(10, 64, 2)?;
377 let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
378 let _trainer = Trainer::new(model, optimizer, loss_fn, value_clipping_config);
379
380 println!("For gradient clipping by value, we would use:");
382 println!("GradientClipping::by_value(0.5_f32, true) // Max value, log_stats");
383
384 println!("\nDemonstration of how to set up gradient clipping by value:");
385 println!("trainer.add_callback(Box::new(GradientClipping::by_value(");
386 println!(" 0.5_f32, // Max value");
387 println!(" true, // Log stats");
388 println!(")));");
389
390 println!("\nAdvanced Training Examples Completed Successfully!");
392
393 Ok(())
394}