pub struct ScalarWeightClassifier { /* private fields */ }
Expand description
A classifier that uses scalar weighting of layers.
See Peters et al., 2018 and Kondratyuk & Straka, 2019.
Implementations
sourceimpl ScalarWeightClassifier
impl ScalarWeightClassifier
pub fn new<'a>(
vs: impl Borrow<PathExt<'a>>,
config: &ScalarWeightClassifierConfig
) -> Result<ScalarWeightClassifier, TransformerError>
pub fn forward(
&self,
layers: &[LayerOutput],
train: bool
) -> Result<Tensor, TransformerError>
pub fn logits(
&self,
layers: &[LayerOutput],
train: bool
) -> Result<Tensor, TransformerError>
Trait Implementations
Auto Trait Implementations
impl RefUnwindSafe for ScalarWeightClassifier
impl Send for ScalarWeightClassifier
impl !Sync for ScalarWeightClassifier
impl Unpin for ScalarWeightClassifier
impl UnwindSafe for ScalarWeightClassifier
Blanket Implementations
sourceimpl<T> BorrowMut<T> for T where
T: ?Sized,
impl<T> BorrowMut<T> for T where
T: ?Sized,
const: unstable · sourcefn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more