1use crate::binding::{BindingPlan, BindingRole};
11use crate::transfer_accounting::TransferAccountingPolicy;
12use crate::BackendError;
13
14const GRAPH_CAPTURE_BINDING_ACCOUNTING: TransferAccountingPolicy =
15 TransferAccountingPolicy::new("graph capture binding plan", "record a smaller graph shape");
16
17pub const SCAN_GRAPH_CAPTURE_EDIT_SCHEMA_VERSION: u32 = 1;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub struct GraphCaptureBindingPlan {
23 pub input_device_capacity: usize,
27 pub output_device_capacity: usize,
32 pub output_readback_capacity: usize,
34 pub kernel_pointer_capacity: usize,
36 pub kernel_argument_capacity: usize,
38 pub resident_input_replay_safe: bool,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
45pub enum ScanGraphCaptureEditKind {
46 PatternDatabaseUpload,
48 HaystackBufferChange,
50 OutputSlabResize,
52 VerifierChange,
54}
55
56impl ScanGraphCaptureEditKind {
57 #[must_use]
59 pub const fn as_str(self) -> &'static str {
60 match self {
61 Self::PatternDatabaseUpload => "pattern_database_upload",
62 Self::HaystackBufferChange => "haystack_buffer_change",
63 Self::OutputSlabResize => "output_slab_resize",
64 Self::VerifierChange => "verifier_change",
65 }
66 }
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
71pub enum GraphCaptureEditAction {
72 Replay,
74 Update,
76 Recapture,
78}
79
80impl GraphCaptureEditAction {
81 #[must_use]
83 pub const fn as_str(self) -> &'static str {
84 match self {
85 Self::Replay => "replay",
86 Self::Update => "update",
87 Self::Recapture => "recapture",
88 }
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
94pub enum GraphCaptureEditStability {
95 GraphStable,
97 GraphBreaking,
99}
100
101impl GraphCaptureEditStability {
102 #[must_use]
104 pub const fn as_str(self) -> &'static str {
105 match self {
106 Self::GraphStable => "graph_stable",
107 Self::GraphBreaking => "graph_breaking",
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
114pub struct ScanGraphCaptureEdit {
115 pub kind: ScanGraphCaptureEditKind,
117 pub previous_byte_len: u64,
119 pub next_byte_len: u64,
121 pub previous_digest: u64,
123 pub next_digest: u64,
125}
126
127impl ScanGraphCaptureEdit {
128 #[must_use]
130 pub const fn new(
131 kind: ScanGraphCaptureEditKind,
132 previous_byte_len: u64,
133 next_byte_len: u64,
134 previous_digest: u64,
135 next_digest: u64,
136 ) -> Self {
137 Self {
138 kind,
139 previous_byte_len,
140 next_byte_len,
141 previous_digest,
142 next_digest,
143 }
144 }
145
146 const fn shape_unchanged(self) -> bool {
147 self.previous_byte_len == self.next_byte_len
148 }
149
150 const fn digest_unchanged(self) -> bool {
151 self.previous_digest == self.next_digest
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub struct ScanGraphCaptureEditClassification {
158 pub schema_version: u32,
160 pub edit_kind: ScanGraphCaptureEditKind,
162 pub action: GraphCaptureEditAction,
164 pub stability: GraphCaptureEditStability,
166 pub reason: &'static str,
168 pub graph_stable: bool,
170 pub graph_breaking: bool,
172 pub parameter_update_required: bool,
174}
175
176impl ScanGraphCaptureEditClassification {
177 #[must_use]
180 pub const fn is_complete(self) -> bool {
181 self.schema_version == SCAN_GRAPH_CAPTURE_EDIT_SCHEMA_VERSION
182 && !self.reason.is_empty()
183 && self.graph_stable == matches!(self.stability, GraphCaptureEditStability::GraphStable)
184 && self.graph_breaking
185 == matches!(self.stability, GraphCaptureEditStability::GraphBreaking)
186 && self.parameter_update_required
187 == matches!(self.action, GraphCaptureEditAction::Update)
188 }
189}
190
191pub fn plan_graph_capture_bindings(
198 bindings: &BindingPlan,
199) -> Result<GraphCaptureBindingPlan, BackendError> {
200 let mut input_device_capacity = 0usize;
201 let mut output_device_capacity = 0usize;
202 let mut output_readback_capacity = 0usize;
203 let mut kernel_pointer_capacity = 0usize;
204 let mut resident_input_replay_safe = true;
205
206 for binding in &bindings.bindings {
207 if binding.role == BindingRole::Shared {
208 continue;
209 }
210
211 kernel_pointer_capacity =
212 graph_capture_capacity_add(kernel_pointer_capacity, 1, "kernel pointer table")?;
213
214 if binding.input_index.is_some() {
215 input_device_capacity =
216 graph_capture_capacity_add(input_device_capacity, 1, "input device table")?;
217 } else {
218 output_device_capacity =
219 graph_capture_capacity_add(output_device_capacity, 1, "output device table")?;
220 }
221
222 if binding.output_index.is_some() {
223 output_readback_capacity =
224 graph_capture_capacity_add(output_readback_capacity, 1, "output readback table")?;
225 }
226
227 if binding.input_index.is_some() && binding.output_index.is_some() {
228 resident_input_replay_safe = false;
229 }
230 }
231
232 let kernel_argument_capacity =
233 graph_capture_capacity_add(kernel_pointer_capacity, 1, "kernel argument table")?;
234
235 Ok(GraphCaptureBindingPlan {
236 input_device_capacity,
237 output_device_capacity,
238 output_readback_capacity,
239 kernel_pointer_capacity,
240 kernel_argument_capacity,
241 resident_input_replay_safe,
242 })
243}
244
245#[must_use]
252pub const fn classify_scan_graph_capture_edit(
253 edit: ScanGraphCaptureEdit,
254) -> ScanGraphCaptureEditClassification {
255 match edit.kind {
256 ScanGraphCaptureEditKind::PatternDatabaseUpload => {
257 if edit.shape_unchanged() && edit.digest_unchanged() {
258 scan_graph_capture_classification(
259 edit.kind,
260 GraphCaptureEditAction::Replay,
261 GraphCaptureEditStability::GraphStable,
262 "pattern_database_unchanged",
263 )
264 } else {
265 scan_graph_capture_classification(
266 edit.kind,
267 GraphCaptureEditAction::Recapture,
268 GraphCaptureEditStability::GraphBreaking,
269 "pattern_database_changed",
270 )
271 }
272 }
273 ScanGraphCaptureEditKind::HaystackBufferChange => {
274 if edit.shape_unchanged() {
275 if edit.digest_unchanged() {
276 scan_graph_capture_classification(
277 edit.kind,
278 GraphCaptureEditAction::Replay,
279 GraphCaptureEditStability::GraphStable,
280 "haystack_unchanged",
281 )
282 } else {
283 scan_graph_capture_classification(
284 edit.kind,
285 GraphCaptureEditAction::Update,
286 GraphCaptureEditStability::GraphStable,
287 "haystack_contents_changed_same_shape",
288 )
289 }
290 } else {
291 scan_graph_capture_classification(
292 edit.kind,
293 GraphCaptureEditAction::Recapture,
294 GraphCaptureEditStability::GraphBreaking,
295 "haystack_shape_changed",
296 )
297 }
298 }
299 ScanGraphCaptureEditKind::OutputSlabResize => {
300 if edit.shape_unchanged() {
301 scan_graph_capture_classification(
302 edit.kind,
303 GraphCaptureEditAction::Replay,
304 GraphCaptureEditStability::GraphStable,
305 "output_slab_unchanged",
306 )
307 } else {
308 scan_graph_capture_classification(
309 edit.kind,
310 GraphCaptureEditAction::Recapture,
311 GraphCaptureEditStability::GraphBreaking,
312 "output_slab_size_changed",
313 )
314 }
315 }
316 ScanGraphCaptureEditKind::VerifierChange => {
317 if edit.shape_unchanged() && edit.digest_unchanged() {
318 scan_graph_capture_classification(
319 edit.kind,
320 GraphCaptureEditAction::Replay,
321 GraphCaptureEditStability::GraphStable,
322 "verifier_unchanged",
323 )
324 } else {
325 scan_graph_capture_classification(
326 edit.kind,
327 GraphCaptureEditAction::Recapture,
328 GraphCaptureEditStability::GraphBreaking,
329 "verifier_changed",
330 )
331 }
332 }
333 }
334}
335
336const fn scan_graph_capture_classification(
337 edit_kind: ScanGraphCaptureEditKind,
338 action: GraphCaptureEditAction,
339 stability: GraphCaptureEditStability,
340 reason: &'static str,
341) -> ScanGraphCaptureEditClassification {
342 ScanGraphCaptureEditClassification {
343 schema_version: SCAN_GRAPH_CAPTURE_EDIT_SCHEMA_VERSION,
344 edit_kind,
345 action,
346 stability,
347 reason,
348 graph_stable: matches!(stability, GraphCaptureEditStability::GraphStable),
349 graph_breaking: matches!(stability, GraphCaptureEditStability::GraphBreaking),
350 parameter_update_required: matches!(action, GraphCaptureEditAction::Update),
351 }
352}
353
354fn graph_capture_capacity_add(lhs: usize, rhs: usize, label: &str) -> Result<usize, BackendError> {
355 GRAPH_CAPTURE_BINDING_ACCOUNTING.add_usize_capacity(lhs, rhs, label)
356}
357
358#[cfg(test)]
359mod tests {
360 use super::{
361 classify_scan_graph_capture_edit, graph_capture_capacity_add, plan_graph_capture_bindings,
362 GraphCaptureBindingPlan, GraphCaptureEditAction, GraphCaptureEditStability,
363 ScanGraphCaptureEdit, ScanGraphCaptureEditKind,
364 };
365 use crate::binding::{Binding, BindingPlan, BindingRole};
366 use std::sync::Arc;
367
368 fn binding(
369 name: &'static str,
370 slot: u32,
371 role: BindingRole,
372 input_index: Option<usize>,
373 output_index: Option<usize>,
374 ) -> Binding {
375 Binding {
376 name: Arc::from(name),
377 binding: slot,
378 buffer_index: slot as usize,
379 role,
380 element_size: 4,
381 preferred_alignment: 4,
382 element_count: 16,
383 static_byte_len: Some(64),
384 input_index,
385 output_index,
386 }
387 }
388
389 fn plan(bindings: Vec<Binding>) -> BindingPlan {
390 BindingPlan {
391 bindings,
392 input_indices: vec![],
393 output_indices: vec![],
394 shared_indices: vec![],
395 }
396 }
397
398 #[test]
399 fn graph_capture_binding_plan_counts_distinct_device_and_readback_tables() {
400 let bindings = plan(vec![
401 binding("input", 0, BindingRole::Input, Some(0), None),
402 binding("shared", 1, BindingRole::Shared, None, None),
403 binding("output", 2, BindingRole::Output, None, Some(0)),
404 binding("state", 3, BindingRole::InputOutput, Some(1), Some(1)),
405 ]);
406
407 assert_eq!(
408 plan_graph_capture_bindings(&bindings)
409 .expect("Fix: graph capture planning should accept normal bindings"),
410 GraphCaptureBindingPlan {
411 input_device_capacity: 2,
412 output_device_capacity: 1,
413 output_readback_capacity: 2,
414 kernel_pointer_capacity: 3,
415 kernel_argument_capacity: 4,
416 resident_input_replay_safe: false,
417 }
418 );
419 }
420
421 #[test]
422 fn generated_graph_capture_binding_plan_preserves_order_independent_counts() {
423 let mut state = 0x9e37_79b9_7f4a_7c15_u64;
424 for case_index in 0..768usize {
425 let binding_count = 1 + (next_u64(&mut state) as usize % 96);
426 let mut bindings = Vec::with_capacity(binding_count);
427 let mut expected_input_device_capacity = 0usize;
428 let mut expected_output_device_capacity = 0usize;
429 let mut expected_output_readback_capacity = 0usize;
430 let mut expected_kernel_pointer_capacity = 0usize;
431 let mut expected_safe = true;
432 let mut next_input = 0usize;
433 let mut next_output = 0usize;
434
435 for slot in 0..binding_count {
436 let role_selector = (next_u64(&mut state) % 4) as u8;
437 let (role, input_index, output_index) = match role_selector {
438 0 => {
439 let index = next_input;
440 next_input += 1;
441 (BindingRole::Input, Some(index), None)
442 }
443 1 => {
444 let index = next_output;
445 next_output += 1;
446 (BindingRole::Output, None, Some(index))
447 }
448 2 => {
449 let input = next_input;
450 let output = next_output;
451 next_input += 1;
452 next_output += 1;
453 expected_safe = false;
454 (BindingRole::InputOutput, Some(input), Some(output))
455 }
456 _ => (BindingRole::Shared, None, None),
457 };
458
459 if role != BindingRole::Shared {
460 expected_kernel_pointer_capacity += 1;
461 if input_index.is_some() {
462 expected_input_device_capacity += 1;
463 } else {
464 expected_output_device_capacity += 1;
465 }
466 if output_index.is_some() {
467 expected_output_readback_capacity += 1;
468 }
469 }
470
471 bindings.push(binding(
472 "generated",
473 slot as u32,
474 role,
475 input_index,
476 output_index,
477 ));
478 }
479
480 let planned = plan_graph_capture_bindings(&plan(bindings))
481 .expect("Fix: generated graph capture plan should fit host capacities");
482 assert_eq!(
483 planned,
484 GraphCaptureBindingPlan {
485 input_device_capacity: expected_input_device_capacity,
486 output_device_capacity: expected_output_device_capacity,
487 output_readback_capacity: expected_output_readback_capacity,
488 kernel_pointer_capacity: expected_kernel_pointer_capacity,
489 kernel_argument_capacity: expected_kernel_pointer_capacity + 1,
490 resident_input_replay_safe: expected_safe,
491 },
492 "case {case_index}"
493 );
494 }
495 }
496
497 #[test]
498 fn graph_capture_capacity_overflow_fails_loudly() {
499 let error = graph_capture_capacity_add(usize::MAX, 1, "kernel argument table")
500 .expect_err("Fix: graph capture capacity overflow must not wrap");
501 let message = error.to_string();
502 assert!(message.contains("graph capture binding plan"));
503 assert!(message.contains("kernel argument table"));
504 assert!(message.contains("record a smaller graph shape"));
505 }
506
507 #[test]
508 fn scan_graph_capture_classifies_replay_update_and_recapture_reasons() {
509 let cases = [
510 (
511 ScanGraphCaptureEdit::new(
512 ScanGraphCaptureEditKind::PatternDatabaseUpload,
513 4096,
514 4096,
515 11,
516 11,
517 ),
518 GraphCaptureEditAction::Replay,
519 GraphCaptureEditStability::GraphStable,
520 "pattern_database_unchanged",
521 ),
522 (
523 ScanGraphCaptureEdit::new(
524 ScanGraphCaptureEditKind::PatternDatabaseUpload,
525 4096,
526 4096,
527 11,
528 12,
529 ),
530 GraphCaptureEditAction::Recapture,
531 GraphCaptureEditStability::GraphBreaking,
532 "pattern_database_changed",
533 ),
534 (
535 ScanGraphCaptureEdit::new(
536 ScanGraphCaptureEditKind::HaystackBufferChange,
537 8192,
538 8192,
539 21,
540 22,
541 ),
542 GraphCaptureEditAction::Update,
543 GraphCaptureEditStability::GraphStable,
544 "haystack_contents_changed_same_shape",
545 ),
546 (
547 ScanGraphCaptureEdit::new(
548 ScanGraphCaptureEditKind::HaystackBufferChange,
549 8192,
550 16_384,
551 21,
552 22,
553 ),
554 GraphCaptureEditAction::Recapture,
555 GraphCaptureEditStability::GraphBreaking,
556 "haystack_shape_changed",
557 ),
558 (
559 ScanGraphCaptureEdit::new(
560 ScanGraphCaptureEditKind::OutputSlabResize,
561 1024,
562 2048,
563 31,
564 31,
565 ),
566 GraphCaptureEditAction::Recapture,
567 GraphCaptureEditStability::GraphBreaking,
568 "output_slab_size_changed",
569 ),
570 (
571 ScanGraphCaptureEdit::new(
572 ScanGraphCaptureEditKind::VerifierChange,
573 512,
574 512,
575 41,
576 42,
577 ),
578 GraphCaptureEditAction::Recapture,
579 GraphCaptureEditStability::GraphBreaking,
580 "verifier_changed",
581 ),
582 ];
583
584 for (edit, action, stability, reason) in cases {
585 let classified = classify_scan_graph_capture_edit(edit);
586 assert!(classified.is_complete());
587 assert_eq!(classified.edit_kind, edit.kind);
588 assert_eq!(classified.action, action);
589 assert_eq!(classified.stability, stability);
590 assert_eq!(classified.reason, reason);
591 }
592 }
593
594 #[test]
595 fn scan_graph_capture_same_shape_haystack_update_is_not_a_hidden_recapture() {
596 let classified = classify_scan_graph_capture_edit(ScanGraphCaptureEdit::new(
597 ScanGraphCaptureEditKind::HaystackBufferChange,
598 65_536,
599 65_536,
600 100,
601 101,
602 ));
603
604 assert_eq!(classified.action, GraphCaptureEditAction::Update);
605 assert!(classified.graph_stable);
606 assert!(!classified.graph_breaking);
607 assert!(classified.parameter_update_required);
608 assert_eq!(classified.reason, "haystack_contents_changed_same_shape");
609 }
610
611 fn next_u64(state: &mut u64) -> u64 {
612 *state = state
613 .wrapping_mul(6_364_136_223_846_793_005)
614 .wrapping_add(1_442_695_040_888_963_407);
615 *state
616 }
617}