pub struct SteTrainer {
pub w1_latent: Vec<f32>,
pub w2_latent: Vec<f32>,
pub in_features: usize,
pub hidden_size: usize,
pub out_features: usize,
pub config: QatConfig,
}Expand description
Maintains latent f32 shadow weights for a 2-layer MLP. During each training step:
- Quantize latent weights → ternary {-1, 0, +1}
- Forward pass (f32 arithmetic on quantized weights)
- MSE loss + backprop through STE
- SGD update on latent weights
Fields§
§w1_latent: Vec<f32>§w2_latent: Vec<f32>§in_features: usize§out_features: usize§config: QatConfigImplementations§
Source§impl SteTrainer
impl SteTrainer
Sourcepub fn from_mlp(mlp: &TernaryMLP, config: QatConfig) -> Self
pub fn from_mlp(mlp: &TernaryMLP, config: QatConfig) -> Self
Initialise from an existing TernaryMLP’s quantized weights. The ternary {-1,0,+1} values become the initial latent floats.
Sourcepub fn from_f32(
in_features: usize,
hidden_size: usize,
out_features: usize,
w1_f32: Vec<f32>,
w2_f32: Vec<f32>,
config: QatConfig,
) -> Self
pub fn from_f32( in_features: usize, hidden_size: usize, out_features: usize, w1_f32: Vec<f32>, w2_f32: Vec<f32>, config: QatConfig, ) -> Self
Initialise from raw f32 weights (quantization happens at first step).
Sourcepub fn train_step(&mut self, input: &[f32], target: &[f32]) -> f32
pub fn train_step(&mut self, input: &[f32], target: &[f32]) -> f32
One SGD step on a single sample.
input — flat f32 row vector, length in_features
target — flat f32 row vector, length out_features
Returns the MSE loss for this sample.
Sourcepub fn train(&mut self, samples: &[(Vec<f32>, Vec<f32>)]) -> QatResult
pub fn train(&mut self, samples: &[(Vec<f32>, Vec<f32>)]) -> QatResult
Run the full training loop over a dataset.
samples — slice of (input, target) pairs
Sourcepub fn finalize(&self) -> TernaryMLP
pub fn finalize(&self) -> TernaryMLP
Finalise training: quantize latent weights and return a TernaryMLP.
Auto Trait Implementations§
impl Freeze for SteTrainer
impl RefUnwindSafe for SteTrainer
impl Send for SteTrainer
impl Sync for SteTrainer
impl Unpin for SteTrainer
impl UnsafeUnpin for SteTrainer
impl UnwindSafe for SteTrainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more