Struct CrossEntropyLoss

Source
pub struct CrossEntropyLoss { /* private fields */ }
Expand description

Cross-entropy loss function.

The cross-entropy loss is defined as: L = -sum(y_true * log(y_pred)) where y_true are target probabilities and y_pred are predicted probabilities.

It is commonly used for classification problems.

§Examples

use scirs2_neural::losses::CrossEntropyLoss;
use scirs2_neural::losses::Loss;
use ndarray::{Array, arr1, arr2};

let ce = CrossEntropyLoss::new(1e-10);

// One-hot encoded targets and softmax'd predictions for a 3-class problem
let predictions = arr2(&[
    [0.7, 0.2, 0.1],  // First sample, class probabilities
    [0.3, 0.6, 0.1]   // Second sample, class probabilities
]).into_dyn();

let targets = arr2(&[
    [1.0, 0.0, 0.0],  // First sample, true class is 0
    [0.0, 1.0, 0.0]   // Second sample, true class is 1
]).into_dyn();

// Forward pass to calculate loss
let loss = ce.forward(&predictions, &targets).unwrap();

// Backward pass to calculate gradients
let gradients = ce.backward(&predictions, &targets).unwrap();

Implementations§

Source§

impl CrossEntropyLoss

Source

pub fn new(epsilon: f64) -> Self

Create a new cross-entropy loss function

§Arguments
  • epsilon - Small value to add to predictions to avoid log(0)
