Skip to main content

QLoraTrainer

Struct QLoraTrainer 

Source
pub struct QLoraTrainer { /* private fields */ }
Expand description

Trainer for QLoRA fine-tuning.

Manages the training loop, gradient computation, and optimizer updates for quantized LoRA training.

§Usage

  1. Create trainer with config
  2. Use var_builder() to create layers that register params in VarMap
  3. Call init_optimizer() to set up optimizer with registered params
  4. Call training_step() or training_step_lm() for each batch

Implementations§

Source§

impl QLoraTrainer

Source

pub fn new(config: QLoraTrainingConfig, device: Device) -> QLoraTrainer

Create a new QLoRA trainer.

§Arguments
  • config - Training configuration
  • device - Device for computation
§Returns

New trainer instance

Source

pub fn var_builder(&self) -> VarBuilderArgs<'_, Box<dyn SimpleBackend + '_>>

Get a VarBuilder backed by this trainer’s VarMap.

Use this to create QuantizedLinear layers with gradient tracking. Params created through this VarBuilder will be registered in the trainer’s VarMap and trained by the optimizer.

§Example
let mut trainer = QLoraTrainer::new(config, device.clone());
let vb = trainer.var_builder();
let layer = QuantizedLinear::from_weight_with_varbuilder(&weight, None, &qlora_config, vb.pp("layer0"))?;
trainer.init_optimizer(&[&layer])?;
Source

pub fn init_optimizer( &mut self, layers: &[&QuantizedLinear], ) -> Result<(), QLoraError>

Initialize the optimizer with trainable parameters.

Creates either a paged or standard AdamW optimizer based on configuration. For paged optimizer, optimizer states are stored on CPU and paged to GPU during updates to reduce VRAM usage.

Important: Layers must be created using var_builder() for standard AdamW, or the optimizer will have no trainable parameters.

§Arguments
  • layers - The QLoRA layers to train
§Errors

Returns error if:

  • VarMap is empty (for standard optimizer) - layers weren’t created with var_builder()
  • Optimizer initialization fails
§Panics

Panics if the VarMap mutex is poisoned.

Source

pub fn state(&self) -> &AdapterTrainingState

Get the current training state.

Source

pub fn current_lr(&self) -> f64

Get the current learning rate.

Source

pub fn global_step(&self) -> usize

Get the current step.

Source

pub fn epoch(&self) -> usize

Get the current epoch.

Source

pub fn training_step( &mut self, layers: &[&QuantizedLinear], input: &Tensor, targets: &Tensor, ) -> Result<f64, QLoraError>

Perform a training step with gradient accumulation.

QLoRA training flow:

  1. Forward pass through frozen quantized base + trainable LoRA
  2. Compute loss (cross-entropy for LM, MSE for regression)
  3. Backward pass - gradients flow only through LoRA weights
  4. Accumulate gradients if gradient_accumulation_steps > 1
  5. Optimizer step when accumulation complete

Supports both standard AdamW and paged AdamW optimizers.

§Arguments
  • layers - The QLoRA layers
  • input - Input tensor [batch, seq_len, hidden]
  • targets - Target tensor (logits or token IDs depending on loss)
§Returns

The loss value for this step

§Errors

Returns error if forward pass or backward pass fails

§Panics

Panics if the VarMap mutex is poisoned.

Source

pub fn training_step_lm( &mut self, layers: &[&QuantizedLinear], input: &Tensor, target_ids: &Tensor, ) -> Result<f64, QLoraError>

Perform training step with cross-entropy loss for language modeling.

Supports both standard AdamW and paged AdamW optimizers.

§Arguments
  • layers - The QLoRA layers
  • input - Input tensor [batch, seq_len, hidden]
  • target_ids - Target token IDs [batch, seq_len]
§Returns

The cross-entropy loss value

§Errors

Returns error if forward pass or loss computation fails

§Panics

Panics if the VarMap mutex is poisoned.

Source

pub fn start_epoch(&mut self)

Start a new training epoch.

Source

pub fn should_continue(&self) -> bool

Check if training should continue.

Source

pub fn update_lr(&mut self)

Update learning rate based on schedule.

Source

pub fn config(&self) -> &QLoraTrainingConfig

Get training configuration.

Source

pub fn optimizer_memory_stats(&self) -> Option<(usize, usize)>

Get optimizer memory statistics (CPU bytes, GPU bytes).

Source

pub fn zero_grad(&mut self)

Zero gradients for next accumulation cycle.

Resets the accumulation step counter. Note: In candle, gradients are automatically zeroed when backward_step is called on the optimizer.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,

Source§

impl<T> Ungil for T
where T: Send,