Struct TrainingSession

Source
pub struct TrainingSession<F: Float + Debug + ScalarOperand> {
    pub history: HashMap<String, Vec<F>>,
    pub initial_learning_rate: F,
    pub epochs_trained: usize,
    pub batches_per_epoch: usize,
    pub total_parameters: usize,
    pub config: TrainingConfig,
}
Expand description

Training session for tracking training history

Fields§

§history: HashMap<String, Vec<F>>

Training metrics history

§initial_learning_rate: F

Initial learning rate

§epochs_trained: usize

Number of epochs trained

§batches_per_epoch: usize

Number of batches per epoch

§total_parameters: usize

Total number of parameters

§config: TrainingConfig

Training configuration

Implementations§

Source§

impl<F: Float + Debug + ScalarOperand> TrainingSession<F>

Source

pub fn new(config: TrainingConfig) -> Self

Create a new training session

Source

pub fn add_metric(&mut self, name: &str, value: F)

Add a metric to history

Source

pub fn get_metric(&self, name: &str) -> Option<&[F]>

Get metric history

Examples found in repository?
examples/image_classification_complete.rs (line 379)
233fn train_image_classifier() -> Result<()> {
234    println!("🚀 Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("📊 Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n🔄 Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n🏗️  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\n⚙️  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("🔄 Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n🏋️  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\n✅ Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n📊 Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n🔍 Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n🎯 Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n📋 Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n🎉 Image classification example completed successfully!");
399
400    Ok(())
401}

Trait Implementations§

Source§

impl<F: Clone + Float + Debug + ScalarOperand> Clone for TrainingSession<F>

Source§

fn clone(&self) -> TrainingSession<F>

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<F: Debug + Float + Debug + ScalarOperand> Debug for TrainingSession<F>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<F> Freeze for TrainingSession<F>
where F: Freeze,

§

impl<F> RefUnwindSafe for TrainingSession<F>
where F: RefUnwindSafe,

§

impl<F> Send for TrainingSession<F>
where F: Send,

§

impl<F> Sync for TrainingSession<F>
where F: Sync,

§

impl<F> Unpin for TrainingSession<F>
where F: Unpin,

§

impl<F> UnwindSafe for TrainingSession<F>
where F: UnwindSafe,

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V