Examples found in repository?
examples/object_detection_complete.rs (line 324)
322    pub fn new(classification_weight: f32, regression_weight: f32) -> Self {
323        Self {
324            classification_loss: CrossEntropyLoss::new(1e-7),
325            regression_loss: MeanSquaredError,
326            classification_weight,
327            regression_weight,
328        }
329    }
More examples
Hide additional examples
examples/improved_model_serialization.rs (line 75)
66fn train_model(
67    model: &mut Sequential<f32>,
68    x: &Array2<f32>,
69    y: &Array2<f32>,
70    epochs: usize,
71) -> Result<()> {
72    println!("Training XOR model...");
73
74    // Setup loss function and optimizer
75    let loss_fn = CrossEntropyLoss::new(1e-10);
76    let mut optimizer = Adam::new(0.01, 0.9, 0.999, 1e-8);
77
78    // Train for specified number of epochs
79    for epoch in 0..epochs {
80        // Convert to dynamic dimension arrays
81        let x_dyn = x.clone().into_dyn();
82        let y_dyn = y.clone().into_dyn();
83
84        // Perform a training step
85        let loss = model.train_batch(&x_dyn, &y_dyn, &loss_fn, &mut optimizer)?;
86
87        // Print progress every 100 epochs
88        if epoch % 100 == 0 || epoch == epochs - 1 {
89            println!("Epoch {}/{}: loss = {:.6}", epoch + 1, epochs, loss);
90        }
91    }
92
93    println!("Training completed.");
94    Ok(())
95}
examples/generative_models_complete.rs (line 786)
770fn train_gan_model() -> StdResult<()> {
771    println!("⚔️ Starting GAN Training");
772
773    let mut rng = SmallRng::seed_from_u64(42);
774    let config = GenerativeConfig::default();
775
776    // Create models
777    println!("🏗️ Building GAN models...");
778    let generator = GANGenerator::new(config.clone(), &mut rng)?;
779    let discriminator = GANDiscriminator::new(config.clone(), &mut rng)?;
780    println!("✅ GAN models created");
781
782    // Create dataset
783    let mut dataset = GenerativeDataset::new(config.clone(), 456);
784
785    // Create loss functions
786    let _adversarial_loss = CrossEntropyLoss::new(1e-7);
787
788    println!("📊 GAN training configuration:");
789    println!("   - Generator latent dim: {}", config.latent_dim);
790    println!("   - Discriminator architecture: {:?}", config.hidden_dims);
791
792    // Training loop (simplified)
793    let num_epochs = 15;
794    let batch_size = 4;
795
796    for epoch in 0..num_epochs {
797        println!("\n📈 Epoch {}/{}", epoch + 1, num_epochs);
798
799        let mut d_loss_total = 0.0;
800        let mut g_loss_total = 0.0;
801        let num_batches = 8;
802
803        for batch_idx in 0..num_batches {
804            // Train Discriminator
805            let real_images = dataset.generate_batch(batch_size);
806            let real_images_dyn = real_images.into_dyn();
807
808            // Generate fake images
809            let mut noise = Array2::<f32>::zeros((batch_size, config.latent_dim));
810            for elem in noise.iter_mut() {
811                *elem = rng.random_range(-1.0..1.0);
812            }
813            let noise_dyn = noise.into_dyn();
814            let fake_images = generator.forward(&noise_dyn)?;
815
816            // Discriminator predictions
817            let real_pred = discriminator.forward(&real_images_dyn)?;
818            let fake_pred = discriminator.forward(&fake_images)?;
819
820            // Simplified loss calculation (normally would use proper labels)
821            let mut d_loss_real = 0.0f32;
822            let mut d_loss_fake = 0.0f32;
823
824            for &pred in real_pred.iter() {
825                d_loss_real += -(1.0f32).ln() - pred; // Log loss for real=1
826            }
827
828            for &pred in fake_pred.iter() {
829                d_loss_fake += -(1.0 - pred).ln(); // Log loss for fake=0
830            }
831
832            let d_loss = (d_loss_real + d_loss_fake) / (batch_size * 2) as f32;
833            d_loss_total += d_loss;
834
835            // Train Generator (simplified)
836            let fake_pred_for_g = discriminator.forward(&fake_images)?;
837            let mut g_loss = 0.0f32;
838
839            for &pred in fake_pred_for_g.iter() {
840                g_loss += -(1.0f32).ln() - pred; // Want discriminator to output 1 for fake
841            }
842            g_loss /= batch_size as f32;
843            g_loss_total += g_loss;
844
845            if batch_idx % 4 == 0 {
846                print!(
847                    "🔄 Batch {}/{} - D Loss: {:.4}, G Loss: {:.4}        \r",
848                    batch_idx + 1,
849                    num_batches,
850                    d_loss,
851                    g_loss
852                );
853            }
854        }
855
856        let avg_d_loss = d_loss_total / num_batches as f32;
857        let avg_g_loss = g_loss_total / num_batches as f32;
858
859        println!(
860            "✅ Epoch {} - D Loss: {:.4}, G Loss: {:.4}",
861            epoch + 1,
862            avg_d_loss,
863            avg_g_loss
864        );
865
866        // Generate samples every few epochs
867        if (epoch + 1) % 5 == 0 {
868            println!("🎲 Generating samples...");
869
870            let mut sample_noise = Array2::<f32>::zeros((4, config.latent_dim));
871            for elem in sample_noise.iter_mut() {
872                *elem = rng.random_range(-1.0..1.0);
873            }
874            let sample_noise_dyn = sample_noise.into_dyn();
875            let generated = generator.forward(&sample_noise_dyn)?;
876
877            println!("📊 Generated {} samples", generated.shape()[0]);
878        }
879    }
880
881    println!("\n🎉 GAN training completed!");
882    Ok(())
883}
examples/loss_functions_example.rs (line 29)
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7    println!("Loss functions example");
8
9    // Mean Squared Error example
10    println!("\n--- Mean Squared Error Example ---");
11    let mse = MeanSquaredError::new();
12
13    // Create sample data for regression
14    let predictions = Array::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
15    let targets = Array::from_vec(vec![1.5, 1.8, 2.5]).into_dyn();
16
17    // Calculate loss
18    let loss = mse.forward(&predictions, &targets)?;
19    println!("Predictions: {:?}", predictions);
20    println!("Targets: {:?}", targets);
21    println!("MSE Loss: {:.4}", loss);
22
23    // Calculate gradients
24    let gradients = mse.backward(&predictions, &targets)?;
25    println!("MSE Gradients: {:?}", gradients);
26
27    // Cross-Entropy Loss example
28    println!("\n--- Cross-Entropy Loss Example ---");
29    let ce = CrossEntropyLoss::new(1e-10);
30
31    // Create sample data for multi-class classification
32    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
33    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
34
35    // Calculate loss
36    let loss = ce.forward(&predictions, &targets)?;
37    println!("Predictions (probabilities):");
38    println!("{:?}", predictions);
39    println!("Targets (one-hot):");
40    println!("{:?}", targets);
41    println!("Cross-Entropy Loss: {:.4}", loss);
42
43    // Calculate gradients
44    let gradients = ce.backward(&predictions, &targets)?;
45    println!("Cross-Entropy Gradients:");
46    println!("{:?}", gradients);
47
48    // Focal Loss example
49    println!("\n--- Focal Loss Example ---");
50    let focal = FocalLoss::new(2.0, Some(0.25), 1e-10);
51
52    // Create sample data for imbalanced classification
53    let predictions = Array::from_shape_vec(IxDyn(&[2, 3]), vec![0.7, 0.2, 0.1, 0.3, 0.6, 0.1])?;
54    let targets = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0])?;
55
56    // Calculate loss
57    let loss = focal.forward(&predictions, &targets)?;
58    println!("Predictions (probabilities):");
59    println!("{:?}", predictions);
60    println!("Targets (one-hot):");
61    println!("{:?}", targets);
62    println!("Focal Loss (gamma=2.0, alpha=0.25): {:.4}", loss);
63
64    // Calculate gradients
65    let gradients = focal.backward(&predictions, &targets)?;
66    println!("Focal Loss Gradients:");
67    println!("{:?}", gradients);
68
69    // Contrastive Loss example
70    println!("\n--- Contrastive Loss Example ---");
71    let contrastive = ContrastiveLoss::new(1.0);
72
73    // Create sample data for similarity learning
74    // Embedding pairs (batch_size x 2 x embedding_dim)
75    let embeddings = Array::from_shape_vec(
76        IxDyn(&[2, 2, 3]),
77        vec![
78            0.1, 0.2, 0.3, // First pair, first embedding
79            0.1, 0.3, 0.3, // First pair, second embedding (similar)
80            0.5, 0.5, 0.5, // Second pair, first embedding
81            0.9, 0.8, 0.7, // Second pair, second embedding (dissimilar)
82        ],
83    )?;
84
85    // Labels: 1 for similar pairs, 0 for dissimilar
86    let labels = Array::from_shape_vec(IxDyn(&[2, 1]), vec![1.0, 0.0])?;
87
88    // Calculate loss
89    let loss = contrastive.forward(&embeddings, &labels)?;
90    println!("Embeddings (batch_size x 2 x embedding_dim):");
91    println!("{:?}", embeddings);
92    println!("Labels (1 for similar, 0 for dissimilar):");
93    println!("{:?}", labels);
94    println!("Contrastive Loss (margin=1.0): {:.4}", loss);
95
96    // Calculate gradients
97    let gradients = contrastive.backward(&embeddings, &labels)?;
98    println!("Contrastive Loss Gradients (first few):");
99    println!("{:?}", gradients.slice(ndarray::s![0, .., 0]));
100
101    // Triplet Loss example
102    println!("\n--- Triplet Loss Example ---");
103    let triplet = TripletLoss::new(0.5);
104
105    // Create sample data for triplet learning
106    // Embedding triplets (batch_size x 3 x embedding_dim)
107    let embeddings = Array::from_shape_vec(
108        IxDyn(&[2, 3, 3]),
109        vec![
110            0.1, 0.2, 0.3, // First triplet, anchor
111            0.1, 0.3, 0.3, // First triplet, positive
112            0.5, 0.5, 0.5, // First triplet, negative
113            0.6, 0.6, 0.6, // Second triplet, anchor
114            0.5, 0.6, 0.6, // Second triplet, positive
115            0.1, 0.1, 0.1, // Second triplet, negative
116        ],
117    )?;
118
119    // Dummy labels (not used by triplet loss)
120    let dummy_labels = Array::zeros(IxDyn(&[2, 1]));
121
122    // Calculate loss
123    let loss = triplet.forward(&embeddings, &dummy_labels)?;
124    println!("Embeddings (batch_size x 3 x embedding_dim):");
125    println!("  - First dimension: batch size");
126    println!("  - Second dimension: [anchor, positive, negative]");
127    println!("  - Third dimension: embedding components");
128    println!("{:?}", embeddings);
129    println!("Triplet Loss (margin=0.5): {:.4}", loss);
130
131    // Calculate gradients
132    let gradients = triplet.backward(&embeddings, &dummy_labels)?;
133    println!("Triplet Loss Gradients (first few):");
134    println!("{:?}", gradients.slice(ndarray::s![0, .., 0]));
135
136    Ok(())
137}
examples/semantic_segmentation_complete.rs (line 634)
617fn train_segmentation_model() -> StdResult<()> {
618    println!("🎨 Starting Semantic Segmentation Training");
619
620    let mut rng = SmallRng::seed_from_u64(42);
621    let config = SegmentationConfig::default();
622
623    println!("🚀 Starting model training...");
624
625    // Create model
626    println!("🏗️ Building U-Net segmentation model...");
627    let model = UNetModel::new(config.clone(), &mut rng)?;
628    println!("✅ Model created with {} classes", config.num_classes);
629
630    // Create dataset
631    let mut dataset = SegmentationDataset::new(config.clone(), 123);
632
633    // Create loss function
634    let loss_fn = CrossEntropyLoss::new(1e-7);
635
636    // Create metrics
637    let metrics = SegmentationMetrics::new(config.num_classes);
638
639    println!("📊 Training configuration:");
640    println!("   - Input size: {:?}", config.input_size);
641    println!("   - Number of classes: {}", config.num_classes);
642    println!("   - Encoder channels: {:?}", config.encoder_channels);
643    println!("   - Decoder channels: {:?}", config.decoder_channels);
644    println!("   - Skip connections: {}", config.skip_connections);
645
646    // Training loop
647    let num_epochs = 15;
648    let batch_size = 2; // Small batch size due to memory constraints
649    let _learning_rate = 0.001;
650
651    for epoch in 0..num_epochs {
652        println!("\n📈 Epoch {}/{}", epoch + 1, num_epochs);
653
654        let mut epoch_loss = 0.0;
655        let num_batches = 10; // Small number of batches for demo
656
657        for batch_idx in 0..num_batches {
658            // Generate training batch
659            let (images, masks) = dataset.generate_batch(batch_size);
660            let images_dyn = images.into_dyn();
661
662            // Forward pass
663            let logits = model.forward(&images_dyn)?;
664
665            // Prepare targets
666            let targets = masks_to_targets(&masks, config.num_classes);
667
668            // Compute loss
669            let batch_loss = loss_fn.forward(&logits, &targets)?;
670            epoch_loss += batch_loss;
671
672            if batch_idx % 5 == 0 {
673                print!(
674                    "🔄 Batch {}/{} - Loss: {:.4}                \r",
675                    batch_idx + 1,
676                    num_batches,
677                    batch_loss
678                );
679            }
680        }
681
682        let avg_loss = epoch_loss / num_batches as f32;
683        println!(
684            "✅ Epoch {} completed - Average Loss: {:.4}",
685            epoch + 1,
686            avg_loss
687        );
688
689        // Evaluation every few epochs
690        if (epoch + 1) % 5 == 0 {
691            println!("🔍 Running evaluation...");
692
693            // Generate validation batch
694            let (val_images, val_masks) = dataset.generate_batch(batch_size);
695            let val_images_dyn = val_images.into_dyn();
696
697            // Get predictions
698            let val_logits = model.forward(&val_images_dyn)?;
699            let predictions = logits_to_predictions(&val_logits);
700
701            // Calculate metrics
702            let pixel_acc = metrics.pixel_accuracy(&predictions, &val_masks);
703            let miou = metrics.mean_iou(&predictions, &val_masks);
704
705            println!("📊 Validation metrics:");
706            println!("   - Pixel Accuracy: {:.4}", pixel_acc);
707            println!("   - Mean IoU: {:.4}", miou);
708
709            // Print class-wise IoU
710            println!("   - Class-wise IoU:");
711            for class_id in 0..config.num_classes {
712                let class_iou = metrics.class_iou(&predictions, &val_masks, class_id);
713                if !class_iou.is_nan() {
714                    println!("     Class {}: {:.4}", class_id, class_iou);
715                }
716            }
717        }
718    }
719
720    println!("\n🎉 Semantic segmentation training completed!");
721
722    // Final evaluation
723    println!("🔬 Final evaluation...");
724    let (test_images, test_masks) = dataset.generate_batch(4);
725    let test_images_dyn = test_images.into_dyn();
726
727    let test_logits = model.forward(&test_images_dyn)?;
728    let final_predictions = logits_to_predictions(&test_logits);
729
730    let final_pixel_acc = metrics.pixel_accuracy(&final_predictions, &test_masks);
731    let final_miou = metrics.mean_iou(&final_predictions, &test_masks);
732
733    println!("📈 Final metrics:");
734    println!("   - Pixel Accuracy: {:.4}", final_pixel_acc);
735    println!("   - Mean IoU: {:.4}", final_miou);
736
737    // Confusion matrix
738    let confusion = metrics.confusion_matrix(&final_predictions, &test_masks);
739    println!("   - Confusion Matrix:");
740    for i in 0..config.num_classes {
741        print!("     [");
742        for j in 0..config.num_classes {
743            print!("{:4}", confusion[[i, j]]);
744        }
745        println!("]");
746    }
747
748    // Performance analysis
749    println!("\n📊 Model Analysis:");
750    println!("   - Architecture: U-Net with skip connections");
751    println!(
752        "   - Parameters: ~{:.1}K (estimated)",
753        (config.encoder_channels.iter().sum::<usize>()
754            + config.decoder_channels.iter().sum::<usize>())
755            / 1000
756    );
757    println!("   - Memory efficient: ✅ (skip connections preserve spatial info)");
758    println!("   - JIT optimized: ✅");
759
760    Ok(())
761}
examples/text_classification_complete.rs (line 537)
458fn train_text_classifier() -> StdResult<()> {
459    println!("📝 Starting Text Classification Training Example");
460    println!("{}", "=".repeat(60));
461
462    let mut rng = SmallRng::seed_from_u64(42);
463
464    // Dataset parameters
465    let num_samples = 800;
466    let num_classes = 3;
467    let max_length = 20;
468    let embedding_dim = 64;
469    let hidden_dim = 128;
470
471    println!("📊 Dataset Configuration:");
472    println!("   - Samples: {}", num_samples);
473    println!(
474        "   - Classes: {} (Positive, Negative, Neutral)",
475        num_classes
476    );
477    println!("   - Max sequence length: {}", max_length);
478    println!("   - Embedding dimension: {}", embedding_dim);
479
480    // Create synthetic text dataset
481    println!("\n🔄 Creating synthetic text dataset...");
482    let dataset = TextDataset::create_synthetic_dataset(num_samples, num_classes, max_length);
483    let (train_dataset, val_dataset) = dataset.train_val_split(0.2);
484
485    println!("   - Vocabulary size: {}", dataset.vocab.vocab_size);
486    println!("   - Training samples: {}", train_dataset.len());
487    println!("   - Validation samples: {}", val_dataset.len());
488
489    // Show some example texts
490    println!("\n📄 Sample texts:");
491    for i in 0..3.min(train_dataset.texts.len()) {
492        println!(
493            "   [Class {}]: {}",
494            train_dataset.labels[i], train_dataset.texts[i]
495        );
496    }
497
498    // Build model
499    println!("\n🏗️  Building text classification model...");
500    let model = build_text_model(
501        dataset.vocab.vocab_size,
502        embedding_dim,
503        hidden_dim,
504        num_classes,
505        max_length,
506        &mut rng,
507    )?;
508
509    let total_params: usize = model.params().iter().map(|p| p.len()).sum();
510    println!("   - Model layers: {}", model.len());
511    println!("   - Total parameters: {}", total_params);
512
513    // Training configuration
514    let config = TrainingConfig {
515        batch_size: 16,
516        epochs: 30,
517        learning_rate: 0.001,
518        shuffle: true,
519        verbose: 1,
520        validation: Some(ValidationSettings {
521            enabled: true,
522            validation_split: 0.2,
523            batch_size: 32,
524            num_workers: 0,
525        }),
526        gradient_accumulation: None,
527        mixed_precision: None,
528        num_workers: 0,
529    };
530
531    println!("\n⚙️  Training Configuration:");
532    println!("   - Batch size: {}", config.batch_size);
533    println!("   - Learning rate: {}", config.learning_rate);
534    println!("   - Epochs: {}", config.epochs);
535
536    // Set up training
537    let loss_fn = CrossEntropyLoss::new(1e-7);
538    let optimizer = Adam::new(config.learning_rate as f32, 0.9, 0.999, 1e-8);
539
540    let mut trainer = Trainer::new(model, optimizer, loss_fn, config);
541
542    // Train the model
543    println!("\n🏋️  Starting training...");
544    println!("{}", "-".repeat(40));
545
546    let training_session = trainer.train(&train_dataset, Some(&val_dataset))?;
547
548    println!("\n✅ Training completed!");
549    println!("   - Epochs trained: {}", training_session.epochs_trained);
550
551    // Evaluate model
552    println!("\n📊 Final Evaluation:");
553    let val_metrics = trainer.validate(&val_dataset)?;
554
555    for (metric, value) in &val_metrics {
556        println!("   - {}: {:.4}", metric, value);
557    }
558
559    // Test on sample texts
560    println!("\n🔍 Sample Predictions:");
561    let sample_indices = vec![0, 1, 2, 3, 4];
562
563    // Manually collect batch since get_batch is not part of Dataset trait
564    let mut batch_tokens = Vec::new();
565    let mut batch_targets = Vec::new();
566
567    for &idx in &sample_indices {
568        let (tokens, targets) = val_dataset.get(idx)?;
569        batch_tokens.push(tokens);
570        batch_targets.push(targets);
571    }
572
573    // Concatenate into batch arrays
574    let sample_tokens = ndarray::concatenate(
575        ndarray::Axis(0),
576        &batch_tokens.iter().map(|a| a.view()).collect::<Vec<_>>(),
577    )?;
578    let sample_targets = ndarray::concatenate(
579        ndarray::Axis(0),
580        &batch_targets.iter().map(|a| a.view()).collect::<Vec<_>>(),
581    )?;
582
583    let model = trainer.get_model();
584    let predictions = model.forward(&sample_tokens)?;
585
586    let class_names = ["Positive", "Negative", "Neutral"];
587
588    for i in 0..sample_indices.len().min(val_dataset.texts.len()) {
589        let pred_row = predictions.slice(s![i, ..]);
590        let target_row = sample_targets.slice(s![i, ..]);
591
592        let pred_class = pred_row
593            .iter()
594            .enumerate()
595            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
596            .map(|(i, _)| i)
597            .unwrap_or(0);
598
599        let true_class = target_row
600            .iter()
601            .enumerate()
602            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
603            .map(|(i, _)| i)
604            .unwrap_or(0);
605
606        let confidence = pred_row[pred_class];
607
608        if sample_indices[i] < val_dataset.texts.len() {
609            println!("   Text: \"{}\"", val_dataset.texts[sample_indices[i]]);
610            println!(
611                "   Predicted: {} (confidence: {:.3})",
612                class_names[pred_class], confidence
613            );
614            println!("   Actual: {}", class_names[true_class]);
615            println!();
616        }
617    }
618
619    // Calculate detailed metrics
620    let detailed_metrics = calculate_text_metrics(&predictions, &sample_targets);
621    println!("📈 Detailed Metrics:");
622    for (metric, value) in &detailed_metrics {
623        println!("   - {}: {:.4}", metric, value);
624    }
625
626    Ok(())
627}

Trait Implementations§

Source§

impl Clone for CrossEntropyLoss

Source§

fn clone(&self) -> CrossEntropyLoss

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for CrossEntropyLoss

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for CrossEntropyLoss

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl<F: Float + Debug> Loss<F> for CrossEntropyLoss

Source§

fn forward( &self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, ) -> Result<F>

Calculate the loss between predictions and targets Read more
Source§

fn backward( &self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>, ) -> Result<Array<F, IxDyn>>

Calculate the gradient of the loss with respect to the predictions Read more
Source§

impl Copy for CrossEntropyLoss

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V