Struct syntaxdot_transformers::loss::CrossEntropyLoss
source · [−]pub struct CrossEntropyLoss { /* private fields */ }
Expand description
Cross-entropy loss function.
Implementations
sourceimpl CrossEntropyLoss
impl CrossEntropyLoss
sourcepub fn new(
ignore_index: i64,
label_smoothing: Option<f64>,
reduction: Reduction
) -> Self
pub fn new(
ignore_index: i64,
label_smoothing: Option<f64>,
reduction: Reduction
) -> Self
Construct the cross-entropy loss function.
Do not include targets that have ignore_index
as their value in the
loss computation. If label_smoothing
is set to p, then the correct
label gets probability 1-p and the probability p is distributed
across incorrect labels. reduction
specifies how the losses should
be reduced/summarized.
sourcepub fn forward(
&self,
logits: &Tensor,
targets: &Tensor,
target_mask: Option<&Tensor>
) -> Result<Tensor, TransformerError>
pub fn forward(
&self,
logits: &Tensor,
targets: &Tensor,
target_mask: Option<&Tensor>
) -> Result<Tensor, TransformerError>
Compute the cross-entropy loss.
logits
should be the unnormalized probablilities of shape
[batch_size, n_classes]
and targets
the gold-standard labels
with shape [batch_size]
.
The optional target mask has to be of shape [batch_size, n_classes]
.
If the mask is not provided, then all n_classes
will be used in
label smoothing.
Auto Trait Implementations
impl RefUnwindSafe for CrossEntropyLoss
impl Send for CrossEntropyLoss
impl Sync for CrossEntropyLoss
impl Unpin for CrossEntropyLoss
impl UnwindSafe for CrossEntropyLoss
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more