training_loop_example/
training_loop_example.rs

1use ndarray::{Array, IxDyn};
2use scirs2_neural::callbacks::{
3    CallbackContext, CallbackTiming, EarlyStopping, ReduceOnPlateau, TensorBoardLogger,
4    VisualizationCallback,
5};
6use scirs2_neural::data::{
7    DataLoader, InMemoryDataset, OneHotEncoder, StandardScaler, TransformedDataset,
8};
9use scirs2_neural::error::Result;
10use scirs2_neural::layers::{Dense, Layer};
11use scirs2_neural::losses::CrossEntropyLoss;
12use scirs2_neural::optimizers::Adam;
13use scirs2_neural::utils::{
14    analyze_training_history, ascii_plot, export_history_to_csv, LearningRateSchedule, PlotOptions,
15};
16use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18
19fn main() -> Result<()> {
20    println!("Training loop example with visualization");
21
22    // Create dummy data
23    let n_samples = 1000;
24    let n_features = 10;
25    let n_classes = 3;
26
27    println!(
28        "Generating dummy data with {} samples, {} features, {} classes",
29        n_samples, n_features, n_classes
30    );
31
32    // Generate random features
33    let features = Array::from_shape_fn(IxDyn(&[n_samples, n_features]), |_| {
34        rand::random::<f32>() * 2.0 - 1.0
35    });
36
37    // Generate random labels (integers 0 to n_classes-1)
38    let labels = Array::from_shape_fn(IxDyn(&[n_samples, 1]), |_| {
39        (rand::random::<f32>() * n_classes as f32).floor()
40    });
41
42    // Create dataset
43    let dataset = InMemoryDataset::new(features, labels)?;
44
45    // Split into training and validation sets
46    let (train_dataset, val_dataset) = dataset.train_test_split(0.2)?;
47
48    println!(
49        "Split data into {} training samples and {} validation samples",
50        train_dataset.features.shape()[0],
51        val_dataset.features.shape()[0]
52    );
53
54    // Create transformations
55    let feature_scaler = StandardScaler::new(false);
56    let label_encoder = OneHotEncoder::new(n_classes);
57
58    // Apply transformations
59    let train_dataset = TransformedDataset::new(train_dataset)
60        .with_feature_transform(feature_scaler)
61        .with_label_transform(label_encoder);
62
63    let val_dataset = TransformedDataset::new(val_dataset)
64        .with_feature_transform(StandardScaler::new(false))
65        .with_label_transform(OneHotEncoder::new(n_classes));
66
67    // Create data loaders
68    let batch_size = 32;
69    let train_loader = DataLoader::new(train_dataset.clone(), batch_size, true, false);
70    let val_loader = DataLoader::new(val_dataset.clone(), batch_size, false, false);
71
72    println!(
73        "Created data loaders with batch size {}. Training: {} batches, Validation: {} batches",
74        batch_size,
75        train_loader.num_batches(),
76        val_loader.num_batches()
77    );
78
79    // Create model, loss, and optimizer
80    let _model = create_model(n_features, n_classes)?;
81    let _loss_fn = CrossEntropyLoss::new(1e-10);
82
83    // Create learning rate schedule
84    let lr_schedule = LearningRateSchedule::StepDecay {
85        initial_lr: 0.001,
86        decay_factor: 0.5,
87        step_size: 3,
88    };
89
90    // Generate learning rates for all epochs
91    let num_epochs = 10;
92    let learning_rates = lr_schedule.generate_schedule(num_epochs);
93    println!("Learning rate schedule:");
94    for (i, &lr) in learning_rates.iter().enumerate() {
95        println!("  Epoch {}: {:.6}", i + 1, lr);
96    }
97
98    let _optimizer = Adam::new(0.001, 0.9, 0.999, 1e-8);
99
100    // Create callbacks
101    let _checkpoint_dir = PathBuf::from("./checkpoints");
102    let tensorboard_dir = PathBuf::from("./logs");
103
104    // Create output directories if they don't exist
105    create_dir_if_not_exists("./checkpoints")?;
106    create_dir_if_not_exists("./logs")?;
107    create_dir_if_not_exists("./outputs")?;
108
109    // For this example, we'll just remove the ModelCheckpoint
110    let mut callbacks: Vec<Box<dyn scirs2_neural::callbacks::Callback<f32>>> = vec![
111        Box::new(EarlyStopping::new(5, 0.001, true)),
112        // ModelCheckpoint removed for simplicity as it requires special handling
113        Box::new(ReduceOnPlateau::new(0.001, 0.5, 3, 0.001, 0.0001)),
114        Box::new(TensorBoardLogger::new(tensorboard_dir, true, 10)),
115        // Add our visualization callback
116        Box::new(
117            VisualizationCallback::new(1)
118                .with_save_path("./outputs/training_plot.txt")
119                .with_tracked_metrics(vec![
120                    "train_loss".to_string(),
121                    "val_loss".to_string(),
122                    "accuracy".to_string(),
123                    "learning_rate".to_string(),
124                ]),
125        ),
126    ];
127
128    // Training loop
129    let mut history = HashMap::<String, Vec<f32>>::new();
130    history.insert("train_loss".to_string(), Vec::new());
131    history.insert("val_loss".to_string(), Vec::new());
132    history.insert("learning_rate".to_string(), Vec::new());
133    history.insert("accuracy".to_string(), Vec::new());
134
135    println!("Starting training for {} epochs", num_epochs);
136
137    // Run callbacks before training
138    // Create a copy of history for the context
139    let mut context_history = HashMap::<String, Vec<f32>>::new();
140    context_history.insert("train_loss".to_string(), Vec::new());
141    context_history.insert("val_loss".to_string(), Vec::new());
142    context_history.insert("learning_rate".to_string(), Vec::new());
143    context_history.insert("accuracy".to_string(), Vec::new());
144
145    // For this example, we adapt to use Vec<F> for metrics
146    // which is simpler than using Vec<(String, Option<F>)>
147    // In a real implementation, use the proper context format
148    let mut context = CallbackContext {
149        epoch: 0,
150        total_epochs: num_epochs,
151        batch: 0,
152        total_batches: train_loader.num_batches(),
153        batch_loss: None,
154        epoch_loss: None,
155        val_loss: None,
156        metrics: Vec::new(),
157        history: &context_history,
158        stop_training: false,
159        model: None,
160    };
161
162    for callback in &mut callbacks {
163        callback.on_event(CallbackTiming::BeforeTraining, &mut context)?;
164    }
165
166    // Training loop
167    for epoch in 0..num_epochs {
168        println!("Epoch {}/{}", epoch + 1, num_epochs);
169
170        // Get learning rate for this epoch
171        let learning_rate = learning_rates[epoch];
172        history
173            .get_mut("learning_rate")
174            .unwrap()
175            .push(learning_rate);
176
177        // Reset data loader
178        let mut train_loader = DataLoader::new(train_dataset.clone(), batch_size, true, false);
179        train_loader.reset();
180
181        // Update context
182        context.epoch = epoch;
183        context.epoch_loss = None;
184        context.val_loss = None;
185
186        // Run callbacks before epoch
187        for callback in &mut callbacks {
188            callback.on_event(CallbackTiming::BeforeEpoch, &mut context)?;
189        }
190
191        // Train on batches
192        let mut epoch_loss = 0.0;
193        let mut batch_count = 0;
194
195        for (batch, batch_result) in train_loader.enumerate() {
196            let (_batch_x, _batch_y) = batch_result?;
197
198            // Update context
199            context.batch = batch;
200            context.batch_loss = None;
201
202            // Run callbacks before batch
203            for callback in &mut callbacks {
204                callback.on_event(CallbackTiming::BeforeBatch, &mut context)?;
205            }
206
207            // In a real implementation, we'd train the model here
208            // For now, just compute a random loss
209            let batch_loss = rand::random::<f32>() * (1.0 / (epoch as f32 + 1.0));
210
211            // Update batch loss
212            context.batch_loss = Some(batch_loss);
213
214            // Run callbacks after batch
215            for callback in &mut callbacks {
216                callback.on_event(CallbackTiming::AfterBatch, &mut context)?;
217            }
218
219            epoch_loss += batch_loss;
220            batch_count += 1;
221        }
222
223        // Compute epoch loss
224        epoch_loss /= batch_count as f32;
225        history.get_mut("train_loss").unwrap().push(epoch_loss);
226        context.epoch_loss = Some(epoch_loss);
227
228        println!("Train loss: {:.6}", epoch_loss);
229
230        // Evaluate on validation set
231        let mut val_loss = 0.0;
232        let mut val_batch_count = 0;
233
234        let mut val_loader = DataLoader::new(val_dataset.clone(), batch_size, false, false);
235        val_loader.reset();
236
237        for batch_result in val_loader {
238            let (_batch_x, _batch_y) = batch_result?;
239
240            // In a real implementation, we'd evaluate the model here
241            // For now, just compute a random loss
242            let batch_loss = rand::random::<f32>() * (1.0 / (epoch as f32 + 1.0)) * 1.1;
243
244            val_loss += batch_loss;
245            val_batch_count += 1;
246        }
247
248        // Compute validation loss
249        val_loss /= val_batch_count as f32;
250        history.get_mut("val_loss").unwrap().push(val_loss);
251        context.val_loss = Some(val_loss);
252
253        // Simulate accuracy metric
254        let accuracy =
255            0.5 + 0.4 * (epoch as f32 / num_epochs as f32) + rand::random::<f32>() * 0.05;
256        history.get_mut("accuracy").unwrap().push(accuracy);
257
258        // Add accuracy to metrics
259        context.metrics = vec![accuracy];
260
261        println!("Validation loss: {:.6}", val_loss);
262        println!("Accuracy: {:.2}%", accuracy * 100.0);
263
264        // Run callbacks after epoch
265        for callback in &mut callbacks {
266            callback.on_event(CallbackTiming::AfterEpoch, &mut context)?;
267        }
268
269        // Check if training should be stopped
270        if context.stop_training {
271            println!("Early stopping triggered, terminating training");
272            break;
273        }
274
275        // Visualize after each epoch
276        if epoch > 0 {
277            // Plot training and validation loss
278            let loss_plot = ascii_plot(
279                &history,
280                Some("Training and Validation Loss"),
281                Some(PlotOptions {
282                    width: 80,
283                    height: 20,
284                    max_x_ticks: 10,
285                    max_y_ticks: 5,
286                    line_char: '─',
287                    point_char: '●',
288                    background_char: ' ',
289                    show_grid: true,
290                    show_legend: true,
291                }),
292            )?;
293            println!("\n{}", loss_plot);
294        }
295    }
296
297    // Run callbacks after training
298    for callback in &mut callbacks {
299        callback.on_event(CallbackTiming::AfterTraining, &mut context)?;
300    }
301
302    println!("Training complete!");
303
304    // Export metrics to CSV
305    let csv_path = "./outputs/training_history.csv";
306    export_history_to_csv(&history, csv_path)?;
307    println!("Training history exported to {}", csv_path);
308
309    // Analyze training history
310    let analysis = analyze_training_history(&history);
311    println!("\nTraining Analysis:");
312    for issue in analysis {
313        println!("  {}", issue);
314    }
315
316    // Final visualization of metrics
317    println!("\nFinal Training Metrics:\n");
318
319    // Prepare subset of metrics for separate accuracy plot
320    let mut accuracy_data = HashMap::new();
321    accuracy_data.insert(
322        "accuracy".to_string(),
323        history.get("accuracy").unwrap().clone(),
324    );
325
326    // Plot accuracy
327    let accuracy_plot = ascii_plot(
328        &accuracy_data,
329        Some("Model Accuracy"),
330        Some(PlotOptions {
331            width: 80,
332            height: 15,
333            max_x_ticks: 10,
334            max_y_ticks: 5,
335            line_char: '─',
336            point_char: '●',
337            background_char: ' ',
338            show_grid: true,
339            show_legend: true,
340        }),
341    )?;
342    println!("{}", accuracy_plot);
343
344    // Prepare subset of metrics for learning rate plot
345    let mut lr_data = HashMap::new();
346    lr_data.insert(
347        "learning_rate".to_string(),
348        history.get("learning_rate").unwrap().clone(),
349    );
350
351    // Plot learning rate
352    let lr_plot = ascii_plot(
353        &lr_data,
354        Some("Learning Rate Schedule"),
355        Some(PlotOptions {
356            width: 80,
357            height: 15,
358            max_x_ticks: 10,
359            max_y_ticks: 5,
360            line_char: '─',
361            point_char: '■',
362            background_char: ' ',
363            show_grid: true,
364            show_legend: true,
365        }),
366    )?;
367    println!("{}", lr_plot);
368
369    // Visualize both train and validation losses in a single plot
370    let mut loss_data = HashMap::new();
371    loss_data.insert(
372        "train_loss".to_string(),
373        history.get("train_loss").unwrap().clone(),
374    );
375    loss_data.insert(
376        "val_loss".to_string(),
377        history.get("val_loss").unwrap().clone(),
378    );
379
380    let loss_plot = ascii_plot(
381        &loss_data,
382        Some("Training and Validation Loss"),
383        Some(PlotOptions {
384            width: 80,
385            height: 20,
386            max_x_ticks: 10,
387            max_y_ticks: 5,
388            line_char: '─',
389            point_char: '●',
390            background_char: ' ',
391            show_grid: true,
392            show_legend: true,
393        }),
394    )?;
395    println!("{}", loss_plot);
396
397    Ok(())
398}
399
400// Create a simple model for classification
401fn create_model(input_size: usize, num_classes: usize) -> Result<impl Layer<f32>> {
402    // Create a mutable RNG for initialization
403    let mut rng = rand::rng();
404
405    // Create a simple neural network model with two hidden layers
406    let hidden_size1 = 64;
407    let hidden_size2 = 32;
408
409    // In a real implementation, we'd connect more layers here
410    // For demo purposes, we're just returning the first layer
411    println!("Creating model with architecture:");
412    println!("  Input size: {}", input_size);
413    println!("  Hidden layer 1: {}", hidden_size1);
414    println!("  Hidden layer 2: {}", hidden_size2);
415    println!("  Output size: {}", num_classes);
416
417    // Create a dense layer with ReLU activation
418    let model = Dense::<f32>::new(input_size, hidden_size1, Some("relu"), &mut rng)?;
419
420    Ok(model)
421}
422
423// Helper function to create a directory if it doesn't exist
424fn create_dir_if_not_exists(path: impl AsRef<Path>) -> Result<()> {
425    let path = path.as_ref();
426    if !path.exists() {
427        std::fs::create_dir_all(path).map_err(|e| {
428            scirs2_neural::error::NeuralError::IOError(format!(
429                "Failed to create directory {}: {}",
430                path.display(),
431                e
432            ))
433        })?;
434    }
435    Ok(())
436}