Skip to main content

train_step_sgd

Function train_step_sgd 

Source
pub fn train_step_sgd(
    graph: &mut Graph,
    optimizer: &mut Sgd,
    prediction: NodeId,
    target: NodeId,
    trainable_nodes: &[NodeId],
) -> Result<f32, ModelError>
Expand description

Runs one full train step: loss forward, backward, and SGD updates.