pub struct TrainingSession<F: Float + Debug + ScalarOperand> {
pub history: HashMap<String, Vec<F>>,
pub initial_learning_rate: F,
pub epochs_trained: usize,
pub batches_per_epoch: usize,
pub total_parameters: usize,
pub config: TrainingConfig,
}
Expand description
Training session for tracking training history
Fields§
§history: HashMap<String, Vec<F>>
Training metrics history
initial_learning_rate: F
Initial learning rate
epochs_trained: usize
Number of epochs trained
batches_per_epoch: usize
Number of batches per epoch
total_parameters: usize
Total number of parameters
config: TrainingConfig
Training configuration
Implementations§
Source§impl<F: Float + Debug + ScalarOperand> TrainingSession<F>
impl<F: Float + Debug + ScalarOperand> TrainingSession<F>
Sourcepub fn new(config: TrainingConfig) -> Self
pub fn new(config: TrainingConfig) -> Self
Create a new training session
Sourcepub fn add_metric(&mut self, name: &str, value: F)
pub fn add_metric(&mut self, name: &str, value: F)
Add a metric to history
Sourcepub fn get_metric(&self, name: &str) -> Option<&[F]>
pub fn get_metric(&self, name: &str) -> Option<&[F]>
Get metric history
Examples found in repository?
examples/image_classification_complete.rs (line 379)
233fn train_image_classifier() -> Result<()> {
234 println!("🚀 Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("📊 Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n🔄 Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n🏗️ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\n⚙️ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("🔄 Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n🏋️ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\n✅ Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n📊 Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n🔍 Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n🎯 Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n📋 Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n🎉 Image classification example completed successfully!");
399
400 Ok(())
401}
Trait Implementations§
Source§impl<F: Clone + Float + Debug + ScalarOperand> Clone for TrainingSession<F>
impl<F: Clone + Float + Debug + ScalarOperand> Clone for TrainingSession<F>
Source§fn clone(&self) -> TrainingSession<F>
fn clone(&self) -> TrainingSession<F>
Returns a duplicate of the value. Read more
1.0.0 · Source§const fn clone_from(&mut self, source: &Self)
const fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source
. Read moreAuto Trait Implementations§
impl<F> Freeze for TrainingSession<F>where
F: Freeze,
impl<F> RefUnwindSafe for TrainingSession<F>where
F: RefUnwindSafe,
impl<F> Send for TrainingSession<F>where
F: Send,
impl<F> Sync for TrainingSession<F>where
F: Sync,
impl<F> Unpin for TrainingSession<F>where
F: Unpin,
impl<F> UnwindSafe for TrainingSession<F>where
F: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more