Skip to main content

rust_ai_core/
memory.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Memory estimation and tracking utilities for GPU operations.
5//!
6//! ## Why This Module Exists
7//!
8//! GPU memory (VRAM) is a precious and limited resource. Unlike CPU memory, there's no
9//! swap space fallback when VRAM runs out - operations simply fail. This module provides
10//! utilities to estimate memory requirements before allocation and track usage during
11//! execution, enabling crates to:
12//!
13//! 1. **Pre-flight checks**: Verify sufficient VRAM before starting expensive operations
14//! 2. **Batch size optimization**: Automatically adjust batch sizes to fit available memory
15//! 3. **Memory budgeting**: Track allocations across multiple operations
16//! 4. **Debugging**: Identify memory leaks or unexpected allocations
17//!
18//! ## Design Decisions
19//!
20//! - **Conservative estimation**: Estimates include overhead buffers because running out
21//!   of memory mid-operation is worse than slightly underutilizing VRAM.
22//!
23//! - **No global state**: `MemoryTracker` is an explicit struct, not a global singleton,
24//!   because different parts of an application may need independent tracking.
25//!
26//! - **Candle-agnostic sizes**: Functions work with shapes and dtypes directly, not just
27//!   Candle tensors, enabling estimation before tensor creation.
28
29use crate::error::{CoreError, Result};
30use candle_core::DType;
31use std::sync::atomic::{AtomicUsize, Ordering};
32
33/// Default overhead factor applied to memory estimates.
34///
35/// Why 1.1x: CUDA allocators have alignment requirements and fragmentation overhead.
36/// A 10% buffer prevents edge-case OOM errors when estimates are exact.
37pub const DEFAULT_OVERHEAD_FACTOR: f64 = 1.1;
38
39/// Estimate the memory required to store a tensor with given shape and dtype.
40///
41/// This function calculates the raw memory requirement without overhead. Use
42/// [`MemoryTracker::estimate_with_overhead`] for production estimates.
43///
44/// ## Arguments
45///
46/// * `shape` - Tensor dimensions (e.g., `[batch, seq_len, hidden_dim]`)
47/// * `dtype` - Data type determining bytes per element
48///
49/// ## Returns
50///
51/// Memory requirement in bytes.
52///
53/// ## Why This Function
54///
55/// Pre-computing memory requirements allows batch size optimization and preflight
56/// checks before committing to expensive allocations.
57///
58/// ## Example
59///
60/// ```rust
61/// use rust_ai_core::estimate_tensor_bytes;
62/// use candle_core::DType;
63///
64/// // LLaMA-2 7B attention output: [batch, heads, seq, head_dim]
65/// let bytes = estimate_tensor_bytes(&[1, 32, 4096, 128], DType::BF16);
66/// assert_eq!(bytes, 1 * 32 * 4096 * 128 * 2); // 32 MB
67/// ```
68#[must_use]
69pub fn estimate_tensor_bytes(shape: &[usize], dtype: DType) -> usize {
70    let numel: usize = shape.iter().product();
71    numel * dtype.size_in_bytes()
72}
73
74/// Estimate memory for attention computation.
75///
76/// Attention requires storing Q, K, V tensors plus the attention weights matrix.
77/// This function estimates the total memory for a single attention layer.
78///
79/// ## Arguments
80///
81/// * `batch_size` - Number of sequences in the batch
82/// * `num_heads` - Number of attention heads
83/// * `seq_len` - Sequence length (context window)
84/// * `head_dim` - Dimension per attention head
85/// * `dtype` - Data type for all tensors
86///
87/// ## Returns
88///
89/// Estimated memory in bytes for one attention layer.
90///
91/// ## Why This Function
92///
93/// Attention is the primary memory consumer in transformers. The attention weights
94/// matrix scales with `O(seq_len²)`, making it the bottleneck for long sequences.
95/// This estimate helps determine maximum context length for a given VRAM budget.
96///
97/// ## Memory Breakdown
98///
99/// - Q, K, V: 3 × (batch × heads × seq × `head_dim`)
100/// - Attention weights: batch × heads × seq × seq
101/// - Output: batch × heads × seq × `head_dim`
102#[must_use]
103pub fn estimate_attention_memory(
104    batch_size: usize,
105    num_heads: usize,
106    seq_len: usize,
107    head_dim: usize,
108    dtype: DType,
109) -> usize {
110    let bytes_per_elem = dtype.size_in_bytes();
111
112    // Q, K, V tensors
113    let qkv_bytes = 3 * batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
114
115    // Attention weights matrix (the O(n²) component)
116    let attn_weights_bytes = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
117
118    // Output tensor
119    let output_bytes = batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
120
121    qkv_bytes + attn_weights_bytes + output_bytes
122}
123
124/// Memory usage tracker for GPU operations.
125///
126/// Tracks allocated and peak memory usage across operations. Thread-safe via atomics.
127///
128/// ## Why This Struct
129///
130/// Unlike CPU memory which is managed by the OS, GPU memory requires explicit tracking
131/// because:
132///
133/// 1. **No swap**: When VRAM runs out, allocations fail immediately
134/// 2. **Fragmentation**: Repeated allocations can fragment the heap
135/// 3. **Debugging**: Memory leaks on GPU are harder to diagnose than CPU leaks
136///
137/// ## Usage Pattern
138///
139/// ```rust
140/// use rust_ai_core::MemoryTracker;
141///
142/// let tracker = MemoryTracker::new();
143///
144/// // Before allocation
145/// tracker.allocate(1024 * 1024).expect("allocation should succeed"); // 1 MB
146///
147/// // After freeing
148/// tracker.deallocate(1024 * 1024);
149///
150/// println!("Peak usage: {} bytes", tracker.peak_bytes());
151/// ```
152#[derive(Debug)]
153pub struct MemoryTracker {
154    /// Currently allocated bytes.
155    allocated: AtomicUsize,
156    /// Peak allocation during lifetime.
157    peak: AtomicUsize,
158    /// Optional memory limit (0 = unlimited).
159    limit: AtomicUsize,
160    /// Overhead factor for estimates.
161    overhead_factor: f64,
162}
163
164impl Default for MemoryTracker {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl MemoryTracker {
171    /// Create a new memory tracker with no limit.
172    ///
173    /// ## Why No Default Limit
174    ///
175    /// Different GPUs have vastly different VRAM capacities (4GB to 80GB+).
176    /// Setting a default limit would either be too restrictive for high-end cards
177    /// or meaningless for consumer cards. Users should set limits explicitly.
178    #[must_use]
179    pub fn new() -> Self {
180        Self {
181            allocated: AtomicUsize::new(0),
182            peak: AtomicUsize::new(0),
183            limit: AtomicUsize::new(0),
184            overhead_factor: DEFAULT_OVERHEAD_FACTOR,
185        }
186    }
187
188    /// Create a tracker with a memory limit.
189    ///
190    /// ## Arguments
191    ///
192    /// * `limit_bytes` - Maximum allowed allocation in bytes
193    ///
194    /// ## Why Memory Limits
195    ///
196    /// Setting a limit below actual VRAM capacity reserves space for:
197    /// - CUDA context overhead (~200-500 MB)
198    /// - Framework allocations (PyTorch/Candle tensor cache)
199    /// - Other processes sharing the GPU
200    #[must_use]
201    pub fn with_limit(limit_bytes: usize) -> Self {
202        Self {
203            allocated: AtomicUsize::new(0),
204            peak: AtomicUsize::new(0),
205            limit: AtomicUsize::new(limit_bytes),
206            overhead_factor: DEFAULT_OVERHEAD_FACTOR,
207        }
208    }
209
210    /// Set a custom overhead factor for estimates.
211    ///
212    /// ## Arguments
213    ///
214    /// * `factor` - Multiplier applied to estimates (default: 1.1)
215    #[must_use]
216    pub fn with_overhead_factor(mut self, factor: f64) -> Self {
217        self.overhead_factor = factor;
218        self
219    }
220
221    /// Record a memory allocation.
222    ///
223    /// ## Arguments
224    ///
225    /// * `bytes` - Number of bytes allocated
226    ///
227    /// ## Returns
228    ///
229    /// `Ok(())` if allocation is within limits.
230    ///
231    /// ## Errors
232    ///
233    /// Returns `CoreError::OutOfMemory` if allocation would exceed the limit.
234    ///
235    /// ## Why Track Allocations
236    ///
237    /// Explicit tracking catches memory issues early. Without tracking, OOM errors
238    /// occur deep in CUDA kernels with unhelpful error messages.
239    pub fn allocate(&self, bytes: usize) -> Result<()> {
240        // Check limit BEFORE updating state to avoid partial updates on failure
241        let limit = self.limit.load(Ordering::SeqCst);
242        let current = self.allocated.load(Ordering::SeqCst);
243        let new_allocated = current + bytes;
244
245        if limit > 0 && new_allocated > limit {
246            return Err(CoreError::oom(format!(
247                "allocation of {bytes} bytes would exceed limit of {limit} bytes \
248                 (current: {current} bytes)"
249            )));
250        }
251
252        // Update allocated (actual update happens here)
253        let actual_new = self.allocated.fetch_add(bytes, Ordering::SeqCst) + bytes;
254
255        // Update peak
256        let mut current_peak = self.peak.load(Ordering::SeqCst);
257        while actual_new > current_peak {
258            match self.peak.compare_exchange_weak(
259                current_peak,
260                actual_new,
261                Ordering::SeqCst,
262                Ordering::SeqCst,
263            ) {
264                Ok(_) => break,
265                Err(p) => current_peak = p,
266            }
267        }
268
269        Ok(())
270    }
271
272    /// Record a memory deallocation.
273    ///
274    /// ## Arguments
275    ///
276    /// * `bytes` - Number of bytes freed
277    pub fn deallocate(&self, bytes: usize) {
278        self.allocated.fetch_sub(bytes, Ordering::SeqCst);
279    }
280
281    /// Get currently allocated bytes.
282    #[must_use]
283    pub fn allocated_bytes(&self) -> usize {
284        self.allocated.load(Ordering::SeqCst)
285    }
286
287    /// Get peak allocation during tracker lifetime.
288    ///
289    /// ## Why Track Peak
290    ///
291    /// Peak usage is more useful than current usage for capacity planning.
292    /// It shows the high-water mark needed to complete a workload.
293    #[must_use]
294    pub fn peak_bytes(&self) -> usize {
295        self.peak.load(Ordering::SeqCst)
296    }
297
298    /// Get configured memory limit (0 = unlimited).
299    #[must_use]
300    pub fn limit_bytes(&self) -> usize {
301        self.limit.load(Ordering::SeqCst)
302    }
303
304    /// Estimate required memory with overhead factor applied.
305    ///
306    /// ## Arguments
307    ///
308    /// * `shape` - Tensor dimensions
309    /// * `dtype` - Data type
310    ///
311    /// ## Returns
312    ///
313    /// Estimated bytes including overhead buffer.
314    #[must_use]
315    pub fn estimate_with_overhead(&self, shape: &[usize], dtype: DType) -> usize {
316        let raw = estimate_tensor_bytes(shape, dtype);
317        #[allow(
318            clippy::cast_sign_loss,
319            clippy::cast_possible_truncation,
320            clippy::cast_precision_loss
321        )]
322        {
323            (raw as f64 * self.overhead_factor) as usize
324        }
325    }
326
327    /// Check if an allocation would fit within limits.
328    ///
329    /// ## Arguments
330    ///
331    /// * `bytes` - Proposed allocation size
332    ///
333    /// ## Returns
334    ///
335    /// `true` if allocation would succeed.
336    #[must_use]
337    pub fn would_fit(&self, bytes: usize) -> bool {
338        let limit = self.limit.load(Ordering::SeqCst);
339        if limit == 0 {
340            return true; // No limit
341        }
342        self.allocated.load(Ordering::SeqCst) + bytes <= limit
343    }
344
345    /// Reset the tracker to initial state.
346    ///
347    /// ## Why Reset
348    ///
349    /// Between training epochs or inference batches, resetting allows tracking
350    /// per-phase memory usage without creating new tracker instances.
351    pub fn reset(&self) {
352        self.allocated.store(0, Ordering::SeqCst);
353        self.peak.store(0, Ordering::SeqCst);
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_estimate_tensor_bytes() {
363        // 1000 f32 elements = 4000 bytes
364        assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F32), 4000);
365
366        // 1000 f16 elements = 2000 bytes
367        assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F16), 2000);
368
369        // Empty tensor = 0 bytes
370        assert_eq!(estimate_tensor_bytes(&[0], DType::F32), 0);
371    }
372
373    #[test]
374    fn test_estimate_attention_memory() {
375        // Simple case: 1 batch, 1 head, 4 seq, 2 head_dim, f32
376        let bytes = estimate_attention_memory(1, 1, 4, 2, DType::F32);
377        // QKV: 3 * 1 * 1 * 4 * 2 * 4 = 96
378        // Attn: 1 * 1 * 4 * 4 * 4 = 64
379        // Out: 1 * 1 * 4 * 2 * 4 = 32
380        // Total: 192
381        assert_eq!(bytes, 192);
382    }
383
384    #[test]
385    fn test_memory_tracker_allocation() {
386        let tracker = MemoryTracker::with_limit(1000);
387
388        // Successful allocation
389        assert!(tracker.allocate(500).is_ok());
390        assert_eq!(tracker.allocated_bytes(), 500);
391
392        // Second allocation
393        assert!(tracker.allocate(400).is_ok());
394        assert_eq!(tracker.allocated_bytes(), 900);
395
396        // Exceeds limit
397        assert!(tracker.allocate(200).is_err());
398        assert_eq!(tracker.allocated_bytes(), 900); // Unchanged
399
400        // Deallocation
401        tracker.deallocate(400);
402        assert_eq!(tracker.allocated_bytes(), 500);
403
404        // Now fits
405        assert!(tracker.allocate(200).is_ok());
406    }
407
408    #[test]
409    fn test_memory_tracker_peak() {
410        let tracker = MemoryTracker::new();
411
412        tracker.allocate(100).unwrap();
413        tracker.allocate(200).unwrap();
414        assert_eq!(tracker.peak_bytes(), 300);
415
416        tracker.deallocate(200);
417        assert_eq!(tracker.allocated_bytes(), 100);
418        assert_eq!(tracker.peak_bytes(), 300); // Peak unchanged
419
420        tracker.allocate(50).unwrap();
421        assert_eq!(tracker.peak_bytes(), 300); // Still 300
422
423        tracker.allocate(300).unwrap();
424        assert_eq!(tracker.peak_bytes(), 450); // New peak
425    }
426
427    #[test]
428    fn test_would_fit() {
429        let tracker = MemoryTracker::with_limit(1000);
430        tracker.allocate(500).unwrap();
431
432        assert!(tracker.would_fit(400));
433        assert!(tracker.would_fit(500));
434        assert!(!tracker.would_fit(501));
435
436        // Unlimited tracker
437        let unlimited = MemoryTracker::new();
438        assert!(unlimited.would_fit(usize::MAX));
439    }
440}