rust_ai_core/memory.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Memory estimation and tracking utilities for GPU operations.
5//!
6//! ## Why This Module Exists
7//!
8//! GPU memory (VRAM) is a precious and limited resource. Unlike CPU memory, there's no
9//! swap space fallback when VRAM runs out - operations simply fail. This module provides
10//! utilities to estimate memory requirements before allocation and track usage during
11//! execution, enabling crates to:
12//!
13//! 1. **Pre-flight checks**: Verify sufficient VRAM before starting expensive operations
14//! 2. **Batch size optimization**: Automatically adjust batch sizes to fit available memory
15//! 3. **Memory budgeting**: Track allocations across multiple operations
16//! 4. **Debugging**: Identify memory leaks or unexpected allocations
17//!
18//! ## Design Decisions
19//!
20//! - **Conservative estimation**: Estimates include overhead buffers because running out
21//! of memory mid-operation is worse than slightly underutilizing VRAM.
22//!
23//! - **No global state**: `MemoryTracker` is an explicit struct, not a global singleton,
24//! because different parts of an application may need independent tracking.
25//!
26//! - **Candle-agnostic sizes**: Functions work with shapes and dtypes directly, not just
27//! Candle tensors, enabling estimation before tensor creation.
28
29use crate::error::{CoreError, Result};
30use candle_core::DType;
31use std::sync::atomic::{AtomicUsize, Ordering};
32
33/// Default overhead factor applied to memory estimates.
34///
35/// Why 1.1x: CUDA allocators have alignment requirements and fragmentation overhead.
36/// A 10% buffer prevents edge-case OOM errors when estimates are exact.
37pub const DEFAULT_OVERHEAD_FACTOR: f64 = 1.1;
38
39/// Estimate the memory required to store a tensor with given shape and dtype.
40///
41/// This function calculates the raw memory requirement without overhead. Use
42/// [`MemoryTracker::estimate_with_overhead`] for production estimates.
43///
44/// ## Arguments
45///
46/// * `shape` - Tensor dimensions (e.g., `[batch, seq_len, hidden_dim]`)
47/// * `dtype` - Data type determining bytes per element
48///
49/// ## Returns
50///
51/// Memory requirement in bytes.
52///
53/// ## Why This Function
54///
55/// Pre-computing memory requirements allows batch size optimization and preflight
56/// checks before committing to expensive allocations.
57///
58/// ## Example
59///
60/// ```rust
61/// use rust_ai_core::estimate_tensor_bytes;
62/// use candle_core::DType;
63///
64/// // LLaMA-2 7B attention output: [batch, heads, seq, head_dim]
65/// let bytes = estimate_tensor_bytes(&[1, 32, 4096, 128], DType::BF16);
66/// assert_eq!(bytes, 1 * 32 * 4096 * 128 * 2); // 32 MB
67/// ```
68#[must_use]
69pub fn estimate_tensor_bytes(shape: &[usize], dtype: DType) -> usize {
70 let numel: usize = shape.iter().product();
71 numel * dtype.size_in_bytes()
72}
73
74/// Estimate memory for attention computation.
75///
76/// Attention requires storing Q, K, V tensors plus the attention weights matrix.
77/// This function estimates the total memory for a single attention layer.
78///
79/// ## Arguments
80///
81/// * `batch_size` - Number of sequences in the batch
82/// * `num_heads` - Number of attention heads
83/// * `seq_len` - Sequence length (context window)
84/// * `head_dim` - Dimension per attention head
85/// * `dtype` - Data type for all tensors
86///
87/// ## Returns
88///
89/// Estimated memory in bytes for one attention layer.
90///
91/// ## Why This Function
92///
93/// Attention is the primary memory consumer in transformers. The attention weights
94/// matrix scales with `O(seq_len²)`, making it the bottleneck for long sequences.
95/// This estimate helps determine maximum context length for a given VRAM budget.
96///
97/// ## Memory Breakdown
98///
99/// - Q, K, V: 3 × (batch × heads × seq × `head_dim`)
100/// - Attention weights: batch × heads × seq × seq
101/// - Output: batch × heads × seq × `head_dim`
102#[must_use]
103pub fn estimate_attention_memory(
104 batch_size: usize,
105 num_heads: usize,
106 seq_len: usize,
107 head_dim: usize,
108 dtype: DType,
109) -> usize {
110 let bytes_per_elem = dtype.size_in_bytes();
111
112 // Q, K, V tensors
113 let qkv_bytes = 3 * batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
114
115 // Attention weights matrix (the O(n²) component)
116 let attn_weights_bytes = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
117
118 // Output tensor
119 let output_bytes = batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
120
121 qkv_bytes + attn_weights_bytes + output_bytes
122}
123
124/// Memory usage tracker for GPU operations.
125///
126/// Tracks allocated and peak memory usage across operations. Thread-safe via atomics.
127///
128/// ## Why This Struct
129///
130/// Unlike CPU memory which is managed by the OS, GPU memory requires explicit tracking
131/// because:
132///
133/// 1. **No swap**: When VRAM runs out, allocations fail immediately
134/// 2. **Fragmentation**: Repeated allocations can fragment the heap
135/// 3. **Debugging**: Memory leaks on GPU are harder to diagnose than CPU leaks
136///
137/// ## Usage Pattern
138///
139/// ```rust
140/// use rust_ai_core::MemoryTracker;
141///
142/// let tracker = MemoryTracker::new();
143///
144/// // Before allocation
145/// tracker.allocate(1024 * 1024).expect("allocation should succeed"); // 1 MB
146///
147/// // After freeing
148/// tracker.deallocate(1024 * 1024);
149///
150/// println!("Peak usage: {} bytes", tracker.peak_bytes());
151/// ```
152#[derive(Debug)]
153pub struct MemoryTracker {
154 /// Currently allocated bytes.
155 allocated: AtomicUsize,
156 /// Peak allocation during lifetime.
157 peak: AtomicUsize,
158 /// Optional memory limit (0 = unlimited).
159 limit: AtomicUsize,
160 /// Overhead factor for estimates.
161 overhead_factor: f64,
162}
163
164impl Default for MemoryTracker {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl MemoryTracker {
171 /// Create a new memory tracker with no limit.
172 ///
173 /// ## Why No Default Limit
174 ///
175 /// Different GPUs have vastly different VRAM capacities (4GB to 80GB+).
176 /// Setting a default limit would either be too restrictive for high-end cards
177 /// or meaningless for consumer cards. Users should set limits explicitly.
178 #[must_use]
179 pub fn new() -> Self {
180 Self {
181 allocated: AtomicUsize::new(0),
182 peak: AtomicUsize::new(0),
183 limit: AtomicUsize::new(0),
184 overhead_factor: DEFAULT_OVERHEAD_FACTOR,
185 }
186 }
187
188 /// Create a tracker with a memory limit.
189 ///
190 /// ## Arguments
191 ///
192 /// * `limit_bytes` - Maximum allowed allocation in bytes
193 ///
194 /// ## Why Memory Limits
195 ///
196 /// Setting a limit below actual VRAM capacity reserves space for:
197 /// - CUDA context overhead (~200-500 MB)
198 /// - Framework allocations (PyTorch/Candle tensor cache)
199 /// - Other processes sharing the GPU
200 #[must_use]
201 pub fn with_limit(limit_bytes: usize) -> Self {
202 Self {
203 allocated: AtomicUsize::new(0),
204 peak: AtomicUsize::new(0),
205 limit: AtomicUsize::new(limit_bytes),
206 overhead_factor: DEFAULT_OVERHEAD_FACTOR,
207 }
208 }
209
210 /// Set a custom overhead factor for estimates.
211 ///
212 /// ## Arguments
213 ///
214 /// * `factor` - Multiplier applied to estimates (default: 1.1)
215 #[must_use]
216 pub fn with_overhead_factor(mut self, factor: f64) -> Self {
217 self.overhead_factor = factor;
218 self
219 }
220
221 /// Record a memory allocation.
222 ///
223 /// ## Arguments
224 ///
225 /// * `bytes` - Number of bytes allocated
226 ///
227 /// ## Returns
228 ///
229 /// `Ok(())` if allocation is within limits.
230 ///
231 /// ## Errors
232 ///
233 /// Returns `CoreError::OutOfMemory` if allocation would exceed the limit.
234 ///
235 /// ## Why Track Allocations
236 ///
237 /// Explicit tracking catches memory issues early. Without tracking, OOM errors
238 /// occur deep in CUDA kernels with unhelpful error messages.
239 pub fn allocate(&self, bytes: usize) -> Result<()> {
240 // Check limit BEFORE updating state to avoid partial updates on failure
241 let limit = self.limit.load(Ordering::SeqCst);
242 let current = self.allocated.load(Ordering::SeqCst);
243 let new_allocated = current + bytes;
244
245 if limit > 0 && new_allocated > limit {
246 return Err(CoreError::oom(format!(
247 "allocation of {bytes} bytes would exceed limit of {limit} bytes \
248 (current: {current} bytes)"
249 )));
250 }
251
252 // Update allocated (actual update happens here)
253 let actual_new = self.allocated.fetch_add(bytes, Ordering::SeqCst) + bytes;
254
255 // Update peak
256 let mut current_peak = self.peak.load(Ordering::SeqCst);
257 while actual_new > current_peak {
258 match self.peak.compare_exchange_weak(
259 current_peak,
260 actual_new,
261 Ordering::SeqCst,
262 Ordering::SeqCst,
263 ) {
264 Ok(_) => break,
265 Err(p) => current_peak = p,
266 }
267 }
268
269 Ok(())
270 }
271
272 /// Record a memory deallocation.
273 ///
274 /// ## Arguments
275 ///
276 /// * `bytes` - Number of bytes freed
277 pub fn deallocate(&self, bytes: usize) {
278 self.allocated.fetch_sub(bytes, Ordering::SeqCst);
279 }
280
281 /// Get currently allocated bytes.
282 #[must_use]
283 pub fn allocated_bytes(&self) -> usize {
284 self.allocated.load(Ordering::SeqCst)
285 }
286
287 /// Get peak allocation during tracker lifetime.
288 ///
289 /// ## Why Track Peak
290 ///
291 /// Peak usage is more useful than current usage for capacity planning.
292 /// It shows the high-water mark needed to complete a workload.
293 #[must_use]
294 pub fn peak_bytes(&self) -> usize {
295 self.peak.load(Ordering::SeqCst)
296 }
297
298 /// Get configured memory limit (0 = unlimited).
299 #[must_use]
300 pub fn limit_bytes(&self) -> usize {
301 self.limit.load(Ordering::SeqCst)
302 }
303
304 /// Estimate required memory with overhead factor applied.
305 ///
306 /// ## Arguments
307 ///
308 /// * `shape` - Tensor dimensions
309 /// * `dtype` - Data type
310 ///
311 /// ## Returns
312 ///
313 /// Estimated bytes including overhead buffer.
314 #[must_use]
315 pub fn estimate_with_overhead(&self, shape: &[usize], dtype: DType) -> usize {
316 let raw = estimate_tensor_bytes(shape, dtype);
317 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
318 {
319 (raw as f64 * self.overhead_factor) as usize
320 }
321 }
322
323 /// Check if an allocation would fit within limits.
324 ///
325 /// ## Arguments
326 ///
327 /// * `bytes` - Proposed allocation size
328 ///
329 /// ## Returns
330 ///
331 /// `true` if allocation would succeed.
332 #[must_use]
333 pub fn would_fit(&self, bytes: usize) -> bool {
334 let limit = self.limit.load(Ordering::SeqCst);
335 if limit == 0 {
336 return true; // No limit
337 }
338 self.allocated.load(Ordering::SeqCst) + bytes <= limit
339 }
340
341 /// Reset the tracker to initial state.
342 ///
343 /// ## Why Reset
344 ///
345 /// Between training epochs or inference batches, resetting allows tracking
346 /// per-phase memory usage without creating new tracker instances.
347 pub fn reset(&self) {
348 self.allocated.store(0, Ordering::SeqCst);
349 self.peak.store(0, Ordering::SeqCst);
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_estimate_tensor_bytes() {
359 // 1000 f32 elements = 4000 bytes
360 assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F32), 4000);
361
362 // 1000 f16 elements = 2000 bytes
363 assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F16), 2000);
364
365 // Empty tensor = 0 bytes
366 assert_eq!(estimate_tensor_bytes(&[0], DType::F32), 0);
367 }
368
369 #[test]
370 fn test_estimate_attention_memory() {
371 // Simple case: 1 batch, 1 head, 4 seq, 2 head_dim, f32
372 let bytes = estimate_attention_memory(1, 1, 4, 2, DType::F32);
373 // QKV: 3 * 1 * 1 * 4 * 2 * 4 = 96
374 // Attn: 1 * 1 * 4 * 4 * 4 = 64
375 // Out: 1 * 1 * 4 * 2 * 4 = 32
376 // Total: 192
377 assert_eq!(bytes, 192);
378 }
379
380 #[test]
381 fn test_memory_tracker_allocation() {
382 let tracker = MemoryTracker::with_limit(1000);
383
384 // Successful allocation
385 assert!(tracker.allocate(500).is_ok());
386 assert_eq!(tracker.allocated_bytes(), 500);
387
388 // Second allocation
389 assert!(tracker.allocate(400).is_ok());
390 assert_eq!(tracker.allocated_bytes(), 900);
391
392 // Exceeds limit
393 assert!(tracker.allocate(200).is_err());
394 assert_eq!(tracker.allocated_bytes(), 900); // Unchanged
395
396 // Deallocation
397 tracker.deallocate(400);
398 assert_eq!(tracker.allocated_bytes(), 500);
399
400 // Now fits
401 assert!(tracker.allocate(200).is_ok());
402 }
403
404 #[test]
405 fn test_memory_tracker_peak() {
406 let tracker = MemoryTracker::new();
407
408 tracker.allocate(100).unwrap();
409 tracker.allocate(200).unwrap();
410 assert_eq!(tracker.peak_bytes(), 300);
411
412 tracker.deallocate(200);
413 assert_eq!(tracker.allocated_bytes(), 100);
414 assert_eq!(tracker.peak_bytes(), 300); // Peak unchanged
415
416 tracker.allocate(50).unwrap();
417 assert_eq!(tracker.peak_bytes(), 300); // Still 300
418
419 tracker.allocate(300).unwrap();
420 assert_eq!(tracker.peak_bytes(), 450); // New peak
421 }
422
423 #[test]
424 fn test_would_fit() {
425 let tracker = MemoryTracker::with_limit(1000);
426 tracker.allocate(500).unwrap();
427
428 assert!(tracker.would_fit(400));
429 assert!(tracker.would_fit(500));
430 assert!(!tracker.would_fit(501));
431
432 // Unlimited tracker
433 let unlimited = MemoryTracker::new();
434 assert!(unlimited.would_fit(usize::MAX));
435 }
436}