Skip to main content

svod_tensor/
schedule.rs

1//! Callable scheduling types and execution.
2//!
3//! This module provides types and functions for managing the execution
4//! schedule of tensor operations. After the rangeify pipeline transforms
5//! the computation graph into callable operations (`CALL`), we need to:
6//!
7//! 1. Extract callable operations from the transformed graph
8//! 2. Allocate buffers for intermediate results (PARAM/DEFINE_LOCAL)
9//! 3. Execute callables in dependency order
10//!
11//! The scheduling process converts from lazy tensor operations to
12//! executable callables with properly allocated device buffers.
13
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::Arc;
16
17use svod_device::Buffer;
18use svod_device::device::Device;
19use svod_device::registry;
20use svod_dtype::{DType, DeviceSpec};
21use svod_ir::{Op, UOp};
22use tracing::{debug, trace};
23
24use crate::error::*;
25use crate::{Error, Result};
26use snafu::ResultExt;
27
28fn canonicalize_callable_source(src: &Arc<UOp>) -> Arc<UOp> {
29    let mut cur = src.clone();
30    loop {
31        match cur.op() {
32            Op::After { .. }
33            | Op::Buffer { .. }
34            | Op::Param { .. }
35            | Op::MSelect { .. }
36            | Op::MStack { .. }
37            | Op::Bind { .. } => return cur,
38            _ => {
39                let sources = cur.op().sources();
40                let Some(next) = sources.first() else {
41                    return cur;
42                };
43                if Arc::ptr_eq(&cur, next) {
44                    return cur;
45                }
46                cur = (*next).clone();
47            }
48        }
49    }
50}
51
52fn source_primary_buffer_id(src: &Arc<UOp>) -> Option<u64> {
53    let src = canonicalize_callable_source(src);
54    match src.op() {
55        Op::Buffer { .. } | Op::Param { .. } | Op::After { .. } => Some(src.buf_uop().id),
56        Op::Bind { .. } => None,
57        Op::MSelect { buffer, device_index } => {
58            if let Op::MStack { buffers } = buffer.op() {
59                buffers.get(*device_index).map(|b| b.buf_uop().id).or_else(|| Some(src.buf_uop().id))
60            } else {
61                Some(src.buf_uop().id)
62            }
63        }
64        Op::MStack { buffers } => buffers.first().map(|b| b.buf_uop().id),
65        _ => None,
66    }
67}
68
69fn collect_callable_dep_ids(dep: &Arc<UOp>, out: &mut HashSet<u64>) -> Result<()> {
70    match dep.op() {
71        Op::Call { .. } => {
72            out.insert(dep.id);
73            Ok(())
74        }
75        Op::End { computation, .. } => {
76            if matches!(computation.op(), Op::Call { .. }) {
77                out.insert(computation.id);
78                Ok(())
79            } else {
80                IrConstructionSnafu {
81                    details: format!("AFTER dependency END must wrap CALL, got {:?}", computation.op()),
82                }
83                .fail()
84            }
85        }
86        Op::Store { .. } => Ok(()),
87        Op::After { deps, .. } => {
88            for nested in deps {
89                collect_callable_dep_ids(nested, out)?;
90            }
91            Ok(())
92        }
93        other => IrConstructionSnafu {
94            details: format!("AFTER dependency must be CALL/END(CALL)/STORE/AFTER, got {other:?}"),
95        }
96        .fail(),
97    }
98}
99
100type AfterDependencySplit = (Vec<Arc<UOp>>, Vec<Arc<UOp>>);
101
102fn split_after_dependencies(after: &Arc<UOp>) -> Result<AfterDependencySplit> {
103    let Op::After { deps, .. } = after.op() else {
104        return IrConstructionSnafu {
105            details: format!("expected AFTER when splitting dependencies, got {:?}", after.op()),
106        }
107        .fail();
108    };
109
110    let mut kernels = Vec::new();
111    let mut after_deps = Vec::new();
112    for dep in deps {
113        match dep.op() {
114            Op::Call { .. } => kernels.push(dep.clone()),
115            Op::End { computation, .. } if matches!(computation.op(), Op::Call { .. }) => kernels.push(dep.clone()),
116            Op::After { .. } => after_deps.push(dep.clone()),
117            Op::Store { .. } => {}
118            other => {
119                return IrConstructionSnafu {
120                    details: format!("AFTER dependency must be CALL/END(CALL)/STORE/AFTER, got {other:?}"),
121                }
122                .fail();
123            }
124        }
125    }
126
127    Ok((kernels, after_deps))
128}
129
130fn collect_source_dependency_callable_ids(src: &Arc<UOp>, out: &mut HashSet<u64>) -> Result<()> {
131    let src = canonicalize_callable_source(src);
132    match src.op() {
133        Op::After { .. } => {
134            let (kernels, after_deps) = split_after_dependencies(&src)?;
135            for kernel in kernels {
136                collect_callable_dep_ids(&kernel, out)?;
137            }
138            for dep in after_deps {
139                collect_source_dependency_callable_ids(&dep, out)?;
140            }
141            Ok(())
142        }
143        Op::MStack { buffers } => {
144            for buffer in buffers {
145                collect_source_dependency_callable_ids(buffer, out)?;
146            }
147            Ok(())
148        }
149        Op::MSelect { buffer, .. } => collect_source_dependency_callable_ids(buffer, out),
150        Op::Buffer { .. } | Op::Param { .. } | Op::Bind { .. } => Ok(()),
151        other => IrConstructionSnafu {
152            details: format!("input to callable must resolve to AFTER/BUFFER/PARAM/MSELECT/MSTACK/BIND, got {other:?}"),
153        }
154        .fail(),
155    }
156}
157
158fn callable_sources(callable: &Arc<UOp>) -> Option<Vec<Arc<UOp>>> {
159    match callable.op() {
160        Op::Call { args, .. } => Some(args.iter().cloned().collect()),
161        _ => None,
162    }
163}
164
165/// Schedule-level RANGEs are uniquely identified by being paired with an
166/// `END(Call)` in the transformed graph — the END structurally proves the
167/// Range wraps a kernel call. Standalone `Bind(DefineVar, Range)` arguments
168/// (user-supplied symbolic variable binds) reference Ranges with no END
169/// pairing and must be skipped to avoid being mistaken for loop wrappers.
170fn collect_scheduled_range_ids(root: &Arc<UOp>, callable_ids: &HashSet<u64>) -> HashSet<u64> {
171    let mut ids = HashSet::new();
172    for node in root.toposort_call_aware(false) {
173        let Op::End { computation, ranges } = node.op() else { continue };
174        if !matches!(computation.op(), Op::Call { .. }) || !callable_ids.contains(&computation.id) {
175            continue;
176        }
177        for r in ranges {
178            if matches!(r.op(), Op::Range { .. }) {
179                ids.insert(r.id);
180            }
181        }
182    }
183    ids
184}
185
186fn collect_call_bound_ranges(callable: &Arc<UOp>, scheduled_range_ids: &HashSet<u64>) -> Result<Vec<BoundRangeRef>> {
187    let Op::Call { args, .. } = callable.op() else {
188        return ExpectedCallableOpSnafu.fail();
189    };
190
191    let mut bound_ranges = Vec::new();
192    for arg in args {
193        let Op::Bind { var, value } = arg.op() else {
194            continue;
195        };
196        let Op::DefineVar { name, .. } = var.op() else {
197            return IrConstructionSnafu {
198                details: format!("CALL BIND source must wrap DEFINE_VAR, got {:?}", var.op()),
199            }
200            .fail();
201        };
202        let Op::Range { .. } = value.op() else {
203            // User variable binds (`BIND(DEFINE_VAR, CONST)`) are not schedule loops.
204            continue;
205        };
206        // Only Ranges paired with an `END(Call)` are schedule-level wrappers;
207        // standalone Range-valued binds carry runtime values, not loop counters.
208        if !scheduled_range_ids.contains(&value.id) {
209            continue;
210        }
211        bound_ranges.push(BoundRangeRef { var_name: name.clone(), range_uop: value.clone() });
212    }
213    Ok(bound_ranges)
214}
215
216fn collect_linear_sched_ops_internal(
217    root: &Arc<UOp>,
218    callable_ids: &HashSet<u64>,
219    scheduled_range_ids: &HashSet<u64>,
220) -> Result<Vec<LinearSchedOp>> {
221    let mut linear_ops = Vec::new();
222
223    for node in root.toposort_call_aware(false) {
224        match node.op() {
225            Op::Range { .. } if scheduled_range_ids.contains(&node.id) => {
226                linear_ops.push(LinearSchedOp::Range { range: node.clone() });
227            }
228            Op::Call { .. } if callable_ids.contains(&node.id) => {
229                linear_ops.push(LinearSchedOp::Call { kernel_id: node.id });
230            }
231            Op::End { computation, ranges } if matches!(computation.op(), Op::Call { .. }) => {
232                if !callable_ids.contains(&computation.id) {
233                    continue;
234                }
235                let wrapper_ranges: Vec<Arc<UOp>> =
236                    ranges.iter().filter(|r| matches!(r.op(), Op::Range { .. })).cloned().collect();
237                match wrapper_ranges.as_slice() {
238                    [] => {}
239                    [outer] => linear_ops.push(LinearSchedOp::End { range: outer.clone(), kernel_id: computation.id }),
240                    _ => {
241                        return IrConstructionSnafu {
242                            details: format!(
243                                "END(CALL) must close at most one wrapper range in strict scheduler, got {}",
244                                wrapper_ranges.len()
245                            ),
246                        }
247                        .fail();
248                    }
249                }
250            }
251            _ => {}
252        }
253    }
254
255    if linear_ops.is_empty() {
256        return IrConstructionSnafu { details: "strict scheduler produced empty linear control stream".to_string() }
257            .fail();
258    }
259    Ok(linear_ops)
260}
261
262/// Eagerly unroll the schedule control stream into a flat list of kernel
263/// invocations.
264///
265/// Every outer loop iteration produces one invocation per kernel inside it,
266/// with concrete `fixedvars` derived from the loop counters at that point.
267/// Outer ranges must have concrete `vmin`/`vmax` (validated by
268/// `schedule_range_bounds`) — there is no symbolic-iteration support today.
269fn collect_kernel_invocations(
270    root: &Arc<UOp>,
271    items: &[PreScheduleItem],
272    scheduled_range_ids: &HashSet<u64>,
273) -> Result<Vec<KernelInvocation>> {
274    let callable_ids: HashSet<u64> = items.iter().map(|it| it.kernel.id).collect();
275    let linear_ops = collect_linear_sched_ops_internal(root, &callable_ids, scheduled_range_ids)?;
276
277    let bound_ranges_by_kernel: HashMap<u64, &[BoundRangeRef]> =
278        items.iter().map(|it| (it.kernel.id, it.bound_ranges.as_slice())).collect();
279
280    // Pre-validation: every declared Range must have a matching End, and every
281    // bound_range on every kernel must reference a declared Range.
282    let mut declared_ranges: HashSet<u64> = HashSet::new();
283    let mut ended_ranges: HashSet<u64> = HashSet::new();
284    for op in &linear_ops {
285        match op {
286            LinearSchedOp::Range { range } => {
287                declared_ranges.insert(range.id);
288            }
289            LinearSchedOp::End { range, .. } => {
290                ended_ranges.insert(range.id);
291            }
292            LinearSchedOp::Call { .. } => {}
293        }
294    }
295    for &rid in &declared_ranges {
296        if !ended_ranges.contains(&rid) {
297            return IrConstructionSnafu { details: format!("schedule range {rid} is missing END in strict scheduler") }
298                .fail();
299        }
300    }
301    for item in items {
302        for br in &item.bound_ranges {
303            if !declared_ranges.contains(&br.range_uop.id) {
304                return IrConstructionSnafu {
305                    details: format!(
306                        "CALL {} bound variable '{}' references schedule range {} missing from linear schedule",
307                        item.kernel.id, br.var_name, br.range_uop.id
308                    ),
309                }
310                .fail();
311            }
312        }
313    }
314
315    // Bytecode interpreter (range/end drives a counter, call emits an
316    // invocation). The output is the eagerly-unrolled invocation list.
317    let mut invocations = Vec::new();
318    let mut in_ranges: HashMap<u64, i64> = HashMap::new();
319    let mut range_ptrs: HashMap<u64, usize> = HashMap::new();
320    let mut range_bounds: HashMap<u64, (i64, i64)> = HashMap::new();
321
322    let mut sched_ptr = 0usize;
323    while sched_ptr < linear_ops.len() {
324        match &linear_ops[sched_ptr] {
325            LinearSchedOp::Range { range } => {
326                let bounds = if let Some(bounds) = range_bounds.get(&range.id).copied() {
327                    bounds
328                } else {
329                    let bounds = schedule_range_bounds(range)?;
330                    range_bounds.insert(range.id, bounds);
331                    bounds
332                };
333                in_ranges.insert(range.id, bounds.0);
334                range_ptrs.insert(range.id, sched_ptr + 1);
335            }
336            LinearSchedOp::End { range, kernel_id } => {
337                if !bound_ranges_by_kernel.contains_key(kernel_id) {
338                    return IrConstructionSnafu {
339                        details: format!("linear END references unknown CALL id {kernel_id}"),
340                    }
341                    .fail();
342                }
343                let (_, vmax) = if let Some(bounds) = range_bounds.get(&range.id).copied() {
344                    bounds
345                } else {
346                    let bounds = schedule_range_bounds(range)?;
347                    range_bounds.insert(range.id, bounds);
348                    bounds
349                };
350                let Some(cur) = in_ranges.get_mut(&range.id) else {
351                    return IrConstructionSnafu {
352                        details: format!("END references schedule range {} that is not active", range.id),
353                    }
354                    .fail();
355                };
356                if *cur < vmax {
357                    *cur += 1;
358                    let Some(jump_ptr) = range_ptrs.get(&range.id).copied() else {
359                        return IrConstructionSnafu {
360                            details: format!("missing loop jump pointer for schedule range {}", range.id),
361                        }
362                        .fail();
363                    };
364                    sched_ptr = jump_ptr;
365                    continue;
366                }
367            }
368            LinearSchedOp::Call { kernel_id } => {
369                let Some(bound_ranges) = bound_ranges_by_kernel.get(kernel_id) else {
370                    return IrConstructionSnafu {
371                        details: format!("linear CALL references unknown kernel id {kernel_id}"),
372                    }
373                    .fail();
374                };
375                let mut fixedvars = HashMap::new();
376                for br in *bound_ranges {
377                    let Some(value) = in_ranges.get(&br.range_uop.id).copied() else {
378                        return IrConstructionSnafu {
379                            details: format!(
380                                "CALL {} bound variable '{}' references inactive schedule range {}",
381                                kernel_id, br.var_name, br.range_uop.id
382                            ),
383                        }
384                        .fail();
385                    };
386                    fixedvars.insert(br.var_name.clone(), value);
387                }
388                invocations.push(KernelInvocation { kernel_id: *kernel_id, fixedvars });
389            }
390        }
391        sched_ptr += 1;
392    }
393
394    Ok(invocations)
395}
396
397fn analyze_callable_dependencies(callables: &[Arc<UOp>], root: &Arc<UOp>) -> Result<Vec<HashSet<usize>>> {
398    // Build callable ID → index mapping
399    let callable_idx: HashMap<u64, usize> = callables.iter().enumerate().map(|(i, c)| (c.id, i)).collect();
400    // Build dependency edges from source-driven dependency extraction
401    // (avoids global writer-union heuristics).
402    let mut dependencies: Vec<HashSet<usize>> = vec![HashSet::new(); callables.len()];
403
404    for (consumer_idx, callable) in callables.iter().enumerate() {
405        let mut dep_ids = HashSet::new();
406        if let Some(sources) = callable_sources(callable) {
407            for src in sources {
408                collect_source_dependency_callable_ids(&src, &mut dep_ids)?;
409            }
410        }
411
412        for dep_id in dep_ids {
413            let Some(&producer_idx) = callable_idx.get(&dep_id) else {
414                return IrConstructionSnafu {
415                    details: format!("callable dependency references unknown callable id {dep_id}"),
416                }
417                .fail();
418            };
419            if producer_idx != consumer_idx {
420                dependencies[consumer_idx].insert(producer_idx);
421            }
422        }
423    }
424
425    // Preserve ordering-only dependencies encoded through AFTER surfaces that
426    // may include void/custom callables with no direct buffer edge.
427    for node in root.toposort() {
428        let Op::After { .. } = node.op() else {
429            continue;
430        };
431
432        let (kernels, after_deps) = split_after_dependencies(&node)?;
433        for kernel in kernels {
434            let callable = match kernel.op() {
435                Op::Call { .. } => kernel.clone(),
436                Op::End { computation, .. } => computation.clone(),
437                _ => unreachable!("split_after_dependencies only returns CALL/END(CALL) kernels"),
438            };
439
440            let Some(&consumer_idx) = callable_idx.get(&callable.id) else {
441                return IrConstructionSnafu {
442                    details: format!("AFTER dependency references unknown callable id {}", callable.id),
443                }
444                .fail();
445            };
446
447            let mut dep_ids = HashSet::new();
448            for dep in &after_deps {
449                collect_source_dependency_callable_ids(dep, &mut dep_ids)?;
450            }
451
452            for dep_id in dep_ids {
453                let Some(&producer_idx) = callable_idx.get(&dep_id) else {
454                    return IrConstructionSnafu {
455                        details: format!("callable dependency references unknown callable id {dep_id}"),
456                    }
457                    .fail();
458                };
459                if producer_idx != consumer_idx {
460                    dependencies[consumer_idx].insert(producer_idx);
461                }
462            }
463        }
464    }
465
466    Ok(dependencies)
467}
468
469/// Input buffers collected before schedule creation.
470///
471/// Maps BUFFER UOp ID → Buffer for input tensors.
472/// This keeps schedule instantiation explicit and avoids global lookups.
473pub type InputBuffers = HashMap<u64, Buffer>;
474
475/// Schedule-level Range bound to a DEFINE_VAR via a CALL argument's
476/// `Bind(DefineVar, Range)`. Each invocation of the wrapped CALL substitutes
477/// the loop counter value into the kernel's variable.
478#[derive(Clone, Debug)]
479pub struct BoundRangeRef {
480    /// Variable name (e.g., "range_0")
481    pub var_name: String,
482    /// RANGE UOp for this bound variable.
483    pub range_uop: Arc<UOp>,
484}
485
486/// Linearized scheduling control op (internal to schedule construction).
487///
488/// The strict scheduler walks these as a small bytecode (Range/End drive a
489/// loop counter, Call emits an invocation) — see `collect_kernel_invocations`.
490/// Eager unrolling at pre-schedule time turns them into a flat
491/// `Vec<KernelInvocation>`.
492#[derive(Clone, Debug)]
493enum LinearSchedOp {
494    Range { range: Arc<UOp> },
495    Call { kernel_id: u64 },
496    End { range: Arc<UOp>, kernel_id: u64 },
497}
498
499/// One concrete kernel invocation: a kernel id + its loop-resolved variable bindings.
500///
501/// The schedule is a flat list of these (eagerly unrolled at
502/// `create_pre_schedule` time); each element is an atomic kernel CALL with
503/// concrete bindings.
504#[derive(Clone, Debug)]
505pub struct KernelInvocation {
506    /// Kernel ID — looked up against the `PreScheduleItem.kernel.id` index.
507    pub kernel_id: u64,
508    /// Concrete `var_name → value` bindings produced by the surrounding loop
509    /// counters at the moment of this invocation.
510    pub fixedvars: HashMap<String, i64>,
511}
512
513/// A single executable callable with its buffers and variable bindings.
514///
515/// Each ScheduleItem represents one callable that needs to be compiled
516/// and executed. The callable AST contains STORE operations that write
517/// results to buffers.
518///
519/// Schedule items are fully expanded during schedule instantiation.
520#[derive(Clone)]
521pub struct ScheduleItem {
522    /// The callable wrapper UOp (`CALL`) used for dependency identity.
523    pub kernel: Arc<UOp>,
524
525    /// The inner callable AST (typically SINK containing STORE ops) - for codegen
526    pub ast: Arc<UOp>,
527
528    /// Device buffers for this callable (in order expected by codegen)
529    pub buffers: Vec<Buffer>,
530
531    /// UOp IDs under which each buffer was registered in buffer index.
532    /// Same length as `buffers`. Used for cleanup - to remove buffers from
533    /// the global registry, we need to know what key they were registered under.
534    pub buffer_uop_ids: Vec<u64>,
535
536    /// Fixed variable values for this specific kernel invocation.
537    /// Maps variable name (e.g., "range_0") to concrete i64 value.
538    /// Always concrete in the strict scheduler path.
539    pub fixedvars: HashMap<String, i64>,
540
541    /// Names of variables in `fixedvars` whose values came from schedule-loop
542    /// counters. User `var_vals` must not override these — see
543    /// `collect_non_overridable_fixedvars`.
544    pub loop_var_names: HashSet<String>,
545
546    /// Callable UOp IDs that must complete before this item can execute.
547    /// Empty for callables without dependencies (first in chain or independent).
548    /// Dependencies are implicit in scheduling order after topological sort.
549    pub dependencies: Vec<u64>,
550
551    /// Concrete schedule-item indices that must complete before this item.
552    /// Used for ordering constraints that cannot be represented by callable ID
553    /// after strict unrolling creates repeated callable IDs.
554    pub instance_dependencies: Vec<usize>,
555
556    /// Additional UOp IDs registered as aliases in buffer index.
557    /// These are IDs where the same buffer was registered under a different key
558    /// for lookup convenience. They need to be cleaned up along with buffer_uop_ids.
559    pub alias_registered_ids: Vec<u64>,
560}
561
562/// Full execution schedule (list of callables in dependency order).
563pub type Schedule = Vec<ScheduleItem>;
564
565/// Cached pre-schedule item.
566///
567/// Contains callable identity/AST and argument UOps, but no concrete buffers.
568#[derive(Clone)]
569pub struct PreScheduleItem {
570    /// Callable wrapper UOp (`CALL`) used for dependency identity.
571    pub kernel: Arc<UOp>,
572    /// Callable body AST used for codegen.
573    pub ast: Arc<UOp>,
574    /// Callable argument UOps in canonical order.
575    pub sources: Vec<Arc<UOp>>,
576    /// Callable dependencies by callable UOp ID.
577    pub dependencies: Vec<u64>,
578    /// Schedule-level Range bindings (`BIND(DEFINE_VAR, RANGE)`) from CALL args.
579    pub bound_ranges: Vec<BoundRangeRef>,
580}
581
582/// Cached pre-schedule artifact.
583///
584/// A flat list of kernel invocations with their concrete bindings. Outer
585/// loops are eagerly unrolled at construction time, so there is no
586/// schedule-level Range/End bytecode — just one entry per kernel call.
587#[derive(Clone)]
588pub struct PreSchedule {
589    /// Per-kernel descriptor pool indexed by `kernel.id` from `KernelInvocation`.
590    pub items: Vec<PreScheduleItem>,
591    /// Flat sequence of kernel invocations after eager loop unrolling.
592    pub invocations: Vec<KernelInvocation>,
593    /// Output buffers in sink source order.
594    pub output_buffer_uops: Vec<Arc<UOp>>,
595}
596
597type SortedCallables = (Vec<Arc<UOp>>, HashMap<u64, Vec<u64>>);
598
599/// Result of schedule creation, including output buffer identification.
600pub struct ScheduleResult {
601    /// The schedule items in dependency order.
602    pub items: Schedule,
603    /// UOp IDs of output buffers, in SINK source order.
604    /// Extracted directly from the SINK's sources via `buf_uop()`.
605    /// For single-tensor realize, contains one ID.
606    pub output_uop_ids: Vec<u64>,
607}
608
609/// Buffers collected for a single callable.
610struct CallableBuffers {
611    /// Device buffers in codegen order.
612    buffers: Vec<Buffer>,
613    /// UOp IDs for each buffer.
614    uop_ids: Vec<u64>,
615    /// Additional alias IDs for cleanup.
616    alias_ids: Vec<u64>,
617}
618
619/// Sort callables by dependencies (producers before consumers).
620///
621/// Uses Kahn's algorithm for topological sort based on buffer dependencies
622/// derived from the graph structure (AFTER nodes and callable sources).
623/// This ensures producer callables are processed before consumers, which is
624/// critical for buffer sharing: the producer allocates the buffer first,
625/// then the consumer finds it in the registry via `get_or_create_buffer`.
626fn sort_callables_by_dependencies(callables: &[Arc<UOp>], root: &Arc<UOp>) -> Result<SortedCallables> {
627    debug!(num_callables = callables.len(), "sorting callables by dependencies");
628
629    let dependencies = analyze_callable_dependencies(callables, root)?;
630
631    // Kahn's algorithm for topological sort
632    let mut in_degree: Vec<usize> = dependencies.iter().map(|deps| deps.len()).collect();
633    let mut dependents: Vec<Vec<usize>> = vec![vec![]; callables.len()];
634
635    for (consumer, deps) in dependencies.iter().enumerate() {
636        for &producer in deps {
637            dependents[producer].push(consumer);
638        }
639    }
640
641    let mut queue: VecDeque<usize> =
642        in_degree.iter().enumerate().filter(|&(_, &deg)| deg == 0).map(|(idx, _)| idx).collect();
643
644    let mut sorted_indices = Vec::new();
645    while let Some(idx) = queue.pop_front() {
646        sorted_indices.push(idx);
647        for &dependent in &dependents[idx] {
648            in_degree[dependent] -= 1;
649            if in_degree[dependent] == 0 {
650                queue.push_back(dependent);
651            }
652        }
653    }
654
655    if sorted_indices.len() < callables.len() {
656        return DependencyCyclesSnafu.fail();
657    }
658
659    let sorted: Vec<Arc<UOp>> = sorted_indices.iter().map(|&idx| callables[idx].clone()).collect();
660
661    let dependency_ids_by_callable: HashMap<u64, Vec<u64>> = callables
662        .iter()
663        .enumerate()
664        .map(|(idx, callable)| {
665            let mut deps: Vec<u64> = dependencies[idx].iter().map(|&dep_idx| callables[dep_idx].id).collect();
666            deps.sort_unstable();
667            (callable.id, deps)
668        })
669        .collect();
670
671    debug!(num_sorted = sorted.len(), "callables sorted");
672
673    Ok((sorted, dependency_ids_by_callable))
674}
675
676/// Extract callables from transformed graph and create pre-schedule artifact.
677///
678/// This function walks the transformed UOp graph (after rangeify and
679/// kernel splitting) and extracts all callable wrappers. For each callable,
680/// it records callable identity, dependencies, and strict control-flow ops.
681///
682/// # Arguments
683///
684/// * `transformed` - The UOp graph after rangeify + kernel splitting
685///
686/// # Returns
687///
688/// A pre-schedule artifact ready for per-run instantiation.
689///
690/// # Errors
691///
692/// Returns error if:
693/// - No callables found after scheduling pipeline
694/// - Callable dependency graph contains cycles
695pub fn create_pre_schedule(transformed: Arc<UOp>) -> Result<PreSchedule> {
696    // Step 1: Find all callable wrappers (CALL) without descending into CALL bodies.
697    let mut callables = Vec::new();
698    for node in transformed.toposort_call_aware(false) {
699        if matches!(node.op(), Op::Call { .. }) {
700            callables.push(node);
701        }
702    }
703
704    if callables.is_empty() {
705        return NoKernelsFoundSnafu.fail();
706    }
707
708    // Step 1.5: Sort callables by dependencies (producers before consumers)
709    let (callables, dependency_ids_by_callable) = sort_callables_by_dependencies(&callables, &transformed)?;
710
711    // Step 1.75: Compute the set of Range UOp IDs paired with `END(Call)` —
712    // these are the schedule-level loop wrappers, identified structurally
713    // (no axis_type filter).
714    let callable_ids: HashSet<u64> = callables.iter().map(|c| c.id).collect();
715    let scheduled_range_ids = collect_scheduled_range_ids(&transformed, &callable_ids);
716
717    // Step 2: Build pre-schedule items (AST + sources + dependencies + bound ranges).
718    let mut items = Vec::with_capacity(callables.len());
719    for callable_uop in callables {
720        let Op::Call { body, args, .. } = callable_uop.op() else {
721            unreachable!("filtered to only call wrappers above")
722        };
723        let dependencies = dependency_ids_by_callable.get(&callable_uop.id).cloned().unwrap_or_default();
724        let bound_ranges = collect_call_bound_ranges(&callable_uop, &scheduled_range_ids)?;
725        items.push(PreScheduleItem {
726            kernel: callable_uop.clone(),
727            ast: body.clone(),
728            sources: args.iter().cloned().collect(),
729            dependencies,
730            bound_ranges,
731        });
732    }
733
734    // Step 3: Eagerly unroll outer loops into a flat list of kernel
735    // invocations — each invocation carries the concrete `fixedvars` produced
736    // by its enclosing loop counters.
737    let invocations = collect_kernel_invocations(&transformed, &items, &scheduled_range_ids)?;
738
739    // Output buffers in SINK source order.
740    let output_buffer_uops: Vec<Arc<UOp>> = match transformed.op() {
741        Op::Sink { sources, .. } => sources.iter().map(|src| src.buf_uop()).collect(),
742        _ => vec![transformed.buf_uop()],
743    };
744
745    Ok(PreSchedule { items, invocations, output_buffer_uops })
746}
747
748/// Instantiate a concrete execution schedule from a pre-schedule artifact.
749///
750/// This is the per-run phase that attaches concrete buffers and runtime
751/// variable bindings to cached callable descriptors.
752pub fn instantiate_schedule(
753    pre_schedule: &PreSchedule,
754    input_buffers: &InputBuffers,
755    var_vals: &HashMap<String, i64>,
756) -> Result<ScheduleResult> {
757    // Track allocated intermediate buffers locally (no global registry needed)
758    let mut allocated_buffers: HashMap<u64, Buffer> = HashMap::new();
759
760    let mut templates: HashMap<u64, ScheduleItemTemplate> = HashMap::with_capacity(pre_schedule.items.len());
761    for item in &pre_schedule.items {
762        let nodes = item.ast.toposort();
763
764        // Map sources to actual Buffers.
765        let kb = collect_callable_buffers(&item.sources, &item.ast, input_buffers, &mut allocated_buffers)?;
766
767        debug!(callable.id = item.kernel.id, num_sources = item.sources.len(), "Schedule item created");
768
769        // Populate fixedvars with only the user Variables referenced by this kernel's AST.
770        let fixedvars: HashMap<String, i64> = if var_vals.is_empty() {
771            HashMap::new()
772        } else {
773            let ast_var_names: HashSet<&str> = nodes
774                .iter()
775                .filter_map(|n| match n.op() {
776                    Op::DefineVar { name, .. } => Some(name.as_str()),
777                    _ => None,
778                })
779                .collect();
780            var_vals
781                .iter()
782                .filter(|(name, _)| ast_var_names.contains(name.as_str()))
783                .map(|(k, v)| (k.clone(), *v))
784                .collect()
785        };
786
787        templates.insert(
788            item.kernel.id,
789            ScheduleItemTemplate {
790                kernel: item.kernel.clone(),
791                ast: item.ast.clone(),
792                buffers: kb.buffers,
793                buffer_uop_ids: kb.uop_ids,
794                dependencies: item.dependencies.clone(),
795                alias_registered_ids: kb.alias_ids,
796                base_fixedvars: fixedvars,
797            },
798        );
799    }
800
801    let mut schedule = Vec::with_capacity(pre_schedule.invocations.len());
802    for invocation in &pre_schedule.invocations {
803        let Some(template) = templates.get(&invocation.kernel_id) else {
804            return IrConstructionSnafu {
805                details: format!("invocation references unknown kernel id {}", invocation.kernel_id),
806            }
807            .fail();
808        };
809
810        // Merge the kernel's user-Variable bindings (`base_fixedvars`) with
811        // the loop-counter bindings produced at this iteration.
812        let mut fixedvars = template.base_fixedvars.clone();
813        fixedvars.extend(invocation.fixedvars.iter().map(|(k, v)| (k.clone(), *v)));
814        let loop_var_names: HashSet<String> = invocation.fixedvars.keys().cloned().collect();
815
816        schedule.push(ScheduleItem {
817            kernel: template.kernel.clone(),
818            ast: template.ast.clone(),
819            buffers: template.buffers.clone(),
820            buffer_uop_ids: template.buffer_uop_ids.clone(),
821            fixedvars,
822            loop_var_names,
823            dependencies: template.dependencies.clone(),
824            instance_dependencies: Vec::new(),
825            alias_registered_ids: template.alias_registered_ids.clone(),
826        });
827    }
828
829    if schedule.is_empty() {
830        return EmptyScheduleSnafu.fail();
831    }
832
833    let output_uop_ids: Vec<u64> = pre_schedule.output_buffer_uops.iter().map(|u| u.buf_uop().id).collect();
834    Ok(ScheduleResult { items: schedule, output_uop_ids })
835}
836
837pub fn create_schedule(
838    transformed: Arc<UOp>,
839    input_buffers: &InputBuffers,
840    var_vals: &HashMap<String, i64>,
841) -> Result<ScheduleResult> {
842    let pre = create_pre_schedule(transformed)?;
843    instantiate_schedule(&pre, input_buffers, var_vals)
844}
845
846/// Extract device from the first input buffer in callable sources.
847///
848/// The first buffer's device determines the device for codegen/compilation and
849/// output buffer allocation.
850///
851/// DISK buffers are skipped: a disk-resident input is never a viable execution
852/// device (no compiler), so the search continues to find a compute-capable
853/// source. When every source is on disk (e.g. fully-materialized parameter
854/// graphs that exist only as safetensors mmaps before realization), we fall
855/// back to CPU so the kernel still has somewhere to run; the runtime then
856/// arranges the disk→CPU copies via the normal Copy ops.
857fn find_first_input_buffer_device(
858    sources: &[Arc<UOp>],
859    input_buffers: &InputBuffers,
860    allocated_buffers: &HashMap<u64, Buffer>,
861) -> Result<Arc<Device>> {
862    let alloc_registry = registry::registry();
863
864    for src in sources {
865        if let Some(buf_id) = source_primary_buffer_id(src) {
866            let buffer = allocated_buffers.get(&buf_id).cloned().or_else(|| input_buffers.get(&buf_id).cloned());
867            if let Some(buffer) = buffer {
868                let device_spec = buffer.allocator().device_spec();
869                if device_spec.is_disk() {
870                    continue;
871                }
872                return svod_runtime::DEVICE_FACTORIES.device(&device_spec, alloc_registry).context(DeviceFactorySnafu);
873            }
874        }
875    }
876
877    // Fallback to CPU if no input buffers found
878    svod_runtime::DEVICE_FACTORIES.device(&DeviceSpec::Cpu, alloc_registry).context(DeviceFactorySnafu)
879}
880
881/// Collect buffers for a callable from its sources.
882///
883/// This walks the callable sources and identifies:
884/// - Input buffers (Op::Buffer) - get from input_buffers
885/// - Intermediate buffers (Op::Param) - allocate and track
886/// - Shared buffers (Op::After) - look up from allocated_buffers (producer callable)
887///
888/// For input buffers (PARAM that maps to an original BUFFER),
889/// we reuse the existing buffer from input_buffers instead of allocating.
890///
891/// For shared buffers (AFTER nodes), we look up the buffer using buf_uop()
892/// which walks through AFTER chains to get the underlying buffer ID.
893///
894/// Output/intermediate buffers are allocated on the same device as the first
895/// input buffer. Newly allocated buffers are tracked in `allocated_buffers`.
896fn collect_callable_buffers(
897    sources: &[Arc<UOp>],
898    ast: &Arc<UOp>,
899    input_buffers: &InputBuffers,
900    allocated_buffers: &mut HashMap<u64, Buffer>,
901) -> Result<CallableBuffers> {
902    // Get target device from the first input buffer.
903    let target_device = find_first_input_buffer_device(sources, input_buffers, allocated_buffers)?;
904
905    let mut buffers = Vec::new();
906    let mut uop_ids = Vec::new();
907    let mut alias_ids = Vec::new();
908
909    for src in sources {
910        let canonical_src = canonicalize_callable_source(src);
911        if canonical_src.id != src.id {
912            alias_ids.push(src.id);
913        }
914
915        match canonical_src.op() {
916            Op::After { passthrough, .. } => {
917                // Shared buffer from producer kernel.
918                // Use buf_uop() to get underlying buffer ID (handles AFTER chains).
919                let buf_id = passthrough.buf_uop().id;
920                if buf_id != canonical_src.id {
921                    alias_ids.push(canonical_src.id);
922                }
923
924                // Look up from allocated_buffers or input_buffers
925                let existing = allocated_buffers.get(&buf_id).cloned().or_else(|| input_buffers.get(&buf_id).cloned());
926
927                if let Some(buffer) = existing {
928                    trace!(
929                        buf_id,
930                        buffer.id = ?buffer.id(),
931                        "Found shared buffer from AFTER"
932                    );
933
934                    // Track under buf_id if not already tracked
935                    allocated_buffers.entry(buf_id).or_insert_with(|| buffer.clone());
936
937                    buffers.push(buffer);
938                    uop_ids.push(buf_id);
939                } else {
940                    trace!(buf_id, "after buffer not found in allocated_buffers or input_buffers");
941                    return Err(Error::BufferNotFound { uop_id: buf_id });
942                }
943            }
944            Op::MSelect { .. } | Op::MStack { .. } => {
945                let Some(canonical_id) = source_primary_buffer_id(&canonical_src) else {
946                    return IrConstructionSnafu {
947                        details: format!(
948                            "multi-device callable source must resolve a primary buffer id: source_id={}, op={:?}",
949                            canonical_src.id,
950                            canonical_src.op()
951                        ),
952                    }
953                    .fail();
954                };
955                if canonical_id != canonical_src.id {
956                    alias_ids.push(canonical_src.id);
957                }
958
959                let existing =
960                    allocated_buffers.get(&canonical_id).cloned().or_else(|| input_buffers.get(&canonical_id).cloned());
961
962                if let Some(buffer) = existing {
963                    trace!(canonical_id, buffer.id = ?buffer.id(), "Found shared buffer from MSELECT/MSTACK source");
964                    allocated_buffers.entry(canonical_id).or_insert_with(|| buffer.clone());
965                    buffers.push(buffer);
966                    uop_ids.push(canonical_id);
967                } else {
968                    trace!(canonical_id, "multi-device source buffer not found in allocated_buffers or input_buffers");
969                    return Err(Error::BufferNotFound { uop_id: canonical_id });
970                }
971            }
972            // Callable args/sources are typically Buffer/Param/After/DefineLocal.
973            Op::DefineLocal(_id) => {
974                // Allocate local/shared memory buffer on same device as inputs
975                let ptr_dtype = canonical_src.dtype();
976                let size = compute_buffer_size(ast, &canonical_src)?;
977
978                // Extract the base scalar dtype from the Ptr type
979                let scalar_dtype = match ptr_dtype {
980                    svod_dtype::DType::Ptr { base, .. } => *base,
981                    other => {
982                        return ExpectedPtrDtypeSnafu { context: "DEFINE_LOCAL", actual: other.clone() }.fail();
983                    }
984                };
985
986                let buffer =
987                    Buffer::new(target_device.allocator.clone(), scalar_dtype.clone(), vec![size], Default::default());
988
989                // Track in allocated_buffers (no registry needed)
990                allocated_buffers.insert(canonical_src.id, buffer.clone());
991
992                buffers.push(buffer);
993                uop_ids.push(canonical_src.id);
994            }
995            Op::Buffer { size, .. } | Op::Param { size, .. } => {
996                let canonical_id = canonical_src.buf_uop().id;
997                if canonical_id != canonical_src.id {
998                    alias_ids.push(canonical_src.id);
999                }
1000
1001                // BUFFER/PARAM can be either input (from input_buffers) or output (needs allocation)
1002                // Try input_buffers first, then allocated_buffers, then allocate new
1003                if let Some(buffer) =
1004                    input_buffers.get(&canonical_id).cloned().or_else(|| input_buffers.get(&canonical_src.id).cloned())
1005                {
1006                    buffers.push(buffer);
1007                    uop_ids.push(canonical_id);
1008                } else if let Some(buffer) = allocated_buffers
1009                    .get(&canonical_id)
1010                    .cloned()
1011                    .or_else(|| allocated_buffers.get(&canonical_src.id).cloned())
1012                {
1013                    buffers.push(buffer);
1014                    uop_ids.push(canonical_id);
1015                } else {
1016                    // Output buffer - allocate new buffer
1017                    trace!(src.id = canonical_src.id, canonical_id, size, "Allocating output BUFFER/PARAM");
1018                    let scalar_dtype = canonical_src.dtype();
1019
1020                    let buffer = Buffer::new(
1021                        target_device.allocator.clone(),
1022                        scalar_dtype.clone(),
1023                        vec![*size],
1024                        Default::default(),
1025                    );
1026
1027                    // Track in allocated_buffers
1028                    allocated_buffers.insert(canonical_id, buffer.clone());
1029                    buffers.push(buffer);
1030                    uop_ids.push(canonical_id);
1031                }
1032            }
1033            Op::Bind { .. } => {
1034                // Variable binding - not a buffer, skip
1035                continue;
1036            }
1037            other => {
1038                return IrConstructionSnafu {
1039                    details: format!("unsupported callable source op for buffer collection: {other:?}"),
1040                }
1041                .fail();
1042            }
1043        }
1044    }
1045
1046    alias_ids.sort_unstable();
1047    alias_ids.dedup();
1048    Ok(CallableBuffers { buffers, uop_ids, alias_ids })
1049}
1050
1051#[derive(Clone)]
1052struct ScheduleItemTemplate {
1053    kernel: Arc<UOp>,
1054    ast: Arc<UOp>,
1055    buffers: Vec<Buffer>,
1056    buffer_uop_ids: Vec<u64>,
1057    dependencies: Vec<u64>,
1058    alias_registered_ids: Vec<u64>,
1059    base_fixedvars: HashMap<String, i64>,
1060}
1061
1062fn schedule_range_bounds(range: &Arc<UOp>) -> Result<(i64, i64)> {
1063    let Op::Range { .. } = range.op() else {
1064        return IrConstructionSnafu {
1065            details: format!("expected RANGE for schedule loop control, got {:?}", range.op()),
1066        }
1067        .fail();
1068    };
1069
1070    let Some(vmin) = range.vmin().try_int() else {
1071        return IrConstructionSnafu {
1072            details: format!("schedule range vmin must be concrete integer, got {:?}", range.vmin()),
1073        }
1074        .fail();
1075    };
1076    let Some(vmax) = range.vmax().try_int() else {
1077        return IrConstructionSnafu {
1078            details: format!("schedule range vmax must be concrete integer, got {:?}", range.vmax()),
1079        }
1080        .fail();
1081    };
1082    if vmax < vmin {
1083        return IrConstructionSnafu { details: format!("invalid schedule range bounds: vmin={vmin}, vmax={vmax}") }
1084            .fail();
1085    }
1086    Ok((vmin, vmax))
1087}
1088
1089/// Compute buffer size from the buffer definition's dtype.
1090///
1091/// Buffer size is embedded in the Ptr dtype by debuf() during rangeify
1092/// (`dtype.ptr(size=...)`).
1093fn compute_buffer_size(_ast: &Arc<UOp>, buffer_def: &Arc<UOp>) -> Result<usize> {
1094    // Extract size from Ptr dtype (set by debuf() in split_patterns.rs)
1095    match buffer_def.dtype() {
1096        DType::Ptr { size: Some(s), .. } => Ok(s),
1097        DType::Ptr { size: None, .. } => BufferPtrNoSizeSnafu.fail(),
1098        other => ExpectedPtrDtypeSnafu { context: "buffer_size", actual: other.clone() }.fail(),
1099    }
1100}