pub struct Trainer<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync + Display> { /* private fields */ }
Expand description
Trainer for a neural network model
Implementationsยง
Sourceยงimpl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> Trainer<F>
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> Trainer<F>
Sourcepub fn new<L, O, LF>(
model: L,
optimizer: O,
loss_fn: LF,
config: TrainingConfig,
) -> Self
pub fn new<L, O, LF>( model: L, optimizer: O, loss_fn: LF, config: TrainingConfig, ) -> Self
Create a new trainer
Examples found in repository?
examples/text_classification_complete.rs (line 540)
458fn train_text_classifier() -> StdResult<()> {
459 println!("๐ Starting Text Classification Training Example");
460 println!("{}", "=".repeat(60));
461
462 let mut rng = SmallRng::seed_from_u64(42);
463
464 // Dataset parameters
465 let num_samples = 800;
466 let num_classes = 3;
467 let max_length = 20;
468 let embedding_dim = 64;
469 let hidden_dim = 128;
470
471 println!("๐ Dataset Configuration:");
472 println!(" - Samples: {}", num_samples);
473 println!(
474 " - Classes: {} (Positive, Negative, Neutral)",
475 num_classes
476 );
477 println!(" - Max sequence length: {}", max_length);
478 println!(" - Embedding dimension: {}", embedding_dim);
479
480 // Create synthetic text dataset
481 println!("\n๐ Creating synthetic text dataset...");
482 let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485 println!(" - Vocabulary size: {}", dataset.vocab.vocab_size);
486 println!(" - Training samples: {}", train_dataset.len());
487 println!(" - Validation samples: {}", val_dataset.len());
488
489 // Show some example texts
490 println!("\n๐ Sample texts:");
491 for i in 0..3.min(train_dataset.texts.len()) {
492 println!(
493 " [Class {}]: {}",
494 train_dataset.labels[i], train_dataset.texts[i]
495 );
496 }
497
498 // Build model
499 println!("\n๐๏ธ Building text classification model...");
500 let model = build_text_model(
501 dataset.vocab.vocab_size,
502 embedding_dim,
503 hidden_dim,
504 num_classes,
505 max_length,
506 &mut rng,
507 )?;
508
509 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510 println!(" - Model layers: {}", model.len());
511 println!(" - Total parameters: {}", total_params);
512
513 // Training configuration
514 let config = TrainingConfig {
515 batch_size: 16,
516 epochs: 30,
517 learning_rate: 0.001,
518 shuffle: true,
519 verbose: 1,
520 validation: Some(ValidationSettings {
521 enabled: true,
522 validation_split: 0.2,
523 batch_size: 32,
524 num_workers: 0,
525 }),
526 gradient_accumulation: None,
527 mixed_precision: None,
528 num_workers: 0,
529 };
530
531 println!("\nโ๏ธ Training Configuration:");
532 println!(" - Batch size: {}", config.batch_size);
533 println!(" - Learning rate: {}", config.learning_rate);
534 println!(" - Epochs: {}", config.epochs);
535
536 // Set up training
537 let loss_fn = CrossEntropyLoss::new(1e-7);
538 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542 // Train the model
543 println!("\n๐๏ธ Starting training...");
544 println!("{}", "-".repeat(40));
545
546 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548 println!("\nโ
Training completed!");
549 println!(" - Epochs trained: {}", training_session.epochs_trained);
550
551 // Evaluate model
552 println!("\n๐ Final Evaluation:");
553 let val_metrics = trainer.validate(&val_dataset)?;
554
555 for (metric, value) in &val_metrics {
556 println!(" - {}: {:.4}", metric, value);
557 }
558
559 // Test on sample texts
560 println!("\n๐ Sample Predictions:");
561 let sample_indices = vec![0, 1, 2, 3, 4];
562
563 // Manually collect batch since get_batch is not part of Dataset trait
564 let mut batch_tokens = Vec::new();
565 let mut batch_targets = Vec::new();
566
567 for &idx in &sample_indices {
568 let (tokens, targets) = val_dataset.get(idx)?;
569 batch_tokens.push(tokens);
570 batch_targets.push(targets);
571 }
572
573 // Concatenate into batch arrays
574 let sample_tokens = ndarray::concatenate(
575 ndarray::Axis(0),
576 &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577 )?;
578 let sample_targets = ndarray::concatenate(
579 ndarray::Axis(0),
580 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581 )?;
582
583 let model = trainer.get_model();
584 let predictions = model.forward(&sample_tokens)?;
585
586 let class_names = ["Positive", "Negative", "Neutral"];
587
588 for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589 let pred_row = predictions.slice(s![i, ..]);
590 let target_row = sample_targets.slice(s![i, ..]);
591
592 let pred_class = pred_row
593 .iter()
594 .enumerate()
595 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596 .map(|(i, _)| i)
597 .unwrap_or(0);
598
599 let true_class = target_row
600 .iter()
601 .enumerate()
602 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603 .map(|(i, _)| i)
604 .unwrap_or(0);
605
606 let confidence = pred_row[pred_class];
607
608 if sample_indices[i] < val_dataset.texts.len() {
609 println!(" Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610 println!(
611 " Predicted: {} (confidence: {:.3})",
612 class_names[pred_class], confidence
613 );
614 println!(" Actual: {}", class_names[true_class]);
615 println!();
616 }
617 }
618
619 // Calculate detailed metrics
620 let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621 println!("๐ Detailed Metrics:");
622 for (metric, value) in &detailed_metrics {
623 println!(" - {}: {:.4}", metric, value);
624 }
625
626 Ok(())
627}
More examples
examples/image_classification_complete.rs (line 285)
233fn train_image_classifier() -> Result<()> {
234 println!("๐ Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("๐ Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n๐ Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n๐๏ธ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\nโ๏ธ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("๐ Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n๐๏ธ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\nโ
Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n๐ Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n๐ Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n๐ฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n๐ Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n๐ Image classification example completed successfully!");
399
400 Ok(())
401}
examples/advanced_training_example.rs (line 169)
124fn main() -> Result<()> {
125 println!("Advanced Training Examples");
126 println!("-------------------------");
127
128 // 1. Gradient Accumulation
129 println!("\n1. Training with Gradient Accumulation:");
130
131 // Generate synthetic dataset
132 let _dataset = generate_regression_dataset::<f32>(1000, 10, 2)?;
133 let _val_dataset = generate_regression_dataset::<f32>(200, 10, 2)?;
134
135 // Create model, optimizer, and loss function
136 let model = create_regression_model::<f32>(10, 64, 2)?;
137 let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
138 let loss_fn = MSELoss::new();
139
140 // Create gradient accumulation config
141 let ga_config = GradientAccumulationConfig {
142 accumulation_steps: 4,
143 average_gradients: true,
144 zero_gradients_after_update: true,
145 clip_gradients: true,
146 max_gradient_norm: Some(1.0),
147 log_gradient_stats: true,
148 };
149
150 // Create training config
151 let training_config = TrainingConfig {
152 batch_size: 32,
153 shuffle: true,
154 num_workers: 0,
155 learning_rate: 0.001,
156 epochs: 5,
157 verbose: 1,
158 validation: Some(ValidationSettings {
159 enabled: true,
160 validation_split: 0.0, // Use separate validation dataset
161 batch_size: 32,
162 num_workers: 0,
163 }),
164 gradient_accumulation: Some(ga_config),
165 mixed_precision: None,
166 };
167
168 // Create trainer
169 let _trainer = Trainer::new(model, optimizer, loss_fn, training_config);
170
171 // Note: To properly use callbacks, we would need to implement the appropriate trait interfaces
172 // Here we're simplifying for the example
173
174 // We'll use a simple closure to describe the early stopping callback
175 println!("Note: We would add callbacks like EarlyStopping and ModelCheckpoint here");
176 println!("For example: EarlyStopping with patience=5, min_delta=0.001");
177
178 // Create learning rate scheduler - we'll just demonstrate its usage
179 let _lr_scheduler = CosineAnnealingScheduler::new(0.001_f32, 0.0001_f32);
180 println!("Using CosineAnnealingScheduler with initial_lr=0.001, min_lr=0.0001");
181
182 // Train model
183 println!("\nTraining model with gradient accumulation...");
184
185 // For demonstration purposes, show what would happen with real training
186 println!("Would execute: trainer.train(&dataset, Some(&val_dataset))?");
187
188 // Since we're not actually training, just show example output
189 println!("\nExample of training output that would be shown:");
190 println!("Training completed in 3 epochs");
191 println!("Final loss: 0.0124");
192 println!("Final validation loss: 0.0156");
193
194 // 2. Manual Gradient Accumulation
195 println!("\n2. Manual Gradient Accumulation:");
196
197 // Create model, optimizer, and loss function
198 let model = create_regression_model::<f32>(10, 64, 2)?;
199 let _optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
200 let _loss_fn = MSELoss::new();
201
202 // Create gradient accumulator
203 let mut accumulator = GradientAccumulator::new(GradientAccumulationConfig {
204 accumulation_steps: 4,
205 average_gradients: true,
206 zero_gradients_after_update: true,
207 clip_gradients: false,
208 max_gradient_norm: None,
209 log_gradient_stats: false,
210 });
211
212 // Initialize accumulator
213 accumulator.initialize(&model)?;
214
215 // We would use a DataLoader in real code, but here we'll simulate it
216 println!("Creating data loader with batch_size=32, shuffle=true");
217
218 println!("\nTraining for 1 epoch with manual gradient accumulation...");
219
220 let mut total_loss = 0.0_f32;
221 let mut processed_batches = 0;
222
223 // Train for one epoch
224 // This is a simplified example - in practice you would iterate through DataLoader batches
225 // Simulated loop to demonstrate the concept:
226 let total_batches = 5;
227 for batch_idx in 0..total_batches {
228 // In a real implementation we would get inputs and targets from data_loader
229 println!("Batch {} - Accumulating gradients...", batch_idx + 1);
230
231 // Simulate a loss value
232 let loss = 0.1 * (batch_idx as f32 + 1.0).powf(-0.5);
233 total_loss += loss;
234 processed_batches += 1;
235
236 // Simulate gradient stats
237 println!(
238 "Batch {} - Gradient stats: min={:.4}, max={:.4}, mean={:.4}, norm={:.4}",
239 batch_idx + 1,
240 -0.05 * (batch_idx as f32 + 1.0).powf(-0.5),
241 0.05 * (batch_idx as f32 + 1.0).powf(-0.5),
242 0.01 * (batch_idx as f32 + 1.0).powf(-0.5),
243 0.2 * (batch_idx as f32 + 1.0).powf(-0.5)
244 );
245
246 // Update if needed - this is conceptual
247 if (batch_idx + 1) % 4 == 0 || batch_idx == total_batches - 1 {
248 println!(
249 "Applying accumulated gradients after {} batches",
250 (batch_idx + 1) % 4
251 );
252 // In a real implementation we would apply gradients:
253 // accumulator.apply_gradients(&mut model, &mut optimizer)?;
254 }
255
256 // Early stopping for example
257 if batch_idx >= 10 {
258 break;
259 }
260 }
261
262 if processed_batches > 0 {
263 println!("Average loss: {:.4}", total_loss / processed_batches as f32);
264 }
265
266 // 3. Mixed Precision (not fully implemented, pseudocode)
267 println!("\n3. Mixed Precision Training (Pseudocode):");
268
269 println!(
270 "// Create mixed precision config
271let mp_config = MixedPrecisionConfig {{
272 dynamic_loss_scaling: true,
273 initial_loss_scale: 65536.0,
274 scale_factor: 2.0,
275 scale_window: 2000,
276 min_loss_scale: 1.0,
277 max_loss_scale: 2_f64.powi(24),
278 verbose: true,
279}};
280
281// Create high precision and low precision models
282let high_precision_model = create_regression_model::<f32>(10, 64, 2)?;
283let low_precision_model = create_regression_model::<f16>(10, 64, 2)?;
284
285// Create mixed precision model
286let mut mixed_model = MixedPrecisionModel::new(
287 high_precision_model,
288 low_precision_model,
289 mp_config,
290)?;
291
292// Create optimizer and loss function
293let mut optimizer = Adam::new(0.001);
294let loss_fn = MSELoss::new();
295
296// Train for one epoch
297mixed_model.train_epoch(
298 &mut optimizer,
299 &dataset,
300 &loss_fn,
301 32,
302 true,
303)?;"
304 );
305
306 // 4. Gradient Clipping
307 println!("\n4. Gradient Clipping:");
308
309 // Create model, optimizer, and loss function
310 let model = create_regression_model::<f32>(10, 64, 2)?;
311 let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
312 let loss_fn = MSELoss::new();
313
314 // Create training config - we need two separate instances
315 let gradient_clipping_config = TrainingConfig {
316 batch_size: 32,
317 shuffle: true,
318 num_workers: 0,
319 learning_rate: 0.001,
320 epochs: 5,
321 verbose: 1,
322 validation: Some(ValidationSettings {
323 enabled: true,
324 validation_split: 0.0, // Use separate validation dataset
325 batch_size: 32,
326 num_workers: 0,
327 }),
328 gradient_accumulation: None,
329 mixed_precision: None,
330 };
331
332 // Create a separate configuration for the value clipping example
333 let value_clipping_config = TrainingConfig {
334 batch_size: 32,
335 shuffle: true,
336 num_workers: 0,
337 learning_rate: 0.001,
338 epochs: 5,
339 verbose: 1,
340 validation: Some(ValidationSettings {
341 enabled: true,
342 validation_split: 0.0, // Use separate validation dataset
343 batch_size: 32,
344 num_workers: 0,
345 }),
346 gradient_accumulation: None,
347 mixed_precision: None,
348 };
349
350 // Create trainer
351 let _trainer = Trainer::new(model, optimizer, loss_fn, gradient_clipping_config);
352
353 // Instead of adding callbacks directly, we'll just demonstrate the concept
354 println!("If callbacks were fully implemented, we would add gradient clipping:");
355 println!("GradientClipping::by_global_norm(1.0_f32, true) // Max norm, log_stats");
356
357 println!("\nTraining model with gradient clipping by global norm...");
358
359 // Train model for a few epochs
360 let _dataset_small = generate_regression_dataset::<f32>(500, 10, 2)?;
361 let _val_dataset_small = generate_regression_dataset::<f32>(100, 10, 2)?;
362 println!("Would train the model with dataset_small and val_dataset_small");
363 // In a real implementation:
364 // let session = trainer.train(&dataset_small, Some(&val_dataset_small))?;
365
366 // Since we're not actually training, just show example output
367 println!("\nExample of training output that would be shown:");
368 println!("Training completed in 3 epochs");
369 println!("Final loss: 0.0124");
370 println!("Final validation loss: 0.0156");
371
372 // Example with value clipping
373 println!("\nExample with gradient clipping by value:");
374
375 // Create model and trainer with value clipping
376 let model = create_regression_model::<f32>(10, 64, 2)?;
377 let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
378 let _trainer = Trainer::new(model, optimizer, loss_fn, value_clipping_config);
379
380 // Instead of actual callbacks, show how we would use them
381 println!("For gradient clipping by value, we would use:");
382 println!("GradientClipping::by_value(0.5_f32, true) // Max value, log_stats");
383
384 println!("\nDemonstration of how to set up gradient clipping by value:");
385 println!("trainer.add_callback(Box::new(GradientClipping::by_value(");
386 println!(" 0.5_f32, // Max value");
387 println!(" true, // Log stats");
388 println!(")));");
389
390 // Demonstrate the training utilities
391 println!("\nAdvanced Training Examples Completed Successfully!");
392
393 Ok(())
394}
Sourcepub fn add_callback(
&mut self,
callback: Box<dyn Fn() -> Result<()> + Send + Sync>,
)
pub fn add_callback( &mut self, callback: Box<dyn Fn() -> Result<()> + Send + Sync>, )
Add a callback to the trainer
Examples found in repository?
examples/image_classification_complete.rs (lines 288-292)
233fn train_image_classifier() -> Result<()> {
234 println!("๐ Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("๐ Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n๐ Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n๐๏ธ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\nโ๏ธ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("๐ Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n๐๏ธ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\nโ
Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n๐ Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n๐ Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n๐ฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n๐ Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n๐ Image classification example completed successfully!");
399
400 Ok(())
401}
Sourcepub fn train<D: Dataset<F> + Clone>(
&mut self,
dataset: &D,
validation_dataset: Option<&D>,
) -> Result<TrainingSession<F>>
pub fn train<D: Dataset<F> + Clone>( &mut self, dataset: &D, validation_dataset: Option<&D>, ) -> Result<TrainingSession<F>>
Train the model
Examples found in repository?
examples/text_classification_complete.rs (line 546)
458fn train_text_classifier() -> StdResult<()> {
459 println!("๐ Starting Text Classification Training Example");
460 println!("{}", "=".repeat(60));
461
462 let mut rng = SmallRng::seed_from_u64(42);
463
464 // Dataset parameters
465 let num_samples = 800;
466 let num_classes = 3;
467 let max_length = 20;
468 let embedding_dim = 64;
469 let hidden_dim = 128;
470
471 println!("๐ Dataset Configuration:");
472 println!(" - Samples: {}", num_samples);
473 println!(
474 " - Classes: {} (Positive, Negative, Neutral)",
475 num_classes
476 );
477 println!(" - Max sequence length: {}", max_length);
478 println!(" - Embedding dimension: {}", embedding_dim);
479
480 // Create synthetic text dataset
481 println!("\n๐ Creating synthetic text dataset...");
482 let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485 println!(" - Vocabulary size: {}", dataset.vocab.vocab_size);
486 println!(" - Training samples: {}", train_dataset.len());
487 println!(" - Validation samples: {}", val_dataset.len());
488
489 // Show some example texts
490 println!("\n๐ Sample texts:");
491 for i in 0..3.min(train_dataset.texts.len()) {
492 println!(
493 " [Class {}]: {}",
494 train_dataset.labels[i], train_dataset.texts[i]
495 );
496 }
497
498 // Build model
499 println!("\n๐๏ธ Building text classification model...");
500 let model = build_text_model(
501 dataset.vocab.vocab_size,
502 embedding_dim,
503 hidden_dim,
504 num_classes,
505 max_length,
506 &mut rng,
507 )?;
508
509 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510 println!(" - Model layers: {}", model.len());
511 println!(" - Total parameters: {}", total_params);
512
513 // Training configuration
514 let config = TrainingConfig {
515 batch_size: 16,
516 epochs: 30,
517 learning_rate: 0.001,
518 shuffle: true,
519 verbose: 1,
520 validation: Some(ValidationSettings {
521 enabled: true,
522 validation_split: 0.2,
523 batch_size: 32,
524 num_workers: 0,
525 }),
526 gradient_accumulation: None,
527 mixed_precision: None,
528 num_workers: 0,
529 };
530
531 println!("\nโ๏ธ Training Configuration:");
532 println!(" - Batch size: {}", config.batch_size);
533 println!(" - Learning rate: {}", config.learning_rate);
534 println!(" - Epochs: {}", config.epochs);
535
536 // Set up training
537 let loss_fn = CrossEntropyLoss::new(1e-7);
538 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542 // Train the model
543 println!("\n๐๏ธ Starting training...");
544 println!("{}", "-".repeat(40));
545
546 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548 println!("\nโ
Training completed!");
549 println!(" - Epochs trained: {}", training_session.epochs_trained);
550
551 // Evaluate model
552 println!("\n๐ Final Evaluation:");
553 let val_metrics = trainer.validate(&val_dataset)?;
554
555 for (metric, value) in &val_metrics {
556 println!(" - {}: {:.4}", metric, value);
557 }
558
559 // Test on sample texts
560 println!("\n๐ Sample Predictions:");
561 let sample_indices = vec![0, 1, 2, 3, 4];
562
563 // Manually collect batch since get_batch is not part of Dataset trait
564 let mut batch_tokens = Vec::new();
565 let mut batch_targets = Vec::new();
566
567 for &idx in &sample_indices {
568 let (tokens, targets) = val_dataset.get(idx)?;
569 batch_tokens.push(tokens);
570 batch_targets.push(targets);
571 }
572
573 // Concatenate into batch arrays
574 let sample_tokens = ndarray::concatenate(
575 ndarray::Axis(0),
576 &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577 )?;
578 let sample_targets = ndarray::concatenate(
579 ndarray::Axis(0),
580 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581 )?;
582
583 let model = trainer.get_model();
584 let predictions = model.forward(&sample_tokens)?;
585
586 let class_names = ["Positive", "Negative", "Neutral"];
587
588 for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589 let pred_row = predictions.slice(s![i, ..]);
590 let target_row = sample_targets.slice(s![i, ..]);
591
592 let pred_class = pred_row
593 .iter()
594 .enumerate()
595 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596 .map(|(i, _)| i)
597 .unwrap_or(0);
598
599 let true_class = target_row
600 .iter()
601 .enumerate()
602 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603 .map(|(i, _)| i)
604 .unwrap_or(0);
605
606 let confidence = pred_row[pred_class];
607
608 if sample_indices[i] < val_dataset.texts.len() {
609 println!(" Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610 println!(
611 " Predicted: {} (confidence: {:.3})",
612 class_names[pred_class], confidence
613 );
614 println!(" Actual: {}", class_names[true_class]);
615 println!();
616 }
617 }
618
619 // Calculate detailed metrics
620 let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621 println!("๐ Detailed Metrics:");
622 for (metric, value) in &detailed_metrics {
623 println!(" - {}: {:.4}", metric, value);
624 }
625
626 Ok(())
627}
More examples
examples/image_classification_complete.rs (line 298)
233fn train_image_classifier() -> Result<()> {
234 println!("๐ Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("๐ Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n๐ Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n๐๏ธ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\nโ๏ธ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("๐ Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n๐๏ธ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\nโ
Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n๐ Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n๐ Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n๐ฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n๐ Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n๐ Image classification example completed successfully!");
399
400 Ok(())
401}
Sourcepub fn validate<D: Dataset<F>>(
&mut self,
dataset: &D,
) -> Result<HashMap<String, F>>
pub fn validate<D: Dataset<F>>( &mut self, dataset: &D, ) -> Result<HashMap<String, F>>
Validate the model on a dataset
Examples found in repository?
examples/text_classification_complete.rs (line 553)
458fn train_text_classifier() -> StdResult<()> {
459 println!("๐ Starting Text Classification Training Example");
460 println!("{}", "=".repeat(60));
461
462 let mut rng = SmallRng::seed_from_u64(42);
463
464 // Dataset parameters
465 let num_samples = 800;
466 let num_classes = 3;
467 let max_length = 20;
468 let embedding_dim = 64;
469 let hidden_dim = 128;
470
471 println!("๐ Dataset Configuration:");
472 println!(" - Samples: {}", num_samples);
473 println!(
474 " - Classes: {} (Positive, Negative, Neutral)",
475 num_classes
476 );
477 println!(" - Max sequence length: {}", max_length);
478 println!(" - Embedding dimension: {}", embedding_dim);
479
480 // Create synthetic text dataset
481 println!("\n๐ Creating synthetic text dataset...");
482 let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485 println!(" - Vocabulary size: {}", dataset.vocab.vocab_size);
486 println!(" - Training samples: {}", train_dataset.len());
487 println!(" - Validation samples: {}", val_dataset.len());
488
489 // Show some example texts
490 println!("\n๐ Sample texts:");
491 for i in 0..3.min(train_dataset.texts.len()) {
492 println!(
493 " [Class {}]: {}",
494 train_dataset.labels[i], train_dataset.texts[i]
495 );
496 }
497
498 // Build model
499 println!("\n๐๏ธ Building text classification model...");
500 let model = build_text_model(
501 dataset.vocab.vocab_size,
502 embedding_dim,
503 hidden_dim,
504 num_classes,
505 max_length,
506 &mut rng,
507 )?;
508
509 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510 println!(" - Model layers: {}", model.len());
511 println!(" - Total parameters: {}", total_params);
512
513 // Training configuration
514 let config = TrainingConfig {
515 batch_size: 16,
516 epochs: 30,
517 learning_rate: 0.001,
518 shuffle: true,
519 verbose: 1,
520 validation: Some(ValidationSettings {
521 enabled: true,
522 validation_split: 0.2,
523 batch_size: 32,
524 num_workers: 0,
525 }),
526 gradient_accumulation: None,
527 mixed_precision: None,
528 num_workers: 0,
529 };
530
531 println!("\nโ๏ธ Training Configuration:");
532 println!(" - Batch size: {}", config.batch_size);
533 println!(" - Learning rate: {}", config.learning_rate);
534 println!(" - Epochs: {}", config.epochs);
535
536 // Set up training
537 let loss_fn = CrossEntropyLoss::new(1e-7);
538 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542 // Train the model
543 println!("\n๐๏ธ Starting training...");
544 println!("{}", "-".repeat(40));
545
546 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548 println!("\nโ
Training completed!");
549 println!(" - Epochs trained: {}", training_session.epochs_trained);
550
551 // Evaluate model
552 println!("\n๐ Final Evaluation:");
553 let val_metrics = trainer.validate(&val_dataset)?;
554
555 for (metric, value) in &val_metrics {
556 println!(" - {}: {:.4}", metric, value);
557 }
558
559 // Test on sample texts
560 println!("\n๐ Sample Predictions:");
561 let sample_indices = vec![0, 1, 2, 3, 4];
562
563 // Manually collect batch since get_batch is not part of Dataset trait
564 let mut batch_tokens = Vec::new();
565 let mut batch_targets = Vec::new();
566
567 for &idx in &sample_indices {
568 let (tokens, targets) = val_dataset.get(idx)?;
569 batch_tokens.push(tokens);
570 batch_targets.push(targets);
571 }
572
573 // Concatenate into batch arrays
574 let sample_tokens = ndarray::concatenate(
575 ndarray::Axis(0),
576 &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577 )?;
578 let sample_targets = ndarray::concatenate(
579 ndarray::Axis(0),
580 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581 )?;
582
583 let model = trainer.get_model();
584 let predictions = model.forward(&sample_tokens)?;
585
586 let class_names = ["Positive", "Negative", "Neutral"];
587
588 for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589 let pred_row = predictions.slice(s![i, ..]);
590 let target_row = sample_targets.slice(s![i, ..]);
591
592 let pred_class = pred_row
593 .iter()
594 .enumerate()
595 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596 .map(|(i, _)| i)
597 .unwrap_or(0);
598
599 let true_class = target_row
600 .iter()
601 .enumerate()
602 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603 .map(|(i, _)| i)
604 .unwrap_or(0);
605
606 let confidence = pred_row[pred_class];
607
608 if sample_indices[i] < val_dataset.texts.len() {
609 println!(" Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610 println!(
611 " Predicted: {} (confidence: {:.3})",
612 class_names[pred_class], confidence
613 );
614 println!(" Actual: {}", class_names[true_class]);
615 println!();
616 }
617 }
618
619 // Calculate detailed metrics
620 let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621 println!("๐ Detailed Metrics:");
622 for (metric, value) in &detailed_metrics {
623 println!(" - {}: {:.4}", metric, value);
624 }
625
626 Ok(())
627}
More examples
examples/image_classification_complete.rs (line 309)
233fn train_image_classifier() -> Result<()> {
234 println!("๐ Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("๐ Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n๐ Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n๐๏ธ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\nโ๏ธ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("๐ Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n๐๏ธ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\nโ
Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n๐ Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n๐ Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n๐ฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n๐ Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n๐ Image classification example completed successfully!");
399
400 Ok(())
401}
Sourcepub fn get_model(&self) -> &dyn Layer<F>
pub fn get_model(&self) -> &dyn Layer<F>
Get the model
Examples found in repository?
examples/text_classification_complete.rs (line 583)
458fn train_text_classifier() -> StdResult<()> {
459 println!("๐ Starting Text Classification Training Example");
460 println!("{}", "=".repeat(60));
461
462 let mut rng = SmallRng::seed_from_u64(42);
463
464 // Dataset parameters
465 let num_samples = 800;
466 let num_classes = 3;
467 let max_length = 20;
468 let embedding_dim = 64;
469 let hidden_dim = 128;
470
471 println!("๐ Dataset Configuration:");
472 println!(" - Samples: {}", num_samples);
473 println!(
474 " - Classes: {} (Positive, Negative, Neutral)",
475 num_classes
476 );
477 println!(" - Max sequence length: {}", max_length);
478 println!(" - Embedding dimension: {}", embedding_dim);
479
480 // Create synthetic text dataset
481 println!("\n๐ Creating synthetic text dataset...");
482 let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485 println!(" - Vocabulary size: {}", dataset.vocab.vocab_size);
486 println!(" - Training samples: {}", train_dataset.len());
487 println!(" - Validation samples: {}", val_dataset.len());
488
489 // Show some example texts
490 println!("\n๐ Sample texts:");
491 for i in 0..3.min(train_dataset.texts.len()) {
492 println!(
493 " [Class {}]: {}",
494 train_dataset.labels[i], train_dataset.texts[i]
495 );
496 }
497
498 // Build model
499 println!("\n๐๏ธ Building text classification model...");
500 let model = build_text_model(
501 dataset.vocab.vocab_size,
502 embedding_dim,
503 hidden_dim,
504 num_classes,
505 max_length,
506 &mut rng,
507 )?;
508
509 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510 println!(" - Model layers: {}", model.len());
511 println!(" - Total parameters: {}", total_params);
512
513 // Training configuration
514 let config = TrainingConfig {
515 batch_size: 16,
516 epochs: 30,
517 learning_rate: 0.001,
518 shuffle: true,
519 verbose: 1,
520 validation: Some(ValidationSettings {
521 enabled: true,
522 validation_split: 0.2,
523 batch_size: 32,
524 num_workers: 0,
525 }),
526 gradient_accumulation: None,
527 mixed_precision: None,
528 num_workers: 0,
529 };
530
531 println!("\nโ๏ธ Training Configuration:");
532 println!(" - Batch size: {}", config.batch_size);
533 println!(" - Learning rate: {}", config.learning_rate);
534 println!(" - Epochs: {}", config.epochs);
535
536 // Set up training
537 let loss_fn = CrossEntropyLoss::new(1e-7);
538 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542 // Train the model
543 println!("\n๐๏ธ Starting training...");
544 println!("{}", "-".repeat(40));
545
546 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548 println!("\nโ
Training completed!");
549 println!(" - Epochs trained: {}", training_session.epochs_trained);
550
551 // Evaluate model
552 println!("\n๐ Final Evaluation:");
553 let val_metrics = trainer.validate(&val_dataset)?;
554
555 for (metric, value) in &val_metrics {
556 println!(" - {}: {:.4}", metric, value);
557 }
558
559 // Test on sample texts
560 println!("\n๐ Sample Predictions:");
561 let sample_indices = vec![0, 1, 2, 3, 4];
562
563 // Manually collect batch since get_batch is not part of Dataset trait
564 let mut batch_tokens = Vec::new();
565 let mut batch_targets = Vec::new();
566
567 for &idx in &sample_indices {
568 let (tokens, targets) = val_dataset.get(idx)?;
569 batch_tokens.push(tokens);
570 batch_targets.push(targets);
571 }
572
573 // Concatenate into batch arrays
574 let sample_tokens = ndarray::concatenate(
575 ndarray::Axis(0),
576 &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577 )?;
578 let sample_targets = ndarray::concatenate(
579 ndarray::Axis(0),
580 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581 )?;
582
583 let model = trainer.get_model();
584 let predictions = model.forward(&sample_tokens)?;
585
586 let class_names = ["Positive", "Negative", "Neutral"];
587
588 for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589 let pred_row = predictions.slice(s![i, ..]);
590 let target_row = sample_targets.slice(s![i, ..]);
591
592 let pred_class = pred_row
593 .iter()
594 .enumerate()
595 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596 .map(|(i, _)| i)
597 .unwrap_or(0);
598
599 let true_class = target_row
600 .iter()
601 .enumerate()
602 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603 .map(|(i, _)| i)
604 .unwrap_or(0);
605
606 let confidence = pred_row[pred_class];
607
608 if sample_indices[i] < val_dataset.texts.len() {
609 println!(" Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610 println!(
611 " Predicted: {} (confidence: {:.3})",
612 class_names[pred_class], confidence
613 );
614 println!(" Actual: {}", class_names[true_class]);
615 println!();
616 }
617 }
618
619 // Calculate detailed metrics
620 let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621 println!("๐ Detailed Metrics:");
622 for (metric, value) in &detailed_metrics {
623 println!(" - {}: {:.4}", metric, value);
624 }
625
626 Ok(())
627}
More examples
examples/image_classification_complete.rs (line 339)
233fn train_image_classifier() -> Result<()> {
234 println!("๐ Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("๐ Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n๐ Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n๐๏ธ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\nโ๏ธ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("๐ Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n๐๏ธ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\nโ
Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n๐ Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n๐ Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n๐ฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n๐ Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n๐ Image classification example completed successfully!");
399
400 Ok(())
401}
Sourcepub fn get_model_mut(&mut self) -> &mut dyn Layer<F>
pub fn get_model_mut(&mut self) -> &mut dyn Layer<F>
Get the model (mutable)
Sourcepub fn get_optimizer(&self) -> &dyn Optimizer<F>
pub fn get_optimizer(&self) -> &dyn Optimizer<F>
Get the optimizer
Sourcepub fn get_optimizer_mut(&mut self) -> &mut dyn Optimizer<F>
pub fn get_optimizer_mut(&mut self) -> &mut dyn Optimizer<F>
Get the optimizer (mutable)
Sourcepub fn get_loss_fn(&self) -> &dyn Loss<F>
pub fn get_loss_fn(&self) -> &dyn Loss<F>
Get the loss function
Sourcepub fn get_session(&self) -> &TrainingSession<F>
pub fn get_session(&self) -> &TrainingSession<F>
Get the current training session
Examples found in repository?
examples/image_classification_complete.rs (line 378)
233fn train_image_classifier() -> Result<()> {
234 println!("๐ Starting Image Classification Training Example");
235 println!("{}", "=".repeat(60));
236
237 // Set up reproducible random number generator
238 let mut rng = SmallRng::seed_from_u64(42);
239
240 // Dataset parameters
241 let num_samples = 1000;
242 let num_classes = 5;
243 let image_size = (32, 32);
244 let input_channels = 3;
245
246 println!("๐ Dataset Configuration:");
247 println!(" - Samples: {}", num_samples);
248 println!(" - Classes: {}", num_classes);
249 println!(" - Image Size: {}x{}", image_size.0, image_size.1);
250 println!(" - Channels: {}", input_channels);
251
252 // Create synthetic dataset
253 println!("\n๐ Creating synthetic dataset...");
254 let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255 let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257 println!(" - Training samples: {}", train_dataset.len());
258 println!(" - Validation samples: {}", val_dataset.len());
259
260 // Build model
261 println!("\n๐๏ธ Building CNN model...");
262 let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264 // Count parameters
265 let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266 println!(" - Model layers: {}", model.len());
267 println!(" - Total parameters: {}", total_params);
268
269 // Create training configuration
270 let config = create_training_config();
271 println!("\nโ๏ธ Training Configuration:");
272 println!(" - Batch size: {}", config.batch_size);
273 println!(" - Learning rate: {}", config.learning_rate);
274 println!(" - Epochs: {}", config.epochs);
275 println!(
276 " - Validation split: {:.1}%",
277 config.validation.as_ref().unwrap().validation_split * 100.0
278 );
279
280 // Set up training components
281 let loss_fn = CrossEntropyLoss::new(1e-7);
282 let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284 // Create trainer
285 let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287 // Add callbacks
288 trainer.add_callback(Box::new(|| {
289 // Custom callback for additional logging
290 println!("๐ Epoch completed");
291 Ok(())
292 }));
293
294 // Train the model
295 println!("\n๐๏ธ Starting training...");
296 println!("{}", "-".repeat(40));
297
298 let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300 println!("\nโ
Training completed!");
301 println!(" - Epochs trained: {}", training_session.epochs_trained);
302 println!(
303 " - Final learning rate: {:.6}",
304 training_session.initial_learning_rate
305 );
306
307 // Evaluate on validation set
308 println!("\n๐ Final Evaluation:");
309 let val_metrics = trainer.validate(&val_dataset)?;
310
311 for (metric, value) in &val_metrics {
312 println!(" - {}: {:.4}", metric, value);
313 }
314
315 // Test predictions on a few samples
316 println!("\n๐ Sample Predictions:");
317 let sample_indices = vec![0, 1, 2, 3, 4];
318
319 // Manually collect batch since get_batch is not part of Dataset trait
320 let mut batch_images = Vec::new();
321 let mut batch_targets = Vec::new();
322
323 for &idx in &sample_indices {
324 let (img, target) = val_dataset.get(idx)?;
325 batch_images.push(img);
326 batch_targets.push(target);
327 }
328
329 // Concatenate into batch arrays
330 let sample_images = ndarray::concatenate(
331 Axis(0),
332 &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333 )?;
334 let sample_targets = ndarray::concatenate(
335 Axis(0),
336 &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337 )?;
338
339 let model = trainer.get_model();
340 let predictions = model.forward(&sample_images)?;
341
342 for i in 0..sample_indices.len() {
343 let pred_row = predictions.slice(s![i, ..]);
344 let target_row = sample_targets.slice(s![i, ..]);
345
346 let pred_class = pred_row
347 .iter()
348 .enumerate()
349 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350 .map(|(i, _)| i)
351 .unwrap_or(0);
352
353 let target_class = target_row
354 .iter()
355 .enumerate()
356 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357 .map(|(i, _)| i)
358 .unwrap_or(0);
359
360 let confidence = pred_row[pred_class];
361
362 println!(
363 " Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364 i + 1,
365 pred_class,
366 target_class,
367 confidence
368 );
369 }
370
371 // Calculate overall accuracy
372 let overall_predictions = trainer.get_model().forward(&sample_images)?;
373 let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374 println!("\n๐ฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376 // Model summary
377 println!("\n๐ Training Summary:");
378 let session = trainer.get_session();
379 if let Some(loss_history) = session.get_metric("loss") {
380 if !loss_history.is_empty() {
381 println!(" - Initial loss: {:.4}", loss_history[0]);
382 println!(
383 " - Final loss: {:.4}",
384 loss_history[loss_history.len() - 1]
385 );
386 }
387 }
388
389 if let Some(val_loss_history) = session.get_metric("val_loss") {
390 if !val_loss_history.is_empty() {
391 println!(
392 " - Final validation loss: {:.4}",
393 val_loss_history[val_loss_history.len() - 1]
394 );
395 }
396 }
397
398 println!("\n๐ Image classification example completed successfully!");
399
400 Ok(())
401}
Auto Trait Implementationsยง
impl<F> Freeze for Trainer<F>where
F: Freeze,
impl<F> !RefUnwindSafe for Trainer<F>
impl<F> Send for Trainer<F>
impl<F> Sync for Trainer<F>
impl<F> Unpin for Trainer<F>where
F: Unpin,
impl<F> !UnwindSafe for Trainer<F>
Blanket Implementationsยง
Sourceยงimpl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Sourceยงfn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Sourceยงimpl<T> IntoEither for T
impl<T> IntoEither for T
Sourceยงfn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSourceยงfn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more