1use std::cmp::Reverse;
29use std::collections::{BinaryHeap, HashMap, HashSet};
30use std::sync::Arc;
31use std::time::Instant;
32
33use rayon::prelude::*;
34use smallvec::SmallVec;
35use svod_device::device::ProgramSpec;
36use svod_device::{Buffer, BufferId};
37use svod_dtype::DeviceSpec;
38use svod_ir::{CustomFunctionKind, Op, UOp};
39
40use crate::error::Result;
41use crate::kernel_cache::CachedKernel;
42use crate::profiler::KernelProfile;
43
44type RuntimeLaunchSizes = (Option<[usize; 3]>, Option<[usize; 3]>);
45
46#[derive(Clone)]
55pub struct PreparedKernel {
56 pub id: u64,
58
59 pub ast: Arc<UOp>,
60
61 pub kernel: Arc<CachedKernel>,
63
64 pub device: DeviceSpec,
66
67 pub buffer_indices: Vec<usize>,
70
71 pub output_indices: Vec<usize>,
73
74 pub vals: Vec<i64>,
76
77 pub fixedvars: HashMap<String, i64>,
82
83 pub dependencies: Vec<u64>,
85
86 pub buffer_ptrs: Vec<usize>,
90
91 pub buffer_ids: Vec<BufferId>,
93
94 pub runtime_vars: Vec<RuntimeVar>,
98}
99
100#[derive(Clone, Debug)]
102pub struct RuntimeVar {
103 pub name: String,
104 pub min_val: i64,
105 pub max_val: i64,
106}
107
108pub fn collect_runtime_vars(root: &Arc<UOp>) -> Vec<RuntimeVar> {
110 let mut vars = Vec::new();
111 let mut seen = std::collections::HashSet::new();
112 for node in root.toposort() {
113 if let Op::DefineVar { name, min_val, max_val } = node.op()
114 && seen.insert(name.clone())
115 {
116 vars.push(RuntimeVar { name: name.clone(), min_val: *min_val, max_val: *max_val });
117 }
118 }
119 vars
120}
121
122#[derive(Clone, Debug)]
124pub struct PreparedCopy {
125 pub id: u64,
127
128 pub buffer_indices: Vec<usize>,
130
131 pub dependencies: Vec<u64>,
133}
134
135#[derive(Clone, Debug)]
137pub struct PreparedBufferView {
138 pub id: u64,
140
141 pub buffer_indices: Vec<usize>,
144
145 pub byte_offset: usize,
147
148 pub byte_size: usize,
150
151 pub dependencies: Vec<u64>,
153}
154
155#[derive(Clone, Debug)]
157pub struct PreparedCustomFunction {
158 pub id: u64,
160
161 pub kind: CustomFunctionKind,
163
164 pub attrs: SmallVec<[Arc<UOp>; 4]>,
166
167 pub buffer_indices: Vec<usize>,
169
170 pub fixedvars: HashMap<String, i64>,
172
173 pub dependencies: Vec<u64>,
175
176 pub runtime_vars: Vec<RuntimeVar>,
180}
181
182#[derive(Clone, Debug)]
184pub enum PreparedOp {
185 CompiledProgram(PreparedKernel),
187
188 BufferCopy(PreparedCopy),
190
191 BufferView(PreparedBufferView),
193
194 CustomFunction(PreparedCustomFunction),
196}
197
198fn op_identity(op: &PreparedOp) -> (u64, Vec<u64>) {
199 match op {
200 PreparedOp::CompiledProgram(kernel) => (kernel.id, kernel.dependencies.clone()),
201 PreparedOp::BufferCopy(copy) => (copy.id, copy.dependencies.clone()),
202 PreparedOp::BufferView(view) => (view.id, view.dependencies.clone()),
203 PreparedOp::CustomFunction(custom) => (custom.id, custom.dependencies.clone()),
204 }
205}
206
207fn validate_var_bound(name: &str, value: i64, min_val: i64, max_val: i64) -> Result<()> {
208 if value < min_val || value > max_val {
209 return Err(crate::error::Error::Execution {
210 reason: format!("variable {name}={value} is outside bounds [{min_val}, {max_val}]"),
211 });
212 }
213 Ok(())
214}
215
216struct DependencyGraph {
217 op_ids: Vec<u64>,
218 in_degree: Vec<usize>,
219 successors: Vec<Vec<usize>>,
220}
221
222fn build_dependency_graph(ops: &[PreparedOp], instance_deps_per_op: Option<&[Vec<usize>]>) -> Result<DependencyGraph> {
223 if let Some(instance_deps) = instance_deps_per_op
224 && instance_deps.len() != ops.len()
225 {
226 return Err(crate::error::Error::Execution {
227 reason: format!(
228 "prepared op instance dependency table length mismatch: ops={}, instance_deps={}",
229 ops.len(),
230 instance_deps.len()
231 ),
232 });
233 }
234
235 let mut op_ids = Vec::with_capacity(ops.len());
236 let mut deps_per_op = Vec::with_capacity(ops.len());
237 let mut id_counts: HashMap<u64, usize> = HashMap::with_capacity(ops.len());
238
239 for op in ops {
240 let (op_id, deps) = op_identity(op);
241 op_ids.push(op_id);
242 deps_per_op.push(deps);
243 *id_counts.entry(op_id).or_insert(0) += 1;
244 }
245
246 let has_duplicate_ids = id_counts.values().any(|&count| count > 1);
247
248 let mut in_degree = vec![0usize; ops.len()];
249 let mut successors: Vec<Vec<usize>> = vec![Vec::new(); ops.len()];
250
251 if !has_duplicate_ids {
252 let mut id_to_idx: HashMap<u64, usize> = HashMap::with_capacity(ops.len());
253 for (idx, &op_id) in op_ids.iter().enumerate() {
254 id_to_idx.insert(op_id, idx);
255 }
256
257 for (idx, deps) in deps_per_op.iter().enumerate() {
258 for dep in deps {
259 let Some(&dep_idx) = id_to_idx.get(dep) else {
260 return Err(crate::error::Error::Execution {
261 reason: format!("prepared op {} depends on unknown op id {}", op_ids[idx], dep),
262 });
263 };
264 in_degree[idx] += 1;
265 successors[dep_idx].push(idx);
266 }
267 }
268 } else {
269 let mut last_seen: HashMap<u64, usize> = HashMap::with_capacity(ops.len());
272
273 for (idx, deps) in deps_per_op.iter().enumerate() {
274 for dep in deps {
275 let Some(&dep_idx) = last_seen.get(dep) else {
276 return Err(crate::error::Error::Execution {
277 reason: format!(
278 "prepared op {} depends on unknown prior op id {} (duplicate-id schedule mode)",
279 op_ids[idx], dep
280 ),
281 });
282 };
283 in_degree[idx] += 1;
284 successors[dep_idx].push(idx);
285 }
286
287 last_seen.insert(op_ids[idx], idx);
288 }
289 }
290
291 if let Some(instance_deps_per_op) = instance_deps_per_op {
292 for (idx, instance_deps) in instance_deps_per_op.iter().enumerate() {
293 for &dep_idx in instance_deps {
294 if dep_idx >= ops.len() {
295 return Err(crate::error::Error::Execution {
296 reason: format!("prepared op {} depends on unknown op index {}", op_ids[idx], dep_idx),
297 });
298 }
299 if dep_idx == idx {
300 return Err(crate::error::Error::Execution {
301 reason: format!("prepared op {} cannot depend on itself by op index {}", op_ids[idx], dep_idx),
302 });
303 }
304 in_degree[idx] += 1;
305 successors[dep_idx].push(idx);
306 }
307 }
308 }
309
310 Ok(DependencyGraph { op_ids, in_degree, successors })
311}
312
313#[cfg(test)]
314fn compute_mixed_op_order(ops: &[PreparedOp]) -> Result<Vec<usize>> {
315 compute_mixed_op_order_with_instance_dependencies(ops, &[])
316}
317
318fn compute_mixed_op_order_with_instance_dependencies(
319 ops: &[PreparedOp],
320 instance_deps_per_op: &[Vec<usize>],
321) -> Result<Vec<usize>> {
322 let instance_deps = (!instance_deps_per_op.is_empty()).then_some(instance_deps_per_op);
323 let DependencyGraph { op_ids, mut in_degree, successors } = build_dependency_graph(ops, instance_deps)?;
324
325 let mut ready: BinaryHeap<Reverse<usize>> = BinaryHeap::new();
326 for (idx, °) in in_degree.iter().enumerate() {
327 if deg == 0 {
328 ready.push(Reverse(idx));
329 }
330 }
331
332 let mut order = Vec::with_capacity(ops.len());
333 while let Some(Reverse(idx)) = ready.pop() {
334 order.push(idx);
335 for &succ in &successors[idx] {
336 in_degree[succ] -= 1;
337 if in_degree[succ] == 0 {
338 ready.push(Reverse(succ));
339 }
340 }
341 }
342
343 if order.len() != ops.len() {
344 let blocked: Vec<u64> = in_degree
345 .iter()
346 .enumerate()
347 .filter_map(|(idx, °)| if deg > 0 { Some(op_ids[idx]) } else { None })
348 .collect();
349 return Err(crate::error::Error::Execution {
350 reason: format!("cycle detected in prepared op dependencies: blocked_ids={blocked:?}"),
351 });
352 }
353
354 Ok(order)
355}
356
357#[cfg(test)]
358fn compute_execution_levels(ops: &[PreparedOp]) -> Result<Vec<Vec<usize>>> {
359 compute_execution_levels_with_instance_dependencies(ops, &[])
360}
361
362fn compute_execution_levels_with_instance_dependencies(
363 ops: &[PreparedOp],
364 instance_deps_per_op: &[Vec<usize>],
365) -> Result<Vec<Vec<usize>>> {
366 let instance_deps = (!instance_deps_per_op.is_empty()).then_some(instance_deps_per_op);
367 let DependencyGraph { op_ids, mut in_degree, successors } = build_dependency_graph(ops, instance_deps)?;
368
369 let mut ready: BinaryHeap<Reverse<usize>> = BinaryHeap::new();
370 for (idx, °) in in_degree.iter().enumerate() {
371 if deg == 0 {
372 ready.push(Reverse(idx));
373 }
374 }
375
376 let mut levels: Vec<Vec<usize>> = Vec::new();
377 let mut visited = 0usize;
378
379 while !ready.is_empty() {
380 let mut level: Vec<usize> = Vec::new();
381 while let Some(Reverse(idx)) = ready.pop() {
382 level.push(idx);
383 }
384
385 let mut next_ready: BinaryHeap<Reverse<usize>> = BinaryHeap::new();
386 for &idx in &level {
387 visited += 1;
388 for &succ in &successors[idx] {
389 in_degree[succ] -= 1;
390 if in_degree[succ] == 0 {
391 next_ready.push(Reverse(succ));
392 }
393 }
394 }
395
396 levels.push(level);
397 ready = next_ready;
398 }
399
400 if visited != ops.len() {
401 let blocked: Vec<u64> = in_degree
402 .iter()
403 .enumerate()
404 .filter_map(|(idx, °)| if deg > 0 { Some(op_ids[idx]) } else { None })
405 .collect();
406 return Err(crate::error::Error::Execution {
407 reason: format!("cycle detected in prepared op dependencies: blocked_ids={blocked:?}"),
408 });
409 }
410
411 Ok(levels)
412}
413
414pub struct ExecutionPlan {
419 ops: Vec<PreparedOp>,
421
422 op_instance_dependencies: Vec<Vec<usize>>,
424
425 op_order: Vec<usize>,
427
428 op_levels: Vec<Vec<usize>>,
430
431 buffers: Vec<Buffer>,
433
434 ast_to_buffer: HashMap<u64, usize>,
436
437 output_buffer_indices: Vec<usize>,
439
440 device: DeviceSpec,
442
443 runtime_var_vals: HashMap<String, i64>,
445
446 alias_ids: Vec<u64>,
448}
449
450impl ExecutionPlan {
455 fn kernel_launch_sizes(kernel: &PreparedKernel) -> Result<RuntimeLaunchSizes> {
456 let mut vars: HashMap<&str, i64> =
457 HashMap::with_capacity(kernel.kernel.var_names.len() + kernel.fixedvars.len());
458 for (idx, name) in kernel.kernel.var_names.iter().enumerate() {
459 let value = kernel.vals.get(idx).copied().ok_or_else(|| crate::error::Error::Execution {
460 reason: format!(
461 "Kernel {} has {} var names but only {} values",
462 kernel.id,
463 kernel.kernel.var_names.len(),
464 kernel.vals.len()
465 ),
466 })?;
467 vars.insert(name.as_str(), value);
468 }
469 for (name, value) in &kernel.fixedvars {
470 vars.insert(name.as_str(), *value);
471 }
472
473 let dims =
474 ProgramSpec::resolve_launch_dims(&kernel.kernel.global_size, kernel.kernel.local_size.as_ref(), &vars)
475 .map_err(|e| crate::error::Error::Execution {
476 reason: format!("Kernel {} launch dimensions failed: {e}", kernel.id),
477 })?;
478 Ok((Some(dims.global_size), dims.local_size))
479 }
480
481 fn kernel_uses_cpu_threading(kernel: &PreparedKernel) -> Result<bool> {
482 if !matches!(kernel.device, DeviceSpec::Cpu) {
483 return Ok(false);
484 }
485 let (global_size, _) = Self::kernel_launch_sizes(kernel)?;
486 Ok(global_size.map(|[x, _, _]| x > 1).unwrap_or(false))
487 }
488
489 #[inline]
490 fn execute_kernel(kernel: &PreparedKernel) -> Result<()> {
491 let buffer_ptrs: SmallVec<[*mut u8; 8]> = kernel.buffer_ptrs.iter().map(|&ptr| ptr as *mut u8).collect();
492 let (global_size, local_size) = Self::kernel_launch_sizes(kernel)?;
493 unsafe {
494 kernel
495 .kernel
496 .program
497 .execute(&buffer_ptrs, &kernel.vals, global_size, local_size)
498 .map_err(|e| crate::error::Error::Execution { reason: format!("Kernel {} failed: {}", kernel.id, e) })
499 }
500 }
501
502 fn validate_runtime_var_bounds(&self, var_vals: &[(&str, i64)]) -> Result<()> {
503 let vals_map: HashMap<&str, i64> = var_vals.iter().copied().collect();
504 for op in &self.ops {
505 match op {
506 PreparedOp::CompiledProgram(kernel) => {
507 for var in &kernel.runtime_vars {
508 if kernel.fixedvars.contains_key(&var.name) || var.name == "core_id" {
509 continue;
510 }
511 if let Some(&value) = vals_map.get(var.name.as_str()) {
512 validate_var_bound(&var.name, value, var.min_val, var.max_val)?;
513 }
514 }
515 }
516 PreparedOp::CustomFunction(custom) => {
517 for var in &custom.runtime_vars {
518 if custom.fixedvars.contains_key(&var.name) || var.name == "core_id" {
519 continue;
520 }
521 if let Some(&value) = vals_map.get(var.name.as_str()) {
522 validate_var_bound(&var.name, value, var.min_val, var.max_val)?;
523 }
524 }
525 }
526 PreparedOp::BufferCopy(_) | PreparedOp::BufferView(_) => {}
527 }
528 }
529 Ok(())
530 }
531
532 fn update_runtime_var_vals(&mut self, var_vals: &[(&str, i64)]) -> Result<()> {
533 self.validate_runtime_var_bounds(var_vals)?;
534
535 let vals_map: HashMap<&str, i64> = var_vals.iter().copied().collect();
536 for &(name, value) in var_vals {
537 if name == "core_id" {
538 continue;
539 }
540 self.runtime_var_vals.insert(name.to_string(), value);
541 }
542 for op in &mut self.ops {
543 if let PreparedOp::CompiledProgram(kernel) = op {
544 for (idx, name) in kernel.kernel.var_names.iter().enumerate() {
545 if kernel.fixedvars.contains_key(name) || name == "core_id" {
546 continue;
547 }
548 if let Some(&v) = vals_map.get(name.as_str()) {
549 let Some(slot) = kernel.vals.get_mut(idx) else {
550 return Err(crate::error::Error::Execution {
551 reason: format!(
552 "Kernel {} has {} var names but only {} values",
553 kernel.id,
554 kernel.kernel.var_names.len(),
555 kernel.vals.len()
556 ),
557 });
558 };
559 *slot = v;
560 }
561 }
562 }
563 }
564 Ok(())
565 }
566
567 #[inline]
568 fn execute_copy(&self, copy: &PreparedCopy) -> Result<()> {
569 if copy.buffer_indices.len() < 2 {
570 return Err(crate::error::Error::Execution {
571 reason: format!(
572 "Copy op {} requires at least two buffer indices (dst, src), got {}",
573 copy.id,
574 copy.buffer_indices.len()
575 ),
576 });
577 }
578 let dst_idx = copy.buffer_indices[0];
579 let src_idx = copy.buffer_indices[1];
580
581 if dst_idx >= self.buffers.len() || src_idx >= self.buffers.len() {
582 return Err(crate::error::Error::Execution {
583 reason: format!(
584 "Copy op {} buffer index out of range: dst={}, src={}, total_buffers={}",
585 copy.id,
586 dst_idx,
587 src_idx,
588 self.buffers.len()
589 ),
590 });
591 }
592
593 let mut dst = self.buffers[dst_idx].clone();
594 let src = &self.buffers[src_idx];
595 dst.copy_from(src)
596 .map_err(|e| crate::error::Error::Execution { reason: format!("Copy op {} failed: {}", copy.id, e) })
597 }
598
599 #[inline]
600 fn execute_buffer_view(&self, view: &PreparedBufferView) -> Result<()> {
601 if view.buffer_indices.len() < 2 {
602 return Err(crate::error::Error::Execution {
603 reason: format!(
604 "BufferView op {} requires at least two buffer indices (out, base), got {}",
605 view.id,
606 view.buffer_indices.len()
607 ),
608 });
609 }
610 let out_idx = view.buffer_indices[0];
611 let base_idx = view.buffer_indices[1];
612
613 if out_idx >= self.buffers.len() || base_idx >= self.buffers.len() {
614 return Err(crate::error::Error::Execution {
615 reason: format!(
616 "BufferView op {} buffer index out of range: out={}, base={}, total_buffers={}",
617 view.id,
618 out_idx,
619 base_idx,
620 self.buffers.len()
621 ),
622 });
623 }
624
625 let out = &self.buffers[out_idx];
626 let base = &self.buffers[base_idx];
627 let expected_offset = base.offset() + view.byte_offset;
628
629 if out.storage_id() != base.storage_id() || out.offset() != expected_offset || out.size() != view.byte_size {
630 return Err(crate::error::Error::Execution {
631 reason: format!(
632 "BufferView op {} mismatch: out(storage={:?},off={},size={}) base(storage={:?},off={}) expected(off={},size={})",
633 view.id,
634 out.storage_id(),
635 out.offset(),
636 out.size(),
637 base.storage_id(),
638 base.offset(),
639 expected_offset,
640 view.byte_size,
641 ),
642 });
643 }
644 Ok(())
645 }
646
647 #[inline]
648 fn execute_custom_function(&self, custom: &PreparedCustomFunction) -> Result<()> {
649 let mut buffers = Vec::with_capacity(custom.buffer_indices.len());
650 for &idx in &custom.buffer_indices {
651 let Some(buffer) = self.buffers.get(idx) else {
652 return Err(crate::error::Error::Execution {
653 reason: format!(
654 "Custom function op {} ({:?}) buffer index out of range: idx={}, total_buffers={}",
655 custom.id,
656 custom.kind,
657 idx,
658 self.buffers.len()
659 ),
660 });
661 };
662 buffers.push(buffer.clone());
663 }
664
665 let mut vars = self.runtime_var_vals.clone();
666 vars.extend(custom.fixedvars.iter().map(|(k, v)| (k.clone(), *v)));
667
668 crate::custom_function::run_custom_function(&custom.kind, &custom.attrs, &mut buffers, &vars).map_err(|e| {
669 match e {
672 crate::error::Error::Unsupported { .. } => e,
673 other => crate::error::Error::Execution {
674 reason: format!("Custom function op {} ({:?}) failed: {other}", custom.id, custom.kind),
675 },
676 }
677 })
678 }
679
680 #[inline]
681 fn execute_op(&self, op: &PreparedOp) -> Result<()> {
682 match op {
683 PreparedOp::CompiledProgram(kernel) => Self::execute_kernel(kernel),
684 PreparedOp::BufferCopy(copy) => self.execute_copy(copy),
685 PreparedOp::BufferView(view) => self.execute_buffer_view(view),
686 PreparedOp::CustomFunction(custom) => self.execute_custom_function(custom),
687 }
688 }
689
690 #[inline]
691 fn op_requires_serial(op: &PreparedOp) -> bool {
692 match op {
693 PreparedOp::CompiledProgram(kernel) => !kernel.kernel.host_parallel_safe,
694 PreparedOp::BufferCopy(_) | PreparedOp::BufferView(_) | PreparedOp::CustomFunction(_) => true,
695 }
696 }
697
698 #[inline]
699 fn compiled_kernel_at(&self, idx: usize) -> Option<&PreparedKernel> {
700 match &self.ops[idx] {
701 PreparedOp::CompiledProgram(kernel) => Some(kernel),
702 _ => None,
703 }
704 }
705
706 fn kernels_conflict(lhs: &PreparedKernel, rhs: &PreparedKernel) -> bool {
707 let lhs_outputs: HashSet<BufferId> =
708 lhs.output_indices.iter().filter_map(|&out_idx| lhs.buffer_ids.get(out_idx).copied()).collect();
709 let rhs_outputs: HashSet<BufferId> =
710 rhs.output_indices.iter().filter_map(|&out_idx| rhs.buffer_ids.get(out_idx).copied()).collect();
711
712 if !lhs_outputs.is_disjoint(&rhs_outputs) {
713 return true;
714 }
715
716 let lhs_reads: HashSet<BufferId> = lhs
717 .buffer_ids
718 .iter()
719 .enumerate()
720 .filter_map(|(idx, &buf)| (!lhs.output_indices.contains(&idx)).then_some(buf))
721 .collect();
722 let rhs_reads: HashSet<BufferId> = rhs
723 .buffer_ids
724 .iter()
725 .enumerate()
726 .filter_map(|(idx, &buf)| (!rhs.output_indices.contains(&idx)).then_some(buf))
727 .collect();
728
729 !lhs_outputs.is_disjoint(&rhs_reads) || !rhs_outputs.is_disjoint(&lhs_reads)
730 }
731
732 fn partition_parallel_safe_group(&self, indices: &[usize]) -> Result<Vec<Vec<usize>>> {
733 let mut groups: Vec<Vec<usize>> = Vec::new();
734
735 for &idx in indices {
736 let Some(kernel) = self.compiled_kernel_at(idx) else {
737 return Err(crate::error::Error::Execution {
738 reason: format!("parallel partition expected compiled kernel at op index {idx}"),
739 });
740 };
741
742 let mut placed = false;
743 for group in &mut groups {
744 let has_conflict = group.iter().any(|&existing_idx| {
745 self.compiled_kernel_at(existing_idx)
746 .map(|existing| Self::kernels_conflict(existing, kernel))
747 .unwrap_or(true)
748 });
749 if !has_conflict {
750 group.push(idx);
751 placed = true;
752 break;
753 }
754 }
755
756 if !placed {
757 groups.push(vec![idx]);
758 }
759 }
760
761 Ok(groups)
762 }
763
764 fn execute_parallel_group(&self, indices: &[usize]) -> Result<()> {
765 if indices.len() <= 1 {
766 if let Some(&idx) = indices.first() {
767 self.execute_op(&self.ops[idx])?;
768 }
769 return Ok(());
770 }
771
772 let has_threaded_cpu_kernel = indices.iter().try_fold(false, |acc, &idx| {
773 let Some(kernel) = self.compiled_kernel_at(idx) else {
774 return Err(crate::error::Error::Execution {
775 reason: format!("parallel execution expected compiled kernel at op index {idx}"),
776 });
777 };
778 Ok(acc || Self::kernel_uses_cpu_threading(kernel)?)
779 })?;
780
781 if has_threaded_cpu_kernel {
782 for &idx in indices {
783 let Some(kernel) = self.compiled_kernel_at(idx) else {
784 return Err(crate::error::Error::Execution {
785 reason: format!("parallel execution expected compiled kernel at op index {idx}"),
786 });
787 };
788 Self::execute_kernel(kernel)?;
789 }
790 return Ok(());
791 }
792
793 indices
794 .par_iter()
795 .map(|&idx| {
796 let Some(kernel) = self.compiled_kernel_at(idx) else {
797 return Err(crate::error::Error::Execution {
798 reason: format!("parallel execution expected compiled kernel at op index {idx}"),
799 });
800 };
801 Self::execute_kernel(kernel)
802 })
803 .collect::<Result<Vec<_>>>()?;
804
805 Ok(())
806 }
807
808 fn execute_parallel_group_profiled(&self, indices: &[usize]) -> Result<Vec<(usize, KernelProfile)>> {
809 if indices.len() <= 1 {
810 let mut profiles = Vec::new();
811 if let Some(&idx) = indices.first() {
812 let Some(kernel) = self.compiled_kernel_at(idx) else {
813 return Err(crate::error::Error::Execution {
814 reason: format!("profiled execution expected compiled kernel at op index {idx}"),
815 });
816 };
817 let start = Instant::now();
818 Self::execute_kernel(kernel)?;
819 profiles.push((
820 idx,
821 KernelProfile {
822 kernel: Arc::clone(&kernel.kernel),
823 device: kernel.device.clone(),
824 num_buffers: kernel.buffer_ptrs.len(),
825 elapsed: start.elapsed(),
826 },
827 ));
828 }
829 return Ok(profiles);
830 }
831
832 let has_threaded_cpu_kernel = indices.iter().try_fold(false, |acc, &idx| {
833 let Some(kernel) = self.compiled_kernel_at(idx) else {
834 return Err(crate::error::Error::Execution {
835 reason: format!("profiled execution expected compiled kernel at op index {idx}"),
836 });
837 };
838 Ok(acc || Self::kernel_uses_cpu_threading(kernel)?)
839 })?;
840
841 if has_threaded_cpu_kernel {
842 let mut profiles = Vec::with_capacity(indices.len());
843 for &idx in indices {
844 let Some(kernel) = self.compiled_kernel_at(idx) else {
845 return Err(crate::error::Error::Execution {
846 reason: format!("profiled execution expected compiled kernel at op index {idx}"),
847 });
848 };
849 let start = Instant::now();
850 Self::execute_kernel(kernel)?;
851 profiles.push((
852 idx,
853 KernelProfile {
854 kernel: Arc::clone(&kernel.kernel),
855 device: kernel.device.clone(),
856 num_buffers: kernel.buffer_ptrs.len(),
857 elapsed: start.elapsed(),
858 },
859 ));
860 }
861 return Ok(profiles);
862 }
863
864 let mut profiles = indices
865 .par_iter()
866 .map(|&idx| {
867 let Some(kernel) = self.compiled_kernel_at(idx) else {
868 return Err(crate::error::Error::Execution {
869 reason: format!("profiled execution expected compiled kernel at op index {idx}"),
870 });
871 };
872 let start = Instant::now();
873 Self::execute_kernel(kernel)?;
874 Ok((
875 idx,
876 KernelProfile {
877 kernel: Arc::clone(&kernel.kernel),
878 device: kernel.device.clone(),
879 num_buffers: kernel.buffer_ptrs.len(),
880 elapsed: start.elapsed(),
881 },
882 ))
883 })
884 .collect::<Result<Vec<_>>>()?;
885
886 profiles.sort_by_key(|(idx, _)| *idx);
887 Ok(profiles)
888 }
889
890 pub fn output_buffer(&self) -> Option<&Buffer> {
895 self.output_buffer_indices.first().and_then(|&i| self.buffers.get(i))
896 }
897
898 pub fn output_buffer_at(&self, position: usize) -> Option<&Buffer> {
902 self.output_buffer_indices.get(position).and_then(|&i| self.buffers.get(i))
903 }
904
905 pub fn output_buffers(&self) -> Vec<&Buffer> {
907 self.output_buffer_indices.iter().map(|&i| &self.buffers[i]).collect()
908 }
909
910 pub fn num_outputs(&self) -> usize {
912 self.output_buffer_indices.len()
913 }
914
915 pub fn buffer(&self, ast_id: u64) -> Option<&Buffer> {
917 self.ast_to_buffer.get(&ast_id).map(|&idx| &self.buffers[idx])
918 }
919
920 pub fn buffer_mut_by_id(&mut self, ast_id: u64) -> Option<&mut Buffer> {
922 self.ast_to_buffer.get(&ast_id).copied().map(|idx| &mut self.buffers[idx])
923 }
924
925 pub fn device(&self) -> &DeviceSpec {
927 &self.device
928 }
929
930 pub fn buffers(&self) -> &[Buffer] {
932 &self.buffers
933 }
934
935 pub fn buffers_mut(&mut self) -> &mut [Buffer] {
937 &mut self.buffers
938 }
939
940 pub fn buffer_at_mut(&mut self, index: usize) -> Option<&mut Buffer> {
942 self.buffers.get_mut(index)
943 }
944
945 pub fn prepared_kernels(&self) -> Vec<&PreparedKernel> {
947 self.ops
948 .iter()
949 .filter_map(|op| match op {
950 PreparedOp::CompiledProgram(kernel) => Some(kernel),
951 _ => None,
952 })
953 .collect()
954 }
955
956 pub fn prepared_ops(&self) -> &[PreparedOp] {
958 &self.ops
959 }
960
961 pub fn kernels(&self) -> impl Iterator<Item = &CachedKernel> {
963 self.ops.iter().filter_map(|op| match op {
964 PreparedOp::CompiledProgram(kernel) => Some(kernel.kernel.as_ref()),
965 _ => None,
966 })
967 }
968
969 pub fn execute(&self) -> Result<()> {
973 for level in &self.op_levels {
974 let mut pending_parallel: Vec<usize> = Vec::new();
975
976 for &idx in level {
977 let op = &self.ops[idx];
978 if Self::op_requires_serial(op) {
979 if !pending_parallel.is_empty() {
980 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
981 for group in groups {
982 self.execute_parallel_group(&group)?;
983 }
984 pending_parallel.clear();
985 }
986 self.execute_op(op)?;
987 } else {
988 pending_parallel.push(idx);
989 }
990 }
991
992 if !pending_parallel.is_empty() {
993 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
994 for group in groups {
995 self.execute_parallel_group(&group)?;
996 }
997 }
998 }
999 Ok(())
1000 }
1001
1002 pub fn execute_profiled(&self) -> Result<Vec<KernelProfile>> {
1020 let mut profiles = Vec::new();
1021 for level in &self.op_levels {
1022 let mut pending_parallel: Vec<usize> = Vec::new();
1023
1024 for &idx in level {
1025 match &self.ops[idx] {
1026 PreparedOp::CompiledProgram(kernel) if kernel.kernel.host_parallel_safe => {
1027 pending_parallel.push(idx);
1028 }
1029 PreparedOp::CompiledProgram(kernel) => {
1030 if !pending_parallel.is_empty() {
1031 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1032 for group in groups {
1033 let mut prof = self.execute_parallel_group_profiled(&group)?;
1034 profiles.extend(prof.drain(..).map(|(_, p)| p));
1035 }
1036 pending_parallel.clear();
1037 }
1038
1039 let start = Instant::now();
1040 Self::execute_kernel(kernel)?;
1041 profiles.push(KernelProfile {
1042 kernel: Arc::clone(&kernel.kernel),
1043 device: kernel.device.clone(),
1044 num_buffers: kernel.buffer_ptrs.len(),
1045 elapsed: start.elapsed(),
1046 });
1047 }
1048 PreparedOp::BufferCopy(copy) => {
1049 if !pending_parallel.is_empty() {
1050 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1051 for group in groups {
1052 let mut prof = self.execute_parallel_group_profiled(&group)?;
1053 profiles.extend(prof.drain(..).map(|(_, p)| p));
1054 }
1055 pending_parallel.clear();
1056 }
1057 self.execute_copy(copy)?;
1058 }
1059 PreparedOp::BufferView(view) => {
1060 if !pending_parallel.is_empty() {
1061 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1062 for group in groups {
1063 let mut prof = self.execute_parallel_group_profiled(&group)?;
1064 profiles.extend(prof.drain(..).map(|(_, p)| p));
1065 }
1066 pending_parallel.clear();
1067 }
1068 self.execute_buffer_view(view)?;
1069 }
1070 PreparedOp::CustomFunction(custom) => {
1071 if !pending_parallel.is_empty() {
1072 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1073 for group in groups {
1074 let mut prof = self.execute_parallel_group_profiled(&group)?;
1075 profiles.extend(prof.drain(..).map(|(_, p)| p));
1076 }
1077 pending_parallel.clear();
1078 }
1079 self.execute_custom_function(custom)?;
1080 }
1081 }
1082 }
1083
1084 if !pending_parallel.is_empty() {
1085 let groups = self.partition_parallel_safe_group(&pending_parallel)?;
1086 for group in groups {
1087 let mut prof = self.execute_parallel_group_profiled(&group)?;
1088 profiles.extend(prof.drain(..).map(|(_, p)| p));
1089 }
1090 }
1091 }
1092 Ok(profiles)
1093 }
1094
1095 pub fn execute_with_vars(&mut self, var_vals: &[(&str, i64)]) -> Result<()> {
1112 self.update_runtime_var_vals(var_vals)?;
1113 self.execute()
1114 }
1115
1116 pub fn execute_with_vars_profiled(&mut self, var_vals: &[(&str, i64)]) -> Result<Vec<KernelProfile>> {
1121 self.update_runtime_var_vals(var_vals)?;
1122 self.execute_profiled()
1123 }
1124
1125 pub fn output_buffer_idx(&self) -> usize {
1127 self.output_buffer_indices[0]
1128 }
1129
1130 pub fn ast_to_buffer_map(&self) -> &HashMap<u64, usize> {
1132 &self.ast_to_buffer
1133 }
1134
1135 pub fn release_intermediate_buffers<F>(&self, remove_fn: F)
1140 where
1141 F: Fn(u64),
1142 {
1143 self.release_buffers_impl(remove_fn, true);
1144 }
1145
1146 pub fn release_all_buffers<F>(&self, remove_fn: F)
1148 where
1149 F: Fn(u64),
1150 {
1151 self.release_buffers_impl(remove_fn, false);
1152 }
1153
1154 fn release_buffers_impl<F>(&self, remove_fn: F, skip_output: bool)
1155 where
1156 F: Fn(u64),
1157 {
1158 let output_buf_ids: std::collections::HashSet<u64> = if skip_output {
1159 self.output_buffer_indices.iter().filter_map(|&idx| self.buffers.get(idx).map(|b| b.id().0)).collect()
1160 } else {
1161 std::collections::HashSet::new()
1162 };
1163
1164 for (&ast_id, &buf_idx) in &self.ast_to_buffer {
1165 if skip_output && output_buf_ids.contains(&self.buffers[buf_idx].id().0) {
1166 continue;
1167 }
1168 remove_fn(ast_id);
1169 }
1170
1171 for &alias_id in &self.alias_ids {
1172 remove_fn(alias_id);
1173 }
1174 }
1175}
1176
1177impl std::fmt::Debug for ExecutionPlan {
1178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1179 let kernel_count = self.ops.iter().filter(|op| matches!(op, PreparedOp::CompiledProgram(_))).count();
1180 f.debug_struct("ExecutionPlan")
1181 .field("ops", &self.ops.len())
1182 .field("op_instance_dependencies", &self.op_instance_dependencies.len())
1183 .field("op_order", &self.op_order.len())
1184 .field("kernels", &kernel_count)
1185 .field("buffers", &self.buffers.len())
1186 .field("device", &self.device)
1187 .finish()
1188 }
1189}
1190
1191impl std::fmt::Debug for PreparedKernel {
1192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1193 f.debug_struct("PreparedKernel")
1194 .field("id", &self.id)
1195 .field("device", &self.device)
1196 .field("buffer_indices", &self.buffer_indices)
1197 .field("output_indices", &self.output_indices)
1198 .field("vals", &self.vals)
1199 .field("fixedvars", &self.fixedvars)
1200 .field("dependencies", &self.dependencies)
1201 .finish()
1202 }
1203}
1204
1205pub struct ExecutionPlanBuilder {
1211 ops: Vec<PreparedOp>,
1212 op_instance_dependencies: Vec<Vec<usize>>,
1213 buffers: Vec<Buffer>,
1214 ast_to_buffer: HashMap<u64, usize>,
1215 output_buffer_indices: Vec<usize>,
1216 device: DeviceSpec,
1217 alias_ids: Vec<u64>,
1218}
1219
1220impl ExecutionPlanBuilder {
1221 pub fn new(device: DeviceSpec) -> Self {
1223 Self {
1224 ops: Vec::new(),
1225 op_instance_dependencies: Vec::new(),
1226 buffers: Vec::new(),
1227 ast_to_buffer: HashMap::new(),
1228 output_buffer_indices: Vec::new(),
1229 device,
1230 alias_ids: Vec::new(),
1231 }
1232 }
1233
1234 pub fn add_alias_ids(&mut self, ids: impl IntoIterator<Item = u64>) {
1236 self.alias_ids.extend(ids);
1237 }
1238
1239 pub fn add_buffer(&mut self, ast_id: u64, buffer: Buffer) -> usize {
1241 let idx = self.buffers.len();
1242 self.buffers.push(buffer);
1243 self.ast_to_buffer.insert(ast_id, idx);
1244 idx
1245 }
1246
1247 pub fn map_buffer(&mut self, ast_id: u64, idx: usize) {
1249 self.ast_to_buffer.insert(ast_id, idx);
1250 }
1251
1252 pub fn replace_buffer(&mut self, idx: usize, buffer: Buffer) {
1254 self.buffers[idx] = buffer;
1255 }
1256
1257 pub fn set_output_buffer(&mut self, idx: usize) {
1259 self.output_buffer_indices = vec![idx];
1260 }
1261
1262 pub fn set_output_buffers(&mut self, indices: Vec<usize>) {
1264 self.output_buffer_indices = indices;
1265 }
1266
1267 pub fn add_kernel(&mut self, kernel: PreparedKernel) {
1271 self.add_op(PreparedOp::CompiledProgram(kernel));
1272 }
1273
1274 pub fn add_op(&mut self, op: PreparedOp) {
1276 self.add_op_with_instance_dependencies(op, Vec::new());
1277 }
1278
1279 pub fn add_op_with_instance_dependencies(&mut self, op: PreparedOp, instance_dependencies: Vec<usize>) {
1281 self.ops.push(op);
1282 self.op_instance_dependencies.push(instance_dependencies);
1283 }
1284
1285 pub fn build(mut self) -> Result<ExecutionPlan> {
1290 for op in &mut self.ops {
1291 let PreparedOp::CompiledProgram(kernel) = op else {
1292 continue;
1293 };
1294
1295 if kernel.output_indices.is_empty() {
1296 return Err(crate::error::Error::Execution {
1297 reason: format!("CompiledProgram {} has no output indices", kernel.id),
1298 });
1299 }
1300 for &out_idx in &kernel.output_indices {
1301 if out_idx >= kernel.buffer_indices.len() {
1302 return Err(crate::error::Error::Execution {
1303 reason: format!(
1304 "CompiledProgram {} output index out of range: output_idx={}, kernel_buffers={}",
1305 kernel.id,
1306 out_idx,
1307 kernel.buffer_indices.len()
1308 ),
1309 });
1310 }
1311 }
1312
1313 let mut buffer_ptrs = Vec::with_capacity(kernel.buffer_indices.len());
1314 let mut buffer_ids = Vec::with_capacity(kernel.buffer_indices.len());
1315
1316 for &idx in &kernel.buffer_indices {
1317 let Some(buffer) = self.buffers.get(idx) else {
1318 return Err(crate::error::Error::Execution {
1319 reason: format!(
1320 "CompiledProgram {} buffer index out of range: idx={}, total_buffers={}",
1321 kernel.id,
1322 idx,
1323 self.buffers.len()
1324 ),
1325 });
1326 };
1327 buffer_ptrs.push(unsafe { buffer.as_raw_ptr() } as usize);
1328 buffer_ids.push(buffer.id());
1329 }
1330
1331 kernel.buffer_ptrs = buffer_ptrs;
1332 kernel.buffer_ids = buffer_ids;
1333 }
1334
1335 if self.output_buffer_indices.is_empty() && !self.buffers.is_empty() {
1336 return Err(crate::error::Error::Execution {
1337 reason: "execution plan output buffers must be set explicitly".to_string(),
1338 });
1339 }
1340
1341 let op_order = compute_mixed_op_order_with_instance_dependencies(&self.ops, &self.op_instance_dependencies)?;
1342 let op_levels = compute_execution_levels_with_instance_dependencies(&self.ops, &self.op_instance_dependencies)?;
1343
1344 Ok(ExecutionPlan {
1345 ops: self.ops,
1346 op_instance_dependencies: self.op_instance_dependencies,
1347 op_order,
1348 op_levels,
1349 buffers: self.buffers,
1350 ast_to_buffer: self.ast_to_buffer,
1351 output_buffer_indices: self.output_buffer_indices,
1352 device: self.device,
1353 runtime_var_vals: HashMap::new(),
1354 alias_ids: self.alias_ids,
1355 })
1356 }
1357}
1358
1359#[cfg(test)]
1360#[path = "test/unit/execution_plan.rs"]
1361mod tests;