1use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::Arc;
16
17use svod_device::Buffer;
18use svod_device::device::Device;
19use svod_device::registry;
20use svod_dtype::{DType, DeviceSpec};
21use svod_ir::{Op, UOp};
22use tracing::{debug, trace};
23
24use crate::error::*;
25use crate::{Error, Result};
26use snafu::ResultExt;
27
28fn canonicalize_callable_source(src: &Arc<UOp>) -> Arc<UOp> {
29 let mut cur = src.clone();
30 loop {
31 match cur.op() {
32 Op::After { .. }
33 | Op::Buffer { .. }
34 | Op::Param { .. }
35 | Op::MSelect { .. }
36 | Op::MStack { .. }
37 | Op::Bind { .. } => return cur,
38 _ => {
39 let sources = cur.op().sources();
40 let Some(next) = sources.first() else {
41 return cur;
42 };
43 if Arc::ptr_eq(&cur, next) {
44 return cur;
45 }
46 cur = (*next).clone();
47 }
48 }
49 }
50}
51
52fn source_primary_buffer_id(src: &Arc<UOp>) -> Option<u64> {
53 let src = canonicalize_callable_source(src);
54 match src.op() {
55 Op::Buffer { .. } | Op::Param { .. } | Op::After { .. } => Some(src.buf_uop().id),
56 Op::Bind { .. } => None,
57 Op::MSelect { buffer, device_index } => {
58 if let Op::MStack { buffers } = buffer.op() {
59 buffers.get(*device_index).map(|b| b.buf_uop().id).or_else(|| Some(src.buf_uop().id))
60 } else {
61 Some(src.buf_uop().id)
62 }
63 }
64 Op::MStack { buffers } => buffers.first().map(|b| b.buf_uop().id),
65 _ => None,
66 }
67}
68
69fn collect_callable_dep_ids(dep: &Arc<UOp>, out: &mut HashSet<u64>) -> Result<()> {
70 match dep.op() {
71 Op::Call { .. } => {
72 out.insert(dep.id);
73 Ok(())
74 }
75 Op::End { computation, .. } => {
76 if matches!(computation.op(), Op::Call { .. }) {
77 out.insert(computation.id);
78 Ok(())
79 } else {
80 IrConstructionSnafu {
81 details: format!("AFTER dependency END must wrap CALL, got {:?}", computation.op()),
82 }
83 .fail()
84 }
85 }
86 Op::Store { .. } => Ok(()),
87 Op::After { deps, .. } => {
88 for nested in deps {
89 collect_callable_dep_ids(nested, out)?;
90 }
91 Ok(())
92 }
93 other => IrConstructionSnafu {
94 details: format!("AFTER dependency must be CALL/END(CALL)/STORE/AFTER, got {other:?}"),
95 }
96 .fail(),
97 }
98}
99
100type AfterDependencySplit = (Vec<Arc<UOp>>, Vec<Arc<UOp>>);
101
102fn split_after_dependencies(after: &Arc<UOp>) -> Result<AfterDependencySplit> {
103 let Op::After { deps, .. } = after.op() else {
104 return IrConstructionSnafu {
105 details: format!("expected AFTER when splitting dependencies, got {:?}", after.op()),
106 }
107 .fail();
108 };
109
110 let mut kernels = Vec::new();
111 let mut after_deps = Vec::new();
112 for dep in deps {
113 match dep.op() {
114 Op::Call { .. } => kernels.push(dep.clone()),
115 Op::End { computation, .. } if matches!(computation.op(), Op::Call { .. }) => kernels.push(dep.clone()),
116 Op::After { .. } => after_deps.push(dep.clone()),
117 Op::Store { .. } => {}
118 other => {
119 return IrConstructionSnafu {
120 details: format!("AFTER dependency must be CALL/END(CALL)/STORE/AFTER, got {other:?}"),
121 }
122 .fail();
123 }
124 }
125 }
126
127 Ok((kernels, after_deps))
128}
129
130fn collect_source_dependency_callable_ids(src: &Arc<UOp>, out: &mut HashSet<u64>) -> Result<()> {
131 let src = canonicalize_callable_source(src);
132 match src.op() {
133 Op::After { .. } => {
134 let (kernels, after_deps) = split_after_dependencies(&src)?;
135 for kernel in kernels {
136 collect_callable_dep_ids(&kernel, out)?;
137 }
138 for dep in after_deps {
139 collect_source_dependency_callable_ids(&dep, out)?;
140 }
141 Ok(())
142 }
143 Op::MStack { buffers } => {
144 for buffer in buffers {
145 collect_source_dependency_callable_ids(buffer, out)?;
146 }
147 Ok(())
148 }
149 Op::MSelect { buffer, .. } => collect_source_dependency_callable_ids(buffer, out),
150 Op::Buffer { .. } | Op::Param { .. } | Op::Bind { .. } => Ok(()),
151 other => IrConstructionSnafu {
152 details: format!("input to callable must resolve to AFTER/BUFFER/PARAM/MSELECT/MSTACK/BIND, got {other:?}"),
153 }
154 .fail(),
155 }
156}
157
158fn callable_sources(callable: &Arc<UOp>) -> Option<Vec<Arc<UOp>>> {
159 match callable.op() {
160 Op::Call { args, .. } => Some(args.iter().cloned().collect()),
161 _ => None,
162 }
163}
164
165fn collect_scheduled_range_ids(root: &Arc<UOp>, callable_ids: &HashSet<u64>) -> HashSet<u64> {
171 let mut ids = HashSet::new();
172 for node in root.toposort_call_aware(false) {
173 let Op::End { computation, ranges } = node.op() else { continue };
174 if !matches!(computation.op(), Op::Call { .. }) || !callable_ids.contains(&computation.id) {
175 continue;
176 }
177 for r in ranges {
178 if matches!(r.op(), Op::Range { .. }) {
179 ids.insert(r.id);
180 }
181 }
182 }
183 ids
184}
185
186fn collect_call_bound_ranges(callable: &Arc<UOp>, scheduled_range_ids: &HashSet<u64>) -> Result<Vec<BoundRangeRef>> {
187 let Op::Call { args, .. } = callable.op() else {
188 return ExpectedCallableOpSnafu.fail();
189 };
190
191 let mut bound_ranges = Vec::new();
192 for arg in args {
193 let Op::Bind { var, value } = arg.op() else {
194 continue;
195 };
196 let Op::DefineVar { name, .. } = var.op() else {
197 return IrConstructionSnafu {
198 details: format!("CALL BIND source must wrap DEFINE_VAR, got {:?}", var.op()),
199 }
200 .fail();
201 };
202 let Op::Range { .. } = value.op() else {
203 continue;
205 };
206 if !scheduled_range_ids.contains(&value.id) {
209 continue;
210 }
211 bound_ranges.push(BoundRangeRef { var_name: name.clone(), range_uop: value.clone() });
212 }
213 Ok(bound_ranges)
214}
215
216fn collect_linear_sched_ops_internal(
217 root: &Arc<UOp>,
218 callable_ids: &HashSet<u64>,
219 scheduled_range_ids: &HashSet<u64>,
220) -> Result<Vec<LinearSchedOp>> {
221 let mut linear_ops = Vec::new();
222
223 for node in root.toposort_call_aware(false) {
224 match node.op() {
225 Op::Range { .. } if scheduled_range_ids.contains(&node.id) => {
226 linear_ops.push(LinearSchedOp::Range { range: node.clone() });
227 }
228 Op::Call { .. } if callable_ids.contains(&node.id) => {
229 linear_ops.push(LinearSchedOp::Call { kernel_id: node.id });
230 }
231 Op::End { computation, ranges } if matches!(computation.op(), Op::Call { .. }) => {
232 if !callable_ids.contains(&computation.id) {
233 continue;
234 }
235 let wrapper_ranges: Vec<Arc<UOp>> =
236 ranges.iter().filter(|r| matches!(r.op(), Op::Range { .. })).cloned().collect();
237 match wrapper_ranges.as_slice() {
238 [] => {}
239 [outer] => linear_ops.push(LinearSchedOp::End { range: outer.clone(), kernel_id: computation.id }),
240 _ => {
241 return IrConstructionSnafu {
242 details: format!(
243 "END(CALL) must close at most one wrapper range in strict scheduler, got {}",
244 wrapper_ranges.len()
245 ),
246 }
247 .fail();
248 }
249 }
250 }
251 _ => {}
252 }
253 }
254
255 if linear_ops.is_empty() {
256 return IrConstructionSnafu { details: "strict scheduler produced empty linear control stream".to_string() }
257 .fail();
258 }
259 Ok(linear_ops)
260}
261
262fn collect_kernel_invocations(
270 root: &Arc<UOp>,
271 items: &[PreScheduleItem],
272 scheduled_range_ids: &HashSet<u64>,
273) -> Result<Vec<KernelInvocation>> {
274 let callable_ids: HashSet<u64> = items.iter().map(|it| it.kernel.id).collect();
275 let linear_ops = collect_linear_sched_ops_internal(root, &callable_ids, scheduled_range_ids)?;
276
277 let bound_ranges_by_kernel: HashMap<u64, &[BoundRangeRef]> =
278 items.iter().map(|it| (it.kernel.id, it.bound_ranges.as_slice())).collect();
279
280 let mut declared_ranges: HashSet<u64> = HashSet::new();
283 let mut ended_ranges: HashSet<u64> = HashSet::new();
284 for op in &linear_ops {
285 match op {
286 LinearSchedOp::Range { range } => {
287 declared_ranges.insert(range.id);
288 }
289 LinearSchedOp::End { range, .. } => {
290 ended_ranges.insert(range.id);
291 }
292 LinearSchedOp::Call { .. } => {}
293 }
294 }
295 for &rid in &declared_ranges {
296 if !ended_ranges.contains(&rid) {
297 return IrConstructionSnafu { details: format!("schedule range {rid} is missing END in strict scheduler") }
298 .fail();
299 }
300 }
301 for item in items {
302 for br in &item.bound_ranges {
303 if !declared_ranges.contains(&br.range_uop.id) {
304 return IrConstructionSnafu {
305 details: format!(
306 "CALL {} bound variable '{}' references schedule range {} missing from linear schedule",
307 item.kernel.id, br.var_name, br.range_uop.id
308 ),
309 }
310 .fail();
311 }
312 }
313 }
314
315 let mut invocations = Vec::new();
318 let mut in_ranges: HashMap<u64, i64> = HashMap::new();
319 let mut range_ptrs: HashMap<u64, usize> = HashMap::new();
320 let mut range_bounds: HashMap<u64, (i64, i64)> = HashMap::new();
321
322 let mut sched_ptr = 0usize;
323 while sched_ptr < linear_ops.len() {
324 match &linear_ops[sched_ptr] {
325 LinearSchedOp::Range { range } => {
326 let bounds = if let Some(bounds) = range_bounds.get(&range.id).copied() {
327 bounds
328 } else {
329 let bounds = schedule_range_bounds(range)?;
330 range_bounds.insert(range.id, bounds);
331 bounds
332 };
333 in_ranges.insert(range.id, bounds.0);
334 range_ptrs.insert(range.id, sched_ptr + 1);
335 }
336 LinearSchedOp::End { range, kernel_id } => {
337 if !bound_ranges_by_kernel.contains_key(kernel_id) {
338 return IrConstructionSnafu {
339 details: format!("linear END references unknown CALL id {kernel_id}"),
340 }
341 .fail();
342 }
343 let (_, vmax) = if let Some(bounds) = range_bounds.get(&range.id).copied() {
344 bounds
345 } else {
346 let bounds = schedule_range_bounds(range)?;
347 range_bounds.insert(range.id, bounds);
348 bounds
349 };
350 let Some(cur) = in_ranges.get_mut(&range.id) else {
351 return IrConstructionSnafu {
352 details: format!("END references schedule range {} that is not active", range.id),
353 }
354 .fail();
355 };
356 if *cur < vmax {
357 *cur += 1;
358 let Some(jump_ptr) = range_ptrs.get(&range.id).copied() else {
359 return IrConstructionSnafu {
360 details: format!("missing loop jump pointer for schedule range {}", range.id),
361 }
362 .fail();
363 };
364 sched_ptr = jump_ptr;
365 continue;
366 }
367 }
368 LinearSchedOp::Call { kernel_id } => {
369 let Some(bound_ranges) = bound_ranges_by_kernel.get(kernel_id) else {
370 return IrConstructionSnafu {
371 details: format!("linear CALL references unknown kernel id {kernel_id}"),
372 }
373 .fail();
374 };
375 let mut fixedvars = HashMap::new();
376 for br in *bound_ranges {
377 let Some(value) = in_ranges.get(&br.range_uop.id).copied() else {
378 return IrConstructionSnafu {
379 details: format!(
380 "CALL {} bound variable '{}' references inactive schedule range {}",
381 kernel_id, br.var_name, br.range_uop.id
382 ),
383 }
384 .fail();
385 };
386 fixedvars.insert(br.var_name.clone(), value);
387 }
388 invocations.push(KernelInvocation { kernel_id: *kernel_id, fixedvars });
389 }
390 }
391 sched_ptr += 1;
392 }
393
394 Ok(invocations)
395}
396
397fn analyze_callable_dependencies(callables: &[Arc<UOp>], root: &Arc<UOp>) -> Result<Vec<HashSet<usize>>> {
398 let callable_idx: HashMap<u64, usize> = callables.iter().enumerate().map(|(i, c)| (c.id, i)).collect();
400 let mut dependencies: Vec<HashSet<usize>> = vec![HashSet::new(); callables.len()];
403
404 for (consumer_idx, callable) in callables.iter().enumerate() {
405 let mut dep_ids = HashSet::new();
406 if let Some(sources) = callable_sources(callable) {
407 for src in sources {
408 collect_source_dependency_callable_ids(&src, &mut dep_ids)?;
409 }
410 }
411
412 for dep_id in dep_ids {
413 let Some(&producer_idx) = callable_idx.get(&dep_id) else {
414 return IrConstructionSnafu {
415 details: format!("callable dependency references unknown callable id {dep_id}"),
416 }
417 .fail();
418 };
419 if producer_idx != consumer_idx {
420 dependencies[consumer_idx].insert(producer_idx);
421 }
422 }
423 }
424
425 for node in root.toposort() {
428 let Op::After { .. } = node.op() else {
429 continue;
430 };
431
432 let (kernels, after_deps) = split_after_dependencies(&node)?;
433 for kernel in kernels {
434 let callable = match kernel.op() {
435 Op::Call { .. } => kernel.clone(),
436 Op::End { computation, .. } => computation.clone(),
437 _ => unreachable!("split_after_dependencies only returns CALL/END(CALL) kernels"),
438 };
439
440 let Some(&consumer_idx) = callable_idx.get(&callable.id) else {
441 return IrConstructionSnafu {
442 details: format!("AFTER dependency references unknown callable id {}", callable.id),
443 }
444 .fail();
445 };
446
447 let mut dep_ids = HashSet::new();
448 for dep in &after_deps {
449 collect_source_dependency_callable_ids(dep, &mut dep_ids)?;
450 }
451
452 for dep_id in dep_ids {
453 let Some(&producer_idx) = callable_idx.get(&dep_id) else {
454 return IrConstructionSnafu {
455 details: format!("callable dependency references unknown callable id {dep_id}"),
456 }
457 .fail();
458 };
459 if producer_idx != consumer_idx {
460 dependencies[consumer_idx].insert(producer_idx);
461 }
462 }
463 }
464 }
465
466 Ok(dependencies)
467}
468
469pub type InputBuffers = HashMap<u64, Buffer>;
474
475#[derive(Clone, Debug)]
479pub struct BoundRangeRef {
480 pub var_name: String,
482 pub range_uop: Arc<UOp>,
484}
485
486#[derive(Clone, Debug)]
493enum LinearSchedOp {
494 Range { range: Arc<UOp> },
495 Call { kernel_id: u64 },
496 End { range: Arc<UOp>, kernel_id: u64 },
497}
498
499#[derive(Clone, Debug)]
505pub struct KernelInvocation {
506 pub kernel_id: u64,
508 pub fixedvars: HashMap<String, i64>,
511}
512
513#[derive(Clone)]
521pub struct ScheduleItem {
522 pub kernel: Arc<UOp>,
524
525 pub ast: Arc<UOp>,
527
528 pub buffers: Vec<Buffer>,
530
531 pub buffer_uop_ids: Vec<u64>,
535
536 pub fixedvars: HashMap<String, i64>,
540
541 pub loop_var_names: HashSet<String>,
545
546 pub dependencies: Vec<u64>,
550
551 pub instance_dependencies: Vec<usize>,
555
556 pub alias_registered_ids: Vec<u64>,
560}
561
562pub type Schedule = Vec<ScheduleItem>;
564
565#[derive(Clone)]
569pub struct PreScheduleItem {
570 pub kernel: Arc<UOp>,
572 pub ast: Arc<UOp>,
574 pub sources: Vec<Arc<UOp>>,
576 pub dependencies: Vec<u64>,
578 pub bound_ranges: Vec<BoundRangeRef>,
580}
581
582#[derive(Clone)]
588pub struct PreSchedule {
589 pub items: Vec<PreScheduleItem>,
591 pub invocations: Vec<KernelInvocation>,
593 pub output_buffer_uops: Vec<Arc<UOp>>,
595}
596
597type SortedCallables = (Vec<Arc<UOp>>, HashMap<u64, Vec<u64>>);
598
599pub struct ScheduleResult {
601 pub items: Schedule,
603 pub output_uop_ids: Vec<u64>,
607}
608
609struct CallableBuffers {
611 buffers: Vec<Buffer>,
613 uop_ids: Vec<u64>,
615 alias_ids: Vec<u64>,
617}
618
619fn sort_callables_by_dependencies(callables: &[Arc<UOp>], root: &Arc<UOp>) -> Result<SortedCallables> {
627 debug!(num_callables = callables.len(), "sorting callables by dependencies");
628
629 let dependencies = analyze_callable_dependencies(callables, root)?;
630
631 let mut in_degree: Vec<usize> = dependencies.iter().map(|deps| deps.len()).collect();
633 let mut dependents: Vec<Vec<usize>> = vec![vec![]; callables.len()];
634
635 for (consumer, deps) in dependencies.iter().enumerate() {
636 for &producer in deps {
637 dependents[producer].push(consumer);
638 }
639 }
640
641 let mut queue: VecDeque<usize> =
642 in_degree.iter().enumerate().filter(|&(_, °)| deg == 0).map(|(idx, _)| idx).collect();
643
644 let mut sorted_indices = Vec::new();
645 while let Some(idx) = queue.pop_front() {
646 sorted_indices.push(idx);
647 for &dependent in &dependents[idx] {
648 in_degree[dependent] -= 1;
649 if in_degree[dependent] == 0 {
650 queue.push_back(dependent);
651 }
652 }
653 }
654
655 if sorted_indices.len() < callables.len() {
656 return DependencyCyclesSnafu.fail();
657 }
658
659 let sorted: Vec<Arc<UOp>> = sorted_indices.iter().map(|&idx| callables[idx].clone()).collect();
660
661 let dependency_ids_by_callable: HashMap<u64, Vec<u64>> = callables
662 .iter()
663 .enumerate()
664 .map(|(idx, callable)| {
665 let mut deps: Vec<u64> = dependencies[idx].iter().map(|&dep_idx| callables[dep_idx].id).collect();
666 deps.sort_unstable();
667 (callable.id, deps)
668 })
669 .collect();
670
671 debug!(num_sorted = sorted.len(), "callables sorted");
672
673 Ok((sorted, dependency_ids_by_callable))
674}
675
676pub fn create_pre_schedule(transformed: Arc<UOp>) -> Result<PreSchedule> {
696 let mut callables = Vec::new();
698 for node in transformed.toposort_call_aware(false) {
699 if matches!(node.op(), Op::Call { .. }) {
700 callables.push(node);
701 }
702 }
703
704 if callables.is_empty() {
705 return NoKernelsFoundSnafu.fail();
706 }
707
708 let (callables, dependency_ids_by_callable) = sort_callables_by_dependencies(&callables, &transformed)?;
710
711 let callable_ids: HashSet<u64> = callables.iter().map(|c| c.id).collect();
715 let scheduled_range_ids = collect_scheduled_range_ids(&transformed, &callable_ids);
716
717 let mut items = Vec::with_capacity(callables.len());
719 for callable_uop in callables {
720 let Op::Call { body, args, .. } = callable_uop.op() else {
721 unreachable!("filtered to only call wrappers above")
722 };
723 let dependencies = dependency_ids_by_callable.get(&callable_uop.id).cloned().unwrap_or_default();
724 let bound_ranges = collect_call_bound_ranges(&callable_uop, &scheduled_range_ids)?;
725 items.push(PreScheduleItem {
726 kernel: callable_uop.clone(),
727 ast: body.clone(),
728 sources: args.iter().cloned().collect(),
729 dependencies,
730 bound_ranges,
731 });
732 }
733
734 let invocations = collect_kernel_invocations(&transformed, &items, &scheduled_range_ids)?;
738
739 let output_buffer_uops: Vec<Arc<UOp>> = match transformed.op() {
741 Op::Sink { sources, .. } => sources.iter().map(|src| src.buf_uop()).collect(),
742 _ => vec![transformed.buf_uop()],
743 };
744
745 Ok(PreSchedule { items, invocations, output_buffer_uops })
746}
747
748pub fn instantiate_schedule(
753 pre_schedule: &PreSchedule,
754 input_buffers: &InputBuffers,
755 var_vals: &HashMap<String, i64>,
756) -> Result<ScheduleResult> {
757 let mut allocated_buffers: HashMap<u64, Buffer> = HashMap::new();
759
760 let mut templates: HashMap<u64, ScheduleItemTemplate> = HashMap::with_capacity(pre_schedule.items.len());
761 for item in &pre_schedule.items {
762 let nodes = item.ast.toposort();
763
764 let kb = collect_callable_buffers(&item.sources, &item.ast, input_buffers, &mut allocated_buffers)?;
766
767 debug!(callable.id = item.kernel.id, num_sources = item.sources.len(), "Schedule item created");
768
769 let fixedvars: HashMap<String, i64> = if var_vals.is_empty() {
771 HashMap::new()
772 } else {
773 let ast_var_names: HashSet<&str> = nodes
774 .iter()
775 .filter_map(|n| match n.op() {
776 Op::DefineVar { name, .. } => Some(name.as_str()),
777 _ => None,
778 })
779 .collect();
780 var_vals
781 .iter()
782 .filter(|(name, _)| ast_var_names.contains(name.as_str()))
783 .map(|(k, v)| (k.clone(), *v))
784 .collect()
785 };
786
787 templates.insert(
788 item.kernel.id,
789 ScheduleItemTemplate {
790 kernel: item.kernel.clone(),
791 ast: item.ast.clone(),
792 buffers: kb.buffers,
793 buffer_uop_ids: kb.uop_ids,
794 dependencies: item.dependencies.clone(),
795 alias_registered_ids: kb.alias_ids,
796 base_fixedvars: fixedvars,
797 },
798 );
799 }
800
801 let mut schedule = Vec::with_capacity(pre_schedule.invocations.len());
802 for invocation in &pre_schedule.invocations {
803 let Some(template) = templates.get(&invocation.kernel_id) else {
804 return IrConstructionSnafu {
805 details: format!("invocation references unknown kernel id {}", invocation.kernel_id),
806 }
807 .fail();
808 };
809
810 let mut fixedvars = template.base_fixedvars.clone();
813 fixedvars.extend(invocation.fixedvars.iter().map(|(k, v)| (k.clone(), *v)));
814 let loop_var_names: HashSet<String> = invocation.fixedvars.keys().cloned().collect();
815
816 schedule.push(ScheduleItem {
817 kernel: template.kernel.clone(),
818 ast: template.ast.clone(),
819 buffers: template.buffers.clone(),
820 buffer_uop_ids: template.buffer_uop_ids.clone(),
821 fixedvars,
822 loop_var_names,
823 dependencies: template.dependencies.clone(),
824 instance_dependencies: Vec::new(),
825 alias_registered_ids: template.alias_registered_ids.clone(),
826 });
827 }
828
829 if schedule.is_empty() {
830 return EmptyScheduleSnafu.fail();
831 }
832
833 let output_uop_ids: Vec<u64> = pre_schedule.output_buffer_uops.iter().map(|u| u.buf_uop().id).collect();
834 Ok(ScheduleResult { items: schedule, output_uop_ids })
835}
836
837pub fn create_schedule(
838 transformed: Arc<UOp>,
839 input_buffers: &InputBuffers,
840 var_vals: &HashMap<String, i64>,
841) -> Result<ScheduleResult> {
842 let pre = create_pre_schedule(transformed)?;
843 instantiate_schedule(&pre, input_buffers, var_vals)
844}
845
846fn find_first_input_buffer_device(
858 sources: &[Arc<UOp>],
859 input_buffers: &InputBuffers,
860 allocated_buffers: &HashMap<u64, Buffer>,
861) -> Result<Arc<Device>> {
862 let alloc_registry = registry::registry();
863
864 for src in sources {
865 if let Some(buf_id) = source_primary_buffer_id(src) {
866 let buffer = allocated_buffers.get(&buf_id).cloned().or_else(|| input_buffers.get(&buf_id).cloned());
867 if let Some(buffer) = buffer {
868 let device_spec = buffer.allocator().device_spec();
869 if device_spec.is_disk() {
870 continue;
871 }
872 return svod_runtime::DEVICE_FACTORIES.device(&device_spec, alloc_registry).context(DeviceFactorySnafu);
873 }
874 }
875 }
876
877 svod_runtime::DEVICE_FACTORIES.device(&DeviceSpec::Cpu, alloc_registry).context(DeviceFactorySnafu)
879}
880
881fn collect_callable_buffers(
897 sources: &[Arc<UOp>],
898 ast: &Arc<UOp>,
899 input_buffers: &InputBuffers,
900 allocated_buffers: &mut HashMap<u64, Buffer>,
901) -> Result<CallableBuffers> {
902 let target_device = find_first_input_buffer_device(sources, input_buffers, allocated_buffers)?;
904
905 let mut buffers = Vec::new();
906 let mut uop_ids = Vec::new();
907 let mut alias_ids = Vec::new();
908
909 for src in sources {
910 let canonical_src = canonicalize_callable_source(src);
911 if canonical_src.id != src.id {
912 alias_ids.push(src.id);
913 }
914
915 match canonical_src.op() {
916 Op::After { passthrough, .. } => {
917 let buf_id = passthrough.buf_uop().id;
920 if buf_id != canonical_src.id {
921 alias_ids.push(canonical_src.id);
922 }
923
924 let existing = allocated_buffers.get(&buf_id).cloned().or_else(|| input_buffers.get(&buf_id).cloned());
926
927 if let Some(buffer) = existing {
928 trace!(
929 buf_id,
930 buffer.id = ?buffer.id(),
931 "Found shared buffer from AFTER"
932 );
933
934 allocated_buffers.entry(buf_id).or_insert_with(|| buffer.clone());
936
937 buffers.push(buffer);
938 uop_ids.push(buf_id);
939 } else {
940 trace!(buf_id, "after buffer not found in allocated_buffers or input_buffers");
941 return Err(Error::BufferNotFound { uop_id: buf_id });
942 }
943 }
944 Op::MSelect { .. } | Op::MStack { .. } => {
945 let Some(canonical_id) = source_primary_buffer_id(&canonical_src) else {
946 return IrConstructionSnafu {
947 details: format!(
948 "multi-device callable source must resolve a primary buffer id: source_id={}, op={:?}",
949 canonical_src.id,
950 canonical_src.op()
951 ),
952 }
953 .fail();
954 };
955 if canonical_id != canonical_src.id {
956 alias_ids.push(canonical_src.id);
957 }
958
959 let existing =
960 allocated_buffers.get(&canonical_id).cloned().or_else(|| input_buffers.get(&canonical_id).cloned());
961
962 if let Some(buffer) = existing {
963 trace!(canonical_id, buffer.id = ?buffer.id(), "Found shared buffer from MSELECT/MSTACK source");
964 allocated_buffers.entry(canonical_id).or_insert_with(|| buffer.clone());
965 buffers.push(buffer);
966 uop_ids.push(canonical_id);
967 } else {
968 trace!(canonical_id, "multi-device source buffer not found in allocated_buffers or input_buffers");
969 return Err(Error::BufferNotFound { uop_id: canonical_id });
970 }
971 }
972 Op::DefineLocal(_id) => {
974 let ptr_dtype = canonical_src.dtype();
976 let size = compute_buffer_size(ast, &canonical_src)?;
977
978 let scalar_dtype = match ptr_dtype {
980 svod_dtype::DType::Ptr { base, .. } => *base,
981 other => {
982 return ExpectedPtrDtypeSnafu { context: "DEFINE_LOCAL", actual: other.clone() }.fail();
983 }
984 };
985
986 let buffer =
987 Buffer::new(target_device.allocator.clone(), scalar_dtype.clone(), vec![size], Default::default());
988
989 allocated_buffers.insert(canonical_src.id, buffer.clone());
991
992 buffers.push(buffer);
993 uop_ids.push(canonical_src.id);
994 }
995 Op::Buffer { size, .. } | Op::Param { size, .. } => {
996 let canonical_id = canonical_src.buf_uop().id;
997 if canonical_id != canonical_src.id {
998 alias_ids.push(canonical_src.id);
999 }
1000
1001 if let Some(buffer) =
1004 input_buffers.get(&canonical_id).cloned().or_else(|| input_buffers.get(&canonical_src.id).cloned())
1005 {
1006 buffers.push(buffer);
1007 uop_ids.push(canonical_id);
1008 } else if let Some(buffer) = allocated_buffers
1009 .get(&canonical_id)
1010 .cloned()
1011 .or_else(|| allocated_buffers.get(&canonical_src.id).cloned())
1012 {
1013 buffers.push(buffer);
1014 uop_ids.push(canonical_id);
1015 } else {
1016 trace!(src.id = canonical_src.id, canonical_id, size, "Allocating output BUFFER/PARAM");
1018 let scalar_dtype = canonical_src.dtype();
1019
1020 let buffer = Buffer::new(
1021 target_device.allocator.clone(),
1022 scalar_dtype.clone(),
1023 vec![*size],
1024 Default::default(),
1025 );
1026
1027 allocated_buffers.insert(canonical_id, buffer.clone());
1029 buffers.push(buffer);
1030 uop_ids.push(canonical_id);
1031 }
1032 }
1033 Op::Bind { .. } => {
1034 continue;
1036 }
1037 other => {
1038 return IrConstructionSnafu {
1039 details: format!("unsupported callable source op for buffer collection: {other:?}"),
1040 }
1041 .fail();
1042 }
1043 }
1044 }
1045
1046 alias_ids.sort_unstable();
1047 alias_ids.dedup();
1048 Ok(CallableBuffers { buffers, uop_ids, alias_ids })
1049}
1050
1051#[derive(Clone)]
1052struct ScheduleItemTemplate {
1053 kernel: Arc<UOp>,
1054 ast: Arc<UOp>,
1055 buffers: Vec<Buffer>,
1056 buffer_uop_ids: Vec<u64>,
1057 dependencies: Vec<u64>,
1058 alias_registered_ids: Vec<u64>,
1059 base_fixedvars: HashMap<String, i64>,
1060}
1061
1062fn schedule_range_bounds(range: &Arc<UOp>) -> Result<(i64, i64)> {
1063 let Op::Range { .. } = range.op() else {
1064 return IrConstructionSnafu {
1065 details: format!("expected RANGE for schedule loop control, got {:?}", range.op()),
1066 }
1067 .fail();
1068 };
1069
1070 let Some(vmin) = range.vmin().try_int() else {
1071 return IrConstructionSnafu {
1072 details: format!("schedule range vmin must be concrete integer, got {:?}", range.vmin()),
1073 }
1074 .fail();
1075 };
1076 let Some(vmax) = range.vmax().try_int() else {
1077 return IrConstructionSnafu {
1078 details: format!("schedule range vmax must be concrete integer, got {:?}", range.vmax()),
1079 }
1080 .fail();
1081 };
1082 if vmax < vmin {
1083 return IrConstructionSnafu { details: format!("invalid schedule range bounds: vmin={vmin}, vmax={vmax}") }
1084 .fail();
1085 }
1086 Ok((vmin, vmax))
1087}
1088
1089fn compute_buffer_size(_ast: &Arc<UOp>, buffer_def: &Arc<UOp>) -> Result<usize> {
1094 match buffer_def.dtype() {
1096 DType::Ptr { size: Some(s), .. } => Ok(s),
1097 DType::Ptr { size: None, .. } => BufferPtrNoSizeSnafu.fail(),
1098 other => ExpectedPtrDtypeSnafu { context: "buffer_size", actual: other.clone() }.fail(),
1099 }
1100}