Skip to main content

vyre_driver/
binding.rs

1//! Backend-neutral binding-plan construction for VYRE programs.
2
3use smallvec::SmallVec;
4use std::sync::Arc;
5use vyre_foundation::ir::{BufferAccess, BufferDecl, MemoryKind, Program};
6
7use crate::BackendError;
8
9/// Host/device binding role assigned to one VYRE buffer.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
11pub enum BindingRole {
12    /// Host input copied to a read-only device buffer.
13    Input,
14    /// Device output read back after dispatch.
15    Output,
16    /// Host input copied to a read-write device buffer and read back later.
17    InputOutput,
18    /// Uniform-style read-only input.
19    Uniform,
20    /// Workgroup-local memory declared in target code.
21    Shared,
22    /// Persistent memory handle managed by runtime ingest APIs.
23    Persistent,
24}
25
26/// One validated binding descriptor.
27#[derive(Debug, Clone, PartialEq, Eq)]
28pub struct Binding {
29    /// VYRE buffer name.
30    pub name: Arc<str>,
31    /// VYRE binding number.
32    pub binding: u32,
33    /// Original buffer index in `Program::buffers`.
34    pub buffer_index: usize,
35    /// Host/device role for launch.
36    pub role: BindingRole,
37    /// Element size in bytes when statically known.
38    pub element_size: usize,
39    /// Preferred byte alignment for backend allocation/upload planning.
40    ///
41    /// This optimization contract is derived from `BufferDecl::hints` and the
42    /// scalar element size. It does not change program semantics; concrete
43    /// drivers use it to choose buffer allocation and launch paths without
44    /// rewalking the IR.
45    pub preferred_alignment: usize,
46    /// Declared or input-derived element count. Zero means runtime-sized.
47    pub element_count: u32,
48    /// Static byte count when known.
49    pub static_byte_len: Option<usize>,
50    /// Index in the caller's input slice, if this binding consumes input.
51    pub input_index: Option<usize>,
52    /// Index in the backend output vector, if this binding is observed output.
53    pub output_index: Option<usize>,
54}
55
56/// Deterministic ABI plan for a VYRE program.
57#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct BindingPlan {
59    /// Ordered binding descriptors, sorted by VYRE binding number.
60    pub bindings: Vec<Binding>,
61    /// Original program buffer indices that consume host inputs.
62    pub input_indices: Vec<usize>,
63    /// Original program buffer indices that produce host outputs.
64    pub output_indices: Vec<usize>,
65    /// Original program buffer indices that are workgroup-local.
66    pub shared_indices: Vec<usize>,
67}
68
69#[derive(Clone, Copy)]
70enum InputLengths<'a> {
71    None,
72    Owned(&'a [Vec<u8>]),
73    Borrowed(&'a [&'a [u8]]),
74    Lengths(&'a [usize]),
75}
76
77impl InputLengths<'_> {
78    fn len(self) -> usize {
79        match self {
80            Self::None => 0,
81            Self::Owned(inputs) => inputs.len(),
82            Self::Borrowed(inputs) => inputs.len(),
83            Self::Lengths(lengths) => lengths.len(),
84        }
85    }
86
87    fn get(self, index: usize) -> Option<usize> {
88        match self {
89            Self::None => None,
90            Self::Owned(inputs) => inputs.get(index).map(Vec::len),
91            Self::Borrowed(inputs) => inputs.get(index).map(|input| input.len()),
92            Self::Lengths(lengths) => lengths.get(index).copied(),
93        }
94    }
95}
96
97impl BindingPlan {
98    /// Build a binding plan from a VYRE program without host input checks.
99    ///
100    /// # Errors
101    ///
102    /// Returns when memory/access combinations or static byte sizing cannot be
103    /// represented by a concrete backend ABI.
104    pub fn build(program: &Program) -> Result<Self, BackendError> {
105        Self::build_inner(program, InputLengths::None, false)
106    }
107
108    /// Build and validate a binding plan from a VYRE program.
109    ///
110    /// # Errors
111    ///
112    /// Returns when input count, input byte lengths, buffer alignment, or
113    /// memory/access combinations do not match the backend ABI contract.
114    pub fn from_program(program: &Program, inputs: &[Vec<u8>]) -> Result<Self, BackendError> {
115        Self::build_inner(program, InputLengths::Owned(inputs), true)
116    }
117
118    /// Build and validate a binding plan from borrowed input buffers.
119    ///
120    /// # Errors
121    ///
122    /// Returns when input count, input byte lengths, buffer alignment, or
123    /// memory/access combinations do not match the backend ABI contract.
124    pub fn from_borrowed_inputs(program: &Program, inputs: &[&[u8]]) -> Result<Self, BackendError> {
125        Self::build_inner(program, InputLengths::Borrowed(inputs), true)
126    }
127
128    /// Build and validate a binding plan from backend-resident input byte lengths.
129    ///
130    /// # Errors
131    ///
132    /// Returns when resident input counts are wrong or byte lengths are
133    /// smaller than the program ABI requires.
134    pub fn from_input_lengths(
135        program: &Program,
136        input_lengths: &[usize],
137    ) -> Result<Self, BackendError> {
138        Self::build_inner(program, InputLengths::Lengths(input_lengths), true)
139    }
140
141    /// Verifies backend-resident input byte lengths satisfy this binding plan.
142    ///
143    /// # Errors
144    ///
145    /// Returns when the caller supplies the wrong number of resident inputs or
146    /// a resident input is smaller than the buffer declaration captured in
147    /// this plan.
148    pub fn validate_input_byte_lengths(&self, input_lengths: &[usize]) -> Result<(), BackendError> {
149        self.validate_input_lengths(InputLengths::Lengths(input_lengths))
150    }
151
152    /// Verifies dynamic input slices match the expected plan.
153    ///
154    /// # Errors
155    ///
156    /// Returns when the caller supplies the wrong number of inputs or an input
157    /// length violates the buffer declaration.
158    pub fn validate_inputs(&self, inputs: &[Vec<u8>]) -> Result<(), BackendError> {
159        self.validate_input_lengths(InputLengths::Owned(inputs))
160    }
161
162    /// Verifies borrowed dynamic input slices match the expected plan.
163    ///
164    /// # Errors
165    ///
166    /// Returns when the caller supplies the wrong number of inputs or an input
167    /// length violates the buffer declaration.
168    pub fn validate_borrowed_inputs(&self, inputs: &[&[u8]]) -> Result<(), BackendError> {
169        self.validate_input_lengths(InputLengths::Borrowed(inputs))
170    }
171
172    fn validate_input_lengths(&self, input_lens: InputLengths<'_>) -> Result<(), BackendError> {
173        if input_lens.len() != self.input_indices.len() {
174            return Err(BackendError::InvalidProgram {
175                fix: format!(
176                    "Fix: dispatch expected {} input buffer(s) from Program declarations but received {}.",
177                    self.input_indices.len(),
178                    input_lens.len()
179                ),
180            });
181        }
182
183        for binding in &self.bindings {
184            if let Some(input_index) = binding.input_index {
185                let byte_len = input_lens.get(input_index).ok_or_else(|| {
186                    BackendError::InvalidProgram {
187                        fix: format!(
188                            "Fix: dispatch input index {input_index} for `{}` was missing after input-count validation.",
189                            binding.name
190                        ),
191                    }
192                })?;
193                validate_input_len(
194                    binding,
195                    byte_len,
196                    !matches!(input_lens, InputLengths::Lengths(_)),
197                )?;
198            }
199        }
200        Ok(())
201    }
202
203    fn build_inner(
204        program: &Program,
205        input_lens: InputLengths<'_>,
206        validate_inputs_now: bool,
207    ) -> Result<Self, BackendError> {
208        let mut ordered = SmallVec::<[(usize, &BufferDecl); 16]>::new();
209        vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
210            &mut ordered,
211            program.buffers().len(),
212        )
213        .map_err(|error| {
214            BackendError::InvalidProgram {
215                fix: format!(
216                    "Fix: binding-plan construction could not reserve {} ordered buffer slot(s): {error}. Split the program buffers or construct a smaller pipeline.",
217                    program.buffers().len()
218                ),
219            }
220        })?;
221        let buffer_count = program.buffers().len();
222        ordered.extend(program.buffers().iter().enumerate());
223        ordered.sort_by_key(|(_, buffer)| buffer.binding());
224
225        let mut bindings = Vec::new();
226        crate::allocation::try_reserve_vec_to_capacity(&mut bindings, ordered.len()).map_err(
227            |error| BackendError::InvalidProgram {
228                fix: format!(
229                    "Fix: binding-plan construction could not reserve {} binding descriptor(s): {error}. Split the program buffers or construct a smaller pipeline.",
230                    ordered.len()
231                ),
232            },
233        )?;
234        let (input_slot_count, output_slot_count, shared_slot_count) =
235            binding_role_counts(&ordered)?;
236        let mut logical_input_slots = Vec::new();
237        crate::allocation::try_reserve_vec_to_capacity(&mut logical_input_slots, buffer_count)
238            .map_err(|error| BackendError::InvalidProgram {
239                fix: format!(
240                    "Fix: binding-plan construction could not reserve {buffer_count} logical input slot(s): {error}. Split the program buffers or construct a smaller pipeline.",
241                ),
242            })?;
243        logical_input_slots.resize(buffer_count, None);
244        let mut logical_output_slots = Vec::new();
245        crate::allocation::try_reserve_vec_to_capacity(&mut logical_output_slots, buffer_count)
246            .map_err(|error| BackendError::InvalidProgram {
247                fix: format!(
248                    "Fix: binding-plan construction could not reserve {buffer_count} logical output slot(s): {error}. Split the program buffers or construct a smaller pipeline.",
249                ),
250            })?;
251        logical_output_slots.resize(buffer_count, None);
252        let mut input_indices = SmallVec::<[usize; 8]>::new();
253        let mut output_indices = SmallVec::<[usize; 8]>::new();
254        let mut shared_indices = SmallVec::<[usize; 4]>::new();
255        vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
256            &mut input_indices,
257            input_slot_count,
258        )
259        .map_err(|error| {
260            BackendError::InvalidProgram {
261                fix: format!(
262                    "Fix: binding-plan construction could not reserve {input_slot_count} input index slot(s): {error}. Split the program buffers or construct a smaller pipeline."
263                ),
264            }
265        })?;
266        vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
267            &mut output_indices,
268            output_slot_count,
269        )
270        .map_err(|error| {
271            BackendError::InvalidProgram {
272                fix: format!(
273                    "Fix: binding-plan construction could not reserve {output_slot_count} output index slot(s): {error}. Split the program buffers or construct a smaller pipeline."
274                ),
275            }
276        })?;
277        vyre_foundation::allocation::try_reserve_smallvec_to_capacity(
278            &mut shared_indices,
279            shared_slot_count,
280        )
281        .map_err(|error| {
282            BackendError::InvalidProgram {
283                fix: format!(
284                    "Fix: binding-plan construction could not reserve {shared_slot_count} shared index slot(s): {error}. Split the program buffers or construct a smaller pipeline."
285                ),
286            }
287        })?;
288
289        for (buffer_index, buffer) in program.buffers().iter().enumerate() {
290            let role = role_for_buffer(buffer)?;
291            if matches!(
292                role,
293                BindingRole::Input | BindingRole::InputOutput | BindingRole::Uniform
294            ) {
295                let index = input_indices.len();
296                input_indices.push(buffer_index);
297                logical_input_slots[buffer_index] = Some(index);
298            }
299            if matches!(role, BindingRole::Output | BindingRole::InputOutput)
300                || buffer.pipeline_live_out
301            {
302                let index = output_indices.len();
303                output_indices.push(buffer_index);
304                logical_output_slots[buffer_index] = Some(index);
305            }
306            if role == BindingRole::Shared {
307                shared_indices.push(buffer_index);
308            }
309        }
310
311        for (buffer_index, buffer) in ordered {
312            let role = role_for_buffer(buffer)?;
313            let consumes_input = matches!(
314                role,
315                BindingRole::Input | BindingRole::InputOutput | BindingRole::Uniform
316            );
317            let produces_output = matches!(role, BindingRole::Output | BindingRole::InputOutput);
318            buffer
319                .element()
320                .validate_layout()
321                .map_err(|error| BackendError::InvalidProgram {
322                    fix: format!(
323                        "Fix: binding `{}` has malformed data-type layout metadata: {error}",
324                        buffer.name()
325                    ),
326                })?;
327            let element_size = buffer.element().min_bytes();
328            let static_byte_len = static_byte_len(buffer)?;
329            let preferred_alignment = preferred_alignment(buffer, element_size)?;
330
331            let input_index = if consumes_input {
332                Some(logical_input_slots
333                    .get(buffer_index)
334                    .copied()
335                    .flatten()
336                    .ok_or_else(|| BackendError::InvalidProgram {
337                        fix: format!(
338                            "Fix: binding `{}` consumes input but no logical input slot was assigned. Rebuild BindingPlan from Program::buffers order before launch.",
339                            buffer.name()
340                        ),
341                    })?)
342            } else {
343                None
344            };
345            let output_index = if produces_output || buffer.pipeline_live_out {
346                Some(logical_output_slots
347                    .get(buffer_index)
348                    .copied()
349                    .flatten()
350                    .ok_or_else(|| BackendError::InvalidProgram {
351                        fix: format!(
352                            "Fix: binding `{}` produces output but no logical output slot was assigned. Rebuild BindingPlan from Program::buffers order before readback.",
353                            buffer.name()
354                        ),
355                    })?)
356            } else {
357                None
358            };
359            let element_count = if buffer.count() == 0 {
360                input_index
361                    .and_then(|index| input_lens.get(index))
362                    .and_then(|byte_len| dynamic_element_count_from_bytes(buffer, byte_len))
363                    .unwrap_or(0)
364            } else {
365                buffer.count()
366            };
367
368            bindings.push(Binding {
369                name: Arc::clone(&buffer.name),
370                binding: buffer.binding(),
371                buffer_index,
372                role,
373                element_size,
374                preferred_alignment,
375                element_count,
376                static_byte_len,
377                input_index,
378                output_index,
379            });
380        }
381
382        let plan = Self {
383            bindings,
384            input_indices: input_indices.into_vec(),
385            output_indices: output_indices.into_vec(),
386            shared_indices: shared_indices.into_vec(),
387        };
388
389        if validate_inputs_now {
390            plan.validate_input_lengths(input_lens)?;
391        }
392
393        Ok(plan)
394    }
395}
396
397fn binding_role_counts(
398    ordered: &SmallVec<[(usize, &BufferDecl); 16]>,
399) -> Result<(usize, usize, usize), BackendError> {
400    ordered
401        .iter()
402        .try_fold((0usize, 0usize, 0usize), |(inputs, outputs, shared), (_, buffer)| {
403            let role = role_for_buffer(buffer)?;
404            let next_inputs = inputs
405                .checked_add(usize::from(matches!(
406                    role,
407                    BindingRole::Input | BindingRole::InputOutput | BindingRole::Uniform
408                )))
409                .ok_or_else(|| BackendError::InvalidProgram {
410                    fix: "Fix: binding-plan input role count overflowed usize. Split the program buffers before binding-plan construction.".to_string(),
411                })?;
412            let next_outputs = outputs
413                .checked_add(usize::from(
414                    matches!(role, BindingRole::Output | BindingRole::InputOutput)
415                        || buffer.pipeline_live_out,
416                ))
417                .ok_or_else(|| BackendError::InvalidProgram {
418                    fix: "Fix: binding-plan output role count overflowed usize. Split the program buffers before binding-plan construction.".to_string(),
419                })?;
420            let next_shared = shared
421                .checked_add(usize::from(role == BindingRole::Shared))
422                .ok_or_else(|| BackendError::InvalidProgram {
423                    fix: "Fix: binding-plan shared role count overflowed usize. Split the program buffers before binding-plan construction.".to_string(),
424                })?;
425            Ok((next_inputs, next_outputs, next_shared))
426        })
427}
428
429fn role_for_buffer(buffer: &BufferDecl) -> Result<BindingRole, BackendError> {
430    if buffer.kind() == MemoryKind::Shared || buffer.access() == BufferAccess::Workgroup {
431        return Ok(BindingRole::Shared);
432    }
433    if buffer.kind() == MemoryKind::Persistent {
434        return Ok(BindingRole::Persistent);
435    }
436    if buffer.is_output || buffer.pipeline_live_out {
437        return Ok(BindingRole::Output);
438    }
439    match buffer.access() {
440        BufferAccess::ReadOnly => Ok(BindingRole::Input),
441        BufferAccess::ReadWrite => Ok(BindingRole::InputOutput),
442        BufferAccess::WriteOnly => Ok(BindingRole::Output),
443        BufferAccess::Uniform => Ok(BindingRole::Uniform),
444        BufferAccess::Workgroup => Ok(BindingRole::Shared),
445        _ => Err(BackendError::InvalidProgram {
446            fix: format!(
447                "Fix: binding `{}` uses an unknown BufferAccess variant; update vyre-driver binding role mapping.",
448                buffer.name()
449            ),
450        }),
451    }
452}
453
454fn preferred_alignment(buffer: &BufferDecl, element_size: usize) -> Result<usize, BackendError> {
455    let hinted = usize::try_from(buffer.hints().preferred_alignment).map_err(|_| {
456        BackendError::InvalidProgram {
457            fix: format!(
458                "Fix: binding `{}` preferred_alignment does not fit usize on this target.",
459                buffer.name()
460            ),
461        }
462    })?;
463    if hinted != 0 && !hinted.is_power_of_two() {
464        return Err(BackendError::InvalidProgram {
465            fix: format!(
466                "Fix: binding `{}` preferred_alignment={} is not a power of two. Use 0 or a power-of-two byte alignment.",
467                buffer.name(),
468                hinted
469            ),
470        });
471    }
472    Ok(hinted.max(element_size.max(1)))
473}
474
475fn static_byte_len(buffer: &BufferDecl) -> Result<Option<usize>, BackendError> {
476    let bytes = buffer
477        .static_byte_len()
478        .map_err(|error| BackendError::InvalidProgram {
479            fix: format!(
480                "Fix: binding `{}` static byte length could not be computed: {error}",
481                buffer.name(),
482            ),
483        })?;
484    if buffer.count() == 0 {
485        return Ok(None);
486    }
487    bytes
488        .map(Some)
489        .ok_or_else(|| BackendError::InvalidProgram {
490            fix: format!(
491                "Fix: binding `{}` declares {} elements of a runtime-sized data type; use a byte-addressed buffer contract or a fixed-width element type.",
492                buffer.name(),
493                buffer.count()
494            ),
495        })
496}
497
498fn dynamic_element_count_from_bytes(buffer: &BufferDecl, byte_len: usize) -> Option<u32> {
499    if let Some(bits) = buffer.element().bit_width() {
500        let total_bits = byte_len.checked_mul(8)?;
501        return u32::try_from(total_bits / bits).ok();
502    }
503    buffer
504        .element()
505        .size_bytes()
506        .and_then(|element_size| byte_len.checked_div(element_size))
507        .and_then(|count| u32::try_from(count).ok())
508}
509
510fn validate_input_len(
511    binding: &Binding,
512    input_len: usize,
513    strict_static_input_len: bool,
514) -> Result<(), BackendError> {
515    if binding.element_size > 1 && input_len % binding.element_size != 0 {
516        return Err(BackendError::InvalidProgram {
517            fix: format!(
518                "Fix: input `{}` has {} bytes, which is not aligned to its {}-byte element size.",
519                binding.name, input_len, binding.element_size
520            ),
521        });
522    }
523    if let Some(expected) = binding.static_byte_len {
524        if strict_static_input_len && input_len != expected {
525            return Err(BackendError::InvalidProgram {
526                fix: format!(
527                    "Fix: input `{}` expected {expected} bytes from its static buffer declaration but received {} bytes.",
528                    binding.name,
529                    input_len
530                ),
531            });
532        }
533        if !strict_static_input_len && input_len < expected {
534            return Err(BackendError::InvalidProgram {
535                fix: format!(
536                    "Fix: resident input `{}` expected at least {expected} bytes from its static buffer declaration but received {} bytes.",
537                    binding.name, input_len
538                ),
539            });
540        }
541    }
542    Ok(())
543}
544
545#[cfg(test)]
546mod exact_length_tests {
547    use super::*;
548    use vyre_foundation::ir::DataType;
549
550    fn static_u32_input_program(count: u32) -> Program {
551        Program::wrapped(
552            vec![BufferDecl::read("input", 0, DataType::U32).with_count(count)],
553            [1, 1, 1],
554            Vec::new(),
555        )
556    }
557
558    #[test]
559    fn static_host_inputs_are_exact_while_resident_inputs_may_be_larger() {
560        let program = static_u32_input_program(2);
561        let short = vec![0u8; 4];
562        let exact = vec![0u8; 8];
563        let oversized = vec![0u8; 12];
564
565        let owned_err = BindingPlan::from_program(&program, &[short.clone()])
566            .expect_err("owned static input length must be exact");
567        assert!(owned_err.to_string().contains("expected 8 bytes"));
568        assert!(BindingPlan::from_program(&program, &[exact.clone()]).is_ok());
569        let owned_oversized_err = BindingPlan::from_program(&program, &[oversized.clone()])
570            .expect_err("owned static input length must remain exact");
571        assert!(owned_oversized_err.to_string().contains("expected 8 bytes"));
572
573        let borrowed_short = [short.as_slice()];
574        let borrowed_err = BindingPlan::from_borrowed_inputs(&program, &borrowed_short)
575            .expect_err("borrowed static input length must be exact");
576        assert!(borrowed_err.to_string().contains("expected 8 bytes"));
577        let borrowed_oversized = [oversized.as_slice()];
578        let borrowed_oversized_err =
579            BindingPlan::from_borrowed_inputs(&program, &borrowed_oversized)
580                .expect_err("borrowed static input length must remain exact");
581        assert!(borrowed_oversized_err
582            .to_string()
583            .contains("expected 8 bytes"));
584
585        let resident_err = BindingPlan::from_input_lengths(&program, &[4])
586            .expect_err("resident static input length must not be smaller than the ABI");
587        assert!(resident_err.to_string().contains("at least 8 bytes"));
588        let resident_exact = BindingPlan::from_input_lengths(&program, &[8])
589            .expect("resident input equal to the ABI size should validate");
590        assert_eq!(resident_exact.bindings[0].element_count, 2);
591        let resident_oversized = BindingPlan::from_input_lengths(&program, &[12])
592            .expect("resident input larger than the ABI size should validate");
593        assert_eq!(resident_oversized.bindings[0].element_count, 2);
594    }
595
596    #[test]
597    fn dynamic_input_length_sets_runtime_element_count() {
598        let program = static_u32_input_program(0);
599        let plan = BindingPlan::from_program(&program, &[vec![0u8; 12]])
600            .expect("Fix: reject bindings without known element width; do not dispatch un-sized dynamic inputs - dynamic input byte length should define element count");
601
602        assert_eq!(plan.bindings[0].element_count, 3);
603        assert_eq!(plan.bindings[0].static_byte_len, None);
604    }
605}
606
607// ---------------------------------------------------------------------------
608// N7 binding-set merging across consecutive dispatches
609// ---------------------------------------------------------------------------
610
611/// Stable fingerprint of a binding set's *layout*  -  the parts that
612/// determine whether two `BindingPlan`s can share a backend bind
613/// group layout / descriptor set.
614///
615/// Two plans with the same [`BindingSetFingerprint`] can reuse the
616/// same `portable::BindGroupLayout` or native descriptor set across
617/// consecutive dispatches, skipping the layout-rebind cost. The
618/// hot-path perf snapshot puts binding rebind at ~20% of warm
619/// dispatch time on attention/softmax/reduce shapes.
620///
621/// Layout (this fingerprint) is distinct from contents (which
622/// `program_vsa_fingerprint` covers)  -  two dispatches of the same
623/// kernel on different input buffers share a layout fingerprint but
624/// differ in their content fingerprint.
625#[derive(Debug, Clone, PartialEq, Eq, Hash)]
626pub struct BindingSetFingerprint {
627    /// Per-binding layout slot: `(binding_index, role, element_size)`.
628    /// Ordered by `binding_index` for deterministic equality.
629    pub slots: Vec<(u32, BindingRole, usize)>,
630}
631
632impl BindingSetFingerprint {
633    /// Derive the layout fingerprint from a `BindingPlan`. Stable
634    /// across runs and across machines (no random salts).
635    #[must_use]
636    pub fn from_plan(plan: &BindingPlan) -> Self {
637        let mut slots: Vec<(u32, BindingRole, usize)> = plan
638            .bindings
639            .iter()
640            .map(|b| (b.binding, b.role, b.element_size))
641            .collect();
642        slots.sort_by_key(|(idx, _, _)| *idx);
643        Self { slots }
644    }
645}
646
647/// True when two binding plans can share a backend bind group
648/// layout / descriptor set. This is the N7 merge predicate; a
649/// driver maintains a cache keyed by [`BindingSetFingerprint`] and
650/// reuses the cached layout when this returns `true`.
651#[must_use]
652pub fn binding_plans_share_layout(a: &BindingPlan, b: &BindingPlan) -> bool {
653    BindingSetFingerprint::from_plan(a) == BindingSetFingerprint::from_plan(b)
654}
655
656/// Backend-neutral descriptor/bind-group layout slot.
657///
658/// Concrete drivers own target-specific object creation, but the
659/// fingerprint used to decide whether a descriptor layout is reusable is
660/// shared here so portable/native/secondary do not grow separate cache-key rules.
661#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
662pub struct BackendLayoutSlot {
663    /// Target descriptor group/set.
664    pub group: u32,
665    /// Binding index inside the descriptor group/set.
666    pub binding: u32,
667    /// Descriptor memory class.
668    pub class: BackendLayoutClass,
669    /// Whether storage descriptors are read-only.
670    pub read_only: bool,
671    /// Element size in bytes when statically known.
672    pub element_size: usize,
673}
674
675/// Backend-neutral descriptor memory class.
676#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
677pub enum BackendLayoutClass {
678    /// Read-write or read-only storage buffer.
679    Storage,
680    /// Uniform/constant buffer.
681    Uniform,
682}
683
684/// Stable descriptor-layout fingerprint for backend object caches.
685#[derive(Debug, Clone, PartialEq, Eq, Hash)]
686pub struct BackendLayoutFingerprint {
687    /// Canonical slots sorted by `(group, binding)`.
688    pub slots: Vec<BackendLayoutSlot>,
689}
690
691impl BackendLayoutFingerprint {
692    /// Build a deterministic fingerprint from unsorted layout slots.
693    #[must_use]
694    pub fn new(mut slots: Vec<BackendLayoutSlot>) -> Self {
695        slots.sort_by_key(|slot| (slot.group, slot.binding));
696        Self { slots }
697    }
698}
699
700#[cfg(test)]
701mod n7_tests {
702    use super::*;
703    use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Program};
704
705    fn add_one_program() -> Program {
706        Program::wrapped(
707            vec![
708                BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
709                    .with_count(16),
710                BufferDecl::output("out", 1, DataType::U32).with_count(16),
711            ],
712            [16, 1, 1],
713            vec![],
714        )
715    }
716
717    fn add_one_program_different_input_count() -> Program {
718        // Same binding shape (slot 0 ReadOnly, slot 1 output, both
719        // U32), different element_count. Layout fingerprint must match;
720        // content fingerprint will not.
721        Program::wrapped(
722            vec![
723                BufferDecl::storage("input", 0, BufferAccess::ReadOnly, DataType::U32)
724                    .with_count(64),
725                BufferDecl::output("out", 1, DataType::U32).with_count(64),
726            ],
727            [16, 1, 1],
728            vec![],
729        )
730    }
731
732    fn different_layout_program() -> Program {
733        // Three bindings instead of two  -  must NOT share layout.
734        Program::wrapped(
735            vec![
736                BufferDecl::storage("a", 0, BufferAccess::ReadOnly, DataType::U32).with_count(16),
737                BufferDecl::storage("b", 1, BufferAccess::ReadOnly, DataType::U32).with_count(16),
738                BufferDecl::output("out", 2, DataType::U32).with_count(16),
739            ],
740            [16, 1, 1],
741            vec![],
742        )
743    }
744
745    #[test]
746    fn same_layout_with_different_element_counts_shares_fingerprint() {
747        let a = BindingPlan::build(&add_one_program()).unwrap();
748        let b = BindingPlan::build(&add_one_program_different_input_count()).unwrap();
749        assert!(
750            binding_plans_share_layout(&a, &b),
751            "plans with same (binding, role, element_size) tuples must share layout"
752        );
753    }
754
755    #[test]
756    fn different_binding_count_does_not_share_layout() {
757        let a = BindingPlan::build(&add_one_program()).unwrap();
758        let b = BindingPlan::build(&different_layout_program()).unwrap();
759        assert!(
760            !binding_plans_share_layout(&a, &b),
761            "plans with different binding count must not share layout"
762        );
763    }
764
765    #[test]
766    fn fingerprint_is_stable_across_repeated_builds() {
767        let a = BindingPlan::build(&add_one_program()).unwrap();
768        let b = BindingPlan::build(&add_one_program()).unwrap();
769        assert_eq!(
770            BindingSetFingerprint::from_plan(&a),
771            BindingSetFingerprint::from_plan(&b),
772            "repeated build of the same Program must produce identical fingerprints"
773        );
774    }
775
776    #[test]
777    fn fingerprint_slots_are_sorted_by_binding_index() {
778        let plan = BindingPlan::build(&add_one_program()).unwrap();
779        let fp = BindingSetFingerprint::from_plan(&plan);
780        let indices: Vec<u32> = fp.slots.iter().map(|(i, _, _)| *i).collect();
781        assert_eq!(indices, [0, 1], "slots must be sorted by binding index");
782    }
783
784    #[test]
785    fn backend_layout_fingerprint_sorts_slots() {
786        let a = BackendLayoutFingerprint::new(vec![
787            BackendLayoutSlot {
788                group: 1,
789                binding: 4,
790                class: BackendLayoutClass::Storage,
791                read_only: false,
792                element_size: 4,
793            },
794            BackendLayoutSlot {
795                group: 0,
796                binding: 1,
797                class: BackendLayoutClass::Uniform,
798                read_only: true,
799                element_size: 4,
800            },
801        ]);
802        let b = BackendLayoutFingerprint::new(vec![
803            BackendLayoutSlot {
804                group: 0,
805                binding: 1,
806                class: BackendLayoutClass::Uniform,
807                read_only: true,
808                element_size: 4,
809            },
810            BackendLayoutSlot {
811                group: 1,
812                binding: 4,
813                class: BackendLayoutClass::Storage,
814                read_only: false,
815                element_size: 4,
816            },
817        ]);
818        assert_eq!(a, b);
819    }
820}
821
822#[cfg(test)]
823mod tests {
824    use super::*;
825    use vyre_foundation::ir::{CacheLocality, DataType, MemoryHints};
826
827    #[test]
828    fn binding_plan_carries_alignment_hints() {
829        let program = Program::wrapped(
830            vec![BufferDecl::output("out", 0, DataType::U32)
831                .with_count(16)
832                .with_hints(MemoryHints {
833                    coalesce_axis: Some(0),
834                    preferred_alignment: 64,
835                    cache_locality: CacheLocality::Streaming,
836                })],
837            [64, 1, 1],
838            vec![],
839        );
840        let plan = BindingPlan::build(&program).expect("Fix: alignment hint should build");
841        assert_eq!(plan.bindings[0].preferred_alignment, 64);
842    }
843
844    #[test]
845    fn binding_plan_keeps_logical_slots_when_binding_numbers_are_reordered() {
846        let program = Program::wrapped(
847            vec![
848                BufferDecl::read("declared_first_high_binding", 9, DataType::U32),
849                BufferDecl::output("declared_output_first_high_binding", 8, DataType::U32)
850                    .with_count(1),
851                BufferDecl::read("declared_second_low_binding", 0, DataType::U32),
852                BufferDecl::output("declared_output_second_low_binding", 1, DataType::U32)
853                    .with_count(1),
854            ],
855            [1, 1, 1],
856            vec![],
857        );
858        let inputs = [vec![0u8; 12], vec![0u8; 8]];
859
860        let plan = BindingPlan::from_program(&program, &inputs)
861            .expect("Fix: binding plan must accept logical input order before descriptor sorting");
862
863        assert_eq!(
864            plan.bindings
865                .iter()
866                .map(|binding| binding.binding)
867                .collect::<Vec<_>>(),
868            [0, 1, 8, 9],
869            "descriptor ABI must remain sorted by VYRE binding number"
870        );
871        assert_eq!(
872            plan.input_indices,
873            [0, 2],
874            "caller input slots must follow Program::buffers declaration order"
875        );
876        assert_eq!(
877            plan.output_indices,
878            [1, 3],
879            "backend output slots must follow Program::buffers declaration order"
880        );
881
882        let high_input = plan
883            .bindings
884            .iter()
885            .find(|binding| binding.binding == 9)
886            .expect("high binding input descriptor must exist");
887        assert_eq!(high_input.input_index, Some(0));
888        assert_eq!(high_input.element_count, 3);
889
890        let low_input = plan
891            .bindings
892            .iter()
893            .find(|binding| binding.binding == 0)
894            .expect("low binding input descriptor must exist");
895        assert_eq!(low_input.input_index, Some(1));
896        assert_eq!(low_input.element_count, 2);
897
898        let high_output = plan
899            .bindings
900            .iter()
901            .find(|binding| binding.binding == 8)
902            .expect("high binding output descriptor must exist");
903        assert_eq!(high_output.output_index, Some(0));
904
905        let low_output = plan
906            .bindings
907            .iter()
908            .find(|binding| binding.binding == 1)
909            .expect("low binding output descriptor must exist");
910        assert_eq!(low_output.output_index, Some(1));
911    }
912
913    #[test]
914    fn binding_plan_rejects_non_power_of_two_alignment_hint() {
915        let program = Program::wrapped(
916            vec![BufferDecl::output("out", 0, DataType::U32)
917                .with_count(16)
918                .with_hints(MemoryHints {
919                    coalesce_axis: None,
920                    preferred_alignment: 48,
921                    cache_locality: CacheLocality::Temporal,
922                })],
923            [64, 1, 1],
924            vec![],
925        );
926        let err = BindingPlan::build(&program).expect_err("bad alignment must fail");
927        assert!(format!("{err}").contains("preferred_alignment=48"));
928    }
929
930    #[test]
931    fn binding_plan_alignment_defaults_to_element_size() {
932        let program = Program::wrapped(
933            vec![BufferDecl::output("out", 0, DataType::U32).with_count(16)],
934            [64, 1, 1],
935            vec![],
936        );
937        let plan = BindingPlan::build(&program).expect("Fix: default alignment should build");
938        assert_eq!(plan.bindings[0].preferred_alignment, 4);
939    }
940
941    #[test]
942    fn binding_plan_uses_packed_static_byte_len_for_subbyte_elements() {
943        let program = Program::wrapped(
944            vec![
945                BufferDecl::storage("packed_i4", 0, BufferAccess::ReadOnly, DataType::I4)
946                    .with_count(3),
947            ],
948            [1, 1, 1],
949            vec![],
950        );
951        let plan =
952            BindingPlan::build(&program).expect("Fix: packed I4 binding layout should build");
953
954        assert_eq!(plan.bindings[0].element_size, 1);
955        assert_eq!(plan.bindings[0].static_byte_len, Some(2));
956    }
957
958    #[test]
959    fn binding_plan_validates_packed_static_input_lengths() {
960        let program = Program::wrapped(
961            vec![
962                BufferDecl::storage("packed_i4", 0, BufferAccess::ReadOnly, DataType::I4)
963                    .with_count(3),
964            ],
965            [1, 1, 1],
966            vec![],
967        );
968        let plan = BindingPlan::from_input_lengths(&program, &[2])
969            .expect("Fix: packed I4 input should accept the exact packed byte count");
970
971        plan.validate_input_byte_lengths(&[2])
972            .expect("Fix: cached packed I4 input length should remain valid");
973        plan.validate_input_byte_lengths(&[3])
974            .expect("Fix: resident packed I4 input may be larger than its static ABI byte count");
975        let error = plan
976            .validate_input_byte_lengths(&[1])
977            .expect_err("undersized resident byte length must not satisfy packed I4 contract");
978        assert!(
979            format!("{error}").contains("at least 2 bytes"),
980            "Fix: packed resident byte mismatch must be explicit: {error}"
981        );
982    }
983
984    #[test]
985    fn binding_plan_rejects_malformed_data_type_layouts() {
986        let program = Program::wrapped(
987            vec![BufferDecl::output(
988                "bad_vec",
989                0,
990                DataType::Vec {
991                    element: Box::new(DataType::U32),
992                    count: 0,
993                },
994            )
995            .with_count(1)],
996            [1, 1, 1],
997            vec![],
998        );
999
1000        let error = BindingPlan::build(&program)
1001            .expect_err("zero-lane vector layout must not enter binding planning");
1002        assert!(
1003            format!("{error}").contains("Vec count must be > 0"),
1004            "Fix: malformed data-type layout diagnostics must survive binding planning: {error}"
1005        );
1006    }
1007
1008    #[test]
1009    fn binding_plan_validates_cached_resident_input_lengths() {
1010        let program = Program::wrapped(
1011            vec![
1012                BufferDecl::read("in", 0, DataType::U32).with_count(4),
1013                BufferDecl::output("out", 1, DataType::U32).with_count(4),
1014            ],
1015            [4, 1, 1],
1016            vec![],
1017        );
1018        let plan = BindingPlan::from_input_lengths(&program, &[16])
1019            .expect("Fix: resident input length should match the declared u32[4] input");
1020
1021        plan.validate_input_byte_lengths(&[16])
1022            .expect("Fix: cached resident plan should accept the same input byte length");
1023        plan.validate_input_byte_lengths(&[20])
1024            .expect("Fix: cached resident plan should accept a larger reused allocation");
1025        let error = plan
1026            .validate_input_byte_lengths(&[12])
1027            .expect_err("cached resident plan must reject stale pipeline shape reuse");
1028        assert!(
1029            format!("{error}").contains("at least 16 bytes"),
1030            "wrong resident input length must produce an actionable size mismatch: {error}"
1031        );
1032    }
1033}