Skip to main content

svod_tensor/
realize.rs

1//! Tensor realization (execution) API.
2//!
3//! This module provides the execution pipeline for tensor operations:
4//! 1. **Rangeify** - Transform movement ops to BUFFERIZE + INDEX
5//! 2. **Kernel splitting** - Split at STORE boundaries into CALL wrappers
6//! 3. **Scheduling** - Extract callables and create execution schedule
7//! 4. **Execution** - Compile and run each kernel in dependency order
8//!
9//! Runtime plan execution is dependency-ordered with conservative mixed-op
10//! barriers and hazard-aware host parallelism for safe compiled kernels.
11//!
12//! # ExecutionPlan (Pre-compiled Execution)
13//!
14//! For repeated executions, use `Tensor::prepare()` to create an `ExecutionPlan`
15//! that pre-compiles all kernels and allocates all buffers. Then call
16//! `plan.execute()` for fast repeated execution without recompilation overhead.
17//!
18//! ```ignore
19//! // One-time preparation (compiles kernels, allocates buffers)
20//! let plan = tensor.prepare()?;
21//!
22//! // Fast execution (can be called many times)
23//! plan.execute()?;
24//!
25//! // Get results
26//! let output = plan.output_buffer();
27//! ```
28
29use std::collections::{HashMap, HashSet};
30use std::hash::{Hash, Hasher};
31
32use svod_schedule::{Scheduler, apply_post_optimization_with_renderer, beam_search_cached, prepare_scheduler};
33use tracing::{debug, trace};
34
35use crate::{
36    PrepareConfig, Result, Tensor,
37    error::{
38        BatchOutputMismatchSnafu, CompileKernelSnafu, CreateProgramSnafu, DeviceSnafu, EmptyScheduleSnafu,
39        ExecutionSnafu, IrConstructionSnafu, OptimizeSnafu, RangeifySnafu, RenderKernelSnafu, ShapeUnknownSnafu,
40        UOpSnafu,
41    },
42    schedule::ScheduleItem,
43};
44use snafu::{OptionExt, ResultExt};
45use std::sync::Arc;
46use std::time::Duration;
47use svod_device::{Buffer, device::Device};
48use svod_ir::pattern::is_any_const;
49use svod_ir::{DeviceSpec, Op, UOp, UOpKey};
50use svod_runtime::{
51    ExecutionPlan, ExecutionPlanBuilder, PreparedBufferView, PreparedCopy, PreparedCustomFunction, PreparedKernel,
52    PreparedOp,
53};
54
55fn collect_pending_indices(tensors: &[&mut Tensor]) -> Vec<usize> {
56    tensors
57        .iter()
58        .enumerate()
59        .filter(|(_, t)| !t.uop().has_buffer_identity() && !is_any_const(&t.uop()) && !t.has_zero_elements())
60        .map(|(i, _)| i)
61        .collect()
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Hash)]
65struct BufferStorageKey {
66    id: u64,
67    offset: usize,
68    size: usize,
69    dtype: svod_dtype::DType,
70}
71
72impl Tensor {
73    /// Realize (execute) this tensor's computation graph.
74    ///
75    /// This is a convenience method that prepares and executes in one call.
76    /// For repeated executions of the same computation, use `prepare()` instead.
77    ///
78    /// # Pipeline
79    ///
80    /// 1. **Prepare**: Creates an `ExecutionPlan` (compiles kernels, allocates buffers)
81    /// 2. **Execute**: Runs all kernels in dependency order
82    /// 3. **Return**: Links output buffer to this tensor's UOp
83    ///
84    /// # Example
85    ///
86    /// ```ignore
87    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
88    /// let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
89    /// let c = (&a + &b).realize()?;
90    /// // c's buffer now contains [5.0, 7.0, 9.0]
91    /// ```
92    ///
93    /// # Errors
94    ///
95    /// Returns error if preparation or execution fails.
96    pub fn realize(&mut self) -> Result<()> {
97        if self.uop().has_buffer_identity() {
98            self.ensure_buffer();
99            return Ok(());
100        }
101        // Pure constants: wrap in CONTIGUOUS to force materialization into a buffer.
102        if is_any_const(&self.uop()) {
103            let contiguous_uop = self.uop().contiguous();
104            self.set_uop(contiguous_uop);
105        }
106        if self.has_zero_elements() {
107            return Ok(());
108        }
109
110        let old_uop = self.uop();
111        let input_buffer_ids: HashSet<u64> = collect_input_buffers(&old_uop).keys().copied().collect();
112
113        let t_prep = std::time::Instant::now();
114        let plan = self.prepare()?;
115        let prep_ms = t_prep.elapsed().as_millis();
116        let t_exec = std::time::Instant::now();
117        plan.execute().context(ExecutionSnafu)?;
118        let exec_ms = t_exec.elapsed().as_millis();
119        debug!(prep_ms, exec_ms, "realize complete");
120
121        self.finalize_realize(&plan, &old_uop)?;
122
123        let realized_uop = self.uop();
124        if !Arc::ptr_eq(&old_uop, &realized_uop) {
125            #[allow(clippy::mutable_key_type)]
126            let becomes_map = HashMap::from([(UOpKey(old_uop), realized_uop)]);
127            crate::tensor_registry::apply_map_to_tensors(&becomes_map);
128        }
129
130        plan.release_intermediate_buffers(|uop_id| {
131            if !input_buffer_ids.contains(&uop_id) {
132                crate::tensor_registry::remove_buffer(uop_id);
133            }
134        });
135
136        Ok(())
137    }
138
139    /// Realize tensor with custom configuration.
140    ///
141    /// Like [`realize()`](Self::realize) but allows specifying optimization strategy
142    /// and codegen backend.
143    ///
144    /// # Example
145    ///
146    /// ```ignore
147    /// use svod_tensor::PrepareConfig;
148    /// use svod_schedule::{OptStrategy, OptimizerConfig};
149    ///
150    /// let c = a.matmul(&b)?;
151    /// let config = PrepareConfig::from(
152    ///     OptimizerConfig::builder()
153    ///         .strategy(OptStrategy::Beam { width: 4 })
154    ///         .build()
155    /// );
156    /// let c = c.realize_with(&config)?;
157    /// ```
158    pub fn realize_with(&mut self, config: &PrepareConfig) -> Result<()> {
159        if self.uop().has_buffer_identity() {
160            self.ensure_buffer();
161            return Ok(());
162        }
163        // Pure constants: wrap in CONTIGUOUS to force materialization into a buffer.
164        if is_any_const(&self.uop()) {
165            let contiguous_uop = self.uop().contiguous();
166            self.set_uop(contiguous_uop);
167        }
168        if self.has_zero_elements() {
169            return Ok(());
170        }
171
172        let old_uop = self.uop();
173        let input_buffer_ids: HashSet<u64> = collect_input_buffers(&old_uop).keys().copied().collect();
174
175        let t_prep = std::time::Instant::now();
176        let plan = self.prepare_with(config)?;
177        let prep_ms = t_prep.elapsed().as_millis();
178        let t_exec = std::time::Instant::now();
179        plan.execute().context(ExecutionSnafu)?;
180        let exec_ms = t_exec.elapsed().as_millis();
181        debug!(prep_ms, exec_ms, "realize_with complete");
182
183        self.finalize_realize(&plan, &old_uop)?;
184
185        let realized_uop = self.uop();
186        if !Arc::ptr_eq(&old_uop, &realized_uop) {
187            #[allow(clippy::mutable_key_type)]
188            let becomes_map = HashMap::from([(UOpKey(old_uop), realized_uop)]);
189            crate::tensor_registry::apply_map_to_tensors(&becomes_map);
190        }
191
192        plan.release_intermediate_buffers(|uop_id| {
193            if !input_buffer_ids.contains(&uop_id) {
194                crate::tensor_registry::remove_buffer(uop_id);
195            }
196        });
197
198        Ok(())
199    }
200
201    /// Finalize realization: bind output buffer to tensor.
202    ///
203    /// Note: intermediate buffer cleanup is deferred to `realize()` so it
204    /// runs AFTER `apply_map_to_tensors`. This ensures other tensors can still
205    /// find buffers during the substitution window.
206    fn finalize_realize(&mut self, plan: &ExecutionPlan, uop: &Arc<UOp>) -> Result<()> {
207        let output_buf = plan.output_buffer().expect("realized plan must have an output buffer").clone();
208
209        trace!(
210            buffer.id = ?output_buf.id(),
211            buffer.size = output_buf.size(),
212            "Realized output buffer"
213        );
214
215        let output_dtype = uop.dtype();
216        let output_device = output_buf.allocator().device_spec();
217        let num_elements = output_buf.size() / output_dtype.bytes();
218
219        let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype.clone());
220        let output_buf_arc = Arc::new(output_buf);
221
222        crate::tensor_registry::register_buffer(buffer_uop.id, self.entry.id, output_buf_arc.clone());
223
224        let shape = uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
225        let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
226
227        debug!(
228            buffer_uop.id = buffer_uop.id,
229            num_elements,
230            shape = ?shape,
231            realized_uop.id = realized_uop.id,
232            realized_uop.base_id = realized_uop.base().id,
233            "Tensor realized"
234        );
235
236        self.set_uop(realized_uop);
237        self.entry.set_buffer(Arc::clone(&output_buf_arc));
238        self.buffer = Some(output_buf_arc);
239        Ok(())
240    }
241
242    /// Prepare an execution plan for this tensor's computation graph.
243    ///
244    /// This performs all one-time work:
245    /// 1. Creates schedule from computation graph
246    /// 2. Instantiates strict range-expanded callable schedule items
247    /// 3. Compiles all kernels
248    /// 4. Allocates all buffers
249    /// 5. Builds dependency-ordered prepared op execution plan
250    ///
251    /// The returned `ExecutionPlan` can then be executed multiple times
252    /// without recompilation overhead.
253    ///
254    /// # Example
255    ///
256    /// ```ignore
257    /// let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
258    /// let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
259    /// let mut c = &a + &b;
260    ///
261    /// // One-time preparation (wires output tensor to plan buffer)
262    /// let plan = c.prepare()?;
263    ///
264    /// // Fast execution (can be called many times)
265    /// plan.execute()?;
266    ///
267    /// // Get results
268    /// let output = plan.output_buffer();
269    /// ```
270    ///
271    /// # Errors
272    ///
273    /// Returns error if:
274    /// - Rangeify transformation fails
275    /// - No kernels found after scheduling
276    /// - Kernel compilation fails
277    /// - Buffer allocation fails
278    pub fn prepare(&mut self) -> Result<ExecutionPlan> {
279        self.prepare_with(&PrepareConfig::from_env())
280    }
281
282    /// Prepare an execution plan with explicit configuration.
283    ///
284    /// This method allows fine-grained control over kernel optimization settings
285    /// and codegen backend selection.
286    ///
287    /// # Example
288    ///
289    /// ```ignore
290    /// use svod_tensor::PrepareConfig;
291    /// use svod_schedule::{OptimizerConfig, OptStrategy, BeamConfig};
292    ///
293    /// // Beam search with width 8 and 120s timeout
294    /// let config = PrepareConfig::from(
295    ///     OptimizerConfig::builder()
296    ///         .strategy(OptStrategy::Beam { width: 8 })
297    ///         .beam(BeamConfig::builder()
298    ///             .timeout_secs(120)
299    ///             .build())
300    ///         .build()
301    /// );
302    ///
303    /// let plan = tensor.prepare_with(&config)?;
304    /// plan.execute()?;
305    /// ```
306    pub fn prepare_with(&mut self, config: &PrepareConfig) -> Result<ExecutionPlan> {
307        let t_total = std::time::Instant::now();
308        let uop = self.uop();
309
310        let sink = UOp::sink(vec![uop.contiguous()]);
311        let schedule_result = schedule_result_from_sink_with_cache(sink, extract_var_vals(&uop)?, config)?;
312        // Per-kernel optimization+compilation is cached globally in prepare_execution_plan
313        // via OPT_CACHE keyed by content_hash(ast). Identical kernel ASTs across calls
314        // (e.g., sort substages, repeated model inference) skip optimize+compile.
315        let plan = prepare_execution_plan(&schedule_result, config)?;
316
317        self.wire_output_tensor(&plan, &uop)?;
318        debug!(total_ms = t_total.elapsed().as_millis() as u64, "prepare: total");
319        Ok(plan)
320    }
321
322    fn wire_output_tensor(&mut self, plan: &ExecutionPlan, uop: &Arc<UOp>) -> Result<()> {
323        if plan.num_outputs() > 0 {
324            let buf = Arc::new(plan.output_buffer().expect("plan with num_outputs > 0 must expose output").clone());
325            let dtype = uop.dtype();
326            let device = buf.allocator().device_spec();
327            let buffer_uop = UOp::new_buffer(device, buf.size() / dtype.bytes(), dtype);
328            crate::tensor_registry::register_buffer(buffer_uop.id, self.entry.id, buf.clone());
329            let shape = uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
330            self.set_uop(buffer_uop.try_reshape(shape).context(UOpSnafu)?);
331            self.entry.set_buffer(buf.clone());
332            self.buffer = Some(buf);
333        }
334        Ok(())
335    }
336
337    // =========================================================================
338    // Batch realize / prepare
339    // =========================================================================
340
341    /// Realize multiple tensors in a single batch, sharing computation.
342    ///
343    /// Merges all tensor computation graphs into one SINK, enabling the scheduler
344    /// to share kernels across outputs. More efficient than calling `realize()`
345    /// individually when tensors share subgraphs.
346    pub fn realize_batch<'a>(tensors: impl IntoIterator<Item = &'a mut Tensor>) -> Result<()> {
347        Self::realize_batch_with(tensors, &PrepareConfig::from_env())
348    }
349
350    /// Realize multiple tensors with custom configuration.
351    pub fn realize_batch_with<'a>(
352        tensors: impl IntoIterator<Item = &'a mut Tensor>,
353        config: &PrepareConfig,
354    ) -> Result<()> {
355        let mut tensors: Vec<&mut Tensor> = tensors.into_iter().collect();
356        if tensors.is_empty() {
357            return Ok(());
358        }
359
360        // Handle already-realized tensors
361        for t in &mut tensors {
362            if t.uop().has_buffer_identity() {
363                t.ensure_buffer();
364            }
365        }
366
367        // Wrap pure constants in CONTIGUOUS to force materialization (matches realize())
368        for t in &mut tensors {
369            if !t.uop().has_buffer_identity() && is_any_const(&t.uop()) {
370                let contiguous_uop = t.uop().contiguous();
371                t.set_uop(contiguous_uop);
372            }
373        }
374
375        // Collect pending (unrealized) tensor indices
376        let pending_indices = collect_pending_indices(&tensors);
377
378        if pending_indices.is_empty() {
379            return Ok(());
380        }
381
382        // Collect input buffers and old UOps from ALL pending tensors
383        let old_uops: Vec<Arc<UOp>> = pending_indices.iter().map(|&i| tensors[i].uop()).collect();
384        let mut all_input_buffers = crate::schedule::InputBuffers::new();
385        for uop in &old_uops {
386            all_input_buffers.extend(collect_input_buffers(uop));
387        }
388        let input_ids: HashSet<u64> = all_input_buffers.keys().copied().collect();
389
390        // Create merged SINK(CONTIGUOUS(t1), ..., CONTIGUOUS(tN))
391        let contiguouses: Vec<Arc<UOp>> = old_uops.iter().map(|u| u.contiguous()).collect();
392        let sink = UOp::sink(contiguouses);
393
394        let mut var_vals = HashMap::new();
395        for uop in &old_uops {
396            let extracted = extract_var_vals(uop)?;
397            merge_var_vals_checked(&mut var_vals, &extracted, "realize_batch input collection")?;
398        }
399        let schedule_result = schedule_result_from_sink_with_cache(sink, var_vals, config)?;
400
401        let t_prep = std::time::Instant::now();
402        let plan = prepare_execution_plan(&schedule_result, config)?;
403        let prep_ms = t_prep.elapsed().as_millis();
404        let t_exec = std::time::Instant::now();
405        plan.execute().context(ExecutionSnafu)?;
406        let exec_ms = t_exec.elapsed().as_millis();
407        debug!(prep_ms, exec_ms, num_outputs = pending_indices.len(), "realize_batch complete");
408
409        snafu::ensure!(
410            plan.num_outputs() >= pending_indices.len(),
411            BatchOutputMismatchSnafu { expected: pending_indices.len(), actual: plan.num_outputs() }
412        );
413
414        // Finalize each pending tensor in-place + build batched becomes_map
415        #[allow(clippy::mutable_key_type)]
416        let mut becomes_map = HashMap::new();
417        for (buf_idx, &orig_idx) in pending_indices.iter().enumerate() {
418            let output_buf = plan.output_buffer_at(buf_idx).expect("buf_idx in range").clone();
419            let old_uop = &old_uops[buf_idx];
420
421            let output_dtype = old_uop.dtype();
422            let output_device = output_buf.allocator().device_spec();
423            let num_elements = output_buf.size() / output_dtype.bytes();
424            let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype);
425            let buf_arc = Arc::new(output_buf);
426
427            let t = &mut tensors[orig_idx];
428            crate::tensor_registry::register_buffer(buffer_uop.id, t.entry.id, buf_arc.clone());
429            let shape = old_uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
430            let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
431            t.set_uop(realized_uop.clone());
432            t.entry.set_buffer(Arc::clone(&buf_arc));
433            t.buffer = Some(buf_arc);
434
435            becomes_map.insert(UOpKey(old_uop.clone()), realized_uop);
436        }
437
438        // Single batched apply_map (one global walk instead of N)
439        crate::tensor_registry::apply_map_to_tensors(&becomes_map);
440
441        // Cleanup intermediate buffers
442        plan.release_intermediate_buffers(|id| {
443            if !input_ids.contains(&id) {
444                crate::tensor_registry::remove_buffer(id);
445            }
446        });
447
448        Ok(())
449    }
450
451    /// Prepare a batch execution plan for multiple tensors.
452    ///
453    /// Output tensors are wired to plan buffers — after `execute`/`execute_with_vars`,
454    /// results are readable directly via `tensor.as_vec()` or `tensor.array_view()`.
455    pub fn prepare_batch<'a>(tensors: impl IntoIterator<Item = &'a mut Tensor>) -> Result<ExecutionPlan> {
456        Self::prepare_batch_with(tensors, &PrepareConfig::from_env())
457    }
458
459    /// Prepare a batch execution plan with custom configuration.
460    pub fn prepare_batch_with<'a>(
461        tensors: impl IntoIterator<Item = &'a mut Tensor>,
462        config: &PrepareConfig,
463    ) -> Result<ExecutionPlan> {
464        let mut tensors: Vec<&mut Tensor> = tensors.into_iter().collect();
465        if tensors.is_empty() {
466            return EmptyScheduleSnafu.fail();
467        }
468
469        // Handle already-realized tensors
470        for t in &mut tensors {
471            if t.uop().has_buffer_identity() {
472                t.ensure_buffer();
473            }
474        }
475
476        // Wrap pure constants in CONTIGUOUS to force materialization (matches realize())
477        for t in &mut tensors {
478            if !t.uop().has_buffer_identity() && is_any_const(&t.uop()) {
479                let contiguous_uop = t.uop().contiguous();
480                t.set_uop(contiguous_uop);
481            }
482        }
483
484        // Collect pending (unrealized) tensor indices
485        let pending_indices = collect_pending_indices(&tensors);
486
487        if pending_indices.is_empty() {
488            return EmptyScheduleSnafu.fail();
489        }
490
491        // Collect UOps from pending tensors only
492        let uops: Vec<Arc<UOp>> = pending_indices.iter().map(|&i| tensors[i].uop()).collect();
493
494        let mut var_vals = HashMap::new();
495        for uop in &uops {
496            let extracted = extract_var_vals(uop)?;
497            merge_var_vals_checked(&mut var_vals, &extracted, "prepare_batch input collection")?;
498        }
499
500        // Create merged SINK(CONTIGUOUS(t1), ..., CONTIGUOUS(tN)) from pending tensors
501        let contiguouses: Vec<Arc<UOp>> = uops.iter().map(|u| u.contiguous()).collect();
502        let sink = UOp::sink(contiguouses);
503
504        let schedule_result = schedule_result_from_sink_with_cache(sink, var_vals, config)?;
505
506        let plan = prepare_execution_plan(&schedule_result, config)?;
507
508        // Wire each pending output tensor to its plan buffer.
509        // After execute/execute_with_vars, tensor.array_view() reads the result directly.
510        for (buf_idx, &orig_idx) in pending_indices.iter().enumerate() {
511            if buf_idx >= plan.num_outputs() {
512                break;
513            }
514            let output_buf = plan.output_buffer_at(buf_idx).expect("buf_idx in range").clone();
515            let buf_arc = Arc::new(output_buf);
516            let old_uop = &uops[buf_idx];
517            let output_dtype = old_uop.dtype();
518            let output_device = buf_arc.allocator().device_spec();
519            let num_elements = buf_arc.size() / output_dtype.bytes();
520            let buffer_uop = UOp::new_buffer(output_device, num_elements, output_dtype);
521            let t = &mut tensors[orig_idx];
522            crate::tensor_registry::register_buffer(buffer_uop.id, t.entry.id, buf_arc.clone());
523            let shape = old_uop.shape().context(UOpSnafu)?.context(ShapeUnknownSnafu)?;
524            let realized_uop = buffer_uop.try_reshape(shape).context(UOpSnafu)?;
525            t.set_uop(realized_uop);
526            t.entry.set_buffer(Arc::clone(&buf_arc));
527            t.buffer = Some(buf_arc);
528        }
529
530        Ok(plan)
531    }
532}
533
534/// Extract bound variable values from a UOp graph (pre-pipeline).
535///
536/// Scans for BIND(DEFINE_VAR, CONST) nodes and extracts the mapping from
537/// variable name to concrete bound value. These values are passed through to
538/// scheduling so that user Variables (like `Variable::new("N", 1, 32).bind(4)`)
539/// are treated as fixed parameters rather than schedule-loop ranges to expand.
540/// Insert `(name, val)` into `var_vals` if not present, otherwise check that
541/// any existing binding agrees. Returns `Err((prev, val))` on mismatch so
542/// callers can format the error in their own context.
543fn try_bind_var_val(var_vals: &mut HashMap<String, i64>, name: &str, val: i64) -> std::result::Result<(), (i64, i64)> {
544    if let Some(&prev) = var_vals.get(name) {
545        if prev != val {
546            return Err((prev, val));
547        }
548        return Ok(());
549    }
550    var_vals.insert(name.to_string(), val);
551    Ok(())
552}
553
554fn insert_var_val_checked(var_vals: &mut HashMap<String, i64>, name: &str, val: i64, context: &str) -> Result<()> {
555    match try_bind_var_val(var_vals, name, val) {
556        Ok(()) => Ok(()),
557        Err((prev, val)) => {
558            IrConstructionSnafu { details: format!("bind mismatch on {name}, {prev} != {val} ({context})") }.fail()
559        }
560    }
561}
562
563fn merge_var_vals_checked(dst: &mut HashMap<String, i64>, src: &HashMap<String, i64>, context: &str) -> Result<()> {
564    for (name, val) in src {
565        insert_var_val_checked(dst, name, *val, context)?;
566    }
567    Ok(())
568}
569
570fn extract_var_vals(root: &Arc<UOp>) -> Result<HashMap<String, i64>> {
571    let mut var_vals = HashMap::new();
572    for node in root.toposort() {
573        if let Op::Bind { var, value } = node.op()
574            && let Op::DefineVar { name, .. } = var.op()
575            && let Op::Const(cv) = value.op()
576            && let Some(val) = cv.0.try_int()
577        {
578            insert_var_val_checked(&mut var_vals, name, val, "bind extraction")?;
579        }
580    }
581    Ok(var_vals)
582}
583
584fn schedule_cache_disabled_by_env() -> bool {
585    static DISABLED: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
586    *DISABLED.get_or_init(|| std::env::var("SVOD_DISABLE_SCHEDULE_CACHE").as_deref() == Ok("1"))
587}
588
589fn schedule_result_from_sink_with_cache(
590    sink: Arc<UOp>,
591    mut var_vals: HashMap<String, i64>,
592    config: &PrepareConfig,
593) -> Result<crate::schedule::ScheduleResult> {
594    if config.disable_schedule_cache || schedule_cache_disabled_by_env() {
595        return schedule_result_from_sink_uncached(sink, var_vals, config);
596    }
597
598    let normalization = normalize_for_schedule_cache(&sink)?;
599    merge_var_vals_checked(&mut var_vals, &normalization.var_vals, "schedule cache normalization")?;
600
601    let codegen = resolve_codegen(&normalization.param_buffers, config)?;
602    let sched_key = (crate::schedule_cache::content_hash(&normalization.normalized), codegen);
603
604    let cache = crate::schedule_cache::schedule_cache();
605    let entry = {
606        let guard = cache.guard();
607        cache.get(&sched_key, &guard).cloned()
608    };
609
610    let entry = match entry {
611        Some(hit) => {
612            debug!("schedule cache hit");
613            hit
614        }
615        None => {
616            let schedule_root = restore_bind_placeholders_for_schedule(&normalization.normalized, &normalization);
617            let rangeify_result = svod_schedule::rangeify_with_map(schedule_root, None).context(RangeifySnafu)?;
618            let (kernel_graph, _) = svod_schedule::try_get_kernel_graph(rangeify_result.sink).context(RangeifySnafu)?;
619            let pre_schedule = crate::schedule::create_pre_schedule(kernel_graph)?;
620            let new_entry = Arc::new(crate::schedule_cache::CachedSchedule { pre_schedule: Arc::new(pre_schedule) });
621            let guard = cache.guard();
622            cache.insert(sched_key, Arc::clone(&new_entry), &guard);
623            new_entry
624        }
625    };
626
627    let restored_pre_schedule = restore_post_schedule_pre_schedule(&entry.pre_schedule, &normalization);
628    let schedule_input_buffers = build_schedule_input_buffers(&restored_pre_schedule, &normalization);
629    let result = crate::schedule::instantiate_schedule(&restored_pre_schedule, &schedule_input_buffers, &var_vals)?;
630    Ok(result)
631}
632
633fn schedule_result_from_sink_uncached(
634    sink: Arc<UOp>,
635    var_vals: HashMap<String, i64>,
636    _config: &PrepareConfig,
637) -> Result<crate::schedule::ScheduleResult> {
638    let rangeify_result = svod_schedule::rangeify_with_map(sink, None).context(RangeifySnafu)?;
639    let (kernel_graph, _) = svod_schedule::try_get_kernel_graph(rangeify_result.sink).context(RangeifySnafu)?;
640    let pre_schedule = crate::schedule::create_pre_schedule(kernel_graph.clone())?;
641    let input_buffers = collect_input_buffers(&kernel_graph);
642    let result = crate::schedule::instantiate_schedule(&pre_schedule, &input_buffers, &var_vals)?;
643    Ok(result)
644}
645
646/// Pre-schedule cache normalization result.
647///
648/// - BUFFER -> PARAM
649/// - BUFFER_VIEW identities normalized via recursive BUFFER -> PARAM
650/// - strip runtime value from BIND(DEFINE_VAR, CONST)
651/// - normalize standalone UNIQUE identity -> LUNIQUE
652pub(crate) struct ScheduleCacheNormalization {
653    pub normalized: Arc<UOp>,
654    pub param_values: Vec<Arc<UOp>>,
655    pub param_buffers: Vec<(u64, Arc<UOp>)>,
656    pub unique_values: Vec<Arc<UOp>>,
657    pub var_vals: HashMap<String, i64>,
658}
659
660/// Context for pre-schedule cache normalization.
661pub(crate) struct NormalizeScheduleCacheCtx {
662    pub param_map: HashMap<u64, usize>,
663    pub param_values: Vec<Arc<UOp>>,
664    pub param_buffers: Vec<(u64, Arc<UOp>)>,
665    pub var_vals: HashMap<String, i64>,
666    pub bind_mismatch: Option<String>,
667}
668
669/// Full pre-schedule cache normalization.
670pub(crate) fn normalize_for_schedule_cache(sink: &Arc<UOp>) -> Result<ScheduleCacheNormalization> {
671    let mut ctx = NormalizeScheduleCacheCtx {
672        param_map: HashMap::new(),
673        param_values: Vec::new(),
674        param_buffers: Vec::new(),
675        var_vals: HashMap::new(),
676        bind_mismatch: None,
677    };
678
679    use svod_ir::op::pattern_derived::OpKey;
680    use svod_ir::pattern::{RewriteResult, SimplifiedPatternMatcher};
681    use svod_ir::rewrite::graph_rewrite;
682
683    let mut matcher = SimplifiedPatternMatcher::<NormalizeScheduleCacheCtx>::new();
684
685    fn to_param(
686        node: &Arc<UOp>,
687        ctx: &mut NormalizeScheduleCacheCtx,
688        size: usize,
689        device: Option<Arc<UOp>>,
690    ) -> Arc<UOp> {
691        let slot = *ctx.param_map.entry(node.id).or_insert_with(|| {
692            let s = ctx.param_values.len();
693            ctx.param_values.push(node.clone());
694            s
695        });
696        UOp::param(slot, size, node.dtype(), device)
697    }
698
699    // BUFFER -> PARAM (erase runtime buffer identity in cache key).
700    matcher.add(&[OpKey::Buffer], |node, ctx| {
701        let Op::Buffer { size, device, .. } = node.op() else {
702            return RewriteResult::NoMatch;
703        };
704        let slot = *ctx.param_map.entry(node.id).or_insert_with(|| {
705            let s = ctx.param_values.len();
706            ctx.param_values.push(node.clone());
707            s
708        });
709        ctx.param_buffers.push((node.id, node.clone()));
710        RewriteResult::Rewritten(UOp::param(slot, *size, node.dtype(), Some(device.clone())))
711    });
712
713    // BUFFER_VIEW -> PARAM.
714    matcher.add(&[OpKey::BufferView], |node, ctx| {
715        let Op::BufferView { size, .. } = node.op() else {
716            return RewriteResult::NoMatch;
717        };
718        RewriteResult::Rewritten(to_param(node, ctx, *size, Some(UOp::device(DeviceSpec::Cpu))))
719    });
720
721    // Strip runtime value from BIND for cache-key stability and collect var_vals.
722    // Replaced with PARAM(device=Some) so restoration stays reversible and
723    // distinguishable from internal PARAM(device=None) nodes created by rangeify.
724    matcher.add(&[OpKey::Bind], |node, ctx| {
725        let Op::Bind { var, value } = node.op() else {
726            return RewriteResult::NoMatch;
727        };
728        let Op::DefineVar { name, .. } = var.op() else {
729            return RewriteResult::NoMatch;
730        };
731        let Op::Const(cv) = value.op() else {
732            return RewriteResult::NoMatch;
733        };
734        let Some(val) = cv.0.try_int() else {
735            return RewriteResult::NoMatch;
736        };
737
738        if let Err((prev, val)) = try_bind_var_val(&mut ctx.var_vals, name, val) {
739            if ctx.bind_mismatch.is_none() {
740                ctx.bind_mismatch = Some(format!("bind mismatch on variable {name}: {prev} vs {val}"));
741            }
742            return RewriteResult::NoMatch;
743        }
744        RewriteResult::Rewritten(to_param(node, ctx, 0, Some(UOp::device(DeviceSpec::Cpu))))
745    });
746
747    // Pre-schedule cache normalization:
748    // - BUFFER(UNIQUE, DEVICE) -> PARAM
749    // - BUFFER_VIEW base identity normalized through child BUFFER -> PARAM
750    // - BIND(DEFINE_VAR, CONST) -> PARAM + var_vals capture
751    let normalized = graph_rewrite(&matcher, sink.clone(), &mut ctx);
752
753    if let Some(details) = ctx.bind_mismatch.take() {
754        return IrConstructionSnafu { details }.fail();
755    }
756
757    // Normalize standalone UNIQUE identity to deterministic LUNIQUE slots.
758    // This runs after BUFFER/BUFFER_VIEW replacement to avoid capturing UNIQUE
759    // nodes that are no longer present in the normalized graph.
760    struct UniqueNormalizationCtx {
761        unique_map: HashMap<u64, usize>,
762        unique_values: Vec<Arc<UOp>>,
763    }
764    let mut unique_ctx = UniqueNormalizationCtx { unique_map: HashMap::new(), unique_values: Vec::new() };
765    let mut unique_matcher = SimplifiedPatternMatcher::<UniqueNormalizationCtx>::new();
766    unique_matcher.add(&[OpKey::Unique], |node, ctx| {
767        let Op::Unique(_) = node.op() else {
768            return RewriteResult::NoMatch;
769        };
770        let slot = *ctx.unique_map.entry(node.id).or_insert_with(|| {
771            let s = ctx.unique_values.len();
772            ctx.unique_values.push(node.clone());
773            s
774        });
775        RewriteResult::Rewritten(UOp::lunique(Some(slot)))
776    });
777    let normalized = graph_rewrite(&unique_matcher, normalized, &mut unique_ctx);
778
779    ctx.param_buffers.sort_unstable_by_key(|(id, _)| *id);
780    ctx.param_buffers.dedup_by_key(|(id, _)| *id);
781
782    Ok(ScheduleCacheNormalization {
783        normalized,
784        param_values: ctx.param_values,
785        param_buffers: ctx.param_buffers,
786        unique_values: unique_ctx.unique_values,
787        var_vals: ctx.var_vals,
788    })
789}
790
791/// Post-schedule cache restore.
792///
793/// Restores normalized placeholders back to runtime graph form for this run:
794/// - PARAM(slot, device=Some(_)) -> original source node for current invocation
795/// - BUFFER(LUNIQUE, DEVICE, size) -> fresh runtime BUFFER (memoized by slot)
796/// - standalone LUNIQUE(slot) -> original UNIQUE identity
797///
798/// BIND runtime values are carried separately through `var_vals` and applied
799/// at execution-time via fixedvars, preserving `execute_with_vars` behavior.
800#[allow(clippy::mutable_key_type)]
801pub(crate) fn restore_post_schedule_cache(root: &Arc<UOp>, normalization: &ScheduleCacheNormalization) -> Arc<UOp> {
802    let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
803    let mut lunique_buffers: HashMap<usize, Arc<UOp>> = HashMap::new();
804
805    for node in root.toposort() {
806        match node.op() {
807            Op::Param { slot, device: Some(_), .. } => {
808                if let Some(original) = normalization.param_values.get(*slot) {
809                    let restored_original = restore_post_schedule_cache(original, normalization);
810                    subs.insert(UOpKey(node.clone()), restored_original);
811                }
812            }
813            Op::Buffer { unique, device, size } => {
814                let Op::LUnique(slot) = unique.op() else {
815                    continue;
816                };
817                let restored = if let Some(existing) = lunique_buffers.get(slot) {
818                    existing.clone()
819                } else {
820                    let runtime_unique = UOp::buffer_id(None);
821                    let fresh = UOp::new(
822                        Op::Buffer { unique: runtime_unique, device: device.clone(), size: *size },
823                        node.dtype(),
824                    );
825                    lunique_buffers.insert(*slot, fresh.clone());
826                    fresh
827                };
828                subs.insert(UOpKey(node.clone()), restored);
829            }
830            Op::LUnique(slot) => {
831                let restored = normalization.unique_values.get(*slot).cloned().unwrap_or_else(|| UOp::buffer_id(None));
832                subs.insert(UOpKey(node.clone()), restored);
833            }
834            _ => {}
835        }
836    }
837
838    // Restore over the whole cached graph so PARAM/BIND placeholders are
839    // rewritten before schedule extraction.
840    root.substitute(&subs)
841}
842
843/// Restore only normalized BIND placeholders back to BIND nodes.
844///
845/// Cache keying strips bind runtime values (`BIND -> PARAM`) for key stability,
846/// but rangeify needs BIND semantics to preserve variable tracking. This helper
847/// rewrites just those placeholders while keeping BUFFER/PARAM normalization —
848/// the kernel AST must stay parametric so the cached pre-schedule can be reused
849/// across runs with different runtime buffers (post-cache restoration only
850/// swaps the buffer-uop *lists*, not the deep kernel AST).
851#[allow(clippy::mutable_key_type)]
852fn restore_bind_placeholders_for_schedule(root: &Arc<UOp>, normalization: &ScheduleCacheNormalization) -> Arc<UOp> {
853    let mut subs: HashMap<UOpKey, Arc<UOp>> = HashMap::new();
854
855    for node in root.toposort() {
856        let Op::Param { slot, device: Some(_), .. } = node.op() else {
857            continue;
858        };
859
860        let Some(original) = normalization.param_values.get(*slot) else {
861            continue;
862        };
863        if matches!(original.op(), Op::Bind { .. }) {
864            subs.insert(UOpKey(node.clone()), original.clone());
865        }
866    }
867
868    if subs.is_empty() { root.clone() } else { root.substitute(&subs) }
869}
870
871/// Restore cached pre-schedule buffer UOps for the current invocation.
872///
873/// `pre_schedule` is cached with normalized PARAM placeholders; this helper
874/// restores source/output buffer UOps to run-specific BUFFER identities while
875/// callable identities/ASTs stay cached.
876fn restore_post_schedule_pre_schedule(
877    pre_schedule: &crate::schedule::PreSchedule,
878    normalization: &ScheduleCacheNormalization,
879) -> crate::schedule::PreSchedule {
880    let mut flat_buf_uops = Vec::new();
881    let mut source_counts = Vec::with_capacity(pre_schedule.items.len());
882
883    for item in &pre_schedule.items {
884        source_counts.push(item.sources.len());
885        flat_buf_uops.extend(item.sources.iter().cloned());
886    }
887    let outputs_offset = flat_buf_uops.len();
888    flat_buf_uops.extend(pre_schedule.output_buffer_uops.iter().cloned());
889
890    if flat_buf_uops.is_empty() {
891        return pre_schedule.clone();
892    }
893
894    let restored_flat = match restore_post_schedule_cache(&UOp::sink(flat_buf_uops), normalization).op() {
895        Op::Sink { sources, .. } => sources.iter().cloned().collect::<Vec<_>>(),
896        _ => unreachable!("sink substitution must preserve SINK root"),
897    };
898
899    let mut cursor = 0usize;
900    let mut restored_items = Vec::with_capacity(pre_schedule.items.len());
901    for (item, source_count) in pre_schedule.items.iter().zip(source_counts) {
902        let end = cursor + source_count;
903        let sources = restored_flat[cursor..end].to_vec();
904        cursor = end;
905        let ast = restore_post_schedule_cache(&item.ast, normalization);
906        restored_items.push(crate::schedule::PreScheduleItem {
907            kernel: item.kernel.clone(),
908            ast,
909            sources,
910            dependencies: item.dependencies.clone(),
911            bound_ranges: item.bound_ranges.clone(),
912        });
913    }
914
915    let output_buffer_uops = restored_flat[outputs_offset..].to_vec();
916    crate::schedule::PreSchedule {
917        items: restored_items,
918        invocations: pre_schedule.invocations.clone(),
919        output_buffer_uops,
920    }
921}
922
923fn build_schedule_input_buffers(
924    pre_schedule: &crate::schedule::PreSchedule,
925    _normalization: &ScheduleCacheNormalization,
926) -> crate::schedule::InputBuffers {
927    let mut inputs = crate::schedule::InputBuffers::new();
928
929    for item in &pre_schedule.items {
930        for src in &item.sources {
931            let buf = src.buf_uop();
932            if let Op::Buffer { .. } = buf.op()
933                && let Some(buffer) = crate::tensor_registry::get_buffer(buf.id)
934            {
935                inputs.insert(buf.id, buffer);
936            }
937        }
938    }
939
940    inputs
941}
942
943/// Collect input buffers from a computation graph.
944///
945/// Walks the UOp graph and collects all BUFFER UOps that have
946/// associated buffers in the tensor registry's buffer index.
947/// Input tensors (from `from_slice()`) and realized tensors
948/// register their buffers for this lookup to work.
949///
950/// This allows schedule creation to receive buffers explicitly without
951/// needing global registry lookups during kernel buffer collection.
952fn collect_input_buffers(root: &Arc<UOp>) -> crate::schedule::InputBuffers {
953    let mut inputs = HashMap::new();
954    for node in root.toposort() {
955        if let Op::Buffer { .. } = node.op() {
956            // Buffers are registered in from_slice_on() and realize()
957            if let Some(buf) = crate::tensor_registry::get_buffer(node.id) {
958                inputs.insert(node.id, buf);
959            }
960        }
961    }
962    inputs
963}
964
965fn output_indices_from_program_metadata(globals: &[usize], outs: &[usize], num_buffers: usize) -> Result<Vec<usize>> {
966    if num_buffers == 0 {
967        return IrConstructionSnafu { details: "cannot map outputs for kernel with zero buffers".to_string() }.fail();
968    }
969    if globals.is_empty() {
970        return IrConstructionSnafu { details: "ProgramSpec.globals is empty".to_string() }.fail();
971    }
972    if outs.is_empty() {
973        return IrConstructionSnafu { details: "ProgramSpec.outs is empty".to_string() }.fail();
974    }
975
976    let slot_to_position: HashMap<usize, usize> =
977        globals.iter().copied().enumerate().map(|(position, slot)| (slot, position)).collect();
978
979    let mut output_indices = Vec::with_capacity(outs.len());
980    for &slot in outs {
981        let Some(position) = slot_to_position.get(&slot).copied() else {
982            return IrConstructionSnafu {
983                details: format!("ProgramSpec.outs slot {slot} not found in ProgramSpec.globals={globals:?}"),
984            }
985            .fail();
986        };
987        if position >= num_buffers {
988            return IrConstructionSnafu {
989                details: format!(
990                    "ProgramSpec output index {position} (slot {slot}) out of range for {num_buffers} buffers"
991                ),
992            }
993            .fail();
994        }
995        output_indices.push(position);
996    }
997
998    output_indices.sort_unstable();
999    output_indices.dedup();
1000    if output_indices.is_empty() {
1001        return IrConstructionSnafu { details: "ProgramSpec output mapping resolved to empty set".to_string() }.fail();
1002    }
1003
1004    Ok(output_indices)
1005}
1006
1007fn resolve_item_buffer_indices(item: &ScheduleItem, uop_id_to_idx: &HashMap<u64, usize>) -> Result<Vec<usize>> {
1008    let mut indices = Vec::with_capacity(item.buffer_uop_ids.len());
1009    for &uop_id in &item.buffer_uop_ids {
1010        let Some(idx) = uop_id_to_idx.get(&uop_id).copied() else {
1011            return Err(crate::error::Error::BufferNotFound { uop_id });
1012        };
1013        indices.push(idx);
1014    }
1015    Ok(indices)
1016}
1017
1018fn resolve_compiled_kernel_buffer_indices(
1019    item: &ScheduleItem,
1020    uop_id_to_idx: &HashMap<u64, usize>,
1021    globals: &[usize],
1022) -> Result<Vec<usize>> {
1023    let buffer_indices = resolve_item_buffer_indices(item, uop_id_to_idx)?;
1024
1025    let mut ordered = Vec::with_capacity(globals.len());
1026    for &position in globals {
1027        let Some(idx) = buffer_indices.get(position).copied() else {
1028            return IrConstructionSnafu {
1029                details: format!(
1030                    "ProgramSpec.globals position {position} out of range for CALL {} buffer list len {} (buffer_uop_ids={:?})",
1031                    item.kernel.id,
1032                    buffer_indices.len(),
1033                    item.buffer_uop_ids
1034                ),
1035            }
1036            .fail();
1037        };
1038        ordered.push(idx);
1039    }
1040
1041    Ok(ordered)
1042}
1043
1044type OptKey = (u64, DeviceSpec, &'static str, u64);
1045
1046/// Bounded global cache for optimized + compiled kernels keyed by AST hash.
1047///
1048/// Reads are lock-free via the underlying `papaya::HashMap`; the FIFO side
1049/// structure is touched only on insert under a short-lived mutex. The cap is
1050/// read once via `SVOD_OPT_CACHE_MAX` (default 4096); when capacity is
1051/// exceeded, the oldest insertions are evicted from both the map and the
1052/// FIFO.
1053struct OptCacheState {
1054    map: papaya::HashMap<OptKey, Arc<svod_runtime::kernel_cache::CachedKernel>>,
1055    fifo: parking_lot::Mutex<std::collections::VecDeque<OptKey>>,
1056    cap: usize,
1057}
1058
1059impl OptCacheState {
1060    const DEFAULT_CAP: usize = 4096;
1061
1062    fn new() -> Self {
1063        let cap = std::env::var("SVOD_OPT_CACHE_MAX")
1064            .ok()
1065            .and_then(|s| s.parse::<usize>().ok())
1066            .filter(|&n| n > 0)
1067            .unwrap_or(Self::DEFAULT_CAP);
1068        Self { map: papaya::HashMap::new(), fifo: parking_lot::Mutex::new(std::collections::VecDeque::new()), cap }
1069    }
1070
1071    fn insert(&self, key: OptKey, val: Arc<svod_runtime::kernel_cache::CachedKernel>) {
1072        let guard = self.map.guard();
1073        let was_new = self.map.insert(key.clone(), val, &guard).is_none();
1074        if !was_new {
1075            return;
1076        }
1077        let mut fifo = self.fifo.lock();
1078        fifo.push_back(key);
1079        while fifo.len() > self.cap {
1080            if let Some(evict) = fifo.pop_front() {
1081                self.map.remove(&evict, &guard);
1082            }
1083        }
1084    }
1085}
1086
1087pub(crate) fn runtime_effect_ast(ast: &Arc<UOp>) -> &Arc<UOp> {
1088    match ast.op() {
1089        Op::End { computation, .. }
1090            if matches!(computation.op(), Op::Copy { .. } | Op::BufferView { .. } | Op::CustomFunction { .. }) =>
1091        {
1092            computation
1093        }
1094        _ => ast,
1095    }
1096}
1097
1098fn optimizer_config_fingerprint(config: &PrepareConfig) -> u64 {
1099    let mut hasher = std::collections::hash_map::DefaultHasher::new();
1100    config.optimizer.hash(&mut hasher);
1101    hasher.finish()
1102}
1103
1104/// Prepare an execution plan from a schedule.
1105///
1106/// This performs all one-time preparation work:
1107/// 1. Allocates all buffers
1108/// 2. Compiles callable kernels
1109/// 3. Creates prepared runtime ops (compiled program + copy/view/custom-function handling)
1110///
1111/// # Arguments
1112///
1113/// * `schedule` - The schedule from `create_schedule()`
1114///
1115/// # Returns
1116///
1117/// An `ExecutionPlan` ready for fast repeated execution.
1118///
1119/// # Errors
1120///
1121/// Returns error if compilation or buffer allocation fails.
1122fn prepare_execution_plan(
1123    schedule_result: &crate::schedule::ScheduleResult,
1124    config: &PrepareConfig,
1125) -> Result<ExecutionPlan> {
1126    // Schedule items are already fully expanded by strict scheduler unroll.
1127    let mut schedule_items = schedule_result.items.clone();
1128
1129    // Liveness-based memory planning. `PlannerMode::Arena` (default) packs
1130    // plannable buffers into one or two large allocations; `Remap` swaps
1131    // per-pool `Arc<Buffer>`s; `Disabled` short-circuits. Mode is selected
1132    // by `SVOD_MEMORY_PLANNER` (`Arena` if unset).
1133    let planner_mode = crate::memory_planner::mode_from_env();
1134    let output_buffer_ids = collect_output_buffer_ids(&schedule_items, &schedule_result.output_uop_ids);
1135    let planner_result = crate::memory_planner::memory_planner(&schedule_items, &output_buffer_ids, planner_mode);
1136    if !planner_result.buffer_replace.is_empty() {
1137        trace!(
1138            replacements = planner_result.buffer_replace.len(),
1139            buffers_reused = planner_result.buffers_reused,
1140            memory_saved_bytes = planner_result.memory_saved,
1141            "applying memory planner buffer replacements"
1142        );
1143        crate::memory_planner::apply_reuse_dependencies(&mut schedule_items, &planner_result.reuse_dependencies);
1144        crate::memory_planner::apply_buffer_replacements(&mut schedule_items, &planner_result.buffer_replace);
1145    }
1146
1147    debug!(num_items = schedule_items.len(), "schedule items ready for execution plan");
1148
1149    // Resolve primary plan device from the first schedule item for plan metadata.
1150    // Individual compiled kernels may still resolve/compile on per-item devices.
1151    let alloc_registry = svod_device::registry::registry();
1152    let plan_device = if !schedule_items.is_empty() {
1153        let device_spec = schedule_items
1154            .iter()
1155            .flat_map(|item| item.buffers.iter().map(|b| b.allocator().device_spec()))
1156            .find(|spec| !spec.is_disk())
1157            .unwrap_or(DeviceSpec::Cpu);
1158        config.resolve_device(&device_spec, alloc_registry)?
1159    } else {
1160        return EmptyScheduleSnafu.fail();
1161    };
1162    let optimizer_fingerprint = optimizer_config_fingerprint(config);
1163
1164    // Build the ExecutionPlan using the builder
1165    let mut builder = ExecutionPlanBuilder::new(plan_device.device.clone());
1166
1167    // Step 1: Add all buffers to the plan
1168    // Buffers in each ScheduleItem are already in the correct order (from collect_callable_buffers).
1169    // We track buffers by their UOp ID (what they were registered under in tensor_registry's buffer index).
1170    let mut uop_id_to_idx: HashMap<u64, usize> = HashMap::new();
1171    let mut storage_to_idx: HashMap<BufferStorageKey, usize> = HashMap::new();
1172
1173    // BUFFER_VIEW output slots are replaced later with base views. Keep them as
1174    // distinct entries even if they currently share physical storage, so replace
1175    // cannot accidentally mutate another logical buffer mapping.
1176    let buffer_view_output_uop_ids: HashSet<u64> = schedule_items
1177        .iter()
1178        .filter_map(|item| {
1179            if matches!(runtime_effect_ast(&item.ast).op(), Op::BufferView { .. }) {
1180                item.buffer_uop_ids.first().copied()
1181            } else {
1182                None
1183            }
1184        })
1185        .collect();
1186
1187    for item in &schedule_items {
1188        // Ensure all buffers are allocated
1189        for (buffer, &uop_id) in item.buffers.iter().zip(item.buffer_uop_ids.iter()) {
1190            buffer.ensure_allocated().context(DeviceSnafu)?;
1191
1192            if uop_id_to_idx.contains_key(&uop_id) {
1193                continue;
1194            }
1195
1196            let storage_key = BufferStorageKey {
1197                id: buffer.id().0,
1198                offset: buffer.offset(),
1199                size: buffer.size(),
1200                dtype: buffer.dtype(),
1201            };
1202
1203            let idx = if !buffer_view_output_uop_ids.contains(&uop_id) {
1204                if let Some(&existing_idx) = storage_to_idx.get(&storage_key) {
1205                    builder.map_buffer(uop_id, existing_idx);
1206                    existing_idx
1207                } else {
1208                    let new_idx = builder.add_buffer(uop_id, buffer.clone());
1209                    storage_to_idx.insert(storage_key, new_idx);
1210                    new_idx
1211                }
1212            } else {
1213                builder.add_buffer(uop_id, buffer.clone())
1214            };
1215            uop_id_to_idx.insert(uop_id, idx);
1216        }
1217
1218        // Collect alias IDs for cleanup
1219        builder.add_alias_ids(item.alias_registered_ids.iter().copied());
1220    }
1221
1222    // Step 2: Compile callable kernels and create prepared runtime ops
1223
1224    // Pre-compile: optimize + compile each UNIQUE ast once, cache by pre-optimization ast id.
1225    // Uses global cache so identical kernels across prepare calls (e.g., sort substages
1226    // with same axis) skip both optimization and compilation. Bounded via FIFO eviction
1227    // to keep long-running processes from accumulating dead kernel entries indefinitely.
1228    static OPT_CACHE: std::sync::OnceLock<OptCacheState> = std::sync::OnceLock::new();
1229    let opt_state = OPT_CACHE.get_or_init(OptCacheState::new);
1230    let opt_cache = &opt_state.map;
1231    let opt_guard = opt_cache.guard();
1232
1233    for item in &schedule_items {
1234        // COPY operations: buffer-to-buffer transfer (DISK→CPU, CPU→CUDA, etc.)
1235        // No compilation needed — register as PreparedOp for runtime execution.
1236        let runtime_ast = runtime_effect_ast(&item.ast);
1237
1238        if matches!(runtime_ast.op(), Op::Copy { .. }) {
1239            let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
1240            builder.add_op_with_instance_dependencies(
1241                PreparedOp::BufferCopy(PreparedCopy {
1242                    id: item.kernel.id,
1243                    buffer_indices,
1244                    dependencies: item.dependencies.clone(),
1245                }),
1246                item.instance_dependencies.clone(),
1247            );
1248            continue;
1249        }
1250
1251        // BUFFER_VIEW: zero-copy sub-buffer view (DISK weight views).
1252        // Creates a view into the base buffer at the specified byte offset.
1253        // BUFFER_VIEW lowers to a base.view(size, dtype, offset) on the source buffer.
1254        if let Op::BufferView { size, offset, .. } = runtime_ast.op() {
1255            let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
1256
1257            if item.buffers.len() >= 2 && item.buffer_uop_ids.len() >= 2 && buffer_indices.len() >= 2 {
1258                let base = &item.buffers[1];
1259                let byte_offset = offset * base.dtype().bytes();
1260                let byte_size = size * runtime_ast.dtype().bytes();
1261                let view = base.view(byte_offset, byte_size).map_err(|e| crate::error::Error::IrConstruction {
1262                    details: format!(
1263                        "BUFFER_VIEW failed for kernel {}: base_buffer_id={}, byte_offset={}, byte_size={}: {e}",
1264                        item.kernel.id,
1265                        base.id().0,
1266                        byte_offset,
1267                        byte_size
1268                    ),
1269                })?;
1270                // Register the view under the output buffer's UOp ID so downstream
1271                // COPY/kernel items find it as their source buffer.
1272                let output_uop_id = item.buffer_uop_ids[0];
1273                if let Some(&idx) = uop_id_to_idx.get(&output_uop_id) {
1274                    builder.replace_buffer(idx, view);
1275                }
1276
1277                builder.add_op_with_instance_dependencies(
1278                    PreparedOp::BufferView(PreparedBufferView {
1279                        id: item.kernel.id,
1280                        buffer_indices,
1281                        byte_offset,
1282                        byte_size,
1283                        dependencies: item.dependencies.clone(),
1284                    }),
1285                    item.instance_dependencies.clone(),
1286                );
1287            }
1288            continue;
1289        }
1290
1291        // CALL bodies rooted at CUSTOM_FUNCTION are lowered directly to runtime
1292        // PreparedOp::CustomFunction with typed dispatch. Match against the
1293        // unwrapped runtime AST so END(CustomFunction) reaches this branch
1294        // consistently with Copy/BufferView above.
1295        if let Op::CustomFunction { kind, attrs } = runtime_ast.op() {
1296            let buffer_indices = resolve_item_buffer_indices(item, &uop_id_to_idx)?;
1297            let runtime_vars = attrs.iter().flat_map(svod_runtime::execution_plan::collect_runtime_vars).collect();
1298            builder.add_op_with_instance_dependencies(
1299                PreparedOp::CustomFunction(PreparedCustomFunction {
1300                    id: item.kernel.id,
1301                    kind: kind.clone(),
1302                    attrs: attrs.clone(),
1303                    buffer_indices,
1304                    fixedvars: item.fixedvars.clone(),
1305                    dependencies: item.dependencies.clone(),
1306                    runtime_vars,
1307                }),
1308                item.instance_dependencies.clone(),
1309            );
1310            continue;
1311        }
1312
1313        let item_device_spec = item
1314            .buffers
1315            .iter()
1316            .map(|b| b.allocator().device_spec())
1317            .find(|spec| !spec.is_disk())
1318            .unwrap_or(DeviceSpec::Cpu);
1319        let item_device = config.resolve_device(&item_device_spec, alloc_registry)?;
1320        let item_codegen: &'static str = item_device.compiler.cache_key();
1321
1322        let opt_key = (
1323            crate::schedule_cache::content_hash(&item.ast),
1324            item_device.device.clone(),
1325            item_codegen,
1326            optimizer_fingerprint,
1327        );
1328
1329        let cached = if let Some(cached) = opt_cache.get(&opt_key, &opt_guard) {
1330            Arc::clone(cached)
1331        } else {
1332            let optimizer_renderer = get_optimizer_renderer(&item_device);
1333            let optimized_ast = if let svod_schedule::OptStrategy::Beam { .. } = config.optimizer.strategy {
1334                beam_search_optimize(
1335                    item.ast.clone(),
1336                    &optimizer_renderer,
1337                    &item_device,
1338                    &item.buffers,
1339                    &config.optimizer,
1340                )?
1341            } else {
1342                svod_schedule::optimize_kernel_with_config(item.ast.clone(), &optimizer_renderer, &config.optimizer)
1343            };
1344
1345            let kernel_name =
1346                optimized_ast.metadata::<svod_schedule::optimizer::KernelInfo>().map(|info| info.function_name());
1347
1348            let ast_decomposed = match item_device.renderer.decompositor() {
1349                Some(matcher) => svod_ir::decompositions::decompose_with(&optimized_ast, &matcher),
1350                None => optimized_ast,
1351            };
1352            let program = svod_codegen::program_pipeline::program_from_sink(ast_decomposed, item_device.device.clone());
1353
1354            let result = svod_runtime::kernel_cache::get_or_compile_kernel(
1355                crate::schedule_cache::content_hash(&program),
1356                item_codegen,
1357                || {
1358                    let (spec, compiled) = compile_with_program_pipeline_components(
1359                        program.clone(),
1360                        item_device.renderer.as_ref(),
1361                        item_device.compiler.as_ref(),
1362                        kernel_name.as_deref(),
1363                    )?;
1364                    let program = (item_device.runtime)(&compiled).context(CreateProgramSnafu)?;
1365                    Ok(svod_runtime::kernel_cache::CachedKernel {
1366                        program,
1367                        device: item_codegen.to_string(),
1368                        code: spec.src.clone(),
1369                        entry_point: spec.name.clone(),
1370                        var_names: spec.var_names.clone(),
1371                        globals: spec.globals.clone(),
1372                        outs: spec.outs.clone(),
1373                        ins: spec.ins.clone(),
1374                        host_parallel_safe: matches!(item_device.device, DeviceSpec::Cpu),
1375                        global_size: spec.global_size.clone(),
1376                        local_size: spec.local_size.clone(),
1377                    })
1378                },
1379            )?;
1380            opt_state.insert(opt_key, Arc::clone(&result));
1381            result
1382        };
1383
1384        // Build buffer indices in compiled ABI order (`ProgramSpec.globals`), not necessarily CALL arg order.
1385        let buffer_indices = resolve_compiled_kernel_buffer_indices(item, &uop_id_to_idx, &cached.globals)?;
1386
1387        trace!(kernel.ast_id = item.ast.id, num_buffers = item.buffers.len(), "kernel buffer mapping");
1388
1389        // Create PreparedKernel
1390        // Note: buffer_ptrs and buffer_ids will be computed in ExecutionPlanBuilder::build()
1391        // Convert fixedvars HashMap to vals Vec using var_names order from CachedKernel
1392        let vals: Vec<i64> =
1393            cached.var_names.iter().map(|name| item.fixedvars.get(name).copied().unwrap_or(0)).collect();
1394        let non_overridable_fixedvars = collect_non_overridable_fixedvars(item);
1395
1396        let output_indices = output_indices_from_program_metadata(&cached.globals, &cached.outs, buffer_indices.len())
1397            .map_err(|e| crate::error::Error::IrConstruction {
1398                details: format!(
1399                    "invalid ProgramSpec output metadata for kernel id {} (globals={:?}, outs={:?}, num_buffers={}): {e}",
1400                    item.kernel.id,
1401                    cached.globals,
1402                    cached.outs,
1403                    buffer_indices.len()
1404                ),
1405            })?;
1406
1407        let runtime_vars = svod_runtime::execution_plan::collect_runtime_vars(&item.ast);
1408        let prepared = PreparedKernel {
1409            id: item.kernel.id,
1410            ast: item.ast.clone(),
1411            kernel: cached,
1412            device: item_device.device.clone(),
1413            buffer_indices,
1414            output_indices,
1415            vals,
1416            fixedvars: non_overridable_fixedvars,
1417            dependencies: item.dependencies.clone(),
1418            buffer_ptrs: Vec::new(), // Computed in build()
1419            buffer_ids: Vec::new(),  // Computed in build()
1420            runtime_vars,
1421        };
1422
1423        builder.add_op_with_instance_dependencies(
1424            PreparedOp::CompiledProgram(prepared),
1425            item.instance_dependencies.clone(),
1426        );
1427    }
1428
1429    // Deterministic output identification via ScheduleResult.output_uop_ids
1430    let mut output_buffer_indices = Vec::with_capacity(schedule_result.output_uop_ids.len());
1431    for &uop_id in &schedule_result.output_uop_ids {
1432        let Some(idx) = uop_id_to_idx.get(&uop_id).copied() else {
1433            return Err(crate::error::Error::BufferNotFound { uop_id });
1434        };
1435        output_buffer_indices.push(idx);
1436    }
1437    if output_buffer_indices.is_empty() {
1438        return IrConstructionSnafu { details: "prepare_execution_plan produced no output buffer indices".to_string() }
1439            .fail();
1440    }
1441    builder.set_output_buffers(output_buffer_indices);
1442
1443    builder.build().context(ExecutionSnafu)
1444}
1445
1446fn collect_output_buffer_ids(schedule: &crate::schedule::Schedule, output_uop_ids: &[u64]) -> HashSet<u64> {
1447    let output_uop_set: HashSet<u64> = output_uop_ids.iter().copied().collect();
1448    let mut output_buffer_ids = HashSet::new();
1449    for item in schedule {
1450        for (buffer, &uop_id) in item.buffers.iter().zip(item.buffer_uop_ids.iter()) {
1451            if output_uop_set.contains(&uop_id) {
1452                output_buffer_ids.insert(buffer.id().0);
1453            }
1454        }
1455    }
1456    output_buffer_ids
1457}
1458
1459fn collect_non_overridable_fixedvars(item: &ScheduleItem) -> HashMap<String, i64> {
1460    // Schedule-loop bindings (eagerly unrolled outer ranges) must not be
1461    // overridden by user `var_vals` — they're loop counters, not symbolic
1462    // input variables. `loop_var_names` is populated at instantiation time
1463    // from the keys of `KernelInvocation.fixedvars`, structurally separating
1464    // loop counters from runtime variable binds.
1465    let mut locked = HashMap::with_capacity(item.loop_var_names.len());
1466    for name in &item.loop_var_names {
1467        if let Some(v) = item.fixedvars.get(name) {
1468            locked.insert(name.clone(), *v);
1469        }
1470    }
1471    locked
1472}
1473
1474/// Render/compile entrypoint backed by PROGRAM pipeline stages.
1475fn compile_with_program_pipeline_components(
1476    kernel_ast: Arc<UOp>,
1477    renderer: &dyn svod_device::device::Renderer,
1478    compiler: &dyn svod_device::device::Compiler,
1479    kernel_name: Option<&str>,
1480) -> Result<(svod_device::device::ProgramSpec, svod_device::device::CompiledSpec)> {
1481    let mut program = match kernel_ast.op() {
1482        Op::Program { .. } => kernel_ast,
1483        other => {
1484            return IrConstructionSnafu {
1485                details: format!("compile_with_program_pipeline_components expects PROGRAM input, got {other:?}"),
1486            }
1487            .fail();
1488        }
1489    };
1490
1491    program = svod_codegen::program_pipeline::get_program(
1492        &program,
1493        renderer,
1494        compiler,
1495        kernel_name,
1496        svod_codegen::program_pipeline::ProgramTarget::Source,
1497    )
1498    .context(RenderKernelSnafu)?;
1499
1500    let rendered_entry = svod_device::device::ProgramSpec::from_uop(&program).map(|spec| spec.name).map_err(|e| {
1501        crate::error::Error::IrConstruction { details: format!("PROGRAM pipeline produced invalid SOURCE stage: {e}") }
1502    })?;
1503
1504    let (program, compiled) =
1505        svod_codegen::program_pipeline::do_compile(&program, compiler).context(CompileKernelSnafu)?;
1506
1507    let spec =
1508        svod_device::device::ProgramSpec::from_uop(&program).map_err(|e| crate::error::Error::IrConstruction {
1509            details: format!(
1510                "PROGRAM pipeline produced invalid ProgramSpec after compile (entry='{}'): {e}",
1511                rendered_entry
1512            ),
1513        })?;
1514    Ok((spec, compiled))
1515}
1516
1517/// Resolve the device string for cache keying (includes compiler cache key).
1518pub(crate) fn resolve_codegen(param_buffers: &[(u64, Arc<UOp>)], config: &PrepareConfig) -> Result<&'static str> {
1519    let alloc_registry = svod_device::registry::registry();
1520    let spec = param_buffers
1521        .iter()
1522        .find_map(|(id, _)| {
1523            let spec = crate::tensor_registry::get_buffer(*id)?.allocator().device_spec();
1524            (!spec.is_disk()).then_some(spec)
1525        })
1526        .or_else(|| {
1527            param_buffers.iter().find_map(|(_, u)| {
1528                let Op::Buffer { device, .. } = u.op() else {
1529                    return None;
1530                };
1531                let Op::Device(spec) = device.op() else {
1532                    return None;
1533                };
1534                (!spec.is_disk()).then_some(spec.clone())
1535            })
1536        })
1537        .unwrap_or(DeviceSpec::Cpu);
1538    let device = config.resolve_device(&spec, alloc_registry)?;
1539    Ok(device.compiler.cache_key())
1540}
1541
1542/// Get the optimizer renderer for a device.
1543fn get_optimizer_renderer(device: &Device) -> svod_schedule::OptimizerRenderer {
1544    match device.device {
1545        DeviceSpec::Cpu => {
1546            if std::env::var("SVOD_AMX").as_deref() == Ok("1") {
1547                svod_schedule::OptimizerRenderer::apple_amx()
1548            } else {
1549                svod_schedule::OptimizerRenderer::cpu()
1550            }
1551        }
1552        DeviceSpec::Cuda { .. } => svod_schedule::OptimizerRenderer::cuda(),
1553        DeviceSpec::Metal { .. } => svod_schedule::OptimizerRenderer::metal(),
1554        _ => svod_schedule::OptimizerRenderer::cpu(),
1555    }
1556}
1557
1558/// Optimize a kernel AST using beam search auto-tuning.
1559///
1560/// Beam search explores multiple optimization paths and selects the fastest
1561/// by compiling and timing each candidate. Slower than heuristics but can
1562/// find better optimizations. Beam and heuristic are mutually exclusive.
1563/// Count the top-K most frequent Op variants in a flat uop list. Used by the
1564/// `BEAM_LOG_SURPASS_MAX` diagnostic to identify which Op type is
1565/// bloating the linearized count for a dropped BEAM candidate.
1566pub(crate) fn count_top_ops(ops: &[Arc<UOp>], top_k: usize) -> Vec<(String, usize)> {
1567    let mut counts: HashMap<String, usize> = HashMap::new();
1568    for u in ops {
1569        *counts.entry(u.op().as_ref().to_string()).or_insert(0) += 1;
1570    }
1571    let mut v: Vec<(String, usize)> = counts.into_iter().collect();
1572    v.sort_by_key(|(_, n)| std::cmp::Reverse(*n));
1573    v.truncate(top_k);
1574    v
1575}
1576
1577pub(crate) fn fmt_op_counts(counts: &[(String, usize)]) -> String {
1578    counts.iter().map(|(o, n)| format!("{o}={n}")).collect::<Vec<_>>().join(", ")
1579}
1580
1581fn beam_search_optimize(
1582    ast: Arc<UOp>,
1583    renderer: &svod_schedule::OptimizerRenderer,
1584    device: &Device,
1585    buffers: &[Buffer],
1586    optimizer_config: &svod_schedule::OptimizerConfig,
1587) -> Result<Arc<UOp>> {
1588    let beam_config = &optimizer_config.beam;
1589    // Prepare scheduler (applies symbolic simplification and loop→global).
1590    // BEAM and heuristic are mutually exclusive.
1591    let scheduler = prepare_scheduler(ast, renderer);
1592
1593    // Ensure all buffers are allocated for timing
1594    for buf in buffers {
1595        buf.ensure_allocated().context(DeviceSnafu)?;
1596    }
1597
1598    // Clone buffers for the closure (Buffer is Clone + Send + Sync)
1599    let buffers: Vec<Buffer> = buffers.to_vec();
1600    let bench_config = svod_runtime::BenchmarkConfig::default();
1601
1602    // Clone device components for the closure
1603    let dev_renderer = device.renderer.clone();
1604    let dev_compiler = device.compiler.clone();
1605    let dev_runtime = device.runtime.clone();
1606    let dev_device = device.device.clone();
1607    let max_uops = beam_config.max_uops;
1608
1609    // Force rayon's global thread pool to materialise before the BEAM loop so
1610    // the first candidate's bench doesn't pay pool-init cost. Subsequent calls
1611    // are O(1) thread-pool dispatch — but the lazy init can dominate the first
1612    // 1-2 measurements at the small kernel sizes BEAM-time uses, biasing
1613    // ranking against fast candidates that happen to run first.
1614    svod_runtime::warmup_thread_pool();
1615
1616    // Per-candidate **compile-only** timeout (default 10s). Rust can't
1617    // safely deliver SIGALRM to a worker, so we use a two-stage detached
1618    // worker thread: the worker signals `CompileDone` once
1619    // codegen/compile/runtime-link finish, after which execution runs
1620    // unbounded (only `early_stop` aborts a slow run). Side effect: a hung
1621    // clang invocation orphans one OS thread per timeout, reaped at
1622    // process exit.
1623    let compile_timeout =
1624        Duration::from_secs(std::env::var("BEAM_TIMEOUT_SEC").ok().and_then(|s| s.parse().ok()).unwrap_or(10));
1625
1626    // When `BEAM_LOG_SURPASS_MAX` is set, every dropped candidate prints
1627    // one line with the failure reason, applied-opt chain, and (for "too
1628    // many uops") the top Op-variant counts in the linearized program.
1629    let log_surpass = std::env::var("BEAM_LOG_SURPASS_MAX").is_ok();
1630
1631    // Per-BEAM-call cache for `apply_post_optimization_with_renderer`. Many
1632    // candidates expand from the same parent state and share the underlying
1633    // raw AST; without this cache the rewrite engine re-runs the full
1634    // post-optimisation graph rewrite for each, costing ~13 ms × N candidates.
1635    // Scoped to one `beam_search_optimize` invocation so stale entries cannot
1636    // leak between calls. Read-mostly access pattern fits papaya's lock-free
1637    // reads + bounded-lock writes.
1638    let post_opt_cache: Arc<papaya::HashMap<u64, Arc<UOp>>> = Arc::new(papaya::HashMap::new());
1639
1640    // Compile-and-time closure: compilation is NOT timed, only execution. Wrapped
1641    // in catch_unwind because beam search explores speculative candidates that
1642    // may trigger rewrite engine limits or other panics. Wrapped in a worker
1643    // thread + recv_timeout so a hung clang invocation cannot block BEAM.
1644    //
1645    // Returns `CandidateMetrics` (timing + structural IR hash + compute-op count)
1646    // so the beam loop can apply `seen_libs` dedup and the
1647    // `least_compute_ops*1000` compute-bloat filter. The optional `early_stop`
1648    // argument is propagated into `BenchmarkConfig` so `benchmark_kernel` can
1649    // abort the run loop the moment any single run exceeds the threshold.
1650    let compile_and_time = |s: &Scheduler, early_stop: Option<Duration>| -> Option<svod_schedule::CandidateMetrics> {
1651        use std::panic::{AssertUnwindSafe, catch_unwind};
1652        use std::sync::mpsc;
1653
1654        // Per-call clones move into the worker. All Arc/Clone — no deep copy.
1655        let s_owned = s.clone();
1656        let renderer_c = renderer.clone();
1657        let dev_renderer_c = dev_renderer.clone();
1658        let dev_compiler_c = dev_compiler.clone();
1659        let dev_runtime_c = dev_runtime.clone();
1660        let dev_device_c = dev_device.clone();
1661        let buffers_c = buffers.clone();
1662        let bench_config_c = bench_config.clone();
1663        let max_uops_c = max_uops;
1664        let post_opt_cache_c = Arc::clone(&post_opt_cache);
1665        let log_surpass_c = log_surpass;
1666        // Snapshot the applied-opts chain so the diagnostic line in the worker
1667        // thread can identify which BEAM branch triggered the drop without
1668        // having to send the full Scheduler back across the channel.
1669        let opts_snapshot: Vec<svod_schedule::optimizer::Opt> = s_owned.applied_opts.clone();
1670
1671        // Two-stage signal: CompileDone fires when codegen/compile/runtime-link
1672        // finish; Final carries the benchmark result (or None on failure/panic).
1673        // Capacity 2 so the worker never blocks on send. Detached: drop the
1674        // JoinHandle so the parent can abandon the worker on compile timeout.
1675        enum WorkerMsg {
1676            CompileDone,
1677            Final(Option<svod_schedule::CandidateMetrics>),
1678        }
1679        let (tx, rx) = mpsc::sync_channel::<WorkerMsg>(2);
1680        let tx_compile = tx.clone();
1681        let _ = std::thread::spawn(move || {
1682            let result = catch_unwind(AssertUnwindSafe(|| {
1683                let raw_ast = s_owned.get_optimized_ast(None);
1684
1685                // Apply post-optimization passes for accurate timing.
1686                // Pass the renderer so pm_add_gpudims fires for has_threads/has_local backends —
1687                // otherwise the Thread axis stays as a plain RANGE and gets rendered as a
1688                // sequential `for` loop instead of a parallel core_id dispatch, making BEAM
1689                // candidates time as single-threaded and converge to wrong tile shapes.
1690                //
1691                // Cache by `raw_ast.content_hash`: BEAM expands children of the same
1692                // parent in lockstep, so siblings share `raw_ast` and would otherwise
1693                // re-run the full graph rewrite (~13 ms each).
1694                let cache_key = raw_ast.content_hash;
1695                let cache_pin = post_opt_cache_c.pin();
1696                let optimized = if let Some(cached) = cache_pin.get(&cache_key) {
1697                    cached.clone()
1698                } else {
1699                    let opt = apply_post_optimization_with_renderer(raw_ast, Some(&renderer_c));
1700                    cache_pin.insert(cache_key, opt.clone());
1701                    opt
1702                };
1703
1704                // Extract kernel name before decomposition (which loses metadata)
1705                let kernel_name =
1706                    optimized.metadata::<svod_schedule::optimizer::KernelInfo>().map(|info| info.function_name());
1707
1708                // Pre-codegen metrics: structural hash for `seen_libs`, ALU node
1709                // count for `least_compute_ops`. Computed before compile so even
1710                // failed compiles still consume a slot fairly.
1711                let ir_hash = svod_schedule::hash_post_codegen_ir(&optimized);
1712                let compute_ops = svod_schedule::compute_ops_estimate(&optimized);
1713
1714                // Apply decomposition
1715                let decomposed = match dev_renderer_c.decompositor() {
1716                    Some(m) => svod_ir::decompositions::decompose_with(&optimized, &m),
1717                    None => optimized,
1718                };
1719                let mut program = svod_codegen::program_pipeline::program_from_sink(decomposed, dev_device_c.clone());
1720
1721                // Linearize *now* so we can count flat uops. Counting the
1722                // post-optimization AST `toposort()` (the earlier behavior)
1723                // under-counts because svod's high-level Reduce/Index/Cast
1724                // nodes get fanned out into many flat uops by the codegen
1725                // pipeline. Counting post-linearize gives the number to
1726                // compare against `BEAM_UOPS_MAX`.
1727                program = match svod_codegen::program_pipeline::do_linearize(&program) {
1728                    Ok(p) => p,
1729                    Err(e) => {
1730                        if log_surpass_c {
1731                            eprintln!("[BEAM drop] linearize_err: {e:?} opts={opts_snapshot:?}");
1732                        }
1733                        return None;
1734                    }
1735                };
1736                let (linear_uops_count, top_op_counts) = if let svod_ir::Op::Program { linear: Some(linear), .. } =
1737                    program.op()
1738                    && let svod_ir::Op::Linear { ops } = linear.op()
1739                {
1740                    (ops.len(), if log_surpass_c { count_top_ops(ops, 8) } else { Vec::new() })
1741                } else {
1742                    (0, Vec::new())
1743                };
1744                if linear_uops_count > max_uops_c {
1745                    if log_surpass_c {
1746                        eprintln!(
1747                            "[BEAM drop] too_many_uops: linear={linear_uops_count} max={max_uops_c} opts={opts_snapshot:?} top_ops=[{}]",
1748                            fmt_op_counts(&top_op_counts)
1749                        );
1750                    }
1751                    return None;
1752                }
1753
1754                // Render and compile through PROGRAM stages (NOT timed).
1755                let (spec, compiled) = match compile_with_program_pipeline_components(
1756                    program,
1757                    dev_renderer_c.as_ref(),
1758                    dev_compiler_c.as_ref(),
1759                    kernel_name.as_deref(),
1760                ) {
1761                    Ok(v) => v,
1762                    Err(e) => {
1763                        if log_surpass_c {
1764                            eprintln!("[BEAM drop] compile_err: {e:?} opts={opts_snapshot:?}");
1765                        }
1766                        return None;
1767                    }
1768                };
1769                let program = match (dev_runtime_c)(&compiled) {
1770                    Ok(p) => p,
1771                    Err(e) => {
1772                        if log_surpass_c {
1773                            eprintln!("[BEAM drop] runtime_err: {e:?} opts={opts_snapshot:?}");
1774                        }
1775                        return None;
1776                    }
1777                };
1778
1779                // Compile phase done — release parent from the `BEAM_TIMEOUT_SEC`
1780                // bound. Anything below is execution-only (bounded by `early_stop`).
1781                let _ = tx_compile.send(WorkerMsg::CompileDone);
1782
1783                // Extract buffer pointers inside the worker (avoids Sync issue
1784                // and keeps raw pointers thread-local).
1785                let buffer_ptrs: Vec<*mut u8> = buffers_c.iter().map(|b| unsafe { b.as_raw_ptr() }).collect();
1786
1787                // Time ONLY execution. Each non-runtime variable gets the midpoint
1788                // of its declared range (`(vmin+vmax)/2`) so symbolic-bound kernels
1789                // do representative work. `core_id` stays unbound (patched per-thread
1790                // by `execute_parallel`).
1791                let mut user_var_vals: HashMap<&str, i64> = HashMap::new();
1792                for v in &spec.vars {
1793                    if v.name != "core_id" {
1794                        user_var_vals.insert(v.name.as_str(), (v.min + v.max) / 2);
1795                    }
1796                }
1797                let launch_dims = spec.launch_dims(&user_var_vals).ok()?;
1798                let vals: Vec<i64> =
1799                    spec.var_names.iter().map(|n| user_var_vals.get(n.as_str()).copied().unwrap_or(0)).collect();
1800
1801                // Bound BEAM dispatch grid: if `prod(global_size) > MAX_TEST_GLOBAL_SIZE`,
1802                // halve the largest dim (>16) until it fits, then scale the measured
1803                // time by `factor = original_size / shrunk_size` to recover the
1804                // full-grid estimate. Svod's CPU dispatch has `global_size[0]` = thread
1805                // count, typically ≤ 16, so this is a no-op for CPU and only engages
1806                // for GPU-style large grids.
1807                const MAX_TEST_GLOBAL_SIZE: usize = 65536;
1808                let mut test_global_size = launch_dims.global_size;
1809                let original_size: usize = test_global_size.iter().product();
1810                while test_global_size.iter().product::<usize>() > MAX_TEST_GLOBAL_SIZE {
1811                    let mut halved = false;
1812                    for j in (0..test_global_size.len()).rev() {
1813                        if test_global_size[j] > 16 {
1814                            test_global_size[j] /= 2;
1815                            halved = true;
1816                            break;
1817                        }
1818                    }
1819                    if !halved {
1820                        break;
1821                    }
1822                }
1823                let shrunk_size: usize = test_global_size.iter().product();
1824                let factor: f64 = if shrunk_size > 0 { original_size as f64 / shrunk_size as f64 } else { 1.0 };
1825
1826                let mut bench_config = bench_config_c.clone();
1827                // Translate the unshrunk early-stop threshold into the shrunk
1828                // timing domain so per-run abort fires at the same effective point.
1829                bench_config.early_stop = early_stop.map(|t| {
1830                    let nanos = t.as_nanos() as f64 / factor;
1831                    Duration::from_nanos(nanos.min(u64::MAX as f64) as u64)
1832                });
1833                // CPU/AMX have no hardware cache-invalidate primitive — run warm-cache.
1834                // GPU backends keep the invalidate.
1835                bench_config.clear_l2 = renderer_c.device.has_hardware_cache_invalidate();
1836                let result = unsafe {
1837                    svod_runtime::benchmark_kernel(
1838                        program.as_ref(),
1839                        &buffer_ptrs,
1840                        &vals,
1841                        Some(test_global_size),
1842                        launch_dims.local_size,
1843                        &bench_config,
1844                    )
1845                    .ok()?
1846                };
1847
1848                // Scale measured time back to the full-grid estimate.
1849                let scaled_nanos = (result.min.as_nanos() as f64 * factor).min(u64::MAX as f64);
1850                let timing = Duration::from_nanos(scaled_nanos as u64);
1851                Some(svod_schedule::CandidateMetrics { timing, ir_hash, compute_ops })
1852            }));
1853            let final_result = match result {
1854                Ok(opt) => opt,
1855                Err(_) => {
1856                    if log_surpass_c {
1857                        eprintln!("[BEAM drop] panic_in_worker opts={opts_snapshot:?}");
1858                    }
1859                    None
1860                }
1861            };
1862            // Receiver may have already given up on compile timeout — ignore send errors.
1863            let _ = tx.send(WorkerMsg::Final(final_result));
1864        });
1865
1866        // Stage 1: wait for compile to finish, bounded by `BEAM_TIMEOUT_SEC`.
1867        // Three outcomes:
1868        // - `CompileDone`: compile succeeded, fall through to unbounded execution wait.
1869        // - `Final(_)`: compile failed/aborted before reaching the signal (early return
1870        //   on linearize_err / too_many_uops / compile_err / runtime_err, or panic).
1871        // - timeout: clang hung past the budget; abandon the worker (orphaned thread).
1872        match rx.recv_timeout(compile_timeout) {
1873            Ok(WorkerMsg::CompileDone) => {
1874                // Stage 2: execution. Unbounded — `bench_config.early_stop`
1875                // aborts a slow run from inside `benchmark_kernel`.
1876                match rx.recv() {
1877                    Ok(WorkerMsg::Final(metrics)) => metrics,
1878                    _ => None,
1879                }
1880            }
1881            Ok(WorkerMsg::Final(metrics)) => metrics,
1882            Err(_) => {
1883                if log_surpass {
1884                    eprintln!("[BEAM drop] compile_timeout opts={:?}", s.applied_opts);
1885                }
1886                None
1887            }
1888        }
1889    };
1890
1891    // Suppress panic output during beam search. Speculative candidates may panic
1892    // at compile or runtime — this is expected. catch_unwind catches panics
1893    // but the default hook prints them first.
1894    let prev_hook = std::panic::take_hook();
1895    std::panic::set_hook(Box::new(|_| {}));
1896    let result = beam_search_cached(scheduler, beam_config, compile_and_time);
1897    std::panic::set_hook(prev_hook);
1898    let result = result.context(OptimizeSnafu)?;
1899
1900    // Debug: log beam search results
1901    tracing::debug!(
1902        opts = ?result.scheduler.applied_opts,
1903        timing = ?result.timing,
1904        iterations = result.iterations,
1905        "beam_search_optimize: completed"
1906    );
1907
1908    // Apply post-optimization to final result with renderer so pm_add_gpudims runs
1909    // (Thread → core_id, Global → SPECIAL).
1910    let raw_ast = result.scheduler.get_optimized_ast(None);
1911    Ok(apply_post_optimization_with_renderer(raw_ast, Some(renderer)))
1912}
1913
1914#[cfg(test)]
1915#[path = "test/unit/realize_internal.rs"]
1916mod tests;