Trait BackwardHook

Source
pub trait BackwardHook<A: Float, D: Dimension> {
    // Required methods
    fn pre_backward(
        &mut self,
        layer_id: &LayerId,
        grad_outputs: &[Array<A, D>],
    ) -> Result<()>;
    fn post_backward(
        &mut self,
        layer_id: &LayerId,
        grad_inputs: &[Array<A, D>],
    ) -> Result<()>;
}
Expand description

Backward pass hook for gradient processing

Required Methods§

Source

fn pre_backward( &mut self, layer_id: &LayerId, grad_outputs: &[Array<A, D>], ) -> Result<()>

Called before layer backward pass

Source

fn post_backward( &mut self, layer_id: &LayerId, grad_inputs: &[Array<A, D>], ) -> Result<()>

Called after layer backward pass

Implementors§