Skip to main content

svod_tensor/memory_planner/
mod.rs

1//! Memory planner for buffer reuse optimization.
2//!
3//! This module implements liveness-based memory planning following Tinygrad's approach.
4//! The memory planner analyzes buffer lifetimes across the schedule and reuses buffers
5//! with non-overlapping lifetimes, reducing memory consumption.
6//!
7//! # Algorithm
8//!
9//! 1. **Liveness Analysis**: Track first/last appearance of each buffer in schedule
10//! 2. **Event Timeline**: Create sorted alloc/free events (frees before allocs at same step)
11//! 3. **Pool-Based Allocation**: Reuse buffers by (device, dtype, size) key
12//! 4. **Apply Replacements**: Map logical buffers to physical buffers
13
14mod tlsf;
15
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use svod_device::Buffer;
20use svod_dtype::{DType, DeviceSpec};
21use svod_ir::{Op, UOp};
22use tracing::{debug, trace};
23
24use crate::schedule::Schedule;
25
26/// Minimum block size for buffer pooling (256-byte alignment, matching tinygrad).
27const MIN_BLOCK_SIZE: usize = 256;
28
29/// Selects the buffer-allocation strategy used by the planner entrypoint.
30///
31/// - `Disabled` short-circuits the planner and emits no replacements.
32/// - `Remap` runs liveness-based pool reuse: groups buffers by
33///   `(device, dtype, rounded_size)` and lets disjoint-lifetime buffers
34///   share an underlying allocation.
35/// - `Arena` packs all plannable buffers into one or two large
36///   per-`(device, copy-lane)` arenas using a TLSF allocator and rewrites
37///   each logical buffer as a `Buffer::view` into its lane's arena
38///   (tinygrad parity).
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PlannerMode {
41    /// Skip the planner entirely. Each `Buffer` keeps its original allocation
42    /// and is freed by lazy `Drop`. Useful for memory-debugging baselines.
43    Disabled,
44    /// Liveness-based pool reuse: groups buffers by
45    /// `(device, dtype, rounded_size)` and lets disjoint-lifetime buffers
46    /// share an underlying allocation via `Arc<Buffer>` swap.
47    Remap,
48    /// Tinygrad-style packing: pack every plannable buffer into one or two
49    /// per-`(device, copy-lane)` arenas using a TLSF allocator and rewrite
50    /// each logical buffer as a fresh `Buffer::view` into its lane's arena.
51    Arena,
52}
53
54/// Pure parser for the `SVOD_MEMORY_PLANNER` env var, exposed for testing.
55///
56/// Default (env unset) is [`PlannerMode::Arena`], matching tinygrad's
57/// `NO_MEMORY_PLANNER=0` default — the arena planner runs unless the user
58/// explicitly opts out. `remap` / `pool` keep the older liveness-based pool
59/// reuse for parity with the previous default if a workload regresses.
60pub fn parse_mode(raw: Option<&str>) -> PlannerMode {
61    let Some(raw) = raw else {
62        return PlannerMode::Arena;
63    };
64    let normalized = raw.trim().to_ascii_lowercase();
65    match normalized.as_str() {
66        "0" | "off" | "none" | "disabled" => PlannerMode::Disabled,
67        "remap" | "pool" => PlannerMode::Remap,
68        // "1" | "on" | "arena" | "" or any unrecognized → Arena (tinygrad default)
69        _ => PlannerMode::Arena,
70    }
71}
72
73/// Read `SVOD_MEMORY_PLANNER` from the environment and resolve to a [`PlannerMode`].
74pub fn mode_from_env() -> PlannerMode {
75    parse_mode(std::env::var("SVOD_MEMORY_PLANNER").ok().as_deref())
76}
77
78type LogicalBufferView = (usize, usize, DType, Vec<usize>);
79
80/// Round up to the nearest multiple of block_size.
81#[inline]
82fn round_up(size: usize, block_size: usize) -> usize {
83    size.div_ceil(block_size) * block_size
84}
85
86// ============================================================================
87// DATA STRUCTURES
88// ============================================================================
89
90/// Key for buffer pooling - groups buffers that can be reused interchangeably.
91///
92/// Buffer reuse is shape-agnostic: codegen reads logical shape from the UOp
93/// graph, runtime dispatch passes raw `*mut u8` pointers, and the planner
94/// skips output buffers (the only consumers of `Buffer::shape()` via
95/// `as_array`/`as_array_mut`). Two non-output buffers with the same
96/// `(device, dtype, rounded_size)` are interchangeable storage regardless
97/// of their logical shapes.
98#[derive(Debug, Clone, PartialEq, Eq, Hash)]
99pub struct BufferPoolKey {
100    /// Device where buffer is allocated.
101    pub device: DeviceSpec,
102    /// Data type of buffer elements.
103    pub dtype: DType,
104    /// Buffer size in bytes (rounded up to MIN_BLOCK_SIZE).
105    pub size: usize,
106}
107
108/// Liveness information for a buffer.
109#[derive(Debug, Clone)]
110pub struct BufferLiveness {
111    /// Index of first schedule step that uses this buffer.
112    pub first_appearance: usize,
113    /// Index of last schedule step that uses this buffer.
114    pub last_appearance: usize,
115    /// Pool key for buffer grouping.
116    pub pool_key: BufferPoolKey,
117    /// Representative logical buffer for this allocation ID.
118    pub prototype: Buffer,
119}
120
121/// Buffer allocation/deallocation event for timeline scheduling.
122#[derive(Debug, Clone)]
123struct BufferEvent {
124    /// Schedule item index when this event occurs.
125    timestep: usize,
126    /// True for allocation, false for deallocation.
127    is_alloc: bool,
128    /// Physical buffer allocation identifier.
129    buffer_id: u64,
130}
131
132/// Schedule-order dependency introduced by a physical buffer reuse.
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub struct ReuseDependency {
135    /// Schedule item that last uses the old logical buffer occupying the storage.
136    pub predecessor_step: usize,
137    /// Schedule item that first uses the new logical buffer reusing that storage.
138    pub successor_step: usize,
139}
140
141struct ReusableBuffer {
142    buffer: Buffer,
143    released_by_step: usize,
144}
145
146/// Collected planner inputs derived from schedule traversal.
147struct PlannerInput {
148    /// Liveness keyed by physical buffer allocation ID.
149    liveness: HashMap<u64, BufferLiveness>,
150    /// Logical schedule slots that are eligible for replacement.
151    occurrences: Vec<(BufferKey, u64)>,
152}
153
154/// Key to identify a buffer within a schedule.
155///
156/// We use (kernel_index, buffer_index) because the same UOp ID might appear
157/// in multiple kernels due to buffer sharing.
158#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
159pub struct BufferKey {
160    /// Index of kernel in the schedule.
161    pub kernel_idx: usize,
162    /// Index of buffer within that kernel's buffer list.
163    pub buffer_idx: usize,
164}
165
166/// Result of memory planning.
167#[derive(Debug)]
168pub struct MemoryPlannerResult {
169    /// Mapping from (kernel_idx, buffer_idx) to replacement buffer.
170    /// Only contains entries for buffers that were replaced.
171    pub buffer_replace: HashMap<BufferKey, Buffer>,
172    /// Total memory saved through buffer reuse (in bytes).
173    pub memory_saved: usize,
174    /// Number of buffers that were reused.
175    pub buffers_reused: usize,
176    /// Execution ordering constraints required by reuse decisions.
177    pub reuse_dependencies: Vec<ReuseDependency>,
178}
179
180// ============================================================================
181// LIVENESS ANALYSIS
182// ============================================================================
183
184/// Analyze buffer liveness across the schedule.
185///
186/// Tracks first and last appearance of each buffer, skipping:
187/// - Already allocated buffers (inputs)
188/// - Output buffers
189/// - Transfer operations
190fn collect_noopt_buffer_ids(schedule: &Schedule) -> HashSet<u64> {
191    // Alias detection groups views/buffers that share the same underlying
192    // storage. Keying by `Buffer::id()` would miss views (since each view
193    // mints a fresh handle id); keying by `storage_id()` correctly groups
194    // every view of one allocation under one bucket.
195    let mut by_storage: HashMap<u64, HashSet<LogicalBufferView>> = HashMap::new();
196    let mut masked_store_ids = HashSet::new();
197    for item in schedule {
198        for buffer in &item.buffers {
199            by_storage.entry(buffer.storage_id().0).or_default().insert((
200                buffer.offset(),
201                buffer.size(),
202                buffer.dtype(),
203                buffer.shape().to_vec(),
204            ));
205        }
206
207        let uop_id_to_buffer_id: HashMap<u64, u64> =
208            item.buffer_uop_ids.iter().copied().zip(item.buffers.iter().map(|b| b.id().0)).collect();
209        for node in item.ast.toposort() {
210            let Op::Store { index, .. } = node.op() else {
211                continue;
212            };
213            collect_masked_store_buffer_ids(index, &uop_id_to_buffer_id, &mut masked_store_ids);
214        }
215    }
216
217    // Map aliased storage ids back to handle ids — every Buffer in the
218    // schedule whose storage has multiple distinct views is non-plannable.
219    let aliased_storages: HashSet<u64> =
220        by_storage.into_iter().filter_map(|(sid, views)| (views.len() > 1).then_some(sid)).collect();
221    let aliased_ids = schedule.iter().flat_map(|item| {
222        item.buffers
223            .iter()
224            .filter(|b| aliased_storages.contains(&b.storage_id().0))
225            .map(|b| b.id().0)
226            .collect::<Vec<_>>()
227    });
228
229    schedule
230        .iter()
231        .filter(|item| !matches!(item.ast.op(), Op::Sink { .. }))
232        .flat_map(|item| item.buffers.iter().map(|b| b.id().0))
233        .chain(aliased_ids)
234        .chain(masked_store_ids)
235        .collect()
236}
237
238fn collect_masked_store_buffer_ids(
239    index: &Arc<UOp>,
240    uop_id_to_buffer_id: &HashMap<u64, u64>,
241    masked_store_ids: &mut HashSet<u64>,
242) {
243    match index.op() {
244        Op::Index { buffer, gate: Some(_), .. } => {
245            if let Some(buffer_id) = uop_id_to_buffer_id.get(&buffer.buf_uop().id) {
246                masked_store_ids.insert(*buffer_id);
247            }
248        }
249        Op::Index { .. } => {}
250        other => {
251            for child in other.children() {
252                collect_masked_store_buffer_ids(child, uop_id_to_buffer_id, masked_store_ids);
253            }
254        }
255    }
256}
257
258fn should_skip_buffer(buffer: &Buffer, output_buffer_ids: &HashSet<u64>, noopt_buffer_ids: &HashSet<u64>) -> bool {
259    // Phase 1 planner only remaps full logical buffers. View/offset buffers require
260    // an alias-preserving remap pass (planned for arena phase).
261    buffer.allocator().device_spec().is_disk()
262        || buffer.offset() != 0
263        || buffer.is_allocated()
264        || output_buffer_ids.contains(&buffer.id().0)
265        || noopt_buffer_ids.contains(&buffer.id().0)
266}
267
268fn analyze_liveness(schedule: &Schedule, output_buffer_ids: &HashSet<u64>) -> PlannerInput {
269    let noopt_buffer_ids = collect_noopt_buffer_ids(schedule);
270    let mut liveness: HashMap<u64, BufferLiveness> = HashMap::new();
271    let mut occurrences: Vec<(BufferKey, u64)> = Vec::new();
272
273    for (step_idx, item) in schedule.iter().enumerate() {
274        for (buf_idx, buffer) in item.buffers.iter().enumerate() {
275            let key = BufferKey { kernel_idx: step_idx, buffer_idx: buf_idx };
276            let buf_id = buffer.id().0;
277
278            if should_skip_buffer(buffer, output_buffer_ids, &noopt_buffer_ids) {
279                trace!(step_idx, buf_idx, buffer_id = buf_id, "skipping buffer in memory planner");
280                continue;
281            }
282
283            occurrences.push((key, buf_id));
284
285            let pool_key = BufferPoolKey {
286                device: buffer.allocator().device_spec(),
287                dtype: buffer.dtype(),
288                size: round_up(buffer.size(), MIN_BLOCK_SIZE),
289            };
290
291            liveness
292                .entry(buf_id)
293                .and_modify(|info| {
294                    info.first_appearance = info.first_appearance.min(step_idx);
295                    info.last_appearance = info.last_appearance.max(step_idx);
296                })
297                .or_insert_with(|| BufferLiveness {
298                    first_appearance: step_idx,
299                    last_appearance: step_idx,
300                    pool_key,
301                    prototype: buffer.clone(),
302                });
303        }
304    }
305
306    debug!(num_optimizable = liveness.len(), "liveness analysis complete");
307
308    PlannerInput { liveness, occurrences }
309}
310
311// ============================================================================
312// EVENT TIMELINE
313// ============================================================================
314
315/// Build sorted event timeline from liveness information.
316///
317/// Events are sorted by (timestep, is_alloc) so that:
318/// - Earlier timesteps come first
319/// - At the same timestep, frees (is_alloc=false) come before allocs (is_alloc=true)
320///
321/// This ordering allows immediate reuse of freed buffers.
322fn build_event_timeline(liveness: &HashMap<u64, BufferLiveness>) -> Vec<BufferEvent> {
323    let mut events = Vec::with_capacity(liveness.len() * 2);
324
325    for (&buf_id, info) in liveness {
326        // Allocation event at first appearance
327        events.push(BufferEvent { timestep: info.first_appearance, is_alloc: true, buffer_id: buf_id });
328
329        // Deallocation event after last appearance
330        events.push(BufferEvent { timestep: info.last_appearance + 1, is_alloc: false, buffer_id: buf_id });
331    }
332
333    // Sort by (timestep, is_alloc) - false < true ensures frees before allocs
334    events.sort_by_key(|e| (e.timestep, e.is_alloc, e.buffer_id));
335
336    events
337}
338
339// ============================================================================
340// POOL-BASED ALLOCATION
341// ============================================================================
342
343/// Process events and compute buffer replacements using pool-based allocation.
344///
345/// For each allocation event:
346/// - Try to reuse a buffer from the pool with matching key
347/// - If no match, the buffer keeps its original allocation
348///
349/// For each deallocation event:
350/// - Return the buffer to the pool for future reuse
351fn process_events(
352    events: &[BufferEvent],
353    liveness: &HashMap<u64, BufferLiveness>,
354    occurrences: &[(BufferKey, u64)],
355) -> (HashMap<BufferKey, Buffer>, usize, usize, Vec<ReuseDependency>) {
356    let mut free_pools: HashMap<BufferPoolKey, Vec<ReusableBuffer>> = HashMap::new();
357    let mut memory_saved: usize = 0;
358    let mut buffers_reused: usize = 0;
359    let mut reuse_dependencies = Vec::new();
360    let mut chosen_by_id: HashMap<u64, Buffer> = HashMap::new();
361
362    // Track live assignment during timeline simulation.
363    let mut active_buffers: HashMap<u64, Buffer> = HashMap::new();
364
365    for event in events {
366        let info = match liveness.get(&event.buffer_id) {
367            Some(info) => info,
368            None => continue,
369        };
370        let pool_key = &info.pool_key;
371
372        if event.is_alloc {
373            // Try to reuse a buffer from the pool
374            if let Some(pool) = free_pools.get_mut(pool_key)
375                && let Some(reused) = pool.pop()
376            {
377                trace!(timestep = event.timestep, reused_buffer_id = reused.buffer.id().0, "reusing buffer from pool");
378
379                reuse_dependencies.push(ReuseDependency {
380                    predecessor_step: reused.released_by_step,
381                    successor_step: event.timestep,
382                });
383                chosen_by_id.insert(event.buffer_id, reused.buffer.clone());
384                active_buffers.insert(event.buffer_id, reused.buffer);
385                memory_saved += pool_key.size;
386                buffers_reused += 1;
387                continue;
388            }
389
390            // No reuse - use original buffer
391            chosen_by_id.insert(event.buffer_id, info.prototype.clone());
392            active_buffers.insert(event.buffer_id, info.prototype.clone());
393        } else {
394            // Deallocation - return buffer to pool
395            if let Some(buffer) = active_buffers.remove(&event.buffer_id) {
396                free_pools
397                    .entry(pool_key.clone())
398                    .or_default()
399                    .push(ReusableBuffer { buffer, released_by_step: info.last_appearance });
400            }
401        }
402    }
403
404    let mut buffer_replace: HashMap<BufferKey, Buffer> = HashMap::new();
405    for (key, buf_id) in occurrences {
406        if let Some(chosen) = chosen_by_id.get(buf_id)
407            && chosen.id().0 != *buf_id
408        {
409            buffer_replace.insert(*key, chosen.clone());
410        }
411    }
412
413    (buffer_replace, memory_saved, buffers_reused, reuse_dependencies)
414}
415
416// ============================================================================
417// ARENA-BASED ALLOCATION
418// ============================================================================
419
420/// Per-`(device, lane)` arena identifier. The `bool` separates the copy lane
421/// (`true`) from the compute lane (`false`); this mirrors tinygrad's lane
422/// keying and prevents introducing copy→compute→copy dependencies that
423/// would force serialization.
424type LaneKey = (DeviceSpec, bool);
425
426/// Tinygrad-style arena planner: replaces every plannable buffer with a
427/// `Buffer::view` into a per-lane arena allocated by [`tlsf::TlsfAllocator`].
428///
429/// Tinygrad rewrites the UOp graph to swap each `BUFFER` for a
430/// `BUFFER_VIEW(arena, ...)`; Svod achieves the same effect at runtime by
431/// populating [`MemoryPlannerResult::buffer_replace`] with arena views, which
432/// the existing [`apply_buffer_replacements`] then swaps into the schedule.
433fn memory_plan_arena(schedule: &Schedule, output_buffer_ids: &HashSet<u64>) -> MemoryPlannerResult {
434    let empty_result = || MemoryPlannerResult {
435        buffer_replace: HashMap::new(),
436        memory_saved: 0,
437        buffers_reused: 0,
438        reuse_dependencies: Vec::new(),
439    };
440
441    let planner_input = analyze_liveness(schedule, output_buffer_ids);
442    let liveness = planner_input.liveness;
443    if liveness.is_empty() {
444        return empty_result();
445    }
446
447    // Identify copy-lane buffers: any plannable buffer that appears as an
448    // argument to a Copy schedule item.
449    let mut copy_bufs: HashSet<u64> = HashSet::new();
450    for item in schedule {
451        let runtime_ast = crate::realize::runtime_effect_ast(&item.ast);
452        if !matches!(runtime_ast.op(), Op::Copy { .. }) {
453            continue;
454        }
455        for buffer in &item.buffers {
456            let id = buffer.id().0;
457            if liveness.contains_key(&id) {
458                copy_bufs.insert(id);
459            }
460        }
461    }
462
463    let lane_key = |id: u64| -> LaneKey {
464        let info = &liveness[&id];
465        (info.prototype.allocator().device_spec(), copy_bufs.contains(&id))
466    };
467
468    // `buf_hold`: copy buffers stay live past their last appearance to avoid
469    // clobbering before downstream copies finish.
470    let buf_hold: HashMap<u64, usize> = copy_bufs
471        .iter()
472        .map(|&id| {
473            let info = &liveness[&id];
474            (id, info.last_appearance - info.first_appearance + 1)
475        })
476        .collect();
477
478    // Per-buffer rounded size and byte size: round to `block_size` so the
479    // TLSF allocator's bucket math stays correct.
480    let nbytes_rounded: HashMap<u64, usize> =
481        liveness.iter().map(|(&id, info)| (id, round_up(info.prototype.size(), MIN_BLOCK_SIZE))).collect();
482
483    // Build event timeline with copy-lane hold extension on free events.
484    let mut events: Vec<BufferEvent> = Vec::with_capacity(liveness.len() * 2);
485    for (&id, info) in &liveness {
486        events.push(BufferEvent { timestep: info.first_appearance, is_alloc: true, buffer_id: id });
487        events.push(BufferEvent {
488            timestep: info.last_appearance + 1 + buf_hold.get(&id).copied().unwrap_or(0),
489            is_alloc: false,
490            buffer_id: id,
491        });
492    }
493    events.sort_by_key(|e| (e.timestep, e.is_alloc, e.buffer_id));
494
495    // Per-lane TLSF allocators. Generous size budget = 2 × Σ(rounded sizes)
496    // so even worst-case fragmentation can fit.
497    let total_bytes: usize = nbytes_rounded.values().sum();
498    let arena_budget = total_bytes.saturating_mul(2).max(MIN_BLOCK_SIZE);
499    let mut tlsfs: HashMap<LaneKey, tlsf::TlsfAllocator> = HashMap::new();
500    let mut offsets: HashMap<u64, usize> = HashMap::new();
501    let mut peaks: HashMap<LaneKey, usize> = HashMap::new();
502    // Track ranges freed within this lane so a later allocator-overlap
503    // becomes an explicit `ReuseDependency`. We record `(offset, end, last_step)`
504    // for every free; on alloc, any live entry whose `[offset, end)` overlaps
505    // the new alloc emits a dep.
506    let mut freed_ranges: HashMap<LaneKey, Vec<(usize, usize, usize)>> = HashMap::new();
507    let mut reuse_dependencies: Vec<ReuseDependency> = Vec::new();
508
509    for event in &events {
510        let lane = lane_key(event.buffer_id);
511        let info = &liveness[&event.buffer_id];
512        let alloc =
513            tlsfs.entry(lane.clone()).or_insert_with(|| tlsf::TlsfAllocator::new(arena_budget, 0, MIN_BLOCK_SIZE, 32));
514        if event.is_alloc {
515            let req = nbytes_rounded[&event.buffer_id];
516            let off = match alloc.alloc(req, 1) {
517                Ok(o) => o,
518                Err(e) => {
519                    tracing::warn!(?e, "arena planner: TLSF alloc failed; skipping arena rewrite");
520                    return empty_result();
521                }
522            };
523            offsets.insert(event.buffer_id, off);
524            // Peak reflects actual byte usage (`buf.arg * itemsize`), not bucket-rounded size.
525            let used_end = off + info.prototype.size();
526            let peak = peaks.entry(lane.clone()).or_insert(0);
527            if used_end > *peak {
528                *peak = used_end;
529            }
530            // Emit reuse dependencies for any freed range this alloc
531            // overlaps — they share storage now that the offset is reused.
532            let alloc_end = off + req;
533            if let Some(ranges) = freed_ranges.get(&lane) {
534                for &(prev_off, prev_end, prev_last_step) in ranges {
535                    let overlaps = off < prev_end && prev_off < alloc_end;
536                    if overlaps {
537                        reuse_dependencies.push(ReuseDependency {
538                            predecessor_step: prev_last_step,
539                            successor_step: info.first_appearance,
540                        });
541                    }
542                }
543            }
544            // Drop freed ranges that this alloc overlaps. Their ReuseDependency
545            // edges were just emitted above, and a future overlapping alloc
546            // should depend on the *current* allocation's last appearance, not
547            // the older eclipsed range — leaving them in would re-emit
548            // redundant deps. The retain predicate keeps only ranges fully
549            // disjoint from `[off, alloc_end)`.
550            if let Some(ranges) = freed_ranges.get_mut(&lane) {
551                ranges.retain(|&(o, e, _)| o >= alloc_end || e <= off);
552            }
553        } else if let Some(off) = offsets.get(&event.buffer_id).copied() {
554            let req = nbytes_rounded[&event.buffer_id];
555            if let Err(e) = alloc.free(off) {
556                tracing::warn!(?e, "arena planner: TLSF free failed; skipping arena rewrite");
557                return empty_result();
558            }
559            freed_ranges.entry(lane).or_default().push((off, off + req, info.last_appearance));
560        }
561    }
562
563    // Allocate one arena buffer per lane, sized to the lane's peak. Precompute a
564    // lane→prototype map so we don't re-scan `liveness` once per lane.
565    let mut lane_proto: HashMap<LaneKey, Buffer> = HashMap::with_capacity(peaks.len());
566    for (&id, info) in &liveness {
567        lane_proto.entry(lane_key(id)).or_insert_with(|| info.prototype.clone());
568    }
569    let mut arenas: HashMap<LaneKey, Buffer> = HashMap::new();
570    for (lane, &peak) in &peaks {
571        if peak == 0 {
572            continue;
573        }
574        let arena_size = round_up(peak, MIN_BLOCK_SIZE);
575        let prototype = lane_proto.get(lane).expect("every populated lane must have a prototype");
576        let arena = Buffer::new(
577            prototype.allocator_arc(),
578            svod_dtype::DType::UInt8,
579            vec![arena_size],
580            svod_device::allocator::BufferOptions::default(),
581        );
582        arenas.insert(lane.clone(), arena);
583    }
584
585    // Build buffer_replace by viewing each plannable buffer's slice of its
586    // lane's arena. `Buffer::view` mints a fresh handle id per view (Path Y),
587    // so disjoint views naturally appear as independent buffers to the
588    // hazard model.
589    let mut buffer_replace: HashMap<BufferKey, Buffer> = HashMap::new();
590    let mut buffers_reused = 0usize;
591    for (key, buf_id) in &planner_input.occurrences {
592        let Some(&offset) = offsets.get(buf_id) else {
593            continue;
594        };
595        let Some(arena) = arenas.get(&lane_key(*buf_id)) else {
596            continue;
597        };
598        let info = &liveness[buf_id];
599        let byte_size = info.prototype.size();
600        let view = match arena.view(offset, byte_size) {
601            Ok(v) => v,
602            Err(e) => {
603                tracing::warn!(?e, "arena planner: view failed; skipping rewrite for one slot");
604                continue;
605            }
606        };
607        buffer_replace.insert(*key, view);
608        buffers_reused += 1;
609    }
610
611    let arena_total: usize = peaks.values().map(|&p| round_up(p, MIN_BLOCK_SIZE)).sum();
612    let memory_saved = total_bytes.saturating_sub(arena_total);
613
614    debug!(
615        buffers_planned = liveness.len(),
616        buffers_replaced = buffers_reused,
617        memory_saved_bytes = memory_saved,
618        arena_count = arenas.len(),
619        "arena memory planner complete"
620    );
621
622    MemoryPlannerResult { buffer_replace, memory_saved, buffers_reused, reuse_dependencies }
623}
624
625// ============================================================================
626// MAIN ENTRY POINT
627// ============================================================================
628
629/// Run memory planner on a schedule.
630///
631/// Analyzes buffer lifetimes and identifies opportunities for buffer reuse.
632/// Returns a mapping from logical buffers to physical buffers.
633///
634/// # Arguments
635///
636/// * `schedule` - The execution schedule to optimize
637/// * `output_buffer_ids` - IDs of output buffers that must not be reused
638/// * `mode` - Selects the planner strategy. [`PlannerMode::Disabled`] returns
639///   an empty result without analyzing the schedule. [`PlannerMode::Remap`]
640///   runs liveness-based pool reuse. [`PlannerMode::Arena`] runs the
641///   tinygrad-style arena packing pass via [`memory_plan_arena`].
642///
643/// # Returns
644///
645/// `MemoryPlannerResult` containing buffer replacements and statistics.
646#[allow(rustdoc::private_intra_doc_links)]
647pub fn memory_planner(schedule: &Schedule, output_buffer_ids: &HashSet<u64>, mode: PlannerMode) -> MemoryPlannerResult {
648    let empty_result = || MemoryPlannerResult {
649        buffer_replace: HashMap::new(),
650        memory_saved: 0,
651        buffers_reused: 0,
652        reuse_dependencies: Vec::new(),
653    };
654
655    if matches!(mode, PlannerMode::Disabled) {
656        return empty_result();
657    }
658
659    if schedule.is_empty() {
660        return empty_result();
661    }
662
663    if matches!(mode, PlannerMode::Arena) {
664        return memory_plan_arena(schedule, output_buffer_ids);
665    }
666
667    // Phase 1: Liveness analysis
668    let planner_input = analyze_liveness(schedule, output_buffer_ids);
669    let liveness = planner_input.liveness;
670
671    if liveness.is_empty() {
672        debug!("no optimizable buffers found");
673        return empty_result();
674    }
675
676    // Phase 2: Build event timeline
677    let events = build_event_timeline(&liveness);
678
679    // Phase 3: Process events and compute replacements
680    let (buffer_replace, memory_saved, buffers_reused, reuse_dependencies) =
681        process_events(&events, &liveness, &planner_input.occurrences);
682
683    debug!(
684        buffers_analyzed = liveness.len(),
685        buffers_reused,
686        memory_saved_bytes = memory_saved,
687        "memory planner complete"
688    );
689
690    MemoryPlannerResult { buffer_replace, memory_saved, buffers_reused, reuse_dependencies }
691}
692
693/// Apply buffer replacements to the schedule.
694///
695/// Modifies the schedule in place, replacing logical buffers with their
696/// physical replacements.
697pub fn apply_buffer_replacements(schedule: &mut Schedule, replacements: &HashMap<BufferKey, Buffer>) {
698    for (&key, replacement) in replacements {
699        if let Some(item) = schedule.get_mut(key.kernel_idx)
700            && let Some(buffer) = item.buffers.get_mut(key.buffer_idx)
701        {
702            *buffer = replacement.clone();
703        }
704    }
705}
706
707/// Add execution-order edges required by physical buffer reuse.
708pub fn apply_reuse_dependencies(schedule: &mut Schedule, reuse_dependencies: &[ReuseDependency]) {
709    for dep in reuse_dependencies {
710        if dep.predecessor_step == dep.successor_step {
711            continue;
712        }
713
714        debug_assert!(
715            dep.successor_step > dep.predecessor_step,
716            "reuse dependency must be forward-edge: predecessor={} >= successor={}",
717            dep.predecessor_step,
718            dep.successor_step,
719        );
720
721        if dep.predecessor_step >= schedule.len() {
722            continue;
723        }
724        let Some(successor) = schedule.get_mut(dep.successor_step) else {
725            continue;
726        };
727        if !successor.instance_dependencies.contains(&dep.predecessor_step) {
728            successor.instance_dependencies.push(dep.predecessor_step);
729        }
730    }
731}
732
733// ============================================================================
734// TESTS
735// ============================================================================
736
737#[cfg(test)]
738#[path = "../test/unit/memory_planner.rs"]
739mod tests;