Struct syntaxdot_transformers::loss::CrossEntropyLoss
source · pub struct CrossEntropyLoss { /* private fields */ }
Expand description
Cross-entropy loss function.
Implementations§
source§impl CrossEntropyLoss
impl CrossEntropyLoss
sourcepub fn new(
ignore_index: i64,
label_smoothing: Option<f64>,
reduction: Reduction
) -> Self
pub fn new(
ignore_index: i64,
label_smoothing: Option<f64>,
reduction: Reduction
) -> Self
Construct the cross-entropy loss function.
Do not include targets that have ignore_index
as their value in the
loss computation. If label_smoothing
is set to p, then the correct
label gets probability 1-p and the probability p is distributed
across incorrect labels. reduction
specifies how the losses should
be reduced/summarized.
sourcepub fn forward(
&self,
logits: &Tensor,
targets: &Tensor,
target_mask: Option<&Tensor>
) -> Result<Tensor, TransformerError>
pub fn forward(
&self,
logits: &Tensor,
targets: &Tensor,
target_mask: Option<&Tensor>
) -> Result<Tensor, TransformerError>
Compute the cross-entropy loss.
logits
should be the unnormalized probablilities of shape
[batch_size, n_classes]
and targets
the gold-standard labels
with shape [batch_size]
.
The optional target mask has to be of shape [batch_size, n_classes]
.
If the mask is not provided, then all n_classes
will be used in
label smoothing.