1mod tlsf;
15
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18
19use svod_device::Buffer;
20use svod_dtype::{DType, DeviceSpec};
21use svod_ir::{Op, UOp};
22use tracing::{debug, trace};
23
24use crate::schedule::Schedule;
25
26const MIN_BLOCK_SIZE: usize = 256;
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PlannerMode {
41 Disabled,
44 Remap,
48 Arena,
52}
53
54pub fn parse_mode(raw: Option<&str>) -> PlannerMode {
61 let Some(raw) = raw else {
62 return PlannerMode::Arena;
63 };
64 let normalized = raw.trim().to_ascii_lowercase();
65 match normalized.as_str() {
66 "0" | "off" | "none" | "disabled" => PlannerMode::Disabled,
67 "remap" | "pool" => PlannerMode::Remap,
68 _ => PlannerMode::Arena,
70 }
71}
72
73pub fn mode_from_env() -> PlannerMode {
75 parse_mode(std::env::var("SVOD_MEMORY_PLANNER").ok().as_deref())
76}
77
78type LogicalBufferView = (usize, usize, DType, Vec<usize>);
79
80#[inline]
82fn round_up(size: usize, block_size: usize) -> usize {
83 size.div_ceil(block_size) * block_size
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash)]
99pub struct BufferPoolKey {
100 pub device: DeviceSpec,
102 pub dtype: DType,
104 pub size: usize,
106}
107
108#[derive(Debug, Clone)]
110pub struct BufferLiveness {
111 pub first_appearance: usize,
113 pub last_appearance: usize,
115 pub pool_key: BufferPoolKey,
117 pub prototype: Buffer,
119}
120
121#[derive(Debug, Clone)]
123struct BufferEvent {
124 timestep: usize,
126 is_alloc: bool,
128 buffer_id: u64,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134pub struct ReuseDependency {
135 pub predecessor_step: usize,
137 pub successor_step: usize,
139}
140
141struct ReusableBuffer {
142 buffer: Buffer,
143 released_by_step: usize,
144}
145
146struct PlannerInput {
148 liveness: HashMap<u64, BufferLiveness>,
150 occurrences: Vec<(BufferKey, u64)>,
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
159pub struct BufferKey {
160 pub kernel_idx: usize,
162 pub buffer_idx: usize,
164}
165
166#[derive(Debug)]
168pub struct MemoryPlannerResult {
169 pub buffer_replace: HashMap<BufferKey, Buffer>,
172 pub memory_saved: usize,
174 pub buffers_reused: usize,
176 pub reuse_dependencies: Vec<ReuseDependency>,
178}
179
180fn collect_noopt_buffer_ids(schedule: &Schedule) -> HashSet<u64> {
191 let mut by_storage: HashMap<u64, HashSet<LogicalBufferView>> = HashMap::new();
196 let mut masked_store_ids = HashSet::new();
197 for item in schedule {
198 for buffer in &item.buffers {
199 by_storage.entry(buffer.storage_id().0).or_default().insert((
200 buffer.offset(),
201 buffer.size(),
202 buffer.dtype(),
203 buffer.shape().to_vec(),
204 ));
205 }
206
207 let uop_id_to_buffer_id: HashMap<u64, u64> =
208 item.buffer_uop_ids.iter().copied().zip(item.buffers.iter().map(|b| b.id().0)).collect();
209 for node in item.ast.toposort() {
210 let Op::Store { index, .. } = node.op() else {
211 continue;
212 };
213 collect_masked_store_buffer_ids(index, &uop_id_to_buffer_id, &mut masked_store_ids);
214 }
215 }
216
217 let aliased_storages: HashSet<u64> =
220 by_storage.into_iter().filter_map(|(sid, views)| (views.len() > 1).then_some(sid)).collect();
221 let aliased_ids = schedule.iter().flat_map(|item| {
222 item.buffers
223 .iter()
224 .filter(|b| aliased_storages.contains(&b.storage_id().0))
225 .map(|b| b.id().0)
226 .collect::<Vec<_>>()
227 });
228
229 schedule
230 .iter()
231 .filter(|item| !matches!(item.ast.op(), Op::Sink { .. }))
232 .flat_map(|item| item.buffers.iter().map(|b| b.id().0))
233 .chain(aliased_ids)
234 .chain(masked_store_ids)
235 .collect()
236}
237
238fn collect_masked_store_buffer_ids(
239 index: &Arc<UOp>,
240 uop_id_to_buffer_id: &HashMap<u64, u64>,
241 masked_store_ids: &mut HashSet<u64>,
242) {
243 match index.op() {
244 Op::Index { buffer, gate: Some(_), .. } => {
245 if let Some(buffer_id) = uop_id_to_buffer_id.get(&buffer.buf_uop().id) {
246 masked_store_ids.insert(*buffer_id);
247 }
248 }
249 Op::Index { .. } => {}
250 other => {
251 for child in other.children() {
252 collect_masked_store_buffer_ids(child, uop_id_to_buffer_id, masked_store_ids);
253 }
254 }
255 }
256}
257
258fn should_skip_buffer(buffer: &Buffer, output_buffer_ids: &HashSet<u64>, noopt_buffer_ids: &HashSet<u64>) -> bool {
259 buffer.allocator().device_spec().is_disk()
262 || buffer.offset() != 0
263 || buffer.is_allocated()
264 || output_buffer_ids.contains(&buffer.id().0)
265 || noopt_buffer_ids.contains(&buffer.id().0)
266}
267
268fn analyze_liveness(schedule: &Schedule, output_buffer_ids: &HashSet<u64>) -> PlannerInput {
269 let noopt_buffer_ids = collect_noopt_buffer_ids(schedule);
270 let mut liveness: HashMap<u64, BufferLiveness> = HashMap::new();
271 let mut occurrences: Vec<(BufferKey, u64)> = Vec::new();
272
273 for (step_idx, item) in schedule.iter().enumerate() {
274 for (buf_idx, buffer) in item.buffers.iter().enumerate() {
275 let key = BufferKey { kernel_idx: step_idx, buffer_idx: buf_idx };
276 let buf_id = buffer.id().0;
277
278 if should_skip_buffer(buffer, output_buffer_ids, &noopt_buffer_ids) {
279 trace!(step_idx, buf_idx, buffer_id = buf_id, "skipping buffer in memory planner");
280 continue;
281 }
282
283 occurrences.push((key, buf_id));
284
285 let pool_key = BufferPoolKey {
286 device: buffer.allocator().device_spec(),
287 dtype: buffer.dtype(),
288 size: round_up(buffer.size(), MIN_BLOCK_SIZE),
289 };
290
291 liveness
292 .entry(buf_id)
293 .and_modify(|info| {
294 info.first_appearance = info.first_appearance.min(step_idx);
295 info.last_appearance = info.last_appearance.max(step_idx);
296 })
297 .or_insert_with(|| BufferLiveness {
298 first_appearance: step_idx,
299 last_appearance: step_idx,
300 pool_key,
301 prototype: buffer.clone(),
302 });
303 }
304 }
305
306 debug!(num_optimizable = liveness.len(), "liveness analysis complete");
307
308 PlannerInput { liveness, occurrences }
309}
310
311fn build_event_timeline(liveness: &HashMap<u64, BufferLiveness>) -> Vec<BufferEvent> {
323 let mut events = Vec::with_capacity(liveness.len() * 2);
324
325 for (&buf_id, info) in liveness {
326 events.push(BufferEvent { timestep: info.first_appearance, is_alloc: true, buffer_id: buf_id });
328
329 events.push(BufferEvent { timestep: info.last_appearance + 1, is_alloc: false, buffer_id: buf_id });
331 }
332
333 events.sort_by_key(|e| (e.timestep, e.is_alloc, e.buffer_id));
335
336 events
337}
338
339fn process_events(
352 events: &[BufferEvent],
353 liveness: &HashMap<u64, BufferLiveness>,
354 occurrences: &[(BufferKey, u64)],
355) -> (HashMap<BufferKey, Buffer>, usize, usize, Vec<ReuseDependency>) {
356 let mut free_pools: HashMap<BufferPoolKey, Vec<ReusableBuffer>> = HashMap::new();
357 let mut memory_saved: usize = 0;
358 let mut buffers_reused: usize = 0;
359 let mut reuse_dependencies = Vec::new();
360 let mut chosen_by_id: HashMap<u64, Buffer> = HashMap::new();
361
362 let mut active_buffers: HashMap<u64, Buffer> = HashMap::new();
364
365 for event in events {
366 let info = match liveness.get(&event.buffer_id) {
367 Some(info) => info,
368 None => continue,
369 };
370 let pool_key = &info.pool_key;
371
372 if event.is_alloc {
373 if let Some(pool) = free_pools.get_mut(pool_key)
375 && let Some(reused) = pool.pop()
376 {
377 trace!(timestep = event.timestep, reused_buffer_id = reused.buffer.id().0, "reusing buffer from pool");
378
379 reuse_dependencies.push(ReuseDependency {
380 predecessor_step: reused.released_by_step,
381 successor_step: event.timestep,
382 });
383 chosen_by_id.insert(event.buffer_id, reused.buffer.clone());
384 active_buffers.insert(event.buffer_id, reused.buffer);
385 memory_saved += pool_key.size;
386 buffers_reused += 1;
387 continue;
388 }
389
390 chosen_by_id.insert(event.buffer_id, info.prototype.clone());
392 active_buffers.insert(event.buffer_id, info.prototype.clone());
393 } else {
394 if let Some(buffer) = active_buffers.remove(&event.buffer_id) {
396 free_pools
397 .entry(pool_key.clone())
398 .or_default()
399 .push(ReusableBuffer { buffer, released_by_step: info.last_appearance });
400 }
401 }
402 }
403
404 let mut buffer_replace: HashMap<BufferKey, Buffer> = HashMap::new();
405 for (key, buf_id) in occurrences {
406 if let Some(chosen) = chosen_by_id.get(buf_id)
407 && chosen.id().0 != *buf_id
408 {
409 buffer_replace.insert(*key, chosen.clone());
410 }
411 }
412
413 (buffer_replace, memory_saved, buffers_reused, reuse_dependencies)
414}
415
416type LaneKey = (DeviceSpec, bool);
425
426fn memory_plan_arena(schedule: &Schedule, output_buffer_ids: &HashSet<u64>) -> MemoryPlannerResult {
434 let empty_result = || MemoryPlannerResult {
435 buffer_replace: HashMap::new(),
436 memory_saved: 0,
437 buffers_reused: 0,
438 reuse_dependencies: Vec::new(),
439 };
440
441 let planner_input = analyze_liveness(schedule, output_buffer_ids);
442 let liveness = planner_input.liveness;
443 if liveness.is_empty() {
444 return empty_result();
445 }
446
447 let mut copy_bufs: HashSet<u64> = HashSet::new();
450 for item in schedule {
451 let runtime_ast = crate::realize::runtime_effect_ast(&item.ast);
452 if !matches!(runtime_ast.op(), Op::Copy { .. }) {
453 continue;
454 }
455 for buffer in &item.buffers {
456 let id = buffer.id().0;
457 if liveness.contains_key(&id) {
458 copy_bufs.insert(id);
459 }
460 }
461 }
462
463 let lane_key = |id: u64| -> LaneKey {
464 let info = &liveness[&id];
465 (info.prototype.allocator().device_spec(), copy_bufs.contains(&id))
466 };
467
468 let buf_hold: HashMap<u64, usize> = copy_bufs
471 .iter()
472 .map(|&id| {
473 let info = &liveness[&id];
474 (id, info.last_appearance - info.first_appearance + 1)
475 })
476 .collect();
477
478 let nbytes_rounded: HashMap<u64, usize> =
481 liveness.iter().map(|(&id, info)| (id, round_up(info.prototype.size(), MIN_BLOCK_SIZE))).collect();
482
483 let mut events: Vec<BufferEvent> = Vec::with_capacity(liveness.len() * 2);
485 for (&id, info) in &liveness {
486 events.push(BufferEvent { timestep: info.first_appearance, is_alloc: true, buffer_id: id });
487 events.push(BufferEvent {
488 timestep: info.last_appearance + 1 + buf_hold.get(&id).copied().unwrap_or(0),
489 is_alloc: false,
490 buffer_id: id,
491 });
492 }
493 events.sort_by_key(|e| (e.timestep, e.is_alloc, e.buffer_id));
494
495 let total_bytes: usize = nbytes_rounded.values().sum();
498 let arena_budget = total_bytes.saturating_mul(2).max(MIN_BLOCK_SIZE);
499 let mut tlsfs: HashMap<LaneKey, tlsf::TlsfAllocator> = HashMap::new();
500 let mut offsets: HashMap<u64, usize> = HashMap::new();
501 let mut peaks: HashMap<LaneKey, usize> = HashMap::new();
502 let mut freed_ranges: HashMap<LaneKey, Vec<(usize, usize, usize)>> = HashMap::new();
507 let mut reuse_dependencies: Vec<ReuseDependency> = Vec::new();
508
509 for event in &events {
510 let lane = lane_key(event.buffer_id);
511 let info = &liveness[&event.buffer_id];
512 let alloc =
513 tlsfs.entry(lane.clone()).or_insert_with(|| tlsf::TlsfAllocator::new(arena_budget, 0, MIN_BLOCK_SIZE, 32));
514 if event.is_alloc {
515 let req = nbytes_rounded[&event.buffer_id];
516 let off = match alloc.alloc(req, 1) {
517 Ok(o) => o,
518 Err(e) => {
519 tracing::warn!(?e, "arena planner: TLSF alloc failed; skipping arena rewrite");
520 return empty_result();
521 }
522 };
523 offsets.insert(event.buffer_id, off);
524 let used_end = off + info.prototype.size();
526 let peak = peaks.entry(lane.clone()).or_insert(0);
527 if used_end > *peak {
528 *peak = used_end;
529 }
530 let alloc_end = off + req;
533 if let Some(ranges) = freed_ranges.get(&lane) {
534 for &(prev_off, prev_end, prev_last_step) in ranges {
535 let overlaps = off < prev_end && prev_off < alloc_end;
536 if overlaps {
537 reuse_dependencies.push(ReuseDependency {
538 predecessor_step: prev_last_step,
539 successor_step: info.first_appearance,
540 });
541 }
542 }
543 }
544 if let Some(ranges) = freed_ranges.get_mut(&lane) {
551 ranges.retain(|&(o, e, _)| o >= alloc_end || e <= off);
552 }
553 } else if let Some(off) = offsets.get(&event.buffer_id).copied() {
554 let req = nbytes_rounded[&event.buffer_id];
555 if let Err(e) = alloc.free(off) {
556 tracing::warn!(?e, "arena planner: TLSF free failed; skipping arena rewrite");
557 return empty_result();
558 }
559 freed_ranges.entry(lane).or_default().push((off, off + req, info.last_appearance));
560 }
561 }
562
563 let mut lane_proto: HashMap<LaneKey, Buffer> = HashMap::with_capacity(peaks.len());
566 for (&id, info) in &liveness {
567 lane_proto.entry(lane_key(id)).or_insert_with(|| info.prototype.clone());
568 }
569 let mut arenas: HashMap<LaneKey, Buffer> = HashMap::new();
570 for (lane, &peak) in &peaks {
571 if peak == 0 {
572 continue;
573 }
574 let arena_size = round_up(peak, MIN_BLOCK_SIZE);
575 let prototype = lane_proto.get(lane).expect("every populated lane must have a prototype");
576 let arena = Buffer::new(
577 prototype.allocator_arc(),
578 svod_dtype::DType::UInt8,
579 vec![arena_size],
580 svod_device::allocator::BufferOptions::default(),
581 );
582 arenas.insert(lane.clone(), arena);
583 }
584
585 let mut buffer_replace: HashMap<BufferKey, Buffer> = HashMap::new();
590 let mut buffers_reused = 0usize;
591 for (key, buf_id) in &planner_input.occurrences {
592 let Some(&offset) = offsets.get(buf_id) else {
593 continue;
594 };
595 let Some(arena) = arenas.get(&lane_key(*buf_id)) else {
596 continue;
597 };
598 let info = &liveness[buf_id];
599 let byte_size = info.prototype.size();
600 let view = match arena.view(offset, byte_size) {
601 Ok(v) => v,
602 Err(e) => {
603 tracing::warn!(?e, "arena planner: view failed; skipping rewrite for one slot");
604 continue;
605 }
606 };
607 buffer_replace.insert(*key, view);
608 buffers_reused += 1;
609 }
610
611 let arena_total: usize = peaks.values().map(|&p| round_up(p, MIN_BLOCK_SIZE)).sum();
612 let memory_saved = total_bytes.saturating_sub(arena_total);
613
614 debug!(
615 buffers_planned = liveness.len(),
616 buffers_replaced = buffers_reused,
617 memory_saved_bytes = memory_saved,
618 arena_count = arenas.len(),
619 "arena memory planner complete"
620 );
621
622 MemoryPlannerResult { buffer_replace, memory_saved, buffers_reused, reuse_dependencies }
623}
624
625#[allow(rustdoc::private_intra_doc_links)]
647pub fn memory_planner(schedule: &Schedule, output_buffer_ids: &HashSet<u64>, mode: PlannerMode) -> MemoryPlannerResult {
648 let empty_result = || MemoryPlannerResult {
649 buffer_replace: HashMap::new(),
650 memory_saved: 0,
651 buffers_reused: 0,
652 reuse_dependencies: Vec::new(),
653 };
654
655 if matches!(mode, PlannerMode::Disabled) {
656 return empty_result();
657 }
658
659 if schedule.is_empty() {
660 return empty_result();
661 }
662
663 if matches!(mode, PlannerMode::Arena) {
664 return memory_plan_arena(schedule, output_buffer_ids);
665 }
666
667 let planner_input = analyze_liveness(schedule, output_buffer_ids);
669 let liveness = planner_input.liveness;
670
671 if liveness.is_empty() {
672 debug!("no optimizable buffers found");
673 return empty_result();
674 }
675
676 let events = build_event_timeline(&liveness);
678
679 let (buffer_replace, memory_saved, buffers_reused, reuse_dependencies) =
681 process_events(&events, &liveness, &planner_input.occurrences);
682
683 debug!(
684 buffers_analyzed = liveness.len(),
685 buffers_reused,
686 memory_saved_bytes = memory_saved,
687 "memory planner complete"
688 );
689
690 MemoryPlannerResult { buffer_replace, memory_saved, buffers_reused, reuse_dependencies }
691}
692
693pub fn apply_buffer_replacements(schedule: &mut Schedule, replacements: &HashMap<BufferKey, Buffer>) {
698 for (&key, replacement) in replacements {
699 if let Some(item) = schedule.get_mut(key.kernel_idx)
700 && let Some(buffer) = item.buffers.get_mut(key.buffer_idx)
701 {
702 *buffer = replacement.clone();
703 }
704 }
705}
706
707pub fn apply_reuse_dependencies(schedule: &mut Schedule, reuse_dependencies: &[ReuseDependency]) {
709 for dep in reuse_dependencies {
710 if dep.predecessor_step == dep.successor_step {
711 continue;
712 }
713
714 debug_assert!(
715 dep.successor_step > dep.predecessor_step,
716 "reuse dependency must be forward-edge: predecessor={} >= successor={}",
717 dep.predecessor_step,
718 dep.successor_step,
719 );
720
721 if dep.predecessor_step >= schedule.len() {
722 continue;
723 }
724 let Some(successor) = schedule.get_mut(dep.successor_step) else {
725 continue;
726 };
727 if !successor.instance_dependencies.contains(&dep.predecessor_step) {
728 successor.instance_dependencies.push(dep.predecessor_step);
729 }
730 }
731}
732
733#[cfg(test)]
738#[path = "../test/unit/memory_planner.rs"]
739mod tests;