pub struct QLoraTrainer { /* private fields */ }Expand description
Trainer for QLoRA fine-tuning.
Manages the training loop, gradient computation, and optimizer updates
for quantized LoRA training.
§Usage
- Create trainer with config
- Use
var_builder()to create layers that register params inVarMap - Call
init_optimizer()to set up optimizer with registered params - Call
training_step()ortraining_step_lm()for each batch
Implementations§
Source§impl QLoraTrainer
impl QLoraTrainer
Sourcepub fn new(config: QLoraTrainingConfig, device: Device) -> QLoraTrainer
pub fn new(config: QLoraTrainingConfig, device: Device) -> QLoraTrainer
Sourcepub fn var_builder(&self) -> VarBuilderArgs<'_, Box<dyn SimpleBackend + '_>>
pub fn var_builder(&self) -> VarBuilderArgs<'_, Box<dyn SimpleBackend + '_>>
Get a VarBuilder backed by this trainer’s VarMap.
Use this to create QuantizedLinear layers with gradient tracking.
Params created through this VarBuilder will be registered in the
trainer’s VarMap and trained by the optimizer.
§Example
let mut trainer = QLoraTrainer::new(config, device.clone());
let vb = trainer.var_builder();
let layer = QuantizedLinear::from_weight_with_varbuilder(&weight, None, &qlora_config, vb.pp("layer0"))?;
trainer.init_optimizer(&[&layer])?;Sourcepub fn init_optimizer(
&mut self,
layers: &[&QuantizedLinear],
) -> Result<(), QLoraError>
pub fn init_optimizer( &mut self, layers: &[&QuantizedLinear], ) -> Result<(), QLoraError>
Initialize the optimizer with trainable parameters.
Creates either a paged or standard AdamW optimizer based on configuration.
For paged optimizer, optimizer states are stored on CPU and paged to GPU
during updates to reduce VRAM usage.
Important: Layers must be created using var_builder() for standard AdamW,
or the optimizer will have no trainable parameters.
§Arguments
layers- TheQLoRAlayers to train
§Errors
Returns error if:
VarMapis empty (for standard optimizer) - layers weren’t created withvar_builder()- Optimizer initialization fails
§Panics
Panics if the VarMap mutex is poisoned.
Sourcepub fn state(&self) -> &AdapterTrainingState
pub fn state(&self) -> &AdapterTrainingState
Get the current training state.
Sourcepub fn current_lr(&self) -> f64
pub fn current_lr(&self) -> f64
Get the current learning rate.
Sourcepub fn global_step(&self) -> usize
pub fn global_step(&self) -> usize
Get the current step.
Sourcepub fn training_step(
&mut self,
layers: &[&QuantizedLinear],
input: &Tensor,
targets: &Tensor,
) -> Result<f64, QLoraError>
pub fn training_step( &mut self, layers: &[&QuantizedLinear], input: &Tensor, targets: &Tensor, ) -> Result<f64, QLoraError>
Perform a training step with gradient accumulation.
QLoRA training flow:
- Forward pass through frozen quantized base + trainable
LoRA - Compute loss (cross-entropy for LM, MSE for regression)
- Backward pass - gradients flow only through
LoRAweights - Accumulate gradients if
gradient_accumulation_steps> 1 - Optimizer step when accumulation complete
Supports both standard AdamW and paged AdamW optimizers.
§Arguments
layers- TheQLoRAlayersinput- Input tensor[batch, seq_len, hidden]targets- Target tensor (logits or token IDs depending on loss)
§Returns
The loss value for this step
§Errors
Returns error if forward pass or backward pass fails
§Panics
Panics if the VarMap mutex is poisoned.
Sourcepub fn training_step_lm(
&mut self,
layers: &[&QuantizedLinear],
input: &Tensor,
target_ids: &Tensor,
) -> Result<f64, QLoraError>
pub fn training_step_lm( &mut self, layers: &[&QuantizedLinear], input: &Tensor, target_ids: &Tensor, ) -> Result<f64, QLoraError>
Perform training step with cross-entropy loss for language modeling.
Supports both standard AdamW and paged AdamW optimizers.
§Arguments
layers- TheQLoRAlayersinput- Input tensor[batch, seq_len, hidden]target_ids- Target token IDs[batch, seq_len]
§Returns
The cross-entropy loss value
§Errors
Returns error if forward pass or loss computation fails
§Panics
Panics if the VarMap mutex is poisoned.
Sourcepub fn start_epoch(&mut self)
pub fn start_epoch(&mut self)
Start a new training epoch.
Sourcepub fn should_continue(&self) -> bool
pub fn should_continue(&self) -> bool
Check if training should continue.
Sourcepub fn config(&self) -> &QLoraTrainingConfig
pub fn config(&self) -> &QLoraTrainingConfig
Get training configuration.
Sourcepub fn optimizer_memory_stats(&self) -> Option<(usize, usize)>
pub fn optimizer_memory_stats(&self) -> Option<(usize, usize)>
Get optimizer memory statistics (CPU bytes, GPU bytes).
Auto Trait Implementations§
impl Freeze for QLoraTrainer
impl !RefUnwindSafe for QLoraTrainer
impl Send for QLoraTrainer
impl Sync for QLoraTrainer
impl Unpin for QLoraTrainer
impl !UnwindSafe for QLoraTrainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more