Skip to main content

Zero3CpuOffloadManager

Struct Zero3CpuOffloadManager 

Source
pub struct Zero3CpuOffloadManager { /* private fields */ }
Expand description

Main ZeRO-3 CPU offload manager that orchestrates all components

This is the primary interface for ZeRO-3 operations, providing a unified API that coordinates between all the specialized modules. It maintains the same interface as the original monolithic implementation for backward compatibility.

Implementations§

Source§

impl Zero3CpuOffloadManager

Source

pub fn new( config: Zero3CpuOffloadConfig, process_group: Arc<ProcessGroup>, model_parameters: &ConfigModelParameters, ) -> TorshResult<Self>

Create a new ZeRO-3 CPU offload manager

Initializes all component systems and establishes distributed coordination. The manager will automatically partition parameters, gradients, and optimizer states according to the ZeRO-3 algorithm.

§Arguments
  • config - Configuration for ZeRO-3 behavior and memory management
  • process_group - Distributed process group for coordination
  • model_parameters - Description of model parameters to be managed
§Returns

Returns a configured ZeRO-3 manager ready for training operations.

Source

pub async fn forward_pass( &mut self, input: &Tensor<f32>, layer_names: &[String], ) -> TorshResult<Tensor<f32>>

Execute forward pass with ZeRO-3 CPU offloading

Processes each layer with intelligent parameter management:

  1. Prefetches parameters for upcoming layers
  2. Ensures current layer parameters are on GPU
  3. Executes layer computation
  4. Optionally offloads parameters back to CPU
  5. Performs memory optimization as needed
§Arguments
  • input - Input tensor for the forward pass
  • layer_names - Ordered list of layer names to execute
§Returns

Returns the output tensor after processing all layers.

Source

pub async fn backward_pass( &mut self, grad_output: &Tensor<f32>, layer_names: &[String], ) -> TorshResult<()>

Execute backward pass with ZeRO-3 CPU offloading

Processes layers in reverse order for gradient computation:

  1. Ensures parameters are available for gradient computation
  2. Computes gradients for each layer
  3. Partitions and manages gradients according to ZeRO-3
  4. Performs all-reduce synchronization across ranks
§Arguments
  • grad_output - Gradient tensor from the loss function
  • layer_names - Ordered list of layer names (processed in reverse)
§Returns

Returns Ok(()) when backward pass completes successfully.

Source

pub async fn optimizer_step(&mut self, learning_rate: f32) -> TorshResult<()>

Update optimizer states and parameters with ZeRO-3 partitioning

Performs optimizer step with intelligent state management:

  1. Gathers partitioned gradients for owned parameters
  2. Fetches optimizer states from CPU if needed
  3. Computes parameter updates using optimizer algorithm
  4. Updates parameters and stores back to appropriate location
  5. Broadcasts updates to all ranks that need them
§Arguments
  • learning_rate - Learning rate for parameter updates
§Returns

Returns Ok(()) when optimizer step completes successfully.

Source

pub fn get_performance_stats(&self) -> Zero3PerformanceStats

Get comprehensive performance statistics

Returns detailed performance metrics including timing, throughput, memory usage, and efficiency measurements.

Source

pub fn get_memory_stats(&self) -> Zero3MemoryStats

Get memory usage statistics

Returns current memory usage across CPU and GPU, including parameter distribution and compression effectiveness.

Source

pub async fn force_memory_optimization(&self) -> TorshResult<()>

Force immediate memory optimization

Triggers aggressive memory optimization regardless of current pressure. Useful for cleaning up before checkpointing or when memory is critically low.

Source

pub fn get_prefetch_status(&self) -> PrefetchQueueStatus

Get prefetch scheduler status

Returns information about current prefetch operations and queue status.

Source

pub async fn adapt_performance(&self) -> TorshResult<()>

Adapt system performance based on runtime metrics

Analyzes recent performance and adjusts prefetch strategies, memory management policies, and other adaptive parameters.

Source

pub async fn reset_state(&self) -> TorshResult<()>

Clear all caches and reset state

Useful for testing or when switching between different models.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more