Struct Trainer

Source
pub struct Trainer<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync + Display> { /* private fields */ }
Expand description

Trainer for a neural network model

Implementationsยง

Sourceยง

impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + Send + Sync> Trainer<F>

Source

pub fn new<L, O, LF>( model: L, optimizer: O, loss_fn: LF, config: TrainingConfig, ) -> Self
where L: ParamLayer<F> + Send + Sync + 'static, O: Optimizer<F> + Send + Sync + 'static, LF: Loss<F> + Send + Sync + 'static,

Create a new trainer

Examples found in repository?
examples/text_classification_complete.rs (line 540)
458fn train_text_classifier() -> StdResult<()> {
459    println!("๐Ÿ“ Starting Text Classification Training Example");
460    println!("{}", "=".repeat(60));
461
462    let mut rng = SmallRng::seed_from_u64(42);
463
464    // Dataset parameters
465    let num_samples = 800;
466    let num_classes = 3;
467    let max_length = 20;
468    let embedding_dim = 64;
469    let hidden_dim = 128;
470
471    println!("๐Ÿ“Š Dataset Configuration:");
472    println!("   - Samples: {}", num_samples);
473    println!(
474        "   - Classes: {} (Positive, Negative, Neutral)",
475        num_classes
476    );
477    println!("   - Max sequence length: {}", max_length);
478    println!("   - Embedding dimension: {}", embedding_dim);
479
480    // Create synthetic text dataset
481    println!("\n๐Ÿ”„ Creating synthetic text dataset...");
482    let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485    println!("   - Vocabulary size: {}", dataset.vocab.vocab_size);
486    println!("   - Training samples: {}", train_dataset.len());
487    println!("   - Validation samples: {}", val_dataset.len());
488
489    // Show some example texts
490    println!("\n๐Ÿ“„ Sample texts:");
491    for i in 0..3.min(train_dataset.texts.len()) {
492        println!(
493            "   [Class {}]: {}",
494            train_dataset.labels[i], train_dataset.texts[i]
495        );
496    }
497
498    // Build model
499    println!("\n๐Ÿ—๏ธ  Building text classification model...");
500    let model = build_text_model(
501        dataset.vocab.vocab_size,
502        embedding_dim,
503        hidden_dim,
504        num_classes,
505        max_length,
506        &mut rng,
507    )?;
508
509    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510    println!("   - Model layers: {}", model.len());
511    println!("   - Total parameters: {}", total_params);
512
513    // Training configuration
514    let config = TrainingConfig {
515        batch_size: 16,
516        epochs: 30,
517        learning_rate: 0.001,
518        shuffle: true,
519        verbose: 1,
520        validation: Some(ValidationSettings {
521            enabled: true,
522            validation_split: 0.2,
523            batch_size: 32,
524            num_workers: 0,
525        }),
526        gradient_accumulation: None,
527        mixed_precision: None,
528        num_workers: 0,
529    };
530
531    println!("\nโš™๏ธ  Training Configuration:");
532    println!("   - Batch size: {}", config.batch_size);
533    println!("   - Learning rate: {}", config.learning_rate);
534    println!("   - Epochs: {}", config.epochs);
535
536    // Set up training
537    let loss_fn = CrossEntropyLoss::new(1e-7);
538    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542    // Train the model
543    println!("\n๐Ÿ‹๏ธ  Starting training...");
544    println!("{}", "-".repeat(40));
545
546    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548    println!("\nโœ… Training completed!");
549    println!("   - Epochs trained: {}", training_session.epochs_trained);
550
551    // Evaluate model
552    println!("\n๐Ÿ“Š Final Evaluation:");
553    let val_metrics = trainer.validate(&val_dataset)?;
554
555    for (metric, value) in &val_metrics {
556        println!("   - {}: {:.4}", metric, value);
557    }
558
559    // Test on sample texts
560    println!("\n๐Ÿ” Sample Predictions:");
561    let sample_indices = vec![0, 1, 2, 3, 4];
562
563    // Manually collect batch since get_batch is not part of Dataset trait
564    let mut batch_tokens = Vec::new();
565    let mut batch_targets = Vec::new();
566
567    for &idx in &sample_indices {
568        let (tokens, targets) = val_dataset.get(idx)?;
569        batch_tokens.push(tokens);
570        batch_targets.push(targets);
571    }
572
573    // Concatenate into batch arrays
574    let sample_tokens = ndarray::concatenate(
575        ndarray::Axis(0),
576        &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577    )?;
578    let sample_targets = ndarray::concatenate(
579        ndarray::Axis(0),
580        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581    )?;
582
583    let model = trainer.get_model();
584    let predictions = model.forward(&sample_tokens)?;
585
586    let class_names = ["Positive", "Negative", "Neutral"];
587
588    for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589        let pred_row = predictions.slice(s![i, ..]);
590        let target_row = sample_targets.slice(s![i, ..]);
591
592        let pred_class = pred_row
593            .iter()
594            .enumerate()
595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596            .map(|(i, _)| i)
597            .unwrap_or(0);
598
599        let true_class = target_row
600            .iter()
601            .enumerate()
602            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603            .map(|(i, _)| i)
604            .unwrap_or(0);
605
606        let confidence = pred_row[pred_class];
607
608        if sample_indices[i] < val_dataset.texts.len() {
609            println!("   Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610            println!(
611                "   Predicted: {} (confidence: {:.3})",
612                class_names[pred_class], confidence
613            );
614            println!("   Actual: {}", class_names[true_class]);
615            println!();
616        }
617    }
618
619    // Calculate detailed metrics
620    let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621    println!("๐Ÿ“ˆ Detailed Metrics:");
622    for (metric, value) in &detailed_metrics {
623        println!("   - {}: {:.4}", metric, value);
624    }
625
626    Ok(())
627}
More examples
Hide additional examples
examples/image_classification_complete.rs (line 285)
233fn train_image_classifier() -> Result<()> {
234    println!("๐Ÿš€ Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("๐Ÿ“Š Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n๐Ÿ”„ Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n๐Ÿ—๏ธ  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\nโš™๏ธ  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("๐Ÿ”„ Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n๐Ÿ‹๏ธ  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\nโœ… Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n๐Ÿ“Š Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n๐Ÿ” Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n๐ŸŽฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n๐Ÿ“‹ Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n๐ŸŽ‰ Image classification example completed successfully!");
399
400    Ok(())
401}
examples/advanced_training_example.rs (line 169)
124fn main() -> Result<()> {
125    println!("Advanced Training Examples");
126    println!("-------------------------");
127
128    // 1. Gradient Accumulation
129    println!("\n1. Training with Gradient Accumulation:");
130
131    // Generate synthetic dataset
132    let _dataset = generate_regression_dataset::<f32>(1000, 10, 2)?;
133    let _val_dataset = generate_regression_dataset::<f32>(200, 10, 2)?;
134
135    // Create model, optimizer, and loss function
136    let model = create_regression_model::<f32>(10, 64, 2)?;
137    let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
138    let loss_fn = MSELoss::new();
139
140    // Create gradient accumulation config
141    let ga_config = GradientAccumulationConfig {
142        accumulation_steps: 4,
143        average_gradients: true,
144        zero_gradients_after_update: true,
145        clip_gradients: true,
146        max_gradient_norm: Some(1.0),
147        log_gradient_stats: true,
148    };
149
150    // Create training config
151    let training_config = TrainingConfig {
152        batch_size: 32,
153        shuffle: true,
154        num_workers: 0,
155        learning_rate: 0.001,
156        epochs: 5,
157        verbose: 1,
158        validation: Some(ValidationSettings {
159            enabled: true,
160            validation_split: 0.0, // Use separate validation dataset
161            batch_size: 32,
162            num_workers: 0,
163        }),
164        gradient_accumulation: Some(ga_config),
165        mixed_precision: None,
166    };
167
168    // Create trainer
169    let _trainer = Trainer::new(model, optimizer, loss_fn, training_config);
170
171    // Note: To properly use callbacks, we would need to implement the appropriate trait interfaces
172    // Here we're simplifying for the example
173
174    // We'll use a simple closure to describe the early stopping callback
175    println!("Note: We would add callbacks like EarlyStopping and ModelCheckpoint here");
176    println!("For example: EarlyStopping with patience=5, min_delta=0.001");
177
178    // Create learning rate scheduler - we'll just demonstrate its usage
179    let _lr_scheduler = CosineAnnealingScheduler::new(0.001_f32, 0.0001_f32);
180    println!("Using CosineAnnealingScheduler with initial_lr=0.001, min_lr=0.0001");
181
182    // Train model
183    println!("\nTraining model with gradient accumulation...");
184
185    // For demonstration purposes, show what would happen with real training
186    println!("Would execute: trainer.train(&dataset, Some(&val_dataset))?");
187
188    // Since we're not actually training, just show example output
189    println!("\nExample of training output that would be shown:");
190    println!("Training completed in 3 epochs");
191    println!("Final loss: 0.0124");
192    println!("Final validation loss: 0.0156");
193
194    // 2. Manual Gradient Accumulation
195    println!("\n2. Manual Gradient Accumulation:");
196
197    // Create model, optimizer, and loss function
198    let model = create_regression_model::<f32>(10, 64, 2)?;
199    let _optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
200    let _loss_fn = MSELoss::new();
201
202    // Create gradient accumulator
203    let mut accumulator = GradientAccumulator::new(GradientAccumulationConfig {
204        accumulation_steps: 4,
205        average_gradients: true,
206        zero_gradients_after_update: true,
207        clip_gradients: false,
208        max_gradient_norm: None,
209        log_gradient_stats: false,
210    });
211
212    // Initialize accumulator
213    accumulator.initialize(&model)?;
214
215    // We would use a DataLoader in real code, but here we'll simulate it
216    println!("Creating data loader with batch_size=32, shuffle=true");
217
218    println!("\nTraining for 1 epoch with manual gradient accumulation...");
219
220    let mut total_loss = 0.0_f32;
221    let mut processed_batches = 0;
222
223    // Train for one epoch
224    // This is a simplified example - in practice you would iterate through DataLoader batches
225    // Simulated loop to demonstrate the concept:
226    let total_batches = 5;
227    for batch_idx in 0..total_batches {
228        // In a real implementation we would get inputs and targets from data_loader
229        println!("Batch {} - Accumulating gradients...", batch_idx + 1);
230
231        // Simulate a loss value
232        let loss = 0.1 * (batch_idx as f32 + 1.0).powf(-0.5);
233        total_loss += loss;
234        processed_batches += 1;
235
236        // Simulate gradient stats
237        println!(
238            "Batch {} - Gradient stats: min={:.4}, max={:.4}, mean={:.4}, norm={:.4}",
239            batch_idx + 1,
240            -0.05 * (batch_idx as f32 + 1.0).powf(-0.5),
241            0.05 * (batch_idx as f32 + 1.0).powf(-0.5),
242            0.01 * (batch_idx as f32 + 1.0).powf(-0.5),
243            0.2 * (batch_idx as f32 + 1.0).powf(-0.5)
244        );
245
246        // Update if needed - this is conceptual
247        if (batch_idx + 1) % 4 == 0 || batch_idx == total_batches - 1 {
248            println!(
249                "Applying accumulated gradients after {} batches",
250                (batch_idx + 1) % 4
251            );
252            // In a real implementation we would apply gradients:
253            // accumulator.apply_gradients(&mut model, &mut optimizer)?;
254        }
255
256        // Early stopping for example
257        if batch_idx >= 10 {
258            break;
259        }
260    }
261
262    if processed_batches > 0 {
263        println!("Average loss: {:.4}", total_loss / processed_batches as f32);
264    }
265
266    // 3. Mixed Precision (not fully implemented, pseudocode)
267    println!("\n3. Mixed Precision Training (Pseudocode):");
268
269    println!(
270        "// Create mixed precision config
271let mp_config = MixedPrecisionConfig {{
272    dynamic_loss_scaling: true,
273    initial_loss_scale: 65536.0,
274    scale_factor: 2.0,
275    scale_window: 2000,
276    min_loss_scale: 1.0,
277    max_loss_scale: 2_f64.powi(24),
278    verbose: true,
279}};
280
281// Create high precision and low precision models
282let high_precision_model = create_regression_model::<f32>(10, 64, 2)?;
283let low_precision_model = create_regression_model::<f16>(10, 64, 2)?;
284
285// Create mixed precision model
286let mut mixed_model = MixedPrecisionModel::new(
287    high_precision_model,
288    low_precision_model,
289    mp_config,
290)?;
291
292// Create optimizer and loss function
293let mut optimizer = Adam::new(0.001);
294let loss_fn = MSELoss::new();
295
296// Train for one epoch
297mixed_model.train_epoch(
298    &mut optimizer,
299    &dataset,
300    &loss_fn,
301    32,
302    true,
303)?;"
304    );
305
306    // 4. Gradient Clipping
307    println!("\n4. Gradient Clipping:");
308
309    // Create model, optimizer, and loss function
310    let model = create_regression_model::<f32>(10, 64, 2)?;
311    let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
312    let loss_fn = MSELoss::new();
313
314    // Create training config - we need two separate instances
315    let gradient_clipping_config = TrainingConfig {
316        batch_size: 32,
317        shuffle: true,
318        num_workers: 0,
319        learning_rate: 0.001,
320        epochs: 5,
321        verbose: 1,
322        validation: Some(ValidationSettings {
323            enabled: true,
324            validation_split: 0.0, // Use separate validation dataset
325            batch_size: 32,
326            num_workers: 0,
327        }),
328        gradient_accumulation: None,
329        mixed_precision: None,
330    };
331
332    // Create a separate configuration for the value clipping example
333    let value_clipping_config = TrainingConfig {
334        batch_size: 32,
335        shuffle: true,
336        num_workers: 0,
337        learning_rate: 0.001,
338        epochs: 5,
339        verbose: 1,
340        validation: Some(ValidationSettings {
341            enabled: true,
342            validation_split: 0.0, // Use separate validation dataset
343            batch_size: 32,
344            num_workers: 0,
345        }),
346        gradient_accumulation: None,
347        mixed_precision: None,
348    };
349
350    // Create trainer
351    let _trainer = Trainer::new(model, optimizer, loss_fn, gradient_clipping_config);
352
353    // Instead of adding callbacks directly, we'll just demonstrate the concept
354    println!("If callbacks were fully implemented, we would add gradient clipping:");
355    println!("GradientClipping::by_global_norm(1.0_f32, true) // Max norm, log_stats");
356
357    println!("\nTraining model with gradient clipping by global norm...");
358
359    // Train model for a few epochs
360    let _dataset_small = generate_regression_dataset::<f32>(500, 10, 2)?;
361    let _val_dataset_small = generate_regression_dataset::<f32>(100, 10, 2)?;
362    println!("Would train the model with dataset_small and val_dataset_small");
363    // In a real implementation:
364    // let session = trainer.train(&dataset_small, Some(&val_dataset_small))?;
365
366    // Since we're not actually training, just show example output
367    println!("\nExample of training output that would be shown:");
368    println!("Training completed in 3 epochs");
369    println!("Final loss: 0.0124");
370    println!("Final validation loss: 0.0156");
371
372    // Example with value clipping
373    println!("\nExample with gradient clipping by value:");
374
375    // Create model and trainer with value clipping
376    let model = create_regression_model::<f32>(10, 64, 2)?;
377    let optimizer = Adam::new(0.001_f32, 0.9_f32, 0.999_f32, 1e-8_f32);
378    let _trainer = Trainer::new(model, optimizer, loss_fn, value_clipping_config);
379
380    // Instead of actual callbacks, show how we would use them
381    println!("For gradient clipping by value, we would use:");
382    println!("GradientClipping::by_value(0.5_f32, true) // Max value, log_stats");
383
384    println!("\nDemonstration of how to set up gradient clipping by value:");
385    println!("trainer.add_callback(Box::new(GradientClipping::by_value(");
386    println!("    0.5_f32, // Max value");
387    println!("    true,    // Log stats");
388    println!(")));");
389
390    // Demonstrate the training utilities
391    println!("\nAdvanced Training Examples Completed Successfully!");
392
393    Ok(())
394}
Source

pub fn add_callback( &mut self, callback: Box<dyn Fn() -> Result<()> + Send + Sync>, )

Add a callback to the trainer

Examples found in repository?
examples/image_classification_complete.rs (lines 288-292)
233fn train_image_classifier() -> Result<()> {
234    println!("๐Ÿš€ Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("๐Ÿ“Š Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n๐Ÿ”„ Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n๐Ÿ—๏ธ  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\nโš™๏ธ  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("๐Ÿ”„ Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n๐Ÿ‹๏ธ  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\nโœ… Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n๐Ÿ“Š Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n๐Ÿ” Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n๐ŸŽฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n๐Ÿ“‹ Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n๐ŸŽ‰ Image classification example completed successfully!");
399
400    Ok(())
401}
Source

pub fn train<D: Dataset<F> + Clone>( &mut self, dataset: &D, validation_dataset: Option<&D>, ) -> Result<TrainingSession<F>>

Train the model

Examples found in repository?
examples/text_classification_complete.rs (line 546)
458fn train_text_classifier() -> StdResult<()> {
459    println!("๐Ÿ“ Starting Text Classification Training Example");
460    println!("{}", "=".repeat(60));
461
462    let mut rng = SmallRng::seed_from_u64(42);
463
464    // Dataset parameters
465    let num_samples = 800;
466    let num_classes = 3;
467    let max_length = 20;
468    let embedding_dim = 64;
469    let hidden_dim = 128;
470
471    println!("๐Ÿ“Š Dataset Configuration:");
472    println!("   - Samples: {}", num_samples);
473    println!(
474        "   - Classes: {} (Positive, Negative, Neutral)",
475        num_classes
476    );
477    println!("   - Max sequence length: {}", max_length);
478    println!("   - Embedding dimension: {}", embedding_dim);
479
480    // Create synthetic text dataset
481    println!("\n๐Ÿ”„ Creating synthetic text dataset...");
482    let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485    println!("   - Vocabulary size: {}", dataset.vocab.vocab_size);
486    println!("   - Training samples: {}", train_dataset.len());
487    println!("   - Validation samples: {}", val_dataset.len());
488
489    // Show some example texts
490    println!("\n๐Ÿ“„ Sample texts:");
491    for i in 0..3.min(train_dataset.texts.len()) {
492        println!(
493            "   [Class {}]: {}",
494            train_dataset.labels[i], train_dataset.texts[i]
495        );
496    }
497
498    // Build model
499    println!("\n๐Ÿ—๏ธ  Building text classification model...");
500    let model = build_text_model(
501        dataset.vocab.vocab_size,
502        embedding_dim,
503        hidden_dim,
504        num_classes,
505        max_length,
506        &mut rng,
507    )?;
508
509    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510    println!("   - Model layers: {}", model.len());
511    println!("   - Total parameters: {}", total_params);
512
513    // Training configuration
514    let config = TrainingConfig {
515        batch_size: 16,
516        epochs: 30,
517        learning_rate: 0.001,
518        shuffle: true,
519        verbose: 1,
520        validation: Some(ValidationSettings {
521            enabled: true,
522            validation_split: 0.2,
523            batch_size: 32,
524            num_workers: 0,
525        }),
526        gradient_accumulation: None,
527        mixed_precision: None,
528        num_workers: 0,
529    };
530
531    println!("\nโš™๏ธ  Training Configuration:");
532    println!("   - Batch size: {}", config.batch_size);
533    println!("   - Learning rate: {}", config.learning_rate);
534    println!("   - Epochs: {}", config.epochs);
535
536    // Set up training
537    let loss_fn = CrossEntropyLoss::new(1e-7);
538    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542    // Train the model
543    println!("\n๐Ÿ‹๏ธ  Starting training...");
544    println!("{}", "-".repeat(40));
545
546    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548    println!("\nโœ… Training completed!");
549    println!("   - Epochs trained: {}", training_session.epochs_trained);
550
551    // Evaluate model
552    println!("\n๐Ÿ“Š Final Evaluation:");
553    let val_metrics = trainer.validate(&val_dataset)?;
554
555    for (metric, value) in &val_metrics {
556        println!("   - {}: {:.4}", metric, value);
557    }
558
559    // Test on sample texts
560    println!("\n๐Ÿ” Sample Predictions:");
561    let sample_indices = vec![0, 1, 2, 3, 4];
562
563    // Manually collect batch since get_batch is not part of Dataset trait
564    let mut batch_tokens = Vec::new();
565    let mut batch_targets = Vec::new();
566
567    for &idx in &sample_indices {
568        let (tokens, targets) = val_dataset.get(idx)?;
569        batch_tokens.push(tokens);
570        batch_targets.push(targets);
571    }
572
573    // Concatenate into batch arrays
574    let sample_tokens = ndarray::concatenate(
575        ndarray::Axis(0),
576        &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577    )?;
578    let sample_targets = ndarray::concatenate(
579        ndarray::Axis(0),
580        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581    )?;
582
583    let model = trainer.get_model();
584    let predictions = model.forward(&sample_tokens)?;
585
586    let class_names = ["Positive", "Negative", "Neutral"];
587
588    for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589        let pred_row = predictions.slice(s![i, ..]);
590        let target_row = sample_targets.slice(s![i, ..]);
591
592        let pred_class = pred_row
593            .iter()
594            .enumerate()
595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596            .map(|(i, _)| i)
597            .unwrap_or(0);
598
599        let true_class = target_row
600            .iter()
601            .enumerate()
602            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603            .map(|(i, _)| i)
604            .unwrap_or(0);
605
606        let confidence = pred_row[pred_class];
607
608        if sample_indices[i] < val_dataset.texts.len() {
609            println!("   Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610            println!(
611                "   Predicted: {} (confidence: {:.3})",
612                class_names[pred_class], confidence
613            );
614            println!("   Actual: {}", class_names[true_class]);
615            println!();
616        }
617    }
618
619    // Calculate detailed metrics
620    let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621    println!("๐Ÿ“ˆ Detailed Metrics:");
622    for (metric, value) in &detailed_metrics {
623        println!("   - {}: {:.4}", metric, value);
624    }
625
626    Ok(())
627}
More examples
Hide additional examples
examples/image_classification_complete.rs (line 298)
233fn train_image_classifier() -> Result<()> {
234    println!("๐Ÿš€ Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("๐Ÿ“Š Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n๐Ÿ”„ Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n๐Ÿ—๏ธ  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\nโš™๏ธ  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("๐Ÿ”„ Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n๐Ÿ‹๏ธ  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\nโœ… Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n๐Ÿ“Š Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n๐Ÿ” Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n๐ŸŽฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n๐Ÿ“‹ Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n๐ŸŽ‰ Image classification example completed successfully!");
399
400    Ok(())
401}
Source

pub fn validate<D: Dataset<F>>( &mut self, dataset: &D, ) -> Result<HashMap<String, F>>

Validate the model on a dataset

Examples found in repository?
examples/text_classification_complete.rs (line 553)
458fn train_text_classifier() -> StdResult<()> {
459    println!("๐Ÿ“ Starting Text Classification Training Example");
460    println!("{}", "=".repeat(60));
461
462    let mut rng = SmallRng::seed_from_u64(42);
463
464    // Dataset parameters
465    let num_samples = 800;
466    let num_classes = 3;
467    let max_length = 20;
468    let embedding_dim = 64;
469    let hidden_dim = 128;
470
471    println!("๐Ÿ“Š Dataset Configuration:");
472    println!("   - Samples: {}", num_samples);
473    println!(
474        "   - Classes: {} (Positive, Negative, Neutral)",
475        num_classes
476    );
477    println!("   - Max sequence length: {}", max_length);
478    println!("   - Embedding dimension: {}", embedding_dim);
479
480    // Create synthetic text dataset
481    println!("\n๐Ÿ”„ Creating synthetic text dataset...");
482    let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485    println!("   - Vocabulary size: {}", dataset.vocab.vocab_size);
486    println!("   - Training samples: {}", train_dataset.len());
487    println!("   - Validation samples: {}", val_dataset.len());
488
489    // Show some example texts
490    println!("\n๐Ÿ“„ Sample texts:");
491    for i in 0..3.min(train_dataset.texts.len()) {
492        println!(
493            "   [Class {}]: {}",
494            train_dataset.labels[i], train_dataset.texts[i]
495        );
496    }
497
498    // Build model
499    println!("\n๐Ÿ—๏ธ  Building text classification model...");
500    let model = build_text_model(
501        dataset.vocab.vocab_size,
502        embedding_dim,
503        hidden_dim,
504        num_classes,
505        max_length,
506        &mut rng,
507    )?;
508
509    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510    println!("   - Model layers: {}", model.len());
511    println!("   - Total parameters: {}", total_params);
512
513    // Training configuration
514    let config = TrainingConfig {
515        batch_size: 16,
516        epochs: 30,
517        learning_rate: 0.001,
518        shuffle: true,
519        verbose: 1,
520        validation: Some(ValidationSettings {
521            enabled: true,
522            validation_split: 0.2,
523            batch_size: 32,
524            num_workers: 0,
525        }),
526        gradient_accumulation: None,
527        mixed_precision: None,
528        num_workers: 0,
529    };
530
531    println!("\nโš™๏ธ  Training Configuration:");
532    println!("   - Batch size: {}", config.batch_size);
533    println!("   - Learning rate: {}", config.learning_rate);
534    println!("   - Epochs: {}", config.epochs);
535
536    // Set up training
537    let loss_fn = CrossEntropyLoss::new(1e-7);
538    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542    // Train the model
543    println!("\n๐Ÿ‹๏ธ  Starting training...");
544    println!("{}", "-".repeat(40));
545
546    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548    println!("\nโœ… Training completed!");
549    println!("   - Epochs trained: {}", training_session.epochs_trained);
550
551    // Evaluate model
552    println!("\n๐Ÿ“Š Final Evaluation:");
553    let val_metrics = trainer.validate(&val_dataset)?;
554
555    for (metric, value) in &val_metrics {
556        println!("   - {}: {:.4}", metric, value);
557    }
558
559    // Test on sample texts
560    println!("\n๐Ÿ” Sample Predictions:");
561    let sample_indices = vec![0, 1, 2, 3, 4];
562
563    // Manually collect batch since get_batch is not part of Dataset trait
564    let mut batch_tokens = Vec::new();
565    let mut batch_targets = Vec::new();
566
567    for &idx in &sample_indices {
568        let (tokens, targets) = val_dataset.get(idx)?;
569        batch_tokens.push(tokens);
570        batch_targets.push(targets);
571    }
572
573    // Concatenate into batch arrays
574    let sample_tokens = ndarray::concatenate(
575        ndarray::Axis(0),
576        &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577    )?;
578    let sample_targets = ndarray::concatenate(
579        ndarray::Axis(0),
580        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581    )?;
582
583    let model = trainer.get_model();
584    let predictions = model.forward(&sample_tokens)?;
585
586    let class_names = ["Positive", "Negative", "Neutral"];
587
588    for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589        let pred_row = predictions.slice(s![i, ..]);
590        let target_row = sample_targets.slice(s![i, ..]);
591
592        let pred_class = pred_row
593            .iter()
594            .enumerate()
595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596            .map(|(i, _)| i)
597            .unwrap_or(0);
598
599        let true_class = target_row
600            .iter()
601            .enumerate()
602            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603            .map(|(i, _)| i)
604            .unwrap_or(0);
605
606        let confidence = pred_row[pred_class];
607
608        if sample_indices[i] < val_dataset.texts.len() {
609            println!("   Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610            println!(
611                "   Predicted: {} (confidence: {:.3})",
612                class_names[pred_class], confidence
613            );
614            println!("   Actual: {}", class_names[true_class]);
615            println!();
616        }
617    }
618
619    // Calculate detailed metrics
620    let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621    println!("๐Ÿ“ˆ Detailed Metrics:");
622    for (metric, value) in &detailed_metrics {
623        println!("   - {}: {:.4}", metric, value);
624    }
625
626    Ok(())
627}
More examples
Hide additional examples
examples/image_classification_complete.rs (line 309)
233fn train_image_classifier() -> Result<()> {
234    println!("๐Ÿš€ Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("๐Ÿ“Š Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n๐Ÿ”„ Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n๐Ÿ—๏ธ  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\nโš™๏ธ  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("๐Ÿ”„ Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n๐Ÿ‹๏ธ  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\nโœ… Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n๐Ÿ“Š Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n๐Ÿ” Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n๐ŸŽฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n๐Ÿ“‹ Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n๐ŸŽ‰ Image classification example completed successfully!");
399
400    Ok(())
401}
Source

pub fn get_model(&self) -> &dyn Layer<F>

Get the model

Examples found in repository?
examples/text_classification_complete.rs (line 583)
458fn train_text_classifier() -> StdResult<()> {
459    println!("๐Ÿ“ Starting Text Classification Training Example");
460    println!("{}", "=".repeat(60));
461
462    let mut rng = SmallRng::seed_from_u64(42);
463
464    // Dataset parameters
465    let num_samples = 800;
466    let num_classes = 3;
467    let max_length = 20;
468    let embedding_dim = 64;
469    let hidden_dim = 128;
470
471    println!("๐Ÿ“Š Dataset Configuration:");
472    println!("   - Samples: {}", num_samples);
473    println!(
474        "   - Classes: {} (Positive, Negative, Neutral)",
475        num_classes
476    );
477    println!("   - Max sequence length: {}", max_length);
478    println!("   - Embedding dimension: {}", embedding_dim);
479
480    // Create synthetic text dataset
481    println!("\n๐Ÿ”„ Creating synthetic text dataset...");
482    let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485    println!("   - Vocabulary size: {}", dataset.vocab.vocab_size);
486    println!("   - Training samples: {}", train_dataset.len());
487    println!("   - Validation samples: {}", val_dataset.len());
488
489    // Show some example texts
490    println!("\n๐Ÿ“„ Sample texts:");
491    for i in 0..3.min(train_dataset.texts.len()) {
492        println!(
493            "   [Class {}]: {}",
494            train_dataset.labels[i], train_dataset.texts[i]
495        );
496    }
497
498    // Build model
499    println!("\n๐Ÿ—๏ธ  Building text classification model...");
500    let model = build_text_model(
501        dataset.vocab.vocab_size,
502        embedding_dim,
503        hidden_dim,
504        num_classes,
505        max_length,
506        &mut rng,
507    )?;
508
509    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510    println!("   - Model layers: {}", model.len());
511    println!("   - Total parameters: {}", total_params);
512
513    // Training configuration
514    let config = TrainingConfig {
515        batch_size: 16,
516        epochs: 30,
517        learning_rate: 0.001,
518        shuffle: true,
519        verbose: 1,
520        validation: Some(ValidationSettings {
521            enabled: true,
522            validation_split: 0.2,
523            batch_size: 32,
524            num_workers: 0,
525        }),
526        gradient_accumulation: None,
527        mixed_precision: None,
528        num_workers: 0,
529    };
530
531    println!("\nโš™๏ธ  Training Configuration:");
532    println!("   - Batch size: {}", config.batch_size);
533    println!("   - Learning rate: {}", config.learning_rate);
534    println!("   - Epochs: {}", config.epochs);
535
536    // Set up training
537    let loss_fn = CrossEntropyLoss::new(1e-7);
538    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542    // Train the model
543    println!("\n๐Ÿ‹๏ธ  Starting training...");
544    println!("{}", "-".repeat(40));
545
546    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548    println!("\nโœ… Training completed!");
549    println!("   - Epochs trained: {}", training_session.epochs_trained);
550
551    // Evaluate model
552    println!("\n๐Ÿ“Š Final Evaluation:");
553    let val_metrics = trainer.validate(&val_dataset)?;
554
555    for (metric, value) in &val_metrics {
556        println!("   - {}: {:.4}", metric, value);
557    }
558
559    // Test on sample texts
560    println!("\n๐Ÿ” Sample Predictions:");
561    let sample_indices = vec![0, 1, 2, 3, 4];
562
563    // Manually collect batch since get_batch is not part of Dataset trait
564    let mut batch_tokens = Vec::new();
565    let mut batch_targets = Vec::new();
566
567    for &idx in &sample_indices {
568        let (tokens, targets) = val_dataset.get(idx)?;
569        batch_tokens.push(tokens);
570        batch_targets.push(targets);
571    }
572
573    // Concatenate into batch arrays
574    let sample_tokens = ndarray::concatenate(
575        ndarray::Axis(0),
576        &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577    )?;
578    let sample_targets = ndarray::concatenate(
579        ndarray::Axis(0),
580        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581    )?;
582
583    let model = trainer.get_model();
584    let predictions = model.forward(&sample_tokens)?;
585
586    let class_names = ["Positive", "Negative", "Neutral"];
587
588    for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589        let pred_row = predictions.slice(s![i, ..]);
590        let target_row = sample_targets.slice(s![i, ..]);
591
592        let pred_class = pred_row
593            .iter()
594            .enumerate()
595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596            .map(|(i, _)| i)
597            .unwrap_or(0);
598
599        let true_class = target_row
600            .iter()
601            .enumerate()
602            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603            .map(|(i, _)| i)
604            .unwrap_or(0);
605
606        let confidence = pred_row[pred_class];
607
608        if sample_indices[i] < val_dataset.texts.len() {
609            println!("   Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610            println!(
611                "   Predicted: {} (confidence: {:.3})",
612                class_names[pred_class], confidence
613            );
614            println!("   Actual: {}", class_names[true_class]);
615            println!();
616        }
617    }
618
619    // Calculate detailed metrics
620    let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621    println!("๐Ÿ“ˆ Detailed Metrics:");
622    for (metric, value) in &detailed_metrics {
623        println!("   - {}: {:.4}", metric, value);
624    }
625
626    Ok(())
627}
More examples
Hide additional examples
examples/image_classification_complete.rs (line 339)
233fn train_image_classifier() -> Result<()> {
234    println!("๐Ÿš€ Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("๐Ÿ“Š Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n๐Ÿ”„ Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n๐Ÿ—๏ธ  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\nโš™๏ธ  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("๐Ÿ”„ Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n๐Ÿ‹๏ธ  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\nโœ… Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n๐Ÿ“Š Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n๐Ÿ” Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n๐ŸŽฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n๐Ÿ“‹ Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n๐ŸŽ‰ Image classification example completed successfully!");
399
400    Ok(())
401}
Source

pub fn get_model_mut(&mut self) -> &mut dyn Layer<F>

Get the model (mutable)

Source

pub fn get_optimizer(&self) -> &dyn Optimizer<F>

Get the optimizer

Source

pub fn get_optimizer_mut(&mut self) -> &mut dyn Optimizer<F>

Get the optimizer (mutable)

Source

pub fn get_loss_fn(&self) -> &dyn Loss<F>

Get the loss function

Source

pub fn get_session(&self) -> &TrainingSession<F>

Get the current training session

Examples found in repository?
examples/image_classification_complete.rs (line 378)
233fn train_image_classifier() -> Result<()> {
234    println!("๐Ÿš€ Starting Image Classification Training Example");
235    println!("{}", "=".repeat(60));
236
237    // Set up reproducible random number generator
238    let mut rng = SmallRng::seed_from_u64(42);
239
240    // Dataset parameters
241    let num_samples = 1000;
242    let num_classes = 5;
243    let image_size = (32, 32);
244    let input_channels = 3;
245
246    println!("๐Ÿ“Š Dataset Configuration:");
247    println!("   - Samples: {}", num_samples);
248    println!("   - Classes: {}", num_classes);
249    println!("   - Image Size: {}x{}", image_size.0, image_size.1);
250    println!("   - Channels: {}", input_channels);
251
252    // Create synthetic dataset
253    println!("\n๐Ÿ”„ Creating synthetic dataset...");
254    let dataset = SyntheticImageDataset::new(num_samples, num_classes, image_size);
255    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
256
257    println!("   - Training samples: {}", train_dataset.len());
258    println!("   - Validation samples: {}", val_dataset.len());
259
260    // Build model
261    println!("\n๐Ÿ—๏ธ  Building CNN model...");
262    let model = build_cnn_model(input_channels, num_classes, &mut rng)?;
263
264    // Count parameters
265    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
266    println!("   - Model layers: {}", model.len());
267    println!("   - Total parameters: {}", total_params);
268
269    // Create training configuration
270    let config = create_training_config();
271    println!("\nโš™๏ธ  Training Configuration:");
272    println!("   - Batch size: {}", config.batch_size);
273    println!("   - Learning rate: {}", config.learning_rate);
274    println!("   - Epochs: {}", config.epochs);
275    println!(
276        "   - Validation split: {:.1}%",
277        config.validation.as_ref().unwrap().validation_split * 100.0
278    );
279
280    // Set up training components
281    let loss_fn = CrossEntropyLoss::new(1e-7);
282    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
283
284    // Create trainer
285    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
286
287    // Add callbacks
288    trainer.add_callback(Box::new(|| {
289        // Custom callback for additional logging
290        println!("๐Ÿ”„ Epoch completed");
291        Ok(())
292    }));
293
294    // Train the model
295    println!("\n๐Ÿ‹๏ธ  Starting training...");
296    println!("{}", "-".repeat(40));
297
298    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
299
300    println!("\nโœ… Training completed!");
301    println!("   - Epochs trained: {}", training_session.epochs_trained);
302    println!(
303        "   - Final learning rate: {:.6}",
304        training_session.initial_learning_rate
305    );
306
307    // Evaluate on validation set
308    println!("\n๐Ÿ“Š Final Evaluation:");
309    let val_metrics = trainer.validate(&val_dataset)?;
310
311    for (metric, value) in &val_metrics {
312        println!("   - {}: {:.4}", metric, value);
313    }
314
315    // Test predictions on a few samples
316    println!("\n๐Ÿ” Sample Predictions:");
317    let sample_indices = vec![0, 1, 2, 3, 4];
318
319    // Manually collect batch since get_batch is not part of Dataset trait
320    let mut batch_images = Vec::new();
321    let mut batch_targets = Vec::new();
322
323    for &idx in &sample_indices {
324        let (img, target) = val_dataset.get(idx)?;
325        batch_images.push(img);
326        batch_targets.push(target);
327    }
328
329    // Concatenate into batch arrays
330    let sample_images = ndarray::concatenate(
331        Axis(0),
332        &batch_images.iter().map(|a| a.view()).collect::<Vec<_>>(),
333    )?;
334    let sample_targets = ndarray::concatenate(
335        Axis(0),
336        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
337    )?;
338
339    let model = trainer.get_model();
340    let predictions = model.forward(&sample_images)?;
341
342    for i in 0..sample_indices.len() {
343        let pred_row = predictions.slice(s![i, ..]);
344        let target_row = sample_targets.slice(s![i, ..]);
345
346        let pred_class = pred_row
347            .iter()
348            .enumerate()
349            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
350            .map(|(i, _)| i)
351            .unwrap_or(0);
352
353        let target_class = target_row
354            .iter()
355            .enumerate()
356            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
357            .map(|(i, _)| i)
358            .unwrap_or(0);
359
360        let confidence = pred_row[pred_class];
361
362        println!(
363            "   Sample {}: Predicted={}, Actual={}, Confidence={:.3}",
364            i + 1,
365            pred_class,
366            target_class,
367            confidence
368        );
369    }
370
371    // Calculate overall accuracy
372    let overall_predictions = trainer.get_model().forward(&sample_images)?;
373    let accuracy = calculate_accuracy(&overall_predictions, &sample_targets);
374    println!("\n๐ŸŽฏ Sample Accuracy: {:.2}%", accuracy * 100.0);
375
376    // Model summary
377    println!("\n๐Ÿ“‹ Training Summary:");
378    let session = trainer.get_session();
379    if let Some(loss_history) = session.get_metric("loss") {
380        if !loss_history.is_empty() {
381            println!("   - Initial loss: {:.4}", loss_history[0]);
382            println!(
383                "   - Final loss: {:.4}",
384                loss_history[loss_history.len() - 1]
385            );
386        }
387    }
388
389    if let Some(val_loss_history) = session.get_metric("val_loss") {
390        if !val_loss_history.is_empty() {
391            println!(
392                "   - Final validation loss: {:.4}",
393                val_loss_history[val_loss_history.len() - 1]
394            );
395        }
396    }
397
398    println!("\n๐ŸŽ‰ Image classification example completed successfully!");
399
400    Ok(())
401}

Auto Trait Implementationsยง

ยง

impl<F> Freeze for Trainer<F>
where F: Freeze,

ยง

impl<F> !RefUnwindSafe for Trainer<F>

ยง

impl<F> Send for Trainer<F>

ยง

impl<F> Sync for Trainer<F>

ยง

impl<F> Unpin for Trainer<F>
where F: Unpin,

ยง

impl<F> !UnwindSafe for Trainer<F>

Blanket Implementationsยง

Sourceยง

impl<T> Any for T
where T: 'static + ?Sized,

Sourceยง

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Sourceยง

impl<T> Borrow<T> for T
where T: ?Sized,

Sourceยง

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Sourceยง

impl<T> BorrowMut<T> for T
where T: ?Sized,

Sourceยง

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Sourceยง

impl<T> From<T> for T

Sourceยง

fn from(t: T) -> T

Returns the argument unchanged.

Sourceยง

impl<T, U> Into<U> for T
where U: From<T>,

Sourceยง

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Sourceยง

impl<T> IntoEither for T

Sourceยง

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Sourceยง

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Sourceยง

impl<T> Pointable for T

Sourceยง

const ALIGN: usize

The alignment of pointer.
Sourceยง

type Init = T

The type for initializers.
Sourceยง

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Sourceยง

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Sourceยง

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Sourceยง

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Sourceยง

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Sourceยง

type Error = Infallible

The type returned in the event of a conversion error.
Sourceยง

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Sourceยง

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Sourceยง

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Sourceยง

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Sourceยง

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Sourceยง

fn vzip(self) -> V