1use std::collections::{HashMap, HashSet};
30use std::hash::{Hash, Hasher};
31
32use svod_schedule::{Scheduler, apply_post_optimization_with_renderer, beam_search_cached, prepare_scheduler};
33use tracing::{debug, trace};
34
35use crate::{
36 PrepareConfig, Result, Tensor,
37 error::{
38 BatchOutputMismatchSnafu, CompileKernelSnafu, CreateProgramSnafu, DeviceSnafu, EmptyScheduleSnafu,
39 ExecutionSnafu, IrConstructionSnafu, OptimizeSnafu, RangeifySnafu, RenderKernelSnafu, ShapeUnknownSnafu,
40 UOpSnafu,
41 },
42 schedule::ScheduleItem,
43};
44use snafu::{OptionExt, ResultExt};
45use std::sync::Arc;
46use std::time::Duration;
47use svod_device::{Buffer, device::Device};
48use svod_ir::pattern::is_any_const;
49use svod_ir::{DeviceSpec, Op, UOp, UOpKey};
50use svod_runtime::{
51 ExecutionPlan, ExecutionPlanBuilder, PreparedBufferView, PreparedCopy, PreparedCustomFunction, PreparedKernel,
52 PreparedOp,
53};
54
55fn collect_pending_indices(tensors: &[&mut Tensor]) -> Vec<usize> {
56 tensors
57 .iter()
58 .enumerate()
59 .filter(|(_, t)| !t.uop().has_buffer_identity() && !is_any_const(&t.uop()) && !t.has_zero_elements())
60 .map(|(i, _)| i)
61 .collect()
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65struct BufferStorageKey {
66 id: u64,
67 offset: usize,
68 size: usize,
69 dtype: svod_dtype::DType,
70}
71
72impl Tensor {
73 pub fn realize(&mut self) -> Result<()> {
97 if self.uop().has_buffer_identity() {
98 self.ensure_buffer();
99 return Ok(());
100 }
101 if is_any_const(&self.uop()) {
103 let contiguous_uop = self.uop().contiguous();
104 self.set_uop(contiguous_uop);
105 }
106 if self.has_zero_elements() {
107 return Ok(());
108 }
109
110 let old_uop = self.uop();
111 let input_buffer_ids: HashSet<u64> = collect_input_buffers(&old_uop).keys().copied().collect();
112
113 let t_prep = std::time::Instant::now();
114 let plan = self.prepare()?;
115 let prep_ms = t_prep.elapsed().as_millis();
116 let t_exec = std::time::Instant::now();
117 plan.execute().context(ExecutionSnafu)?;
118 let exec_ms = t_exec.elapsed().as_millis();
119 debug!(prep_ms, exec_ms, "realize complete");
120
121 self.finalize_realize(&plan, &old_uop)?;
122
123 let realized_uop = self.uop();
124 if !Arc::ptr_eq(&old_uop, &realized_uop) {
125 #[allow(clippy::mutable_key_type)]
126 let becomes_map = HashMap::from([(UOpKey(old_uop), realized_uop)]);
127 crate::tensor_registry::apply_map_to_tensors(&becomes_map);
128 }
129
130 plan.release_intermediate_buffers(|uop_id| {
131 if !input_buffer_ids.contains(&uop_id) {
132 crate::tensor_registry::remove_buffer(uop_id);
133 }
134 });
135
136 Ok(())
137 }
138
139 pub fn realize_with(&mut self, config: &PrepareConfig) -> Result<()> {
159 if self.uop().has_buffer_identity() {
160 self.ensure_buffer();
161 return Ok(());
162 }
163 if is_any_const(&self.uop()) {
165 let contiguous_uop = self.uop().contiguous();
166 self.set_uop(contiguous_uop);
167 }
168 if self.has_zero_elements() {
169 return Ok(());
170 }
171
172 let old_uop = self.uop();
173 let input_buffer_ids: HashSet<u64> = collect_input_buffers(&old_uop).keys().copied().collect();
174
175 let t_prep = std::time::Instant::now();
176 let plan = self.prepare_with(config)?;
177 let prep_ms = t_prep.elapsed().as_millis();
178 let t_exec = std::time::Instant::now();
179 plan.execute().context(ExecutionSnafu)?;
180 let exec_ms = t_exec.elapsed().as_millis();
181 debug!(prep_ms, exec_ms, "realize_with complete");
182
183 self.finalize_realize(&plan, &old_uop)?;
184
185 let realized_uop = self.uop();
186 if !Arc::ptr_eq(&old_uop, &realized_uop) {
187 #[allow(clippy::mutable_key_type)]
188 let becomes_map = HashMap::from([(UOpKey(old_uop), realized_uop)]);
189 crate::tensor_registry::apply_map_to_tensors(&becomes_map);
190 }
191
192 plan.release_intermediate_buffers(|uop_id| {
193 if !input_buffer_ids.contains(&uop_id) {
194 crate::tensor_registry::remove_buffer(uop_id);
195 }
196 });
197
198 Ok(())
199 }
200
201 fn finalize_realize(&mut self, plan: &ExecutionPlan, uop: &Arc<UOp>) -> Result<()> {
207 let output_buf = plan.output_buffer().expect("realized plan must have an output buffer").clone();
208
209 trace!(
210 buffer.id = ?output_buf.id(),
211 buffer.size = output_buf.size(),
212 "Realized output buffer"
213 );
214
215 let output_dtype = uop.dtype();
216 let output_device = output_buf.allocator().device_spec();
217 let num_elements = output_buf.size() / output_dtype.bytes();
218
219 let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype.clone());
220 let output_buf_arc = Arc::new(output_buf);
221
222 crate::tensor_registry::register_buffer(buffer_uop.id, self.entry.id, output_buf_arc.clone());
223
224 let shape = uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
225 let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
226
227 debug!(
228 buffer_uop.id = buffer_uop.id,
229 num_elements,
230 shape = ?shape,
231 realized_uop.id = realized_uop.id,
232 realized_uop.base_id = realized_uop.base().id,
233 "Tensor realized"
234 );
235
236 self.set_uop(realized_uop);
237 self.entry.set_buffer(Arc::clone(&output_buf_arc));
238 self.buffer = Some(output_buf_arc);
239 Ok(())
240 }
241
242 pub fn prepare(&mut self) -> Result<ExecutionPlan> {
279 self.prepare_with(&PrepareConfig::from_env())
280 }
281
282 pub fn prepare_with(&mut self, config: &PrepareConfig) -> Result<ExecutionPlan> {
307 let t_total = std::time::Instant::now();
308 let uop = self.uop();
309
310 let sink = UOp::sink(vec![uop.contiguous()]);
311 let schedule_result = schedule_result_from_sink_with_cache(sink, extract_var_vals(&uop)?, config)?;
312 let plan = prepare_execution_plan(&schedule_result, config)?;
316
317 self.wire_output_tensor(&plan, &uop)?;
318 debug!(total_ms = t_total.elapsed().as_millis() as u64, "prepare: total");
319 Ok(plan)
320 }
321
322 fn wire_output_tensor(&mut self, plan: &ExecutionPlan, uop: &Arc<UOp>) -> Result<()> {
323 if plan.num_outputs() > 0 {
324 let buf = Arc::new(plan.output_buffer().expect("plan with num_outputs > 0 must expose output").clone());
325 let dtype = uop.dtype();
326 let device = buf.allocator().device_spec();
327 let buffer_uop = UOp::new_buffer(device, buf.size() / dtype.bytes(), dtype);
328 crate::tensor_registry::register_buffer(buffer_uop.id, self.entry.id, buf.clone());
329 let shape = uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
330 self.set_uop(buffer_uop.try_reshape(shape).context(UOpSnafu)?);
331 self.entry.set_buffer(buf.clone());
332 self.buffer = Some(buf);
333 }
334 Ok(())
335 }
336
337 pub fn realize_batch<'a>(tensors: impl IntoIterator<Item = &'a mut Tensor>) -> Result<()> {
347 Self::realize_batch_with(tensors, &PrepareConfig::from_env())
348 }
349
350 pub fn realize_batch_with<'a>(
352 tensors: impl IntoIterator<Item = &'a mut Tensor>,
353 config: &PrepareConfig,
354 ) -> Result<()> {
355 let mut tensors: Vec<&mut Tensor> = tensors.into_iter().collect();
356 if tensors.is_empty() {
357 return Ok(());
358 }
359
360 for t in &mut tensors {
362 if t.uop().has_buffer_identity() {
363 t.ensure_buffer();
364 }
365 }
366
367 for t in &mut tensors {
369 if !t.uop().has_buffer_identity() && is_any_const(&t.uop()) {
370 let contiguous_uop = t.uop().contiguous();
371 t.set_uop(contiguous_uop);
372 }
373 }
374
375 let pending_indices = collect_pending_indices(&tensors);
377
378 if pending_indices.is_empty() {
379 return Ok(());
380 }
381
382 let old_uops: Vec<Arc<UOp>> = pending_indices.iter().map(|&i| tensors[i].uop()).collect();
384 let mut all_input_buffers = crate::schedule::InputBuffers::new();
385 for uop in &old_uops {
386 all_input_buffers.extend(collect_input_buffers(uop));
387 }
388 let input_ids: HashSet<u64> = all_input_buffers.keys().copied().collect();
389
390 let contiguouses: Vec<Arc<UOp>> = old_uops.iter().map(|u| u.contiguous()).collect();
392 let sink = UOp::sink(contiguouses);
393
394 let mut var_vals = HashMap::new();
395 for uop in &old_uops {
396 let extracted = extract_var_vals(uop)?;
397 merge_var_vals_checked(&mut var_vals, &extracted, "realize_batch input collection")?;
398 }
399 let schedule_result = schedule_result_from_sink_with_cache(sink, var_vals, config)?;
400
401 let t_prep = std::time::Instant::now();
402 let plan = prepare_execution_plan(&schedule_result, config)?;
403 let prep_ms = t_prep.elapsed().as_millis();
404 let t_exec = std::time::Instant::now();
405 plan.execute().context(ExecutionSnafu)?;
406 let exec_ms = t_exec.elapsed().as_millis();
407 debug!(prep_ms, exec_ms, num_outputs = pending_indices.len(), "realize_batch complete");
408
409 snafu::ensure!(
410 plan.num_outputs() >= pending_indices.len(),
411 BatchOutputMismatchSnafu { expected: pending_indices.len(), actual: plan.num_outputs() }
412 );
413
414 #[allow(clippy::mutable_key_type)]
416 let mut becomes_map = HashMap::new();
417 for (buf_idx, &orig_idx) in pending_indices.iter().enumerate() {
418 let output_buf = plan.output_buffer_at(buf_idx).expect("buf_idx in range").clone();
419 let old_uop = &old_uops[buf_idx];
420
421 let output_dtype = old_uop.dtype();
422 let output_device = output_buf.allocator().device_spec();
423 let num_elements = output_buf.size() / output_dtype.bytes();
424 let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype);
425 let buf_arc = Arc::new(output_buf);
426
427 let t = &mut tensors[orig_idx];
428 crate::tensor_registry::register_buffer(buffer_uop.id, t.entry.id, buf_arc.clone());
429 let shape = old_uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
430 let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
431 t.set_uop(realized_uop.clone());
432 t.entry.set_buffer(Arc::clone(&buf_arc));
433 t.buffer = Some(buf_arc);
434
435 becomes_map.insert(UOpKey(old_uop.clone()), realized_uop);
436 }
437
438 crate::tensor_registry::apply_map_to_tensors(&becomes_map);
440
441 plan.release_intermediate_buffers(|id| {
443 if !input_ids.contains(&id) {
444 crate::tensor_registry::remove_buffer(id);
445 }
446 });
447
448 Ok(())
449 }
450
451 pub fn prepare_batch<'a>(tensors: impl IntoIterator<Item = &'a mut Tensor>) -> Result<ExecutionPlan> {
456 Self::prepare_batch_with(tensors, &PrepareConfig::from_env())
457 }
458
459 pub fn prepare_batch_with<'a>(
461 tensors: impl IntoIterator<Item = &'a mut Tensor>,
462 config: &PrepareConfig,
463 ) -> Result<ExecutionPlan> {
464 let mut tensors: Vec<&mut Tensor> = tensors.into_iter().collect();
465 if tensors.is_empty() {
466 return EmptyScheduleSnafu.fail();
467 }
468
469 for t in &mut tensors {
471 if t.uop().has_buffer_identity() {
472 t.ensure_buffer();
473 }
474 }
475
476 for t in &mut tensors {
478 if !t.uop().has_buffer_identity() && is_any_const(&t.uop()) {
479 let contiguous_uop = t.uop().contiguous();
480 t.set_uop(contiguous_uop);
481 }
482 }
483
484 let pending_indices = collect_pending_indices(&tensors);
486
487 if pending_indices.is_empty() {
488 return EmptyScheduleSnafu.fail();
489 }
490
491 let uops: Vec<Arc<UOp>> = pending_indices.iter().map(|&i| tensors[i].uop()).collect();
493
494 let mut var_vals = HashMap::new();
495 for uop in &uops {
496 let extracted = extract_var_vals(uop)?;
497 merge_var_vals_checked(&mut var_vals, &extracted, "prepare_batch input collection")?;
498 }
499
500 let contiguouses: Vec<Arc<UOp>> = uops.iter().map(|u| u.contiguous()).collect();
502 let sink = UOp::sink(contiguouses);
503
504 let schedule_result = schedule_result_from_sink_with_cache(sink, var_vals, config)?;
505
506 let plan = prepare_execution_plan(&schedule_result, config)?;
507
508 for (buf_idx, &orig_idx) in pending_indices.iter().enumerate() {
511 if buf_idx >= plan.num_outputs() {
512 break;
513 }
514 let output_buf = plan.output_buffer_at(buf_idx).expect("buf_idx in range").clone();
515 let buf_arc = Arc::new(output_buf);
516 let old_uop = &uops[buf_idx];
517 let output_dtype = old_uop.dtype();
518 let output_device = buf_arc.allocator().device_spec();
519 let num_elements = buf_arc.size() / output_dtype.bytes();
520 let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype);
521 let t = &mut tensors[orig_idx];
522 crate::tensor_registry::register_buffer(buffer_uop.id, t.entry.id, buf_arc.clone());
523 let shape = old_uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
524 let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
525 t.set_uop(realized_uop);
526 t.entry.set_buffer(Arc::clone(&buf_arc));
527 t.buffer = Some(buf_arc);
528 }
529
530 Ok(plan)
531 }
532}
533
534fn try_bind_var_val(var_vals: &mut HashMap<String, i64>, name: &str, val: i64) -> std::result::Result<(), (i64, i64)> {
544 if let Some(&prev) = var_vals.get(name) {
545 if prev != val {
546 return Err((prev, val));
547 }
548 return Ok(());
549 }
550 var_vals.insert(name.to_string(), val);
551 Ok(())
552}
553
554fn insert_var_val_checked(var_vals: &mut HashMap<String, i64>, name: &str, val: i64, context: &str) -> Result<()> {
555 match try_bind_var_val(var_vals, name, val) {
556 Ok(()) => Ok(()),
557 Err((prev, val)) => {
558 IrConstructionSnafu { details: format!("bind mismatch on {name}, {prev} != {val} ({context})") }.fail()
559 }
560 }
561}
562
563fn merge_var_vals_checked(dst: &mut HashMap<String, i64>, src: &HashMap<String, i64>, context: &str) -> Result<()> {
564 for (name, val) in src {
565 insert_var_val_checked(dst, name, *val, context)?;
566 }
567 Ok(())
568}
569
570fn extract_var_vals(root: &Arc<UOp>) -> Result<HashMap<String, i64>> {
571 let mut var_vals = HashMap::new();
572 for node in root.toposort() {
573 if let Op::Bind { var, value } = node.op()
574 && let Op::DefineVar { name, .. } = var.op()
575 && let Op::Const(cv) = value.op()
576 && let Some(val) = cv.0.try_int()
577 {
578 insert_var_val_checked(&mut var_vals, name, val, "bind extraction")?;
579 }
580 }
581 Ok(var_vals)
582}
583
584fn schedule_cache_disabled_by_env() -> bool {
585 static DISABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
586 *DISABLED.get_or_init(|| std::env::var("SVOD_DISABLE_SCHEDULE_CACHE").as_deref() == Ok("1"))
587}
588
589fn schedule_result_from_sink_with_cache(
590 sink: Arc<UOp>,
591 mut var_vals: HashMap<String, i64>,
592 config: &PrepareConfig,
593) -> Result<crate::schedule::ScheduleResult> {
594 if config.disable_schedule_cache || schedule_cache_disabled_by_env() {
595 return schedule_result_from_sink_uncached(sink, var_vals, config);
596 }
597
598 let normalization = normalize_for_schedule_cache(&sink)?;
599 merge_var_vals_checked(&mut var_vals, &normalization.var_vals, "schedule cache normalization")?;
600
601 let codegen = resolve_codegen(&normalization.param_buffers, config)?;
602 let sched_key = (crate::schedule_cache::content_hash(&normalization.normalized), codegen);
603
604 let cache = crate::schedule_cache::schedule_cache();
605 let entry = {
606 let guard = cache.guard();
607 cache.get(&sched_key, &guard).cloned()
608 };
609
610 let entry = match entry {
611 Some(hit) => {
612 debug!("schedule cache hit");
613 hit
614 }
615 None => {
616 let schedule_root = restore_bind_placeholders_for_schedule(&normalization.normalized, &normalization);
617 let rangeify_result = svod_schedule::rangeify_with_map(schedule_root, None).context(RangeifySnafu)?;
618 let (kernel_graph, _) = svod_schedule::try_get_kernel_graph(rangeify_result.sink).context(RangeifySnafu)?;
619 let pre_schedule = crate::schedule::create_pre_schedule(kernel_graph)?;
620 let new_entry = Arc::new(crate::schedule_cache::CachedSchedule { pre_schedule: Arc::new(pre_schedule) });
621 let guard = cache.guard();
622 cache.insert(sched_key, Arc::clone(&new_entry), &guard);
623 new_entry
624 }
625 };
626
627 let restored_pre_schedule = restore_post_schedule_pre_schedule(&entry.pre_schedule, &normalization);
628 let schedule_input_buffers = build_schedule_input_buffers(&restored_pre_schedule, &normalization);
629 let result = crate::schedule::instantiate_schedule(&restored_pre_schedule, &schedule_input_buffers, &var_vals)?;
630 Ok(result)
631}
632
633fn schedule_result_from_sink_uncached(
634 sink: Arc<UOp>,
635 var_vals: HashMap<String, i64>,
636 _config: &PrepareConfig,
637) -> Result<crate::schedule::ScheduleResult> {
638 let rangeify_result = svod_schedule::rangeify_with_map(sink, None).context(RangeifySnafu)?;
639 let (kernel_graph, _) = svod_schedule::try_get_kernel_graph(rangeify_result.sink).context(RangeifySnafu)?;
640 let pre_schedule = crate::schedule::create_pre_schedule(kernel_graph.clone())?;
641 let input_buffers = collect_input_buffers(&kernel_graph);
642 let result = crate::schedule::instantiate_schedule(&pre_schedule, &input_buffers, &var_vals)?;
643 Ok(result)
644}
645
646pub(crate) struct ScheduleCacheNormalization {
653 pub normalized: Arc<UOp>,
654 pub param_values: Vec<Arc<UOp>>,
655 pub param_buffers: Vec<(u64, Arc<UOp>)>,
656 pub unique_values: Vec<Arc<UOp>>,
657 pub var_vals: HashMap<String, i64>,
658}
659
660pub(crate) struct NormalizeScheduleCacheCtx {
662 pub param_map: HashMap<u64, usize>,
663 pub param_values: Vec<Arc<UOp>>,
664 pub param_buffers: Vec<(u64, Arc<UOp>)>,
665 pub var_vals: HashMap<String, i64>,
666 pub bind_mismatch: Option<String>,
667}
668
669pub(crate) fn normalize_for_schedule_cache(sink: &Arc<UOp>) -> Result<ScheduleCacheNormalization> {
671 let mut ctx = NormalizeScheduleCacheCtx {
672 param_map: HashMap::new(),
673 param_values: Vec::new(),
674 param_buffers: Vec::new(),
675 var_vals: HashMap::new(),
676 bind_mismatch: None,
677 };
678
679 use svod_ir::op::pattern_derived::OpKey;
680 use svod_ir::pattern::{RewriteResult, SimplifiedPatternMatcher};
681 use svod_ir::rewrite::graph_rewrite;
682
683 let mut matcher = SimplifiedPatternMatcher::<NormalizeScheduleCacheCtx>::new();
684
685 fn to_param(
686 node: &Arc<UOp>,
687 ctx: &mut NormalizeScheduleCacheCtx,
688 size: usize,
689 device: Option<Arc<UOp>>,
690 ) -> Arc<UOp> {
691 let slot = *ctx.param_map.entry(node.id).or_insert_with(|| {
692 let s = ctx.param_values.len();
693 ctx.param_values.push(node.clone());
694 s
695 });
696 UOp::param(slot, size, node.dtype(), device)
697 }
698
699 matcher.add(&[OpKey::Buffer], |node, ctx| {
701 let Op::Buffer { size, device, .. } = node.op() else {
702 return RewriteResult::NoMatch;
703 };
704 let slot = *ctx.param_map.entry(node.id).or_insert_with(|| {
705 let s = ctx.param_values.len();
706 ctx.param_values.push(node.clone());
707 s
708 });
709 ctx.param_buffers.push((node.id, node.clone()));
710 RewriteResult::Rewritten(UOp::param(slot, *size, node.dtype(), Some(device.clone())))
711 });
712
713 matcher.add(&[OpKey::BufferView], |node, ctx| {
715 let Op::BufferView { size, .. } = node.op() else {
716 return RewriteResult::NoMatch;
717 };
718 RewriteResult::Rewritten(to_param(node, ctx, *size, Some(UOp::device(DeviceSpec::Cpu))))
719 });
720
721 matcher.add(&[OpKey::Bind], |node, ctx| {
725 let Op::Bind { var, value } = node.op() else {
726 return RewriteResult::NoMatch;
727 };
728 let Op::DefineVar { name, .. } = var.op() else {
729 return RewriteResult::NoMatch;
730 };
731 let Op::Const(cv) = value.op() else {
732 return RewriteResult::NoMatch;
733 };
734 let Some(val) = cv.0.try_int() else {
735 return RewriteResult::NoMatch;
736 };
737
738 if let Err((prev, val)) = try_bind_var_val(&mut ctx.var_vals, name, val) {
739 if ctx.bind_mismatch.is_none() {
740 ctx.bind_mismatch = Some(format!("bind mismatch on variable {name}: {prev} vs {val}"));
741 }
742 return RewriteResult::NoMatch;
743 }
744 RewriteResult::Rewritten(to_param(node, ctx, 0, Some(UOp::device(DeviceSpec::Cpu))))
745 });
746
747 let normalized = graph_rewrite(&matcher, sink.clone(), &mut ctx);
752
753 if let Some(details) = ctx.bind_mismatch.take() {
754 return IrConstructionSnafu { details }.fail();
755 }
756
757 struct UniqueNormalizationCtx {
761 unique_map: HashMap<u64, usize>,
762 unique_values: Vec<Arc<UOp>>,
763 }
764 let mut unique_ctx = UniqueNormalizationCtx { unique_map: HashMap::new(), unique_values: Vec::new() };
765 let mut unique_matcher = SimplifiedPatternMatcher::<UniqueNormalizationCtx>::new();
766 unique_matcher.add(&[OpKey::Unique], |node, ctx| {
767 let Op::Unique(_) = node.op() else {
768 return RewriteResult::NoMatch;
769 };
770 let slot = *ctx.unique_map.entry(node.id).or_insert_with(|| {
771 let s = ctx.unique_values.len();
772 ctx.unique_values.push(node.clone());
773 s
774 });
775 RewriteResult::Rewritten(UOp::lunique(Some(slot)))
776 });
777 let normalized = graph_rewrite(&unique_matcher, normalized, &mut unique_ctx);
778
779 ctx.param_buffers.sort_unstable_by_key(|(id, _)| *id);
780 ctx.param_buffers.dedup_by_key(|(id, _)| *id);
781
782 Ok(ScheduleCacheNormalization {
783 normalized,
784 param_values: ctx.param_values,
785 param_buffers: ctx.param_buffers,
786 unique_values: unique_ctx.unique_values,
787 var_vals: ctx.var_vals,
788 })
789}
790
791#[allow(clippy::mutable_key_type)]
801pub(crate) fn restore_post_schedule_cache(root: &Arc<UOp>, normalization: &ScheduleCacheNormalization) -> Arc<UOp> {
802 let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
803 let mut lunique_buffers: HashMap<usize, Arc<UOp>> = HashMap::new();
804
805 for node in root.toposort() {
806 match node.op() {
807 Op::Param { slot, device: Some(_), .. } => {
808 if let Some(original) = normalization.param_values.get(*slot) {
809 let restored_original = restore_post_schedule_cache(original, normalization);
810 subs.insert(UOpKey(node.clone()), restored_original);
811 }
812 }
813 Op::Buffer { unique, device, size } => {
814 let Op::LUnique(slot) = unique.op() else {
815 continue;
816 };
817 let restored = if let Some(existing) = lunique_buffers.get(slot) {
818 existing.clone()
819 } else {
820 let runtime_unique = UOp::buffer_id(None);
821 let fresh = UOp::new(
822 Op::Buffer { unique: runtime_unique, device: device.clone(), size: *size },
823 node.dtype(),
824 );
825 lunique_buffers.insert(*slot, fresh.clone());
826 fresh
827 };
828 subs.insert(UOpKey(node.clone()), restored);
829 }
830 Op::LUnique(slot) => {
831 let restored = normalization.unique_values.get(*slot).cloned().unwrap_or_else(|| UOp::buffer_id(None));
832 subs.insert(UOpKey(node.clone()), restored);
833 }
834 _ => {}
835 }
836 }
837
838 root.substitute(&subs)
841}
842
843#[allow(clippy::mutable_key_type)]
852fn restore_bind_placeholders_for_schedule(root: &Arc<UOp>, normalization: &ScheduleCacheNormalization) -> Arc<UOp> {
853 let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
854
855 for node in root.toposort() {
856 let Op::Param { slot, device: Some(_), .. } = node.op() else {
857 continue;
858 };
859
860 let Some(original) = normalization.param_values.get(*slot) else {
861 continue;
862 };
863 if matches!(original.op(), Op::Bind { .. }) {
864 subs.insert(UOpKey(node.clone()), original.clone());
865 }
866 }
867
868 if subs.is_empty() { root.clone() } else { root.substitute(&subs) }
869}
870
871fn restore_post_schedule_pre_schedule(
877 pre_schedule: &crate::schedule::PreSchedule,
878 normalization: &ScheduleCacheNormalization,
879) -> crate::schedule::PreSchedule {
880 let mut flat_buf_uops = Vec::new();
881 let mut source_counts = Vec::with_capacity(pre_schedule.items.len());
882
883 for item in &pre_schedule.items {
884 source_counts.push(item.sources.len());
885 flat_buf_uops.extend(item.sources.iter().cloned());
886 }
887 let outputs_offset = flat_buf_uops.len();
888 flat_buf_uops.extend(pre_schedule.output_buffer_uops.iter().cloned());
889
890 if flat_buf_uops.is_empty() {
891 return pre_schedule.clone();
892 }
893
894 let restored_flat = match restore_post_schedule_cache(&UOp::sink(flat_buf_uops), normalization).op() {
895 Op::Sink { sources, .. } => sources.iter().cloned().collect::<Vec<_>>(),
896 _ => unreachable!("sink substitution must preserve SINK root"),
897 };
898
899 let mut cursor = 0usize;
900 let mut restored_items = Vec::with_capacity(pre_schedule.items.len());
901 for (item, source_count) in pre_schedule.items.iter().zip(source_counts) {
902 let end = cursor + source_count;
903 let sources = restored_flat[cursor..end].to_vec();
904 cursor = end;
905 let ast = restore_post_schedule_cache(&item.ast, normalization);
906 restored_items.push(crate::schedule::PreScheduleItem {
907 kernel: item.kernel.clone(),
908 ast,
909 sources,
910 dependencies: item.dependencies.clone(),
911 bound_ranges: item.bound_ranges.clone(),
912 });
913 }
914
915 let output_buffer_uops = restored_flat[outputs_offset..].to_vec();
916 crate::schedule::PreSchedule {
917 items: restored_items,
918 invocations: pre_schedule.invocations.clone(),
919 output_buffer_uops,
920 }
921}
922
923fn build_schedule_input_buffers(
924 pre_schedule: &crate::schedule::PreSchedule,
925 _normalization: &ScheduleCacheNormalization,
926) -> crate::schedule::InputBuffers {
927 let mut inputs = crate::schedule::InputBuffers::new();
928
929 for item in &pre_schedule.items {
930 for src in &item.sources {
931 let buf = src.buf_uop();
932 if let Op::Buffer { .. } = buf.op()
933 && let Some(buffer) = crate::tensor_registry::get_buffer(buf.id)
934 {
935 inputs.insert(buf.id, buffer);
936 }
937 }
938 }
939
940 inputs
941}
942
943fn collect_input_buffers(root: &Arc<UOp>) -> crate::schedule::InputBuffers {
953 let mut inputs = HashMap::new();
954 for node in root.toposort() {
955 if let Op::Buffer { .. } = node.op() {
956 if let Some(buf) = crate::tensor_registry::get_buffer(node.id) {
958 inputs.insert(node.id, buf);
959 }
960 }
961 }
962 inputs
963}
964
965fn output_indices_from_program_metadata(globals: &[usize], outs: &[usize], num_buffers: usize) -> Result<Vec<usize>> {
966 if num_buffers == 0 {
967 return IrConstructionSnafu { details: "cannot map outputs for kernel with zero buffers".to_string() }.fail();
968 }
969 if globals.is_empty() {
970 return IrConstructionSnafu { details: "ProgramSpec.globals is empty".to_string() }.fail();
971 }
972 if outs.is_empty() {
973 return IrConstructionSnafu { details: "ProgramSpec.outs is empty".to_string() }.fail();
974 }
975
976 let slot_to_position: HashMap<usize, usize> =
977 globals.iter().copied().enumerate().map(|(position, slot)| (slot, position)).collect();
978
979 let mut output_indices = Vec::with_capacity(outs.len());
980 for &slot in outs {
981 let Some(position) = slot_to_position.get(&slot).copied() else {
982 return IrConstructionSnafu {
983 details: format!("ProgramSpec.outs slot {slot} not found in ProgramSpec.globals={globals:?}"),
984 }
985 .fail();
986 };
987 if position >= num_buffers {
988 return IrConstructionSnafu {
989 details: format!(
990 "ProgramSpec output index {position} (slot {slot}) out of range for {num_buffers} buffers"
991 ),
992 }
993 .fail();
994 }
995 output_indices.push(position);
996 }
997
998 output_indices.sort_unstable();
999 output_indices.dedup();
1000 if output_indices.is_empty() {
1001 return IrConstructionSnafu { details: "ProgramSpec output mapping resolved to empty set".to_string() }.fail();
1002 }
1003
1004 Ok(output_indices)
1005}
1006
1007fn resolve_item_buffer_indices(item: &ScheduleItem, uop_id_to_idx: &HashMap<u64, usize>) -> Result<Vec<usize>> {
1008 let mut indices = Vec::with_capacity(item.buffer_uop_ids.len());
1009 for &uop_id in &item.buffer_uop_ids {
1010 let Some(idx) = uop_id_to_idx.get(&uop_id).copied() else {
1011 return Err(crate::error::Error::BufferNotFound { uop_id });
1012 };
1013 indices.push(idx);
1014 }
1015 Ok(indices)
1016}
1017
1018fn resolve_compiled_kernel_buffer_indices(
1019 item: &ScheduleItem,
1020 uop_id_to_idx: &HashMap<u64, usize>,
1021 globals: &[usize],
1022) -> Result<Vec<usize>> {
1023 let buffer_indices = resolve_item_buffer_indices(item, uop_id_to_idx)?;
1024
1025 let mut ordered = Vec::with_capacity(globals.len());
1026 for &position in globals {
1027 let Some(idx) = buffer_indices.get(position).copied() else {
1028 return IrConstructionSnafu {
1029 details: format!(
1030 "ProgramSpec.globals position {position} out of range for CALL {} buffer list len {} (buffer_uop_ids={:?})",
1031 item.kernel.id,
1032 buffer_indices.len(),
1033 item.buffer_uop_ids
1034 ),
1035 }
1036 .fail();
1037 };
1038 ordered.push(idx);
1039 }
1040
1041 Ok(ordered)
1042}
1043
1044type OptKey = (u64, DeviceSpec, &'static str, u64);
1045
1046struct OptCacheState {
1054 map: papaya::HashMap<OptKey, Arc<svod_runtime::kernel_cache::CachedKernel>>,
1055 fifo: parking_lot::Mutex<std::collections::VecDeque<OptKey>>,
1056 cap: usize,
1057}
1058
1059impl OptCacheState {
1060 const DEFAULT_CAP: usize = 4096;
1061
1062 fn new() -> Self {
1063 let cap = std::env::var("SVOD_OPT_CACHE_MAX")
1064 .ok()
1065 .and_then(|s| s.parse::<usize>().ok())
1066 .filter(|&n| n > 0)
1067 .unwrap_or(Self::DEFAULT_CAP);
1068 Self { map: papaya::HashMap::new(), fifo: parking_lot::Mutex::new(std::collections::VecDeque::new()), cap }
1069 }
1070
1071 fn insert(&self, key: OptKey, val: Arc<svod_runtime::kernel_cache::CachedKernel>) {
1072 let guard = self.map.guard();
1073 let was_new = self.map.insert(key.clone(), val, &guard).is_none();
1074 if !was_new {
1075 return;
1076 }
1077 let mut fifo = self.fifo.lock();
1078 fifo.push_back(key);
1079 while fifo.len() > self.cap {
1080 if let Some(evict) = fifo.pop_front() {
1081 self.map.remove(&evict, &guard);
1082 }
1083 }
1084 }
1085}
1086
1087pub(crate) fn runtime_effect_ast(ast: &Arc<UOp>) -> &Arc<UOp> {
1088 match ast.op() {
1089 Op::End { computation, .. }
1090 if matches!(computation.op(), Op::Copy { .. } | Op::BufferView { .. } | Op::CustomFunction { .. }) =>
1091 {
1092 computation
1093 }
1094 _ => ast,
1095 }
1096}
1097
1098fn optimizer_config_fingerprint(config: &PrepareConfig) -> u64 {
1099 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1100 config.optimizer.hash(&mut hasher);
1101 hasher.finish()
1102}
1103
1104fn prepare_execution_plan(
1123 schedule_result: &crate::schedule::ScheduleResult,
1124 config: &PrepareConfig,
1125) -> Result<ExecutionPlan> {
1126 let mut schedule_items = schedule_result.items.clone();
1128
1129 let planner_mode = crate::memory_planner::mode_from_env();
1134 let output_buffer_ids = collect_output_buffer_ids(&schedule_items, &schedule_result.output_uop_ids);
1135 let planner_result = crate::memory_planner::memory_planner(&schedule_items, &output_buffer_ids, planner_mode);
1136 if !planner_result.buffer_replace.is_empty() {
1137 trace!(
1138 replacements = planner_result.buffer_replace.len(),
1139 buffers_reused = planner_result.buffers_reused,
1140 memory_saved_bytes = planner_result.memory_saved,
1141 "applying memory planner buffer replacements"
1142 );
1143 crate::memory_planner::apply_reuse_dependencies(&mut schedule_items, &planner_result.reuse_dependencies);
1144 crate::memory_planner::apply_buffer_replacements(&mut schedule_items, &planner_result.buffer_replace);
1145 }
1146
1147 debug!(num_items = schedule_items.len(), "schedule items ready for execution plan");
1148
1149 let alloc_registry = svod_device::registry::registry();
1152 let plan_device = if !schedule_items.is_empty() {
1153 let device_spec = schedule_items
1154 .iter()
1155 .flat_map(|item| item.buffers.iter().map(|b| b.allocator().device_spec()))
1156 .find(|spec| !spec.is_disk())
1157 .unwrap_or(DeviceSpec::Cpu);
1158 config.resolve_device(&device_spec, alloc_registry)?
1159 } else {
1160 return EmptyScheduleSnafu.fail();
1161 };
1162 let optimizer_fingerprint = optimizer_config_fingerprint(config);
1163
1164 let mut builder = ExecutionPlanBuilder::new(plan_device.device.clone());
1166
1167 let mut uop_id_to_idx: HashMap<u64, usize> = HashMap::new();
1171 let mut storage_to_idx: HashMap<BufferStorageKey, usize> = HashMap::new();
1172
1173 let buffer_view_output_uop_ids: HashSet<u64> = schedule_items
1177 .iter()
1178 .filter_map(|item| {
1179 if matches!(runtime_effect_ast(&item.ast).op(), Op::BufferView { .. }) {
1180 item.buffer_uop_ids.first().copied()
1181 } else {
1182 None
1183 }
1184 })
1185 .collect();
1186
1187 for item in &schedule_items {
1188 for (buffer, &uop_id) in item.buffers.iter().zip(item.buffer_uop_ids.iter()) {
1190 buffer.ensure_allocated().context(DeviceSnafu)?;
1191
1192 if uop_id_to_idx.contains_key(&uop_id) {
1193 continue;
1194 }
1195
1196 let storage_key = BufferStorageKey {
1197 id: buffer.id().0,
1198 offset: buffer.offset(),
1199 size: buffer.size(),
1200 dtype: buffer.dtype(),
1201 };
1202
1203 let idx = if !buffer_view_output_uop_ids.contains(&uop_id) {
1204 if let Some(&existing_idx) = storage_to_idx.get(&storage_key) {
1205 builder.map_buffer(uop_id, existing_idx);
1206 existing_idx
1207 } else {
1208 let new_idx = builder.add_buffer(uop_id, buffer.clone());
1209 storage_to_idx.insert(storage_key, new_idx);
1210 new_idx
1211 }
1212 } else {
1213 builder.add_buffer(uop_id, buffer.clone())
1214 };
1215 uop_id_to_idx.insert(uop_id, idx);
1216 }
1217
1218 builder.add_alias_ids(item.alias_registered_ids.iter().copied());
1220 }
1221
1222 static OPT_CACHE: std::sync::OnceLock<OptCacheState> = std::sync::OnceLock::new();
1229 let opt_state = OPT_CACHE.get_or_init(OptCacheState::new);
1230 let opt_cache = &opt_state.map;
1231 let opt_guard = opt_cache.guard();
1232
1233 for item in &schedule_items {
1234 let runtime_ast = runtime_effect_ast(&item.ast);
1237
1238 if matches!(runtime_ast.op(), Op::Copy { .. }) {
1239 let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
1240 builder.add_op_with_instance_dependencies(
1241 PreparedOp::BufferCopy(PreparedCopy {
1242 id: item.kernel.id,
1243 buffer_indices,
1244 dependencies: item.dependencies.clone(),
1245 }),
1246 item.instance_dependencies.clone(),
1247 );
1248 continue;
1249 }
1250
1251 if let Op::BufferView { size, offset, .. } = runtime_ast.op() {
1255 let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
1256
1257 if item.buffers.len() >= 2 && item.buffer_uop_ids.len() >= 2 && buffer_indices.len() >= 2 {
1258 let base = &item.buffers[1];
1259 let byte_offset = offset * base.dtype().bytes();
1260 let byte_size = size * runtime_ast.dtype().bytes();
1261 let view = base.view(byte_offset, byte_size).map_err(|e| crate::error::Error::IrConstruction {
1262 details: format!(
1263 "BUFFER_VIEW failed for kernel {}: base_buffer_id={}, byte_offset={}, byte_size={}: {e}",
1264 item.kernel.id,
1265 base.id().0,
1266 byte_offset,
1267 byte_size
1268 ),
1269 })?;
1270 let output_uop_id = item.buffer_uop_ids[0];
1273 if let Some(&idx) = uop_id_to_idx.get(&output_uop_id) {
1274 builder.replace_buffer(idx, view);
1275 }
1276
1277 builder.add_op_with_instance_dependencies(
1278 PreparedOp::BufferView(PreparedBufferView {
1279 id: item.kernel.id,
1280 buffer_indices,
1281 byte_offset,
1282 byte_size,
1283 dependencies: item.dependencies.clone(),
1284 }),
1285 item.instance_dependencies.clone(),
1286 );
1287 }
1288 continue;
1289 }
1290
1291 if let Op::CustomFunction { kind, attrs } = runtime_ast.op() {
1296 let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
1297 let runtime_vars = attrs.iter().flat_map(svod_runtime::execution_plan::collect_runtime_vars).collect();
1298 builder.add_op_with_instance_dependencies(
1299 PreparedOp::CustomFunction(PreparedCustomFunction {
1300 id: item.kernel.id,
1301 kind: kind.clone(),
1302 attrs: attrs.clone(),
1303 buffer_indices,
1304 fixedvars: item.fixedvars.clone(),
1305 dependencies: item.dependencies.clone(),
1306 runtime_vars,
1307 }),
1308 item.instance_dependencies.clone(),
1309 );
1310 continue;
1311 }
1312
1313 let item_device_spec = item
1314 .buffers
1315 .iter()
1316 .map(|b| b.allocator().device_spec())
1317 .find(|spec| !spec.is_disk())
1318 .unwrap_or(DeviceSpec::Cpu);
1319 let item_device = config.resolve_device(&item_device_spec, alloc_registry)?;
1320 let item_codegen: &'static str = item_device.compiler.cache_key();
1321
1322 let opt_key = (
1323 crate::schedule_cache::content_hash(&item.ast),
1324 item_device.device.clone(),
1325 item_codegen,
1326 optimizer_fingerprint,
1327 );
1328
1329 let cached = if let Some(cached) = opt_cache.get(&opt_key, &opt_guard) {
1330 Arc::clone(cached)
1331 } else {
1332 let optimizer_renderer = get_optimizer_renderer(&item_device);
1333 let optimized_ast = if let svod_schedule::OptStrategy::Beam { .. } = config.optimizer.strategy {
1334 beam_search_optimize(
1335 item.ast.clone(),
1336 &optimizer_renderer,
1337 &item_device,
1338 &item.buffers,
1339 &config.optimizer,
1340 )?
1341 } else {
1342 svod_schedule::optimize_kernel_with_config(item.ast.clone(), &optimizer_renderer, &config.optimizer)
1343 };
1344
1345 let kernel_name =
1346 optimized_ast.metadata::<svod_schedule::optimizer::KernelInfo>().map(|info| info.function_name());
1347
1348 let ast_decomposed = match item_device.renderer.decompositor() {
1349 Some(matcher) => svod_ir::decompositions::decompose_with(&optimized_ast, &matcher),
1350 None => optimized_ast,
1351 };
1352 let program = svod_codegen::program_pipeline::program_from_sink(ast_decomposed, item_device.device.clone());
1353
1354 let result = svod_runtime::kernel_cache::get_or_compile_kernel(
1355 crate::schedule_cache::content_hash(&program),
1356 item_codegen,
1357 || {
1358 let (spec, compiled) = compile_with_program_pipeline_components(
1359 program.clone(),
1360 item_device.renderer.as_ref(),
1361 item_device.compiler.as_ref(),
1362 kernel_name.as_deref(),
1363 )?;
1364 let program = (item_device.runtime)(&compiled).context(CreateProgramSnafu)?;
1365 Ok(svod_runtime::kernel_cache::CachedKernel {
1366 program,
1367 device: item_codegen.to_string(),
1368 code: spec.src.clone(),
1369 entry_point: spec.name.clone(),
1370 var_names: spec.var_names.clone(),
1371 globals: spec.globals.clone(),
1372 outs: spec.outs.clone(),
1373 ins: spec.ins.clone(),
1374 host_parallel_safe: matches!(item_device.device, DeviceSpec::Cpu),
1375 global_size: spec.global_size.clone(),
1376 local_size: spec.local_size.clone(),
1377 })
1378 },
1379 )?;
1380 opt_state.insert(opt_key, Arc::clone(&result));
1381 result
1382 };
1383
1384 let buffer_indices = resolve_compiled_kernel_buffer_indices(item, &uop_id_to_idx, &cached.globals)?;
1386
1387 trace!(kernel.ast_id = item.ast.id, num_buffers = item.buffers.len(), "kernel buffer mapping");
1388
1389 let vals: Vec<i64> =
1393 cached.var_names.iter().map(|name| item.fixedvars.get(name).copied().unwrap_or(0)).collect();
1394 let non_overridable_fixedvars = collect_non_overridable_fixedvars(item);
1395
1396 let output_indices = output_indices_from_program_metadata(&cached.globals, &cached.outs, buffer_indices.len())
1397 .map_err(|e| crate::error::Error::IrConstruction {
1398 details: format!(
1399 "invalid ProgramSpec output metadata for kernel id {} (globals={:?}, outs={:?}, num_buffers={}): {e}",
1400 item.kernel.id,
1401 cached.globals,
1402 cached.outs,
1403 buffer_indices.len()
1404 ),
1405 })?;
1406
1407 let runtime_vars = svod_runtime::execution_plan::collect_runtime_vars(&item.ast);
1408 let prepared = PreparedKernel {
1409 id: item.kernel.id,
1410 ast: item.ast.clone(),
1411 kernel: cached,
1412 device: item_device.device.clone(),
1413 buffer_indices,
1414 output_indices,
1415 vals,
1416 fixedvars: non_overridable_fixedvars,
1417 dependencies: item.dependencies.clone(),
1418 buffer_ptrs: Vec::new(), buffer_ids: Vec::new(), runtime_vars,
1421 };
1422
1423 builder.add_op_with_instance_dependencies(
1424 PreparedOp::CompiledProgram(prepared),
1425 item.instance_dependencies.clone(),
1426 );
1427 }
1428
1429 let mut output_buffer_indices = Vec::with_capacity(schedule_result.output_uop_ids.len());
1431 for &uop_id in &schedule_result.output_uop_ids {
1432 let Some(idx) = uop_id_to_idx.get(&uop_id).copied() else {
1433 return Err(crate::error::Error::BufferNotFound { uop_id });
1434 };
1435 output_buffer_indices.push(idx);
1436 }
1437 if output_buffer_indices.is_empty() {
1438 return IrConstructionSnafu { details: "prepare_execution_plan produced no output buffer indices".to_string() }
1439 .fail();
1440 }
1441 builder.set_output_buffers(output_buffer_indices);
1442
1443 builder.build().context(ExecutionSnafu)
1444}
1445
1446fn collect_output_buffer_ids(schedule: &crate::schedule::Schedule, output_uop_ids: &[u64]) -> HashSet<u64> {
1447 let output_uop_set: HashSet<u64> = output_uop_ids.iter().copied().collect();
1448 let mut output_buffer_ids = HashSet::new();
1449 for item in schedule {
1450 for (buffer, &uop_id) in item.buffers.iter().zip(item.buffer_uop_ids.iter()) {
1451 if output_uop_set.contains(&uop_id) {
1452 output_buffer_ids.insert(buffer.id().0);
1453 }
1454 }
1455 }
1456 output_buffer_ids
1457}
1458
1459fn collect_non_overridable_fixedvars(item: &ScheduleItem) -> HashMap<String, i64> {
1460 let mut locked = HashMap::with_capacity(item.loop_var_names.len());
1466 for name in &item.loop_var_names {
1467 if let Some(v) = item.fixedvars.get(name) {
1468 locked.insert(name.clone(), *v);
1469 }
1470 }
1471 locked
1472}
1473
1474fn compile_with_program_pipeline_components(
1476 kernel_ast: Arc<UOp>,
1477 renderer: &dyn svod_device::device::Renderer,
1478 compiler: &dyn svod_device::device::Compiler,
1479 kernel_name: Option<&str>,
1480) -> Result<(svod_device::device::ProgramSpec, svod_device::device::CompiledSpec)> {
1481 let mut program = match kernel_ast.op() {
1482 Op::Program { .. } => kernel_ast,
1483 other => {
1484 return IrConstructionSnafu {
1485 details: format!("compile_with_program_pipeline_components expects PROGRAM input, got {other:?}"),
1486 }
1487 .fail();
1488 }
1489 };
1490
1491 program = svod_codegen::program_pipeline::get_program(
1492 &program,
1493 renderer,
1494 compiler,
1495 kernel_name,
1496 svod_codegen::program_pipeline::ProgramTarget::Source,
1497 )
1498 .context(RenderKernelSnafu)?;
1499
1500 let rendered_entry = svod_device::device::ProgramSpec::from_uop(&program).map(|spec| spec.name).map_err(|e| {
1501 crate::error::Error::IrConstruction { details: format!("PROGRAM pipeline produced invalid SOURCE stage: {e}") }
1502 })?;
1503
1504 let (program, compiled) =
1505 svod_codegen::program_pipeline::do_compile(&program, compiler).context(CompileKernelSnafu)?;
1506
1507 let spec =
1508 svod_device::device::ProgramSpec::from_uop(&program).map_err(|e| crate::error::Error::IrConstruction {
1509 details: format!(
1510 "PROGRAM pipeline produced invalid ProgramSpec after compile (entry='{}'): {e}",
1511 rendered_entry
1512 ),
1513 })?;
1514 Ok((spec, compiled))
1515}
1516
1517pub(crate) fn resolve_codegen(param_buffers: &[(u64, Arc<UOp>)], config: &PrepareConfig) -> Result<&'static str> {
1519 let alloc_registry = svod_device::registry::registry();
1520 let spec = param_buffers
1521 .iter()
1522 .find_map(|(id, _)| {
1523 let spec = crate::tensor_registry::get_buffer(*id)?.allocator().device_spec();
1524 (!spec.is_disk()).then_some(spec)
1525 })
1526 .or_else(|| {
1527 param_buffers.iter().find_map(|(_, u)| {
1528 let Op::Buffer { device, .. } = u.op() else {
1529 return None;
1530 };
1531 let Op::Device(spec) = device.op() else {
1532 return None;
1533 };
1534 (!spec.is_disk()).then_some(spec.clone())
1535 })
1536 })
1537 .unwrap_or(DeviceSpec::Cpu);
1538 let device = config.resolve_device(&spec, alloc_registry)?;
1539 Ok(device.compiler.cache_key())
1540}
1541
1542fn get_optimizer_renderer(device: &Device) -> svod_schedule::OptimizerRenderer {
1544 match device.device {
1545 DeviceSpec::Cpu => {
1546 if std::env::var("SVOD_AMX").as_deref() == Ok("1") {
1547 svod_schedule::OptimizerRenderer::apple_amx()
1548 } else {
1549 svod_schedule::OptimizerRenderer::cpu()
1550 }
1551 }
1552 DeviceSpec::Cuda { .. } => svod_schedule::OptimizerRenderer::cuda(),
1553 DeviceSpec::Metal { .. } => svod_schedule::OptimizerRenderer::metal(),
1554 _ => svod_schedule::OptimizerRenderer::cpu(),
1555 }
1556}
1557
1558pub(crate) fn count_top_ops(ops: &[Arc<UOp>], top_k: usize) -> Vec<(String, usize)> {
1567 let mut counts: HashMap<String, usize> = HashMap::new();
1568 for u in ops {
1569 *counts.entry(u.op().as_ref().to_string()).or_insert(0) += 1;
1570 }
1571 let mut v: Vec<(String, usize)> = counts.into_iter().collect();
1572 v.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
1573 v.truncate(top_k);
1574 v
1575}
1576
1577pub(crate) fn fmt_op_counts(counts: &[(String, usize)]) -> String {
1578 counts.iter().map(|(o, n)| format!("{o}={n}")).collect::<Vec<_>>().join(", ")
1579}
1580
1581fn beam_search_optimize(
1582 ast: Arc<UOp>,
1583 renderer: &svod_schedule::OptimizerRenderer,
1584 device: &Device,
1585 buffers: &[Buffer],
1586 optimizer_config: &svod_schedule::OptimizerConfig,
1587) -> Result<Arc<UOp>> {
1588 let beam_config = &optimizer_config.beam;
1589 let scheduler = prepare_scheduler(ast, renderer);
1592
1593 for buf in buffers {
1595 buf.ensure_allocated().context(DeviceSnafu)?;
1596 }
1597
1598 let buffers: Vec<Buffer> = buffers.to_vec();
1600 let bench_config = svod_runtime::BenchmarkConfig::default();
1601
1602 let dev_renderer = device.renderer.clone();
1604 let dev_compiler = device.compiler.clone();
1605 let dev_runtime = device.runtime.clone();
1606 let dev_device = device.device.clone();
1607 let max_uops = beam_config.max_uops;
1608
1609 svod_runtime::warmup_thread_pool();
1615
1616 let compile_timeout =
1624 Duration::from_secs(std::env::var("BEAM_TIMEOUT_SEC").ok().and_then(|s| s.parse().ok()).unwrap_or(10));
1625
1626 let log_surpass = std::env::var("BEAM_LOG_SURPASS_MAX").is_ok();
1630
1631 let post_opt_cache: Arc<papaya::HashMap<u64, Arc<UOp>>> = Arc::new(papaya::HashMap::new());
1639
1640 let compile_and_time = |s: &Scheduler, early_stop: Option<Duration>| -> Option<svod_schedule::CandidateMetrics> {
1651 use std::panic::{AssertUnwindSafe, catch_unwind};
1652 use std::sync::mpsc;
1653
1654 let s_owned = s.clone();
1656 let renderer_c = renderer.clone();
1657 let dev_renderer_c = dev_renderer.clone();
1658 let dev_compiler_c = dev_compiler.clone();
1659 let dev_runtime_c = dev_runtime.clone();
1660 let dev_device_c = dev_device.clone();
1661 let buffers_c = buffers.clone();
1662 let bench_config_c = bench_config.clone();
1663 let max_uops_c = max_uops;
1664 let post_opt_cache_c = Arc::clone(&post_opt_cache);
1665 let log_surpass_c = log_surpass;
1666 let opts_snapshot: Vec<svod_schedule::optimizer::Opt> = s_owned.applied_opts.clone();
1670
1671 enum WorkerMsg {
1676 CompileDone,
1677 Final(Option<svod_schedule::CandidateMetrics>),
1678 }
1679 let (tx, rx) = mpsc::sync_channel::<WorkerMsg>(2);
1680 let tx_compile = tx.clone();
1681 let _ = std::thread::spawn(move || {
1682 let result = catch_unwind(AssertUnwindSafe(|| {
1683 let raw_ast = s_owned.get_optimized_ast(None);
1684
1685 let cache_key = raw_ast.content_hash;
1695 let cache_pin = post_opt_cache_c.pin();
1696 let optimized = if let Some(cached) = cache_pin.get(&cache_key) {
1697 cached.clone()
1698 } else {
1699 let opt = apply_post_optimization_with_renderer(raw_ast, Some(&renderer_c));
1700 cache_pin.insert(cache_key, opt.clone());
1701 opt
1702 };
1703
1704 let kernel_name =
1706 optimized.metadata::<svod_schedule::optimizer::KernelInfo>().map(|info| info.function_name());
1707
1708 let ir_hash = svod_schedule::hash_post_codegen_ir(&optimized);
1712 let compute_ops = svod_schedule::compute_ops_estimate(&optimized);
1713
1714 let decomposed = match dev_renderer_c.decompositor() {
1716 Some(m) => svod_ir::decompositions::decompose_with(&optimized, &m),
1717 None => optimized,
1718 };
1719 let mut program = svod_codegen::program_pipeline::program_from_sink(decomposed, dev_device_c.clone());
1720
1721 program = match svod_codegen::program_pipeline::do_linearize(&program) {
1728 Ok(p) => p,
1729 Err(e) => {
1730 if log_surpass_c {
1731 eprintln!("[BEAM drop] linearize_err: {e:?} opts={opts_snapshot:?}");
1732 }
1733 return None;
1734 }
1735 };
1736 let (linear_uops_count, top_op_counts) = if let svod_ir::Op::Program { linear: Some(linear), .. } =
1737 program.op()
1738 && let svod_ir::Op::Linear { ops } = linear.op()
1739 {
1740 (ops.len(), if log_surpass_c { count_top_ops(ops, 8) } else { Vec::new() })
1741 } else {
1742 (0, Vec::new())
1743 };
1744 if linear_uops_count > max_uops_c {
1745 if log_surpass_c {
1746 eprintln!(
1747 "[BEAM drop] too_many_uops: linear={linear_uops_count} max={max_uops_c} opts={opts_snapshot:?} top_ops=[{}]",
1748 fmt_op_counts(&top_op_counts)
1749 );
1750 }
1751 return None;
1752 }
1753
1754 let (spec, compiled) = match compile_with_program_pipeline_components(
1756 program,
1757 dev_renderer_c.as_ref(),
1758 dev_compiler_c.as_ref(),
1759 kernel_name.as_deref(),
1760 ) {
1761 Ok(v) => v,
1762 Err(e) => {
1763 if log_surpass_c {
1764 eprintln!("[BEAM drop] compile_err: {e:?} opts={opts_snapshot:?}");
1765 }
1766 return None;
1767 }
1768 };
1769 let program = match (dev_runtime_c)(&compiled) {
1770 Ok(p) => p,
1771 Err(e) => {
1772 if log_surpass_c {
1773 eprintln!("[BEAM drop] runtime_err: {e:?} opts={opts_snapshot:?}");
1774 }
1775 return None;
1776 }
1777 };
1778
1779 let _ = tx_compile.send(WorkerMsg::CompileDone);
1782
1783 let buffer_ptrs: Vec<*mut u8> = buffers_c.iter().map(|b| unsafe { b.as_raw_ptr() }).collect();
1786
1787 let mut user_var_vals: HashMap<&str, i64> = HashMap::new();
1792 for v in &spec.vars {
1793 if v.name != "core_id" {
1794 user_var_vals.insert(v.name.as_str(), (v.min + v.max) / 2);
1795 }
1796 }
1797 let launch_dims = spec.launch_dims(&user_var_vals).ok()?;
1798 let vals: Vec<i64> =
1799 spec.var_names.iter().map(|n| user_var_vals.get(n.as_str()).copied().unwrap_or(0)).collect();
1800
1801 const MAX_TEST_GLOBAL_SIZE: usize = 65536;
1808 let mut test_global_size = launch_dims.global_size;
1809 let original_size: usize = test_global_size.iter().product();
1810 while test_global_size.iter().product::<usize>() > MAX_TEST_GLOBAL_SIZE {
1811 let mut halved = false;
1812 for j in (0..test_global_size.len()).rev() {
1813 if test_global_size[j] > 16 {
1814 test_global_size[j] /= 2;
1815 halved = true;
1816 break;
1817 }
1818 }
1819 if !halved {
1820 break;
1821 }
1822 }
1823 let shrunk_size: usize = test_global_size.iter().product();
1824 let factor: f64 = if shrunk_size > 0 { original_size as f64 / shrunk_size as f64 } else { 1.0 };
1825
1826 let mut bench_config = bench_config_c.clone();
1827 bench_config.early_stop = early_stop.map(|t| {
1830 let nanos = t.as_nanos() as f64 / factor;
1831 Duration::from_nanos(nanos.min(u64::MAX as f64) as u64)
1832 });
1833 bench_config.clear_l2 = renderer_c.device.has_hardware_cache_invalidate();
1836 let result = unsafe {
1837 svod_runtime::benchmark_kernel(
1838 program.as_ref(),
1839 &buffer_ptrs,
1840 &vals,
1841 Some(test_global_size),
1842 launch_dims.local_size,
1843 &bench_config,
1844 )
1845 .ok()?
1846 };
1847
1848 let scaled_nanos = (result.min.as_nanos() as f64 * factor).min(u64::MAX as f64);
1850 let timing = Duration::from_nanos(scaled_nanos as u64);
1851 Some(svod_schedule::CandidateMetrics { timing, ir_hash, compute_ops })
1852 }));
1853 let final_result = match result {
1854 Ok(opt) => opt,
1855 Err(_) => {
1856 if log_surpass_c {
1857 eprintln!("[BEAM drop] panic_in_worker opts={opts_snapshot:?}");
1858 }
1859 None
1860 }
1861 };
1862 let _ = tx.send(WorkerMsg::Final(final_result));
1864 });
1865
1866 match rx.recv_timeout(compile_timeout) {
1873 Ok(WorkerMsg::CompileDone) => {
1874 match rx.recv() {
1877 Ok(WorkerMsg::Final(metrics)) => metrics,
1878 _ => None,
1879 }
1880 }
1881 Ok(WorkerMsg::Final(metrics)) => metrics,
1882 Err(_) => {
1883 if log_surpass {
1884 eprintln!("[BEAM drop] compile_timeout opts={:?}", s.applied_opts);
1885 }
1886 None
1887 }
1888 }
1889 };
1890
1891 let prev_hook = std::panic::take_hook();
1895 std::panic::set_hook(Box::new(|_| {}));
1896 let result = beam_search_cached(scheduler, beam_config, compile_and_time);
1897 std::panic::set_hook(prev_hook);
1898 let result = result.context(OptimizeSnafu)?;
1899
1900 tracing::debug!(
1902 opts = ?result.scheduler.applied_opts,
1903 timing = ?result.timing,
1904 iterations = result.iterations,
1905 "beam_search_optimize: completed"
1906 );
1907
1908 let raw_ast = result.scheduler.get_optimized_ast(None);
1911 Ok(apply_post_optimization_with_renderer(raw_ast, Some(renderer)))
1912}
1913
1914#[cfg(test)]
1915#[path = "test/unit/realize_internal.rs"]
1916mod tests;