rust_ai_core/memory.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Memory estimation and tracking utilities for GPU operations.
5//!
6//! ## Why This Module Exists
7//!
8//! GPU memory (VRAM) is a precious and limited resource. Unlike CPU memory, there's no
9//! swap space fallback when VRAM runs out - operations simply fail. This module provides
10//! utilities to estimate memory requirements before allocation and track usage during
11//! execution, enabling crates to:
12//!
13//! 1. **Pre-flight checks**: Verify sufficient VRAM before starting expensive operations
14//! 2. **Batch size optimization**: Automatically adjust batch sizes to fit available memory
15//! 3. **Memory budgeting**: Track allocations across multiple operations
16//! 4. **Debugging**: Identify memory leaks or unexpected allocations
17//!
18//! ## Design Decisions
19//!
20//! - **Conservative estimation**: Estimates include overhead buffers because running out
21//! of memory mid-operation is worse than slightly underutilizing VRAM.
22//!
23//! - **No global state**: `MemoryTracker` is an explicit struct, not a global singleton,
24//! because different parts of an application may need independent tracking.
25//!
26//! - **Candle-agnostic sizes**: Functions work with shapes and dtypes directly, not just
27//! Candle tensors, enabling estimation before tensor creation.
28
29use crate::error::{CoreError, Result};
30use candle_core::DType;
31use std::sync::atomic::{AtomicUsize, Ordering};
32
33/// Default overhead factor applied to memory estimates.
34///
35/// Why 1.1x: CUDA allocators have alignment requirements and fragmentation overhead.
36/// A 10% buffer prevents edge-case OOM errors when estimates are exact.
37pub const DEFAULT_OVERHEAD_FACTOR: f64 = 1.1;
38
39/// Estimate the memory required to store a tensor with given shape and dtype.
40///
41/// This function calculates the raw memory requirement without overhead. Use
42/// [`MemoryTracker::estimate_with_overhead`] for production estimates.
43///
44/// ## Arguments
45///
46/// * `shape` - Tensor dimensions (e.g., `[batch, seq_len, hidden_dim]`)
47/// * `dtype` - Data type determining bytes per element
48///
49/// ## Returns
50///
51/// Memory requirement in bytes.
52///
53/// ## Why This Function
54///
55/// Pre-computing memory requirements allows batch size optimization and preflight
56/// checks before committing to expensive allocations.
57///
58/// ## Example
59///
60/// ```rust
61/// use rust_ai_core::estimate_tensor_bytes;
62/// use candle_core::DType;
63///
64/// // LLaMA-2 7B attention output: [batch, heads, seq, head_dim]
65/// let bytes = estimate_tensor_bytes(&[1, 32, 4096, 128], DType::BF16);
66/// assert_eq!(bytes, 1 * 32 * 4096 * 128 * 2); // 32 MB
67/// ```
68#[must_use]
69pub fn estimate_tensor_bytes(shape: &[usize], dtype: DType) -> usize {
70 let numel: usize = shape.iter().product();
71 numel * dtype.size_in_bytes()
72}
73
74/// Estimate memory for attention computation.
75///
76/// Attention requires storing Q, K, V tensors plus the attention weights matrix.
77/// This function estimates the total memory for a single attention layer.
78///
79/// ## Arguments
80///
81/// * `batch_size` - Number of sequences in the batch
82/// * `num_heads` - Number of attention heads
83/// * `seq_len` - Sequence length (context window)
84/// * `head_dim` - Dimension per attention head
85/// * `dtype` - Data type for all tensors
86///
87/// ## Returns
88///
89/// Estimated memory in bytes for one attention layer.
90///
91/// ## Why This Function
92///
93/// Attention is the primary memory consumer in transformers. The attention weights
94/// matrix scales with `O(seq_len²)`, making it the bottleneck for long sequences.
95/// This estimate helps determine maximum context length for a given VRAM budget.
96///
97/// ## Memory Breakdown
98///
99/// - Q, K, V: 3 × (batch × heads × seq × `head_dim`)
100/// - Attention weights: batch × heads × seq × seq
101/// - Output: batch × heads × seq × `head_dim`
102#[must_use]
103pub fn estimate_attention_memory(
104 batch_size: usize,
105 num_heads: usize,
106 seq_len: usize,
107 head_dim: usize,
108 dtype: DType,
109) -> usize {
110 let bytes_per_elem = dtype.size_in_bytes();
111
112 // Q, K, V tensors
113 let qkv_bytes = 3 * batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
114
115 // Attention weights matrix (the O(n²) component)
116 let attn_weights_bytes = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
117
118 // Output tensor
119 let output_bytes = batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
120
121 qkv_bytes + attn_weights_bytes + output_bytes
122}
123
124/// Memory usage tracker for GPU operations.
125///
126/// Tracks allocated and peak memory usage across operations. Thread-safe via atomics.
127///
128/// ## Why This Struct
129///
130/// Unlike CPU memory which is managed by the OS, GPU memory requires explicit tracking
131/// because:
132///
133/// 1. **No swap**: When VRAM runs out, allocations fail immediately
134/// 2. **Fragmentation**: Repeated allocations can fragment the heap
135/// 3. **Debugging**: Memory leaks on GPU are harder to diagnose than CPU leaks
136///
137/// ## Usage Pattern
138///
139/// ```rust
140/// use rust_ai_core::MemoryTracker;
141///
142/// let tracker = MemoryTracker::new();
143///
144/// // Before allocation
145/// tracker.allocate(1024 * 1024).expect("allocation should succeed"); // 1 MB
146///
147/// // After freeing
148/// tracker.deallocate(1024 * 1024);
149///
150/// println!("Peak usage: {} bytes", tracker.peak_bytes());
151/// ```
152#[derive(Debug)]
153pub struct MemoryTracker {
154 /// Currently allocated bytes.
155 allocated: AtomicUsize,
156 /// Peak allocation during lifetime.
157 peak: AtomicUsize,
158 /// Optional memory limit (0 = unlimited).
159 limit: AtomicUsize,
160 /// Overhead factor for estimates.
161 overhead_factor: f64,
162}
163
164impl Default for MemoryTracker {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl MemoryTracker {
171 /// Create a new memory tracker with no limit.
172 ///
173 /// ## Why No Default Limit
174 ///
175 /// Different GPUs have vastly different VRAM capacities (4GB to 80GB+).
176 /// Setting a default limit would either be too restrictive for high-end cards
177 /// or meaningless for consumer cards. Users should set limits explicitly.
178 #[must_use]
179 pub fn new() -> Self {
180 Self {
181 allocated: AtomicUsize::new(0),
182 peak: AtomicUsize::new(0),
183 limit: AtomicUsize::new(0),
184 overhead_factor: DEFAULT_OVERHEAD_FACTOR,
185 }
186 }
187
188 /// Create a tracker with a memory limit.
189 ///
190 /// ## Arguments
191 ///
192 /// * `limit_bytes` - Maximum allowed allocation in bytes
193 ///
194 /// ## Why Memory Limits
195 ///
196 /// Setting a limit below actual VRAM capacity reserves space for:
197 /// - CUDA context overhead (~200-500 MB)
198 /// - Framework allocations (PyTorch/Candle tensor cache)
199 /// - Other processes sharing the GPU
200 #[must_use]
201 pub fn with_limit(limit_bytes: usize) -> Self {
202 Self {
203 allocated: AtomicUsize::new(0),
204 peak: AtomicUsize::new(0),
205 limit: AtomicUsize::new(limit_bytes),
206 overhead_factor: DEFAULT_OVERHEAD_FACTOR,
207 }
208 }
209
210 /// Set a custom overhead factor for estimates.
211 ///
212 /// ## Arguments
213 ///
214 /// * `factor` - Multiplier applied to estimates (default: 1.1)
215 #[must_use]
216 pub fn with_overhead_factor(mut self, factor: f64) -> Self {
217 self.overhead_factor = factor;
218 self
219 }
220
221 /// Record a memory allocation.
222 ///
223 /// ## Arguments
224 ///
225 /// * `bytes` - Number of bytes allocated
226 ///
227 /// ## Returns
228 ///
229 /// `Ok(())` if allocation is within limits.
230 ///
231 /// ## Errors
232 ///
233 /// Returns `CoreError::OutOfMemory` if allocation would exceed the limit.
234 ///
235 /// ## Why Track Allocations
236 ///
237 /// Explicit tracking catches memory issues early. Without tracking, OOM errors
238 /// occur deep in CUDA kernels with unhelpful error messages.
239 pub fn allocate(&self, bytes: usize) -> Result<()> {
240 // Check limit BEFORE updating state to avoid partial updates on failure
241 let limit = self.limit.load(Ordering::SeqCst);
242 let current = self.allocated.load(Ordering::SeqCst);
243 let new_allocated = current + bytes;
244
245 if limit > 0 && new_allocated > limit {
246 return Err(CoreError::oom(format!(
247 "allocation of {bytes} bytes would exceed limit of {limit} bytes \
248 (current: {current} bytes)"
249 )));
250 }
251
252 // Update allocated (actual update happens here)
253 let actual_new = self.allocated.fetch_add(bytes, Ordering::SeqCst) + bytes;
254
255 // Update peak
256 let mut current_peak = self.peak.load(Ordering::SeqCst);
257 while actual_new > current_peak {
258 match self.peak.compare_exchange_weak(
259 current_peak,
260 actual_new,
261 Ordering::SeqCst,
262 Ordering::SeqCst,
263 ) {
264 Ok(_) => break,
265 Err(p) => current_peak = p,
266 }
267 }
268
269 Ok(())
270 }
271
272 /// Record a memory deallocation.
273 ///
274 /// ## Arguments
275 ///
276 /// * `bytes` - Number of bytes freed
277 pub fn deallocate(&self, bytes: usize) {
278 self.allocated.fetch_sub(bytes, Ordering::SeqCst);
279 }
280
281 /// Get currently allocated bytes.
282 #[must_use]
283 pub fn allocated_bytes(&self) -> usize {
284 self.allocated.load(Ordering::SeqCst)
285 }
286
287 /// Get peak allocation during tracker lifetime.
288 ///
289 /// ## Why Track Peak
290 ///
291 /// Peak usage is more useful than current usage for capacity planning.
292 /// It shows the high-water mark needed to complete a workload.
293 #[must_use]
294 pub fn peak_bytes(&self) -> usize {
295 self.peak.load(Ordering::SeqCst)
296 }
297
298 /// Get configured memory limit (0 = unlimited).
299 #[must_use]
300 pub fn limit_bytes(&self) -> usize {
301 self.limit.load(Ordering::SeqCst)
302 }
303
304 /// Estimate required memory with overhead factor applied.
305 ///
306 /// ## Arguments
307 ///
308 /// * `shape` - Tensor dimensions
309 /// * `dtype` - Data type
310 ///
311 /// ## Returns
312 ///
313 /// Estimated bytes including overhead buffer.
314 #[must_use]
315 pub fn estimate_with_overhead(&self, shape: &[usize], dtype: DType) -> usize {
316 let raw = estimate_tensor_bytes(shape, dtype);
317 #[allow(
318 clippy::cast_sign_loss,
319 clippy::cast_possible_truncation,
320 clippy::cast_precision_loss
321 )]
322 {
323 (raw as f64 * self.overhead_factor) as usize
324 }
325 }
326
327 /// Check if an allocation would fit within limits.
328 ///
329 /// ## Arguments
330 ///
331 /// * `bytes` - Proposed allocation size
332 ///
333 /// ## Returns
334 ///
335 /// `true` if allocation would succeed.
336 #[must_use]
337 pub fn would_fit(&self, bytes: usize) -> bool {
338 let limit = self.limit.load(Ordering::SeqCst);
339 if limit == 0 {
340 return true; // No limit
341 }
342 self.allocated.load(Ordering::SeqCst) + bytes <= limit
343 }
344
345 /// Reset the tracker to initial state.
346 ///
347 /// ## Why Reset
348 ///
349 /// Between training epochs or inference batches, resetting allows tracking
350 /// per-phase memory usage without creating new tracker instances.
351 pub fn reset(&self) {
352 self.allocated.store(0, Ordering::SeqCst);
353 self.peak.store(0, Ordering::SeqCst);
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_estimate_tensor_bytes() {
363 // 1000 f32 elements = 4000 bytes
364 assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F32), 4000);
365
366 // 1000 f16 elements = 2000 bytes
367 assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F16), 2000);
368
369 // Empty tensor = 0 bytes
370 assert_eq!(estimate_tensor_bytes(&[0], DType::F32), 0);
371 }
372
373 #[test]
374 fn test_estimate_attention_memory() {
375 // Simple case: 1 batch, 1 head, 4 seq, 2 head_dim, f32
376 let bytes = estimate_attention_memory(1, 1, 4, 2, DType::F32);
377 // QKV: 3 * 1 * 1 * 4 * 2 * 4 = 96
378 // Attn: 1 * 1 * 4 * 4 * 4 = 64
379 // Out: 1 * 1 * 4 * 2 * 4 = 32
380 // Total: 192
381 assert_eq!(bytes, 192);
382 }
383
384 #[test]
385 fn test_memory_tracker_allocation() {
386 let tracker = MemoryTracker::with_limit(1000);
387
388 // Successful allocation
389 assert!(tracker.allocate(500).is_ok());
390 assert_eq!(tracker.allocated_bytes(), 500);
391
392 // Second allocation
393 assert!(tracker.allocate(400).is_ok());
394 assert_eq!(tracker.allocated_bytes(), 900);
395
396 // Exceeds limit
397 assert!(tracker.allocate(200).is_err());
398 assert_eq!(tracker.allocated_bytes(), 900); // Unchanged
399
400 // Deallocation
401 tracker.deallocate(400);
402 assert_eq!(tracker.allocated_bytes(), 500);
403
404 // Now fits
405 assert!(tracker.allocate(200).is_ok());
406 }
407
408 #[test]
409 fn test_memory_tracker_peak() {
410 let tracker = MemoryTracker::new();
411
412 tracker.allocate(100).unwrap();
413 tracker.allocate(200).unwrap();
414 assert_eq!(tracker.peak_bytes(), 300);
415
416 tracker.deallocate(200);
417 assert_eq!(tracker.allocated_bytes(), 100);
418 assert_eq!(tracker.peak_bytes(), 300); // Peak unchanged
419
420 tracker.allocate(50).unwrap();
421 assert_eq!(tracker.peak_bytes(), 300); // Still 300
422
423 tracker.allocate(300).unwrap();
424 assert_eq!(tracker.peak_bytes(), 450); // New peak
425 }
426
427 #[test]
428 fn test_would_fit() {
429 let tracker = MemoryTracker::with_limit(1000);
430 tracker.allocate(500).unwrap();
431
432 assert!(tracker.would_fit(400));
433 assert!(tracker.would_fit(500));
434 assert!(!tracker.would_fit(501));
435
436 // Unlimited tracker
437 let unlimited = MemoryTracker::new();
438 assert!(unlimited.would_fit(usize::MAX));
439 }
440}