1use std::collections::{HashMap, HashSet};
63use std::sync::Arc;
64
65use smallvec::SmallVec;
66use vyre_foundation::ir::{BufferAccess, BufferDecl, Expr, Ident, MemoryKind, Node, Program};
67use vyre_foundation::memory_model::MemoryOrdering;
68
69use crate::backend::{
70 BackendError, DispatchConfig, OutputBuffers, ResidentDispatchStep, ResidentReadRange, Resource,
71 TimedDispatchResult, VyreBackend,
72};
73use crate::binding::{Binding, BindingPlan, BindingRole};
74
75#[derive(Clone, Debug, PartialEq, Eq)]
83enum EntryWrapper {
84 Region { generator: Ident },
85 Block,
86}
87
88struct PlannedGridSyncSegment {
89 program: Program,
90 input_names: Vec<Ident>,
91 output_names: Vec<Ident>,
92}
93
94fn peel_entry_wrappers(program: &Program) -> (Vec<EntryWrapper>, &[Node]) {
95 let mut wrappers = Vec::new();
96 let mut entry = program.entry();
97 loop {
98 if entry.len() == 1 {
99 match &entry[0] {
100 Node::Region {
101 generator, body, ..
102 } => {
103 wrappers.push(EntryWrapper::Region {
104 generator: generator.clone(),
105 });
106 entry = body.as_slice();
107 continue;
108 }
109 Node::Block(body) => {
110 wrappers.push(EntryWrapper::Block);
111 entry = body.as_slice();
112 continue;
113 }
114 _ => {}
115 }
116 }
117 break;
118 }
119 (wrappers, entry)
120}
121
122fn entry_sequence(program: &Program) -> &[Node] {
123 peel_entry_wrappers(program).1
124}
125
126#[must_use]
135pub fn contains_grid_sync(program: &Program) -> bool {
136 if !program.stats().has_node_barrier() {
142 return false;
143 }
144 node_slice_contains_grid_sync(entry_sequence(program))
145}
146
147fn node_slice_contains_grid_sync(nodes: &[Node]) -> bool {
148 nodes.iter().any(node_contains_grid_sync)
149}
150
151fn node_contains_grid_sync(node: &Node) -> bool {
152 match node {
153 Node::Barrier {
154 ordering: MemoryOrdering::GridSync,
155 ..
156 } => true,
157 Node::If {
158 then, otherwise, ..
159 } => node_slice_contains_grid_sync(then) || node_slice_contains_grid_sync(otherwise),
160 Node::Loop { body, .. } | Node::Block(body) => node_slice_contains_grid_sync(body),
161 Node::Region { body, .. } => node_slice_contains_grid_sync(body),
162 _ => false,
163 }
164}
165
166#[must_use]
179pub fn split_on_grid_sync(program: &Program) -> Vec<Program> {
180 match try_split_on_grid_sync(program) {
181 Ok(segments) => segments,
182 Err(_error) => Vec::new(),
183 }
184}
185
186fn hoist_grid_sync_barriers(nodes: &[Node]) -> Vec<Node> {
192 let mut new_nodes = Vec::new();
193 for node in nodes {
194 match node {
195 Node::Block(body) => {
196 let new_body = hoist_grid_sync_barriers(body);
197 let has_barrier = new_body.iter().any(|n| {
198 matches!(
199 n,
200 Node::Barrier {
201 ordering: MemoryOrdering::GridSync,
202 ..
203 }
204 )
205 });
206 if has_barrier {
207 let mut current_segment = Vec::new();
208 for b_node in new_body {
209 if matches!(
210 b_node,
211 Node::Barrier {
212 ordering: MemoryOrdering::GridSync,
213 ..
214 }
215 ) {
216 new_nodes.push(Node::Block(std::mem::take(&mut current_segment)));
217 new_nodes.push(b_node);
218 } else {
219 current_segment.push(b_node);
220 }
221 }
222 new_nodes.push(Node::Block(current_segment));
223 } else {
224 new_nodes.push(Node::Block(new_body));
225 }
226 }
227 Node::Region {
228 generator,
229 source_region,
230 body,
231 } => {
232 let new_body = hoist_grid_sync_barriers(body);
233 let has_barrier = new_body.iter().any(|n| {
234 matches!(
235 n,
236 Node::Barrier {
237 ordering: MemoryOrdering::GridSync,
238 ..
239 }
240 )
241 });
242 if has_barrier {
243 let mut current_segment = Vec::new();
244 for b_node in new_body {
245 if matches!(
246 b_node,
247 Node::Barrier {
248 ordering: MemoryOrdering::GridSync,
249 ..
250 }
251 ) {
252 new_nodes.push(Node::Region {
253 generator: generator.clone(),
254 source_region: source_region.clone(),
255 body: Arc::new(std::mem::take(&mut current_segment)),
256 });
257 new_nodes.push(b_node);
258 } else {
259 current_segment.push(b_node);
260 }
261 }
262 new_nodes.push(Node::Region {
263 generator: generator.clone(),
264 source_region: source_region.clone(),
265 body: Arc::new(current_segment),
266 });
267 } else {
268 new_nodes.push(Node::Region {
269 generator: generator.clone(),
270 source_region: source_region.clone(),
271 body: Arc::new(new_body),
272 });
273 }
274 }
275 other => {
276 new_nodes.push(other.clone());
277 }
278 }
279 }
280 new_nodes
281}
282
283fn collect_global_let_bindings(nodes: &[Node], map: &mut std::collections::HashMap<String, Node>) {
284 for node in nodes {
285 match node {
286 Node::Let { name, .. } => {
287 map.insert(name.as_str().to_string(), node.clone());
288 }
289 Node::If {
290 then, otherwise, ..
291 } => {
292 collect_global_let_bindings(then, map);
293 collect_global_let_bindings(otherwise, map);
294 }
295 Node::Loop { body, .. } | Node::Block(body) => {
296 collect_global_let_bindings(body, map);
297 }
298 Node::Region { body, .. } => {
299 collect_global_let_bindings(&body[..], map);
300 }
301 _ => {}
302 }
303 }
304}
305
306fn collect_locally_defined_vars(nodes: &[Node], vars: &mut std::collections::HashSet<String>) {
307 for node in nodes {
308 match node {
309 Node::Let { name, .. } => {
310 vars.insert(name.as_str().to_string());
311 }
312 Node::Loop { var, body, .. } => {
313 vars.insert(var.as_str().to_string());
314 collect_locally_defined_vars(body, vars);
315 }
316 Node::If {
317 then, otherwise, ..
318 } => {
319 collect_locally_defined_vars(then, vars);
320 collect_locally_defined_vars(otherwise, vars);
321 }
322 Node::Block(body) => {
323 collect_locally_defined_vars(body, vars);
324 }
325 Node::Region { body, .. } => {
326 collect_locally_defined_vars(&body[..], vars);
327 }
328 _ => {}
329 }
330 }
331}
332
333fn collect_referenced_vars(expr: &Expr, vars: &mut std::collections::HashSet<String>) {
334 match expr {
335 Expr::Var(name) => {
336 vars.insert(name.as_str().to_string());
337 }
338 Expr::Load { index, .. } => {
339 collect_referenced_vars(index, vars);
340 }
341 Expr::BinOp { left, right, .. } => {
342 collect_referenced_vars(left, vars);
343 collect_referenced_vars(right, vars);
344 }
345 Expr::UnOp { operand, .. } => {
346 collect_referenced_vars(operand, vars);
347 }
348 Expr::Call { args, .. } => {
349 for arg in args {
350 collect_referenced_vars(arg, vars);
351 }
352 }
353 Expr::Select {
354 cond,
355 true_val,
356 false_val,
357 } => {
358 collect_referenced_vars(cond, vars);
359 collect_referenced_vars(true_val, vars);
360 collect_referenced_vars(false_val, vars);
361 }
362 Expr::Cast { value, .. } => {
363 collect_referenced_vars(value, vars);
364 }
365 Expr::Fma { a, b, c } => {
366 collect_referenced_vars(a, vars);
367 collect_referenced_vars(b, vars);
368 collect_referenced_vars(c, vars);
369 }
370 Expr::Atomic {
371 index,
372 expected,
373 value,
374 ..
375 } => {
376 collect_referenced_vars(index, vars);
377 if let Some(expected) = expected {
378 collect_referenced_vars(expected, vars);
379 }
380 collect_referenced_vars(value, vars);
381 }
382 Expr::SubgroupBallot { cond } => {
383 collect_referenced_vars(cond, vars);
384 }
385 Expr::SubgroupShuffle { value, lane } => {
386 collect_referenced_vars(value, vars);
387 collect_referenced_vars(lane, vars);
388 }
389 Expr::SubgroupAdd { value } => {
390 collect_referenced_vars(value, vars);
391 }
392 _ => {}
393 }
394}
395
396fn collect_node_referenced_vars(node: &Node, vars: &mut std::collections::HashSet<String>) {
397 match node {
398 Node::Let { value, .. } => {
399 collect_referenced_vars(value, vars);
400 }
401 Node::Assign { value, .. } => {
402 collect_referenced_vars(value, vars);
403 }
404 Node::Store { index, value, .. } => {
405 collect_referenced_vars(index, vars);
406 collect_referenced_vars(value, vars);
407 }
408 Node::If {
409 cond,
410 then,
411 otherwise,
412 } => {
413 collect_referenced_vars(cond, vars);
414 for n in then {
415 collect_node_referenced_vars(n, vars);
416 }
417 for n in otherwise {
418 collect_node_referenced_vars(n, vars);
419 }
420 }
421 Node::Loop { from, to, body, .. } => {
422 collect_referenced_vars(from, vars);
423 collect_referenced_vars(to, vars);
424 for n in body {
425 collect_node_referenced_vars(n, vars);
426 }
427 }
428 Node::Block(body) => {
429 for n in body {
430 collect_node_referenced_vars(n, vars);
431 }
432 }
433 Node::Region { body, .. } => {
434 for n in body.as_ref() {
435 collect_node_referenced_vars(n, vars);
436 }
437 }
438 Node::AsyncLoad { offset, size, .. } => {
439 collect_referenced_vars(offset, vars);
440 collect_referenced_vars(size, vars);
441 }
442 Node::AsyncStore { offset, size, .. } => {
443 collect_referenced_vars(offset, vars);
444 collect_referenced_vars(size, vars);
445 }
446 Node::Trap { address, .. } => {
447 collect_referenced_vars(address, vars);
448 }
449 _ => {}
450 }
451}
452
453fn resolve_dependencies(
454 name: &str,
455 global_lets: &std::collections::HashMap<String, Node>,
456 resolved_names: &mut std::collections::HashSet<String>,
457 resolved_lets: &mut Vec<Node>,
458) {
459 if resolved_names.contains(name) {
460 return;
461 }
462 if let Some(let_node) = global_lets.get(name) {
463 resolved_names.insert(name.to_string());
464 let mut deps = std::collections::HashSet::new();
465 collect_node_referenced_vars(let_node, &mut deps);
466 for dep in deps {
467 resolve_dependencies(&dep, global_lets, resolved_names, resolved_lets);
468 }
469 resolved_lets.push(let_node.clone());
470 }
471}
472
473fn propagate_let_bindings(segments: &mut [Vec<Node>], hoisted_inner: &[Node]) {
474 let mut global_lets = std::collections::HashMap::new();
475 collect_global_let_bindings(hoisted_inner, &mut global_lets);
476
477 for segment_nodes in segments {
478 let mut locally_defined = std::collections::HashSet::new();
479 collect_locally_defined_vars(segment_nodes, &mut locally_defined);
480
481 let mut referenced = std::collections::HashSet::new();
482 for node in segment_nodes.iter() {
483 collect_node_referenced_vars(node, &mut referenced);
484 }
485
486 let mut free_vars = Vec::new();
487 for name in referenced {
488 if !locally_defined.contains(&name) {
489 free_vars.push(name);
490 }
491 }
492
493 let mut resolved_lets = Vec::new();
494 let mut resolved_names = std::collections::HashSet::new();
495 for name in free_vars {
496 resolve_dependencies(&name, &global_lets, &mut resolved_names, &mut resolved_lets);
497 }
498
499 if !resolved_lets.is_empty() {
500 resolved_lets.extend(std::mem::take(segment_nodes));
501 *segment_nodes = resolved_lets;
502 }
503 }
504}
505
506pub fn try_split_on_grid_sync(program: &Program) -> Result<Vec<Program>, BackendError> {
513 let (wrappers, inner) = peel_entry_wrappers(program);
514 let hoisted_inner = hoist_grid_sync_barriers(inner);
515 let split_count = hoisted_inner
516 .iter()
517 .filter(|node| {
518 matches!(
519 node,
520 Node::Barrier {
521 ordering: MemoryOrdering::GridSync,
522 ..
523 }
524 )
525 })
526 .count();
527 if split_count == 0 {
528 let mut segments = Vec::new();
529 reserve_grid_sync_vec(&mut segments, 1, "grid-sync no-op segment")?;
530 segments.push(program.clone());
531 return Ok(segments);
532 }
533
534 let segment_count = split_count + 1;
535 let executable_nodes = hoisted_inner.len().checked_sub(split_count).ok_or_else(|| {
536 BackendError::InvalidProgram {
537 fix: format!(
538 "grid-sync split_count {split_count} exceeded entry node count {}. Fix: split_on_grid_sync must count barriers from the same entry sequence it segments.",
539 hoisted_inner.len()
540 ),
541 }
542 })?;
543 let segment_capacity = executable_nodes.div_ceil(segment_count);
544
545 let mut raw_segments = Vec::new();
546 let mut current = Vec::new();
547 reserve_grid_sync_vec(&mut current, segment_capacity, "grid-sync current segment")?;
548 for node in &hoisted_inner {
549 match node {
550 Node::Barrier {
551 ordering: MemoryOrdering::GridSync,
552 ..
553 } => {
554 let mut next = Vec::new();
555 reserve_grid_sync_vec(&mut next, segment_capacity, "grid-sync next segment")?;
556 let entry = std::mem::replace(&mut current, next);
557 raw_segments.push(entry);
558 }
559 other => {
560 current.push(other.clone());
561 }
562 }
563 }
564 raw_segments.push(current);
565
566 propagate_let_bindings(&mut raw_segments, &hoisted_inner);
567
568 let mut segments = Vec::new();
569 reserve_grid_sync_vec(
570 &mut segments,
571 raw_segments.len(),
572 "grid-sync split segments",
573 )?;
574 for entry in raw_segments {
575 segments.push(wrap_split_segment(program, &wrappers, entry));
576 }
577 Ok(segments)
578}
579
580fn wrap_split_segment(program: &Program, wrappers: &[EntryWrapper], entry: Vec<Node>) -> Program {
581 let mut wrapped_entry = entry;
585 for wrapper in wrappers.iter().rev() {
586 match wrapper {
587 EntryWrapper::Region { generator } => {
588 wrapped_entry = vec![Node::Region {
589 generator: generator.clone(),
590 source_region: None,
591 body: Arc::new(wrapped_entry),
592 }];
593 }
594 EntryWrapper::Block => {
595 wrapped_entry = vec![Node::Block(wrapped_entry)];
596 }
597 }
598 }
599 program.with_rewritten_entry(wrapped_entry)
600}
601
602pub fn plan_host_grid_sync_segment_programs(
612 program: &Program,
613) -> Result<Vec<Program>, BackendError> {
614 Ok(plan_host_grid_sync_segments(program)?
615 .into_iter()
616 .map(|segment| segment.program)
617 .collect())
618}
619
620fn plan_host_grid_sync_segments(
621 program: &Program,
622) -> Result<Vec<PlannedGridSyncSegment>, BackendError> {
623 let split = try_split_on_grid_sync(program)?;
624 let first_writer = first_writer_segment_per_buffer(&split, program)?;
625 let mut planned = Vec::new();
626 reserve_grid_sync_vec(&mut planned, split.len(), "grid-sync planned host segments")?;
627 for (segment_idx, segment) in split.into_iter().enumerate() {
628 let rewritten = rewrite_segment_buffers_for_host_split(
629 program,
630 &segment,
631 segment_idx,
632 &first_writer,
633 )?;
634 let input_names = segment_input_names(&rewritten)?;
635 let output_names = segment_output_names(&rewritten)?;
636 planned.push(PlannedGridSyncSegment {
637 program: rewritten,
638 input_names,
639 output_names,
640 });
641 }
642 Ok(planned)
643}
644
645fn first_writer_segment_per_buffer(
658 split: &[Program],
659 program: &Program,
660) -> Result<HashMap<Ident, usize>, BackendError> {
661 let mut first_writer: HashMap<Ident, usize> = HashMap::new();
662 reserve_grid_sync_hash_map(
663 &mut first_writer,
664 program.buffers().len(),
665 "grid-sync first-writer map",
666 )?;
667 for (segment_idx, segment) in split.iter().enumerate() {
668 let mut reads = HashSet::new();
669 let mut writes = HashSet::new();
670 reserve_grid_sync_hash_set(
671 &mut reads,
672 program.buffers().len(),
673 "grid-sync first-writer read scan",
674 )?;
675 reserve_grid_sync_hash_set(
676 &mut writes,
677 program.buffers().len(),
678 "grid-sync first-writer write scan",
679 )?;
680 for node in entry_sequence(segment) {
681 collect_segment_buffer_targets(node, &mut reads, &mut writes);
682 }
683 for name in writes {
684 first_writer.entry(name).or_insert(segment_idx);
685 }
686 }
687 Ok(first_writer)
688}
689
690fn rewrite_segment_buffers_for_host_split(
691 source: &Program,
692 segment: &Program,
693 segment_idx: usize,
694 first_writer: &HashMap<Ident, usize>,
695) -> Result<Program, BackendError> {
696 let mut reads = HashSet::new();
697 let mut writes = HashSet::new();
698 reserve_grid_sync_hash_set(
699 &mut reads,
700 source.buffers().len(),
701 "grid-sync segment read set",
702 )?;
703 reserve_grid_sync_hash_set(
704 &mut writes,
705 source.buffers().len(),
706 "grid-sync segment write set",
707 )?;
708 for node in entry_sequence(segment) {
709 collect_segment_buffer_targets(node, &mut reads, &mut writes);
710 }
711
712 let mut buffers = Vec::new();
713 reserve_grid_sync_vec(
714 &mut buffers,
715 source.buffers().len(),
716 "grid-sync segment buffers",
717 )?;
718 for buffer in source.buffers() {
719 let name = Ident::from(buffer.name());
720 let reads_this = reads.contains(&name);
721 let writes_this = writes.contains(&name);
722 let readwrite_passthrough = matches!(buffer.access(), BufferAccess::ReadWrite)
723 && !buffer.is_output()
724 && !buffer.is_pipeline_live_out()
725 && !reads_this
726 && !writes_this;
727
728 if !reads_this && !writes_this && !readwrite_passthrough {
729 continue;
730 }
731
732 let mut rewritten = buffer.clone();
733 if matches!(rewritten.access(), BufferAccess::Workgroup) {
734 buffers.push(rewritten);
735 continue;
736 }
737
738 let is_source_output = buffer.is_output() || buffer.is_pipeline_live_out();
745 let earlier_segment_wrote_output = is_source_output
746 && first_writer
747 .get(&name)
748 .is_some_and(|&first| first < segment_idx);
749
750 let access = if readwrite_passthrough {
751 BufferAccess::ReadWrite
752 } else if earlier_segment_wrote_output && writes_this {
753 BufferAccess::ReadWrite
757 } else {
758 match (reads_this, writes_this) {
759 (true, true) => BufferAccess::ReadWrite,
760 (true, false) => BufferAccess::ReadOnly,
761 (false, true) => BufferAccess::WriteOnly,
762 (false, false) => BufferAccess::ReadWrite,
763 }
764 };
765 rewrite_segment_buffer_access(&mut rewritten, access);
766 rewritten.is_output = false;
775 rewritten.pipeline_live_out = false;
776 buffers.push(rewritten);
777 }
778
779 Ok(segment.with_rewritten_buffers(buffers))
780}
781
782fn rewrite_segment_buffer_access(buffer: &mut BufferDecl, access: BufferAccess) {
783 buffer.kind = match &access {
784 BufferAccess::ReadOnly => MemoryKind::Readonly,
785 BufferAccess::Uniform => MemoryKind::Uniform,
786 BufferAccess::Workgroup => MemoryKind::Shared,
787 _ => MemoryKind::Global,
788 };
789 buffer.access = access;
790}
791
792fn segment_input_names(segment: &Program) -> Result<Vec<Ident>, BackendError> {
793 let mut names = Vec::new();
794 reserve_grid_sync_vec(
795 &mut names,
796 segment.buffers().len(),
797 "grid-sync segment input names",
798 )?;
799 for buffer in segment.buffers() {
800 if matches!(buffer.access(), BufferAccess::Workgroup) {
801 continue;
802 }
803 if segment_buffer_consumes_input(buffer) {
804 names.push(Ident::from(buffer.name()));
805 }
806 }
807 Ok(names)
808}
809
810fn segment_output_names(segment: &Program) -> Result<Vec<Ident>, BackendError> {
811 let mut names = Vec::new();
812 reserve_grid_sync_vec(
813 &mut names,
814 segment.buffers().len(),
815 "grid-sync segment output names",
816 )?;
817 for buffer in segment.buffers() {
818 if matches!(buffer.access(), BufferAccess::Workgroup) {
819 continue;
820 }
821 if segment_buffer_produces_output(buffer) {
822 names.push(Ident::from(buffer.name()));
823 }
824 }
825 Ok(names)
826}
827
828fn original_input_names(program: &Program) -> Result<Vec<Ident>, BackendError> {
829 segment_input_names(program)
830}
831
832fn original_output_names(program: &Program) -> Result<Vec<Ident>, BackendError> {
833 segment_output_names(program)
834}
835
836fn segment_buffer_consumes_input(buffer: &BufferDecl) -> bool {
837 if buffer.is_output() || buffer.is_pipeline_live_out() {
838 return false;
839 }
840 matches!(
841 buffer.access(),
842 BufferAccess::ReadOnly | BufferAccess::ReadWrite | BufferAccess::Uniform
843 )
844}
845
846fn segment_buffer_produces_output(buffer: &BufferDecl) -> bool {
847 buffer.is_output()
848 || buffer.is_pipeline_live_out()
849 || matches!(
850 buffer.access(),
851 BufferAccess::ReadWrite | BufferAccess::WriteOnly
852 )
853}
854
855fn collect_segment_buffer_targets(
856 node: &Node,
857 reads: &mut HashSet<Ident>,
858 writes: &mut HashSet<Ident>,
859) {
860 match node {
861 Node::Let { value, .. } | Node::Assign { value, .. } => {
862 collect_segment_expr_targets(value, reads, writes);
863 }
864 Node::Store {
865 buffer,
866 index,
867 value,
868 } => {
869 writes.insert(Ident::from(buffer));
870 collect_segment_expr_targets(index, reads, writes);
871 collect_segment_expr_targets(value, reads, writes);
872 }
873 Node::If {
874 cond,
875 then,
876 otherwise,
877 } => {
878 collect_segment_expr_targets(cond, reads, writes);
879 for child in then.iter().chain(otherwise.iter()) {
880 collect_segment_buffer_targets(child, reads, writes);
881 }
882 }
883 Node::Loop { from, to, body, .. } => {
884 collect_segment_expr_targets(from, reads, writes);
885 collect_segment_expr_targets(to, reads, writes);
886 for child in body {
887 collect_segment_buffer_targets(child, reads, writes);
888 }
889 }
890 Node::Block(body) => {
891 for child in body {
892 collect_segment_buffer_targets(child, reads, writes);
893 }
894 }
895 Node::Region { body, .. } => {
896 for child in body.iter() {
897 collect_segment_buffer_targets(child, reads, writes);
898 }
899 }
900 Node::AllReduce { buffer, .. } | Node::Broadcast { buffer, .. } => {
901 reads.insert(buffer.clone());
902 writes.insert(buffer.clone());
903 }
904 Node::AllGather { input, output, .. } | Node::ReduceScatter { input, output, .. } => {
905 reads.insert(input.clone());
906 writes.insert(output.clone());
907 }
908 Node::IndirectDispatch { .. }
909 | Node::Return
910 | Node::Barrier { .. }
911 | Node::AsyncLoad { .. }
912 | Node::AsyncStore { .. }
913 | Node::AsyncWait { .. }
914 | Node::Trap { .. }
915 | Node::Resume { .. }
916 | Node::Opaque(_) => {}
917 _ => {}
918 }
919}
920
921fn collect_segment_expr_targets(
922 expr: &Expr,
923 reads: &mut HashSet<Ident>,
924 writes: &mut HashSet<Ident>,
925) {
926 match expr {
927 Expr::Load { buffer, index } => {
928 reads.insert(Ident::from(buffer));
929 collect_segment_expr_targets(index, reads, writes);
930 }
931 Expr::Atomic {
932 buffer,
933 index,
934 expected,
935 value,
936 ..
937 } => {
938 let name = Ident::from(buffer);
939 reads.insert(name.clone());
940 writes.insert(name);
941 collect_segment_expr_targets(index, reads, writes);
942 if let Some(expected) = expected {
943 collect_segment_expr_targets(expected, reads, writes);
944 }
945 collect_segment_expr_targets(value, reads, writes);
946 }
947 Expr::BinOp { left, right, .. } => {
948 collect_segment_expr_targets(left, reads, writes);
949 collect_segment_expr_targets(right, reads, writes);
950 }
951 Expr::UnOp { operand, .. } | Expr::Cast { value: operand, .. } => {
952 collect_segment_expr_targets(operand, reads, writes);
953 }
954 Expr::Fma { a, b, c } => {
955 collect_segment_expr_targets(a, reads, writes);
956 collect_segment_expr_targets(b, reads, writes);
957 collect_segment_expr_targets(c, reads, writes);
958 }
959 Expr::Call { args, .. } => {
960 for arg in args {
961 collect_segment_expr_targets(arg, reads, writes);
962 }
963 }
964 Expr::Select {
965 cond,
966 true_val,
967 false_val,
968 } => {
969 collect_segment_expr_targets(cond, reads, writes);
970 collect_segment_expr_targets(true_val, reads, writes);
971 collect_segment_expr_targets(false_val, reads, writes);
972 }
973 Expr::SubgroupBallot { cond } => collect_segment_expr_targets(cond, reads, writes),
974 Expr::SubgroupShuffle { value, lane } => {
975 collect_segment_expr_targets(value, reads, writes);
976 collect_segment_expr_targets(lane, reads, writes);
977 }
978 Expr::SubgroupAdd { value } => collect_segment_expr_targets(value, reads, writes),
979 _ => {}
980 }
981}
982
983pub fn dispatch_with_grid_sync_split(
1003 backend: &dyn VyreBackend,
1004 program: &Program,
1005 inputs: &[&[u8]],
1006 config: &DispatchConfig,
1007) -> Result<Vec<Vec<u8>>, BackendError> {
1008 let mut outputs = Vec::new();
1009 reserve_grid_sync_vec(
1010 &mut outputs,
1011 program.output_buffer_indices().len().max(1),
1012 "grid-sync final outputs",
1013 )?;
1014 dispatch_with_grid_sync_split_into(backend, program, inputs, config, &mut outputs)?;
1015 Ok(outputs)
1016}
1017
1018pub fn dispatch_with_grid_sync_split_timed(
1023 backend: &dyn VyreBackend,
1024 program: &Program,
1025 inputs: &[&[u8]],
1026 config: &DispatchConfig,
1027) -> Result<TimedDispatchResult, BackendError> {
1028 let started = std::time::Instant::now();
1029 let outputs = dispatch_with_grid_sync_split(backend, program, inputs, config)?;
1030 Ok(TimedDispatchResult {
1031 outputs,
1032 wall_ns: elapsed_wall_ns(started)?,
1033 device_ns: None,
1034 enqueue_ns: None,
1035 wait_ns: None,
1036 })
1037}
1038
1039pub fn dispatch_resident_with_grid_sync_split_timed(
1049 backend: &dyn VyreBackend,
1050 program: &Program,
1051 resources: &[Resource],
1052 config: &DispatchConfig,
1053) -> Result<TimedDispatchResult, BackendError> {
1054 if !contains_grid_sync(program) {
1061 return backend.dispatch_resident_timed(program, resources, config);
1062 }
1063 let segments = try_split_on_grid_sync(program)?;
1064 if segments.is_empty() {
1065 return Err(BackendError::InvalidProgram {
1066 fix: "Fix: program contains GridSync barrier but split_on_grid_sync produced 0 \
1067 segments. This is a grid_sync invariant bug - split_on_grid_sync must \
1068 always return at least one segment."
1069 .to_string(),
1070 });
1071 }
1072 let started = std::time::Instant::now();
1073 let mut final_outputs = Vec::new();
1074 let mut device_ns = Some(0_u64);
1075 let mut enqueue_ns = Some(0_u64);
1076 let mut wait_ns = Some(0_u64);
1077 for (segment_idx, segment) in segments.iter().enumerate() {
1078 let timed = backend
1079 .dispatch_resident_timed(segment, resources, config)
1080 .map_err(|error| grid_sync_segment_error(error, segment_idx, segments.len()))?;
1081 if segment_idx + 1 == segments.len() {
1082 final_outputs = timed.outputs;
1083 }
1084 device_ns = sum_optional_timing(device_ns, timed.device_ns, "device timing")?;
1085 enqueue_ns = sum_optional_timing(enqueue_ns, timed.enqueue_ns, "enqueue timing")?;
1086 wait_ns = sum_optional_timing(wait_ns, timed.wait_ns, "wait timing")?;
1087 }
1088 Ok(TimedDispatchResult {
1089 outputs: final_outputs,
1090 wall_ns: elapsed_wall_ns(started)?,
1091 device_ns,
1092 enqueue_ns,
1093 wait_ns,
1094 })
1095}
1096
1097fn elapsed_wall_ns(started: std::time::Instant) -> Result<u64, BackendError> {
1098 u64::try_from(started.elapsed().as_nanos()).map_err(|error| BackendError::InvalidProgram {
1099 fix: format!(
1100 "Fix: grid-sync segmented wall timing cannot fit u64 nanoseconds: {error}. Split telemetry windows or report per-segment timing."
1101 ),
1102 })
1103}
1104
1105fn sum_optional_timing(
1106 accumulator: Option<u64>,
1107 next: Option<u64>,
1108 field: &'static str,
1109) -> Result<Option<u64>, BackendError> {
1110 match (accumulator, next) {
1111 (Some(left), Some(right)) => Ok(Some(left.checked_add(right).ok_or_else(|| {
1112 BackendError::InvalidProgram {
1113 fix: format!(
1114 "Fix: grid-sync segmented {field} overflowed u64 nanoseconds. Split telemetry windows or report per-segment timing instead of silently clamping."
1115 ),
1116 }
1117 })?)),
1118 _ => Ok(None),
1119 }
1120}
1121
1122pub fn dispatch_with_grid_sync_split_into(
1128 backend: &dyn VyreBackend,
1129 program: &Program,
1130 inputs: &[&[u8]],
1131 config: &DispatchConfig,
1132 outputs: &mut OutputBuffers,
1133) -> Result<(), BackendError> {
1134 if !contains_grid_sync(program) {
1141 return backend.dispatch_borrowed_into(program, inputs, config, outputs);
1142 }
1143 let segments = plan_host_grid_sync_segments(program)?;
1144 if segments.is_empty() {
1145 return Err(BackendError::InvalidProgram {
1146 fix: "Fix: program contains GridSync barrier but split_on_grid_sync produced 0 \
1147 segments. This is a grid_sync invariant bug - split_on_grid_sync must \
1148 always return at least one segment."
1149 .to_string(),
1150 });
1151 }
1152 crate::observability::record_grid_sync_split(segments.len());
1153 let initial_input_names = original_input_names(program)?;
1160 if inputs.len() != initial_input_names.len() {
1161 return Err(BackendError::InvalidProgram {
1162 fix: format!(
1163 "Fix: grid-sync split expected {} initial input buffer(s) but received {}. Rebuild the dispatch inputs from the Program buffer declarations before splitting.",
1164 initial_input_names.len(),
1165 inputs.len()
1166 ),
1167 });
1168 }
1169 let mut current_inputs: HashMap<Ident, GridSyncInput<'_>> = HashMap::new();
1170 reserve_grid_sync_hash_map(
1171 &mut current_inputs,
1172 program.buffers().len(),
1173 "grid-sync rotating input map",
1174 )?;
1175 for (name, bytes) in initial_input_names.into_iter().zip(inputs.iter().copied()) {
1176 current_inputs.insert(name, GridSyncInput::Borrowed(bytes));
1177 }
1178 let mut segment_outputs = Vec::new();
1179 reserve_grid_sync_vec(
1180 &mut segment_outputs,
1181 outputs.capacity().max(1),
1182 "grid-sync intermediate outputs",
1183 )?;
1184 let final_output_names = original_output_names(program)?;
1185
1186 let iterations = crate::fixpoint_iterations::resolve_fixpoint_iterations(
1208 config,
1209 "grid-sync split",
1210 )?;
1211 let mut segment_config = config.clone();
1212 segment_config.fixpoint_iterations = Some(1);
1213
1214 let mut prev_fingerprint: Option<u64> = None;
1222 for _ in 0..iterations {
1223 for (segment_idx, segment) in segments.iter().enumerate() {
1224 let borrowed = borrowed_grid_sync_inputs_by_name(segment, ¤t_inputs)?;
1225 backend
1226 .dispatch_borrowed_into(
1227 &segment.program,
1228 borrowed.as_slice(),
1229 &segment_config,
1230 &mut segment_outputs,
1231 )
1232 .map_err(|error| grid_sync_segment_error(error, segment_idx, segments.len()))?;
1233 drop(borrowed);
1234 refresh_named_outputs(segment, &mut segment_outputs, &mut current_inputs)?;
1235 }
1236 let fingerprint = owned_accumulator_fingerprint(¤t_inputs);
1237 if prev_fingerprint == Some(fingerprint) {
1238 break;
1239 }
1240 prev_fingerprint = Some(fingerprint);
1241 }
1242 collect_final_named_outputs(&final_output_names, &mut current_inputs, outputs)?;
1243 Ok(())
1244}
1245
1246pub fn dispatch_resident_grid_sync_fixpoint_into(
1284 backend: &dyn VyreBackend,
1285 program: &Program,
1286 inputs: &[&[u8]],
1287 config: &DispatchConfig,
1288 outputs: &mut OutputBuffers,
1289) -> Result<(), BackendError> {
1290 if !contains_grid_sync(program) {
1297 return backend.dispatch_borrowed_into(program, inputs, config, outputs);
1298 }
1299 let segments = try_split_on_grid_sync(program)?;
1300 if segments.is_empty() {
1301 return Err(BackendError::InvalidProgram {
1302 fix: "Fix: program contains GridSync barrier but split_on_grid_sync produced 0 \
1303 segments. This is a grid_sync invariant bug - split_on_grid_sync must \
1304 always return at least one segment."
1305 .to_string(),
1306 });
1307 }
1308 crate::observability::record_grid_sync_split(segments.len());
1309
1310 let resident = allocate_resident_program_resources(backend, program, inputs)?;
1314 let result =
1315 run_resident_grid_sync_fixpoint(backend, program, &segments, &resident, config, outputs);
1316 let free_result = free_resident_program_resources(backend, resident);
1318 result.and(free_result)
1319}
1320
1321struct ResidentProgramResources {
1325 ordered: Vec<Resource>,
1328 by_name: HashMap<Ident, (Resource, usize)>,
1331}
1332
1333fn allocate_resident_program_resources(
1339 backend: &dyn VyreBackend,
1340 program: &Program,
1341 inputs: &[&[u8]],
1342) -> Result<ResidentProgramResources, BackendError> {
1343 let plan = BindingPlan::from_borrowed_inputs(program, inputs)?;
1344 let mut ordered = Vec::new();
1345 reserve_grid_sync_vec(&mut ordered, plan.bindings.len(), "resident grid-sync resources")?;
1346 let mut by_name = HashMap::new();
1347 reserve_grid_sync_hash_map(
1348 &mut by_name,
1349 plan.bindings.len(),
1350 "resident grid-sync resource name map",
1351 )?;
1352 for binding in &plan.bindings {
1353 if binding.role == BindingRole::Shared {
1354 continue;
1355 }
1356 let byte_len = resident_binding_byte_len(binding, inputs)?;
1365 let alloc_len = byte_len.max(binding.element_size.max(1));
1366 let resource = backend.allocate_resident(alloc_len)?;
1367 match binding.input_index {
1371 Some(index) if !inputs.get(index).copied().unwrap_or(&[]).is_empty() => {
1372 let bytes = inputs[index];
1373 backend.upload_resident(&resource, bytes)?;
1374 }
1375 _ => {
1376 let zeros = zeroed_upload_buffer(alloc_len)?;
1377 backend.upload_resident(&resource, &zeros)?;
1378 }
1379 }
1380 by_name.insert(
1381 Ident::from(binding.name.as_ref()),
1382 (resource.clone(), byte_len),
1383 );
1384 ordered.push(resource);
1385 }
1386 Ok(ResidentProgramResources { ordered, by_name })
1387}
1388
1389fn resident_binding_byte_len(
1392 binding: &Binding,
1393 inputs: &[&[u8]],
1394) -> Result<usize, BackendError> {
1395 if let Some(index) = binding.input_index {
1396 if let Some(bytes) = inputs.get(index) {
1397 return Ok(bytes.len());
1398 }
1399 }
1400 binding.static_byte_len.ok_or_else(|| BackendError::InvalidProgram {
1401 fix: format!(
1402 "Fix: resident grid-sync output buffer `{}` has no static byte length; dynamic-sized outputs are not supported on the resident grid-sync path. Declare a fixed `count` on the buffer or route this program through dispatch_with_grid_sync_split_into.",
1403 binding.name
1404 ),
1405 })
1406}
1407
1408fn zeroed_upload_buffer(byte_len: usize) -> Result<Vec<u8>, BackendError> {
1411 let mut zeros = Vec::new();
1412 crate::allocation::try_reserve_vec_to_capacity(&mut zeros, byte_len).map_err(|error| {
1413 BackendError::InvalidProgram {
1414 fix: format!(
1415 "Fix: failed to reserve a {byte_len}-byte zero-init staging buffer for a resident grid-sync output: {error}. Shard the program into smaller buffers."
1416 ),
1417 }
1418 })?;
1419 zeros.resize(byte_len, 0);
1420 Ok(zeros)
1421}
1422
1423fn run_resident_grid_sync_fixpoint(
1427 backend: &dyn VyreBackend,
1428 program: &Program,
1429 segments: &[Program],
1430 resident: &ResidentProgramResources,
1431 config: &DispatchConfig,
1432 outputs: &mut OutputBuffers,
1433) -> Result<(), BackendError> {
1434 let iterations =
1435 crate::fixpoint_iterations::resolve_fixpoint_iterations(config, "resident grid-sync split")?;
1436 let repeat_count = u32::try_from(iterations).map_err(|error| BackendError::InvalidProgram {
1437 fix: format!(
1438 "Fix: resident grid-sync fixpoint iteration count {iterations} does not fit u32: {error}."
1439 ),
1440 })?;
1441
1442 let mut steps = Vec::new();
1445 reserve_grid_sync_vec(&mut steps, segments.len(), "resident grid-sync steps")?;
1446 for segment in segments {
1447 steps.push(ResidentDispatchStep {
1448 program: segment,
1449 resources: resident.ordered.as_slice(),
1450 grid_override: config.grid_override,
1451 workgroup_override: config.workgroup_override,
1455 });
1456 }
1457
1458 let output_names = original_output_names(program)?;
1461 let mut read_ranges = Vec::new();
1462 reserve_grid_sync_vec(&mut read_ranges, output_names.len(), "resident grid-sync read ranges")?;
1463 for name in &output_names {
1464 let (resource, byte_len) =
1465 resident.by_name.get(name).ok_or_else(|| BackendError::InvalidProgram {
1466 fix: format!(
1467 "Fix: resident grid-sync final output `{name}` has no resident resource; it was not declared as a non-shared program buffer."
1468 ),
1469 })?;
1470 read_ranges.push(ResidentReadRange {
1471 resource,
1472 byte_offset: 0,
1473 byte_len: *byte_len,
1474 });
1475 }
1476
1477 while outputs.len() < output_names.len() {
1480 outputs.push(Vec::new());
1481 }
1482 outputs.truncate(output_names.len());
1483 for slot in outputs.iter_mut() {
1484 slot.clear();
1485 }
1486 let mut output_refs: Vec<&mut Vec<u8>> = outputs.iter_mut().collect();
1487
1488 backend.dispatch_resident_repeated_sequence_read_ranges_into(
1489 &[],
1490 &steps,
1491 repeat_count,
1492 &read_ranges,
1493 output_refs.as_mut_slice(),
1494 )
1495}
1496
1497fn free_resident_program_resources(
1501 backend: &dyn VyreBackend,
1502 resident: ResidentProgramResources,
1503) -> Result<(), BackendError> {
1504 let ResidentProgramResources { ordered, by_name } = resident;
1505 drop(by_name);
1508 let mut first_error: Option<BackendError> = None;
1509 for resource in ordered {
1510 if let Err(error) = backend.free_resident(resource) {
1511 if first_error.is_none() {
1512 first_error = Some(error);
1513 }
1514 }
1515 }
1516 match first_error {
1517 Some(error) => Err(error),
1518 None => Ok(()),
1519 }
1520}
1521
1522fn reserve_grid_sync_vec<T>(
1523 vec: &mut Vec<T>,
1524 capacity: usize,
1525 field: &'static str,
1526) -> Result<(), BackendError> {
1527 crate::allocation::try_reserve_vec_to_capacity(vec, capacity).map_err(|error| {
1528 BackendError::InvalidProgram {
1529 fix: format!(
1530 "Fix: failed to reserve {field} for {capacity} entries during grid-sync dispatch splitting: {error}. Split the program into fewer grid-sync segments or run on a backend with native grid sync."
1531 ),
1532 }
1533 })
1534}
1535
1536fn reserve_grid_sync_hash_map<K, V>(
1537 map: &mut HashMap<K, V>,
1538 capacity: usize,
1539 field: &'static str,
1540) -> Result<(), BackendError>
1541where
1542 K: Eq + std::hash::Hash,
1543{
1544 map.try_reserve(capacity)
1545 .map_err(|error| BackendError::InvalidProgram {
1546 fix: format!(
1547 "Fix: failed to reserve {field} for {capacity} entries during grid-sync dispatch splitting: {error}. Split the program into fewer grid-sync segments or run on a backend with native grid sync."
1548 ),
1549 })
1550}
1551
1552fn reserve_grid_sync_hash_set<T>(
1553 set: &mut HashSet<T>,
1554 capacity: usize,
1555 field: &'static str,
1556) -> Result<(), BackendError>
1557where
1558 T: Eq + std::hash::Hash,
1559{
1560 set.try_reserve(capacity)
1561 .map_err(|error| BackendError::InvalidProgram {
1562 fix: format!(
1563 "Fix: failed to reserve {field} for {capacity} entries during grid-sync dispatch splitting: {error}. Split the program into fewer grid-sync segments or run on a backend with native grid sync."
1564 ),
1565 })
1566}
1567
1568fn borrowed_grid_sync_inputs<'a>(
1569 inputs: &'a [GridSyncInput<'a>],
1570) -> Result<SmallVec<[&'a [u8]; 8]>, BackendError> {
1571 let mut borrowed = SmallVec::<[&[u8]; 8]>::new();
1572 borrowed.try_reserve(inputs.len()).map_err(|error| {
1573 BackendError::InvalidProgram {
1574 fix: format!(
1575 "Fix: failed to reserve grid-sync borrowed input slices for {} input(s): {error}. Split the program into fewer grid-sync live buffers or run on a backend with native grid sync.",
1576 inputs.len()
1577 ),
1578 }
1579 })?;
1580 borrowed.extend(inputs.iter().map(GridSyncInput::as_slice));
1581 Ok(borrowed)
1582}
1583
1584fn borrowed_grid_sync_inputs_by_name<'a>(
1585 segment: &PlannedGridSyncSegment,
1586 inputs: &'a HashMap<Ident, GridSyncInput<'a>>,
1587) -> Result<SmallVec<[&'a [u8]; 8]>, BackendError> {
1588 let mut borrowed = SmallVec::<[&[u8]; 8]>::new();
1589 borrowed
1590 .try_reserve(segment.input_names.len())
1591 .map_err(|error| BackendError::InvalidProgram {
1592 fix: format!(
1593 "Fix: failed to reserve grid-sync borrowed input slices for {} segment input(s): {error}. Split the program into fewer grid-sync live buffers or run on a backend with native grid sync.",
1594 segment.input_names.len()
1595 ),
1596 })?;
1597 for name in &segment.input_names {
1598 let input = inputs.get(name).ok_or_else(|| BackendError::InvalidProgram {
1599 fix: format!(
1600 "Fix: grid-sync segment input `{name}` has no bytes from caller input or a prior segment output. Ensure every cross-segment read is written before the GridSync barrier."
1601 ),
1602 })?;
1603 borrowed.push(input.as_slice());
1604 }
1605 Ok(borrowed)
1606}
1607
1608fn owned_accumulator_fingerprint(inputs: &HashMap<Ident, GridSyncInput<'_>>) -> u64 {
1620 const FNV_OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
1621 const FNV_PRIME: u64 = 0x0000_0100_0000_01b3;
1622 let mut combined: u64 = 0;
1623 for (name, input) in inputs {
1624 let GridSyncInput::Owned(bytes) = input else {
1625 continue;
1626 };
1627 let mut hash = FNV_OFFSET;
1628 for byte in name.as_str().as_bytes() {
1629 hash ^= u64::from(*byte);
1630 hash = hash.wrapping_mul(FNV_PRIME);
1631 }
1632 hash ^= 0xff;
1634 hash = hash.wrapping_mul(FNV_PRIME);
1635 for byte in bytes.iter() {
1636 hash ^= u64::from(*byte);
1637 hash = hash.wrapping_mul(FNV_PRIME);
1638 }
1639 combined ^= hash;
1640 }
1641 combined
1642}
1643
1644fn grid_sync_segment_error(
1645 error: BackendError,
1646 segment_idx: usize,
1647 segment_count: usize,
1648) -> BackendError {
1649 match error {
1650 BackendError::InvalidProgram { fix } => BackendError::InvalidProgram {
1651 fix: format!(
1652 "Fix: grid-sync split segment {segment_idx} of {segment_count} dispatch failed: {fix}"
1653 ),
1654 },
1655 other => other,
1656 }
1657}
1658
1659enum GridSyncInput<'a> {
1660 Borrowed(&'a [u8]),
1661 Owned(Vec<u8>),
1662}
1663
1664impl GridSyncInput<'_> {
1665 fn as_slice(&self) -> &[u8] {
1666 match self {
1667 Self::Borrowed(bytes) => bytes,
1668 Self::Owned(bytes) => bytes.as_slice(),
1669 }
1670 }
1671
1672 fn refresh_from_output(&mut self, bytes: &mut Vec<u8>) -> Result<(), BackendError> {
1673 match self {
1674 Self::Borrowed(_) => {
1675 let mut owned = Vec::new();
1676 reserve_grid_sync_vec(&mut owned, bytes.len(), "grid-sync readwrite input")?;
1677 owned.extend_from_slice(bytes);
1678 *self = Self::Owned(owned);
1679 }
1680 Self::Owned(owned) => {
1681 std::mem::swap(owned, bytes);
1682 }
1683 }
1684 Ok(())
1685 }
1686}
1687
1688fn refresh_named_outputs<'a>(
1689 segment: &PlannedGridSyncSegment,
1690 outputs: &mut Vec<Vec<u8>>,
1691 inputs: &mut HashMap<Ident, GridSyncInput<'a>>,
1692) -> Result<(), BackendError> {
1693 if outputs.len() != segment.output_names.len() {
1694 return Err(BackendError::InvalidProgram {
1695 fix: format!(
1696 "Fix: grid-sync split segment produced {} output slot(s) but the planned buffer map expected {}. Preserve segment output declaration order when dispatching split kernels.",
1697 outputs.len(),
1698 segment.output_names.len()
1699 ),
1700 });
1701 }
1702 for (name, bytes) in segment.output_names.iter().cloned().zip(outputs.iter_mut()) {
1703 match inputs.get_mut(&name) {
1704 Some(slot) => slot.refresh_from_output(bytes)?,
1705 None => {
1706 let mut owned = GridSyncInput::Owned(Vec::new());
1707 owned.refresh_from_output(bytes)?;
1708 inputs.insert(name, owned);
1709 }
1710 }
1711 }
1712 for output in outputs {
1713 output.clear();
1714 }
1715 Ok(())
1716}
1717
1718fn collect_final_named_outputs<'a>(
1719 final_output_names: &[Ident],
1720 inputs: &mut HashMap<Ident, GridSyncInput<'a>>,
1721 outputs: &mut OutputBuffers,
1722) -> Result<(), BackendError> {
1723 let mut final_outputs = Vec::new();
1724 reserve_grid_sync_vec(
1725 &mut final_outputs,
1726 final_output_names.len(),
1727 "grid-sync final named outputs",
1728 )?;
1729 for name in final_output_names {
1730 let output = inputs
1731 .remove(name)
1732 .ok_or_else(|| BackendError::InvalidProgram {
1733 fix: format!(
1734 "Fix: grid-sync final output `{name}` was not produced by any split segment."
1735 ),
1736 })?;
1737 match output {
1738 GridSyncInput::Owned(bytes) => final_outputs.push(bytes),
1739 GridSyncInput::Borrowed(bytes) => {
1740 let mut owned = Vec::new();
1741 reserve_grid_sync_vec(&mut owned, bytes.len(), "grid-sync borrowed final output")?;
1742 owned.extend_from_slice(bytes);
1743 final_outputs.push(owned);
1744 }
1745 }
1746 }
1747 crate::replace_output_buffers_preserving_slots(final_outputs, outputs);
1748 Ok(())
1749}
1750
1751fn refresh_readwrite_inputs(
1758 segment: &Program,
1759 outputs: &mut Vec<Vec<u8>>,
1760 inputs: &mut [GridSyncInput<'_>],
1761) -> Result<(), BackendError> {
1762 use vyre_foundation::ir::BufferAccess;
1763 let mut input_idx = 0usize;
1769 let mut output_idx = 0usize;
1770 for buffer in segment.buffers() {
1771 if matches!(buffer.access(), BufferAccess::Workgroup) {
1772 continue;
1773 }
1774 let is_output_buffer = buffer.is_output();
1775 let is_readwrite = matches!(buffer.access(), BufferAccess::ReadWrite);
1776
1777 if is_readwrite && !is_output_buffer {
1781 if let (Some(slot), Some(bytes)) =
1782 (inputs.get_mut(input_idx), outputs.get_mut(output_idx))
1783 {
1784 slot.refresh_from_output(bytes)?;
1785 }
1786 }
1787
1788 if !is_output_buffer {
1790 input_idx += 1;
1791 }
1792 if is_readwrite {
1795 output_idx += 1;
1796 }
1797 }
1798 for output in outputs {
1799 output.clear();
1800 }
1801 Ok(())
1802}
1803
1804#[cfg(test)]
1805mod tests {
1806 use super::*;
1807 use std::sync::atomic::{AtomicUsize, Ordering};
1808 use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr};
1809
1810 fn buffer() -> BufferDecl {
1811 BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
1812 }
1813
1814 fn region(generator: &str, body: Vec<Node>) -> Node {
1815 Node::Region {
1816 generator: Ident::from(generator),
1817 source_region: None,
1818 body: Arc::new(body),
1819 }
1820 }
1821
1822 #[test]
1823 fn grid_sync_release_paths_use_fallible_split_storage() {
1824 let source = include_str!("grid_sync.rs");
1825 let production = source
1826 .split("#[cfg(test)]")
1827 .next()
1828 .expect("Fix: grid-sync production source must precede tests");
1829
1830 assert!(
1831 production.contains("pub fn try_split_on_grid_sync")
1832 && production.contains("fn reserve_grid_sync_vec")
1833 && production.contains("try_reserve_vec_to_capacity"),
1834 "Fix: grid-sync splitting must expose fallible segment/input/output scratch reservation."
1835 );
1836 assert!(
1837 production.contains("let segments = try_split_on_grid_sync(program)?")
1838 && !production.contains("let segments = split_on_grid_sync(program);"),
1839 "Fix: production grid-sync dispatch paths must use fallible splitting, not the legacy infallible helper."
1840 );
1841 assert!(
1842 !production.contains("Vec::with_capacity"),
1843 "Fix: production grid-sync splitting must not allocate dispatch scratch infallibly."
1844 );
1845 assert!(
1846 !production.contains(".as_nanos() as u64")
1847 && !production.contains("segmented timing overflowed u64"),
1848 "Fix: production grid-sync timing telemetry must return typed errors instead of truncating or panicking."
1849 );
1850 }
1851
1852 fn inner_len(program: &Program) -> usize {
1854 entry_sequence(program).len()
1855 }
1856
1857 #[test]
1858 fn no_grid_sync_returns_single_segment() {
1859 let program = Program::wrapped(
1860 vec![buffer()],
1861 [1, 1, 1],
1862 vec![region(
1863 "a",
1864 vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
1865 )],
1866 );
1867 assert!(!contains_grid_sync(&program));
1868 let segments = split_on_grid_sync(&program);
1869 assert_eq!(segments.len(), 1);
1870 assert_eq!(inner_len(&segments[0]), 1);
1872 }
1873
1874 #[test]
1875 fn one_grid_sync_splits_into_two() {
1876 let program = Program::wrapped(
1877 vec![buffer()],
1878 [1, 1, 1],
1879 vec![
1880 region("a", vec![Node::store("buf", Expr::u32(0), Expr::u32(1))]),
1881 Node::barrier_with_ordering(MemoryOrdering::GridSync),
1882 region("b", vec![Node::store("buf", Expr::u32(1), Expr::u32(2))]),
1883 ],
1884 );
1885 assert!(contains_grid_sync(&program));
1886 let segments = split_on_grid_sync(&program);
1887 assert_eq!(segments.len(), 2);
1888 assert_eq!(inner_len(&segments[0]), 1);
1889 assert_eq!(inner_len(&segments[1]), 1);
1890 }
1891
1892 #[test]
1893 fn block_nested_grid_sync_splits_into_two() {
1894 let program = Program::wrapped(
1895 vec![buffer()],
1896 [1, 1, 1],
1897 vec![Node::Block(vec![
1898 region("a", vec![Node::store("buf", Expr::u32(0), Expr::u32(1))]),
1899 Node::barrier_with_ordering(MemoryOrdering::GridSync),
1900 region("b", vec![Node::store("buf", Expr::u32(1), Expr::u32(2))]),
1901 ])],
1902 );
1903 assert!(contains_grid_sync(&program));
1904 let segments = split_on_grid_sync(&program);
1905 assert_eq!(segments.len(), 2);
1906 assert_eq!(inner_len(&segments[0]), 1);
1907 assert_eq!(inner_len(&segments[1]), 1);
1908 }
1909
1910 #[test]
1911 fn three_grid_syncs_split_into_four() {
1912 let program = Program::wrapped(
1913 vec![buffer()],
1914 [1, 1, 1],
1915 vec![
1916 region("a", vec![Node::Return]),
1917 Node::barrier_with_ordering(MemoryOrdering::GridSync),
1918 region("b", vec![Node::Return]),
1919 Node::barrier_with_ordering(MemoryOrdering::GridSync),
1920 region("c", vec![Node::Return]),
1921 Node::barrier_with_ordering(MemoryOrdering::GridSync),
1922 region("d", vec![Node::Return]),
1923 ],
1924 );
1925 let segments = split_on_grid_sync(&program);
1926 assert_eq!(segments.len(), 4);
1927 }
1928
1929 #[test]
1930 fn workgroup_barrier_does_not_split() {
1931 let program = Program::wrapped(
1932 vec![buffer()],
1933 [1, 1, 1],
1934 vec![
1935 region("a", vec![Node::Return]),
1936 Node::barrier_with_ordering(MemoryOrdering::SeqCst),
1937 region("b", vec![Node::Return]),
1938 ],
1939 );
1940 assert!(!contains_grid_sync(&program));
1941 let segments = split_on_grid_sync(&program);
1942 assert_eq!(segments.len(), 1);
1943 assert_eq!(inner_len(&segments[0]), 3);
1945 }
1946
1947 #[test]
1948 fn buffers_and_workgroup_size_propagate_to_each_segment() {
1949 let program = Program::wrapped(
1950 vec![buffer()],
1951 [256, 1, 1],
1952 vec![
1953 region("a", vec![Node::Return]),
1954 Node::barrier_with_ordering(MemoryOrdering::GridSync),
1955 region("b", vec![Node::Return]),
1956 ],
1957 );
1958 let segments = split_on_grid_sync(&program);
1959 for seg in &segments {
1960 assert_eq!(seg.workgroup_size(), [256, 1, 1]);
1961 assert_eq!(seg.buffers().len(), 1);
1962 assert_eq!(seg.buffers()[0].name(), "buf");
1963 }
1964 }
1965
1966 #[test]
1967 fn refresh_readwrite_inputs_swaps_owned_buffers_after_first_segment() {
1968 let segment = Program::wrapped(vec![buffer()], [1, 1, 1], vec![Node::Return]);
1969 let initial = [1u8, 0, 0, 0];
1970 let mut inputs = [GridSyncInput::Borrowed(initial.as_slice())];
1971 let mut outputs = vec![Vec::with_capacity(8)];
1972 let output_ptr = outputs[0].as_ptr() as usize;
1973 outputs[0].extend_from_slice(&[2, 0, 0, 0]);
1974
1975 refresh_readwrite_inputs(&segment, &mut outputs, &mut inputs)
1976 .expect("Fix: test readwrite refresh should fit borrowed promotion storage");
1977
1978 let first_owned_ptr = match &inputs[0] {
1979 GridSyncInput::Owned(bytes) => {
1980 assert_eq!(bytes, &[2, 0, 0, 0]);
1981 bytes.as_ptr() as usize
1982 }
1983 GridSyncInput::Borrowed(_) => panic!("ReadWrite input must become owned after refresh"),
1984 };
1985 assert_eq!(outputs[0].as_ptr() as usize, output_ptr);
1986 assert!(outputs[0].is_empty());
1987
1988 outputs[0].extend_from_slice(&[3, 0, 0, 0]);
1989 let second_output_ptr = outputs[0].as_ptr() as usize;
1990 refresh_readwrite_inputs(&segment, &mut outputs, &mut inputs)
1991 .expect("Fix: test readwrite refresh should reuse owned storage");
1992
1993 match &inputs[0] {
1994 GridSyncInput::Owned(bytes) => {
1995 assert_eq!(bytes, &[3, 0, 0, 0]);
1996 assert_eq!(
1997 bytes.as_ptr() as usize,
1998 second_output_ptr,
1999 "owned ReadWrite input should take the backend output allocation instead of copying"
2000 );
2001 }
2002 GridSyncInput::Borrowed(_) => panic!("ReadWrite input must remain owned"),
2003 }
2004 assert_eq!(
2005 outputs[0].as_ptr() as usize,
2006 first_owned_ptr,
2007 "backend output slot should receive the previous owned input allocation for reuse"
2008 );
2009 }
2010
2011 struct ReuseCheckingBackend {
2012 calls: AtomicUsize,
2013 final_outputs_addr: usize,
2014 final_slot_addr: usize,
2015 }
2016
2017 impl crate::backend::private::Sealed for ReuseCheckingBackend {}
2018
2019 impl VyreBackend for ReuseCheckingBackend {
2020 fn id(&self) -> &'static str {
2021 "grid-sync-reuse-checking"
2022 }
2023
2024 fn dispatch(
2025 &self,
2026 _program: &Program,
2027 _inputs: &[Vec<u8>],
2028 _config: &DispatchConfig,
2029 ) -> Result<Vec<Vec<u8>>, BackendError> {
2030 unreachable!("test uses dispatch_borrowed_into")
2031 }
2032
2033 fn dispatch_borrowed_into(
2034 &self,
2035 _program: &Program,
2036 inputs: &[&[u8]],
2037 _config: &DispatchConfig,
2038 outputs: &mut OutputBuffers,
2039 ) -> Result<(), BackendError> {
2040 let call = self.calls.fetch_add(1, Ordering::SeqCst);
2041 if call == 1 && self.final_outputs_addr != 0 {
2042 assert_eq!(outputs.as_ptr() as usize, self.final_outputs_addr);
2043 assert_eq!(outputs[0].as_ptr() as usize, self.final_slot_addr);
2044 }
2045 if outputs.is_empty() {
2046 outputs.push(Vec::new());
2047 }
2048 outputs[0].clear();
2049 outputs[0].extend_from_slice(inputs[0]);
2050 if call == 0 {
2051 outputs[0][0] = 7;
2052 } else {
2053 outputs[0][0] = outputs[0][0].saturating_add(1);
2054 }
2055 Ok(())
2056 }
2057 }
2058
2059 #[test]
2060 fn split_into_preserves_caller_output_slot_after_named_output_collection() {
2061 let program = Program::wrapped(
2062 vec![buffer()],
2063 [1, 1, 1],
2064 vec![
2065 region("a", vec![Node::Return]),
2066 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2067 region("b", vec![Node::Return]),
2068 ],
2069 );
2070 let mut outputs = vec![Vec::with_capacity(8)];
2071 let outputs_addr = outputs.as_ptr() as usize;
2072 let slot_addr = outputs[0].as_ptr() as usize;
2073 let backend = ReuseCheckingBackend {
2074 calls: AtomicUsize::new(0),
2075 final_outputs_addr: 0,
2076 final_slot_addr: 0,
2077 };
2078 let input = [0u8, 0, 0, 0];
2079 dispatch_with_grid_sync_split_into(
2080 &backend,
2081 &program,
2082 &[input.as_slice()],
2083 &DispatchConfig::default(),
2084 &mut outputs,
2085 )
2086 .expect("Fix: grid-sync split should write into caller-owned output storage");
2087
2088 assert_eq!(backend.calls.load(Ordering::SeqCst), 2);
2089 assert_eq!(outputs, vec![vec![8, 0, 0, 0]]);
2090 assert_eq!(outputs.as_ptr() as usize, outputs_addr);
2091 assert_eq!(outputs[0].as_ptr() as usize, slot_addr);
2092 }
2093
2094 struct IncrementingBackend {
2102 calls: AtomicUsize,
2103 }
2104
2105 impl crate::backend::private::Sealed for IncrementingBackend {}
2106
2107 impl VyreBackend for IncrementingBackend {
2108 fn id(&self) -> &'static str {
2109 "grid-sync-incrementing"
2110 }
2111
2112 fn dispatch(
2113 &self,
2114 _program: &Program,
2115 _inputs: &[Vec<u8>],
2116 _config: &DispatchConfig,
2117 ) -> Result<Vec<Vec<u8>>, BackendError> {
2118 unreachable!("test uses dispatch_borrowed_into")
2119 }
2120
2121 fn dispatch_borrowed_into(
2122 &self,
2123 _program: &Program,
2124 inputs: &[&[u8]],
2125 config: &DispatchConfig,
2126 outputs: &mut OutputBuffers,
2127 ) -> Result<(), BackendError> {
2128 self.calls.fetch_add(1, Ordering::SeqCst);
2129 assert_eq!(
2132 config.fixpoint_iterations,
2133 Some(1),
2134 "segment dispatch must receive fixpoint_iterations=1; the outer split loop owns the iteration count"
2135 );
2136 if outputs.is_empty() {
2137 outputs.push(Vec::new());
2138 }
2139 outputs[0].clear();
2140 outputs[0].extend_from_slice(inputs[0]);
2141 outputs[0][0] = outputs[0][0].saturating_add(1);
2142 Ok(())
2143 }
2144 }
2145
2146 #[test]
2147 fn split_into_loops_whole_sequence_fixpoint_iterations_times() {
2148 let program = Program::wrapped(
2150 vec![buffer()],
2151 [1, 1, 1],
2152 vec![
2153 region("a", vec![Node::Return]),
2154 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2155 region("b", vec![Node::Return]),
2156 ],
2157 );
2158
2159 let backend = IncrementingBackend {
2161 calls: AtomicUsize::new(0),
2162 };
2163 let mut outputs = vec![Vec::new()];
2164 dispatch_with_grid_sync_split_into(
2165 &backend,
2166 &program,
2167 &[[0u8, 0, 0, 0].as_slice()],
2168 &DispatchConfig::default(),
2169 &mut outputs,
2170 )
2171 .expect("single-pass split dispatch");
2172 assert_eq!(backend.calls.load(Ordering::SeqCst), 2);
2173 assert_eq!(outputs, vec![vec![2, 0, 0, 0]]);
2174
2175 let backend = IncrementingBackend {
2180 calls: AtomicUsize::new(0),
2181 };
2182 let config = DispatchConfig {
2183 fixpoint_iterations: Some(3),
2184 ..DispatchConfig::default()
2185 };
2186 let mut outputs = vec![Vec::new()];
2187 dispatch_with_grid_sync_split_into(
2188 &backend,
2189 &program,
2190 &[[0u8, 0, 0, 0].as_slice()],
2191 &config,
2192 &mut outputs,
2193 )
2194 .expect("multi-pass split dispatch");
2195 assert_eq!(
2196 backend.calls.load(Ordering::SeqCst),
2197 6,
2198 "split must re-run the whole 2-segment sequence 3 times"
2199 );
2200 assert_eq!(
2201 outputs,
2202 vec![vec![6, 0, 0, 0]],
2203 "accumulator must advance one hop per fixpoint pass (2 segments × 3 passes)"
2204 );
2205 }
2206
2207 struct OwnedFinalReserveBackend {
2208 calls: AtomicUsize,
2209 }
2210
2211 impl crate::backend::private::Sealed for OwnedFinalReserveBackend {}
2212
2213 impl VyreBackend for OwnedFinalReserveBackend {
2214 fn id(&self) -> &'static str {
2215 "grid-sync-owned-final-reserve"
2216 }
2217
2218 fn dispatch(
2219 &self,
2220 _program: &Program,
2221 _inputs: &[Vec<u8>],
2222 _config: &DispatchConfig,
2223 ) -> Result<Vec<Vec<u8>>, BackendError> {
2224 unreachable!("test uses dispatch_borrowed_into")
2225 }
2226
2227 fn dispatch_borrowed_into(
2228 &self,
2229 _program: &Program,
2230 inputs: &[&[u8]],
2231 _config: &DispatchConfig,
2232 outputs: &mut OutputBuffers,
2233 ) -> Result<(), BackendError> {
2234 let call = self.calls.fetch_add(1, Ordering::SeqCst);
2235 if call == 1 {
2236 assert!(
2237 outputs.capacity() >= 1,
2238 "owned grid-sync split wrapper must pre-reserve final output slots before the final segment dispatch"
2239 );
2240 }
2241 if outputs.is_empty() {
2242 outputs.push(Vec::new());
2243 }
2244 outputs[0].clear();
2245 outputs[0].extend_from_slice(inputs[0]);
2246 outputs[0][0] = outputs[0][0].saturating_add(1);
2247 Ok(())
2248 }
2249 }
2250
2251 #[test]
2252 fn split_owned_wrapper_reserves_final_output_vector_before_final_segment() {
2253 let program = Program::wrapped(
2254 vec![buffer()],
2255 [1, 1, 1],
2256 vec![
2257 region("a", vec![Node::Return]),
2258 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2259 region("b", vec![Node::Return]),
2260 ],
2261 );
2262 let backend = OwnedFinalReserveBackend {
2263 calls: AtomicUsize::new(0),
2264 };
2265 let input = [4u8, 0, 0, 0];
2266
2267 let outputs = dispatch_with_grid_sync_split(
2268 &backend,
2269 &program,
2270 &[input.as_slice()],
2271 &DispatchConfig::default(),
2272 )
2273 .expect("Fix: owned grid-sync split should reserve and return final outputs");
2274
2275 assert_eq!(backend.calls.load(Ordering::SeqCst), 2);
2276 assert_eq!(outputs, vec![vec![6, 0, 0, 0]]);
2277 }
2278
2279 #[test]
2280 fn grid_sync_split_records_segment_telemetry() {
2281 let program = Program::wrapped(
2282 vec![buffer()],
2283 [1, 1, 1],
2284 vec![
2285 region("a", vec![Node::Return]),
2286 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2287 region("b", vec![Node::Return]),
2288 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2289 region("c", vec![Node::Return]),
2290 ],
2291 );
2292 let backend = ReuseCheckingBackend {
2293 calls: AtomicUsize::new(0),
2294 final_outputs_addr: 0,
2295 final_slot_addr: 0,
2296 };
2297 let before = crate::observability::snapshot_dispatch_telemetry();
2298 let input = [0u8, 0, 0, 0];
2299 let mut outputs = Vec::new();
2300
2301 dispatch_with_grid_sync_split_into(
2302 &backend,
2303 &program,
2304 &[input.as_slice()],
2305 &DispatchConfig::default(),
2306 &mut outputs,
2307 )
2308 .expect("Fix: grid-sync split should dispatch every segment");
2309
2310 let after = crate::observability::snapshot_dispatch_telemetry();
2311 assert_eq!(backend.calls.load(Ordering::SeqCst), 3);
2312 assert!(after.grid_sync_splits >= before.grid_sync_splits + 1);
2313 assert!(after.grid_sync_segments >= before.grid_sync_segments + 3);
2314 assert!(after.grid_sync_points >= before.grid_sync_points + 2);
2315 }
2316
2317 struct IntermediateReuseBackend {
2318 calls: AtomicUsize,
2319 first_outputs_addr: AtomicUsize,
2320 first_slot_addr: AtomicUsize,
2321 }
2322
2323 impl crate::backend::private::Sealed for IntermediateReuseBackend {}
2324
2325 impl VyreBackend for IntermediateReuseBackend {
2326 fn id(&self) -> &'static str {
2327 "grid-sync-intermediate-reuse"
2328 }
2329
2330 fn dispatch(
2331 &self,
2332 _program: &Program,
2333 _inputs: &[Vec<u8>],
2334 _config: &DispatchConfig,
2335 ) -> Result<Vec<Vec<u8>>, BackendError> {
2336 unreachable!("test uses dispatch_borrowed_into")
2337 }
2338
2339 fn dispatch_borrowed_into(
2340 &self,
2341 _program: &Program,
2342 inputs: &[&[u8]],
2343 _config: &DispatchConfig,
2344 outputs: &mut OutputBuffers,
2345 ) -> Result<(), BackendError> {
2346 let call = self.calls.fetch_add(1, Ordering::SeqCst);
2347 if outputs.is_empty() {
2348 outputs.push(Vec::with_capacity(8));
2349 }
2350 if call == 0 {
2351 self.first_outputs_addr
2352 .store(outputs.as_ptr() as usize, Ordering::SeqCst);
2353 self.first_slot_addr
2354 .store(outputs[0].as_ptr() as usize, Ordering::SeqCst);
2355 } else if call == 1 {
2356 assert_eq!(
2357 outputs.as_ptr() as usize,
2358 self.first_outputs_addr.load(Ordering::SeqCst)
2359 );
2360 assert_eq!(
2361 outputs[0].as_ptr() as usize,
2362 self.first_slot_addr.load(Ordering::SeqCst)
2363 );
2364 }
2365 outputs[0].clear();
2366 outputs[0].extend_from_slice(inputs[0]);
2367 outputs[0][0] = outputs[0][0].saturating_add(1);
2368 Ok(())
2369 }
2370 }
2371
2372 #[test]
2373 fn split_reuses_intermediate_output_slot_between_segments() {
2374 let program = Program::wrapped(
2375 vec![buffer()],
2376 [1, 1, 1],
2377 vec![
2378 region("a", vec![Node::Return]),
2379 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2380 region("b", vec![Node::Return]),
2381 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2382 region("c", vec![Node::Return]),
2383 ],
2384 );
2385 let backend = IntermediateReuseBackend {
2386 calls: AtomicUsize::new(0),
2387 first_outputs_addr: AtomicUsize::new(0),
2388 first_slot_addr: AtomicUsize::new(0),
2389 };
2390 let input = [1u8, 0, 0, 0];
2391 let mut outputs = vec![Vec::with_capacity(8)];
2392
2393 dispatch_with_grid_sync_split_into(
2394 &backend,
2395 &program,
2396 &[input.as_slice()],
2397 &DispatchConfig::default(),
2398 &mut outputs,
2399 )
2400 .expect("Fix: grid-sync split should reuse intermediate output scratch");
2401
2402 assert_eq!(backend.calls.load(Ordering::SeqCst), 3);
2403 assert_eq!(outputs, vec![vec![4, 0, 0, 0]]);
2404 }
2405
2406 #[test]
2407 fn split_keeps_multi_segment_output_as_readwrite_accumulator() {
2408 let out = BufferDecl::output("out", 0, DataType::U32).with_count(4);
2417 let program = Program::wrapped(
2418 vec![out],
2419 [1, 1, 1],
2420 vec![
2421 region("a", vec![Node::store("out", Expr::u32(0), Expr::u32(0xAA))]),
2422 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2423 region("b", vec![Node::store("out", Expr::u32(2), Expr::u32(0xBB))]),
2424 ],
2425 );
2426 let segments =
2427 plan_host_grid_sync_segment_programs(&program).expect("plan host grid-sync segments");
2428 assert_eq!(segments.len(), 2, "one GridSync barrier -> two segments");
2429
2430 let seg0_out = segments[0]
2431 .buffers()
2432 .iter()
2433 .find(|b| b.name() == "out")
2434 .expect("segment 0 must declare the output it writes");
2435 assert_eq!(
2436 seg0_out.access(),
2437 BufferAccess::WriteOnly,
2438 "the first writer establishes the accumulator as write-only"
2439 );
2440 assert!(
2441 !seg0_out.is_output() && !seg0_out.is_pipeline_live_out(),
2442 "split segment buffers must never be marked program-output; final values are reassembled by name"
2443 );
2444
2445 let seg1_out = segments[1]
2446 .buffers()
2447 .iter()
2448 .find(|b| b.name() == "out")
2449 .expect("segment 1 must declare the output it writes");
2450 assert_eq!(
2451 seg1_out.access(),
2452 BufferAccess::ReadWrite,
2453 "a later writer of a multi-segment output must read+merge the accumulated value, not overwrite it"
2454 );
2455 assert!(
2456 !seg1_out.is_output() && !seg1_out.is_pipeline_live_out(),
2457 "the later writer must consume its forwarded prior value, which `segment_buffer_consumes_input` refuses for is_output buffers"
2458 );
2459 assert!(
2460 segment_input_names(&segments[1])
2461 .expect("segment 1 input names")
2462 .iter()
2463 .any(|n| n.as_str() == "out"),
2464 "the accumulated output must be forwarded as an input to the later writing segment"
2465 );
2466 }
2467
2468 struct SlotStoringBackend {
2474 calls: AtomicUsize,
2475 }
2476
2477 impl crate::backend::private::Sealed for SlotStoringBackend {}
2478
2479 impl VyreBackend for SlotStoringBackend {
2480 fn id(&self) -> &'static str {
2481 "grid-sync-slot-storing"
2482 }
2483
2484 fn dispatch(
2485 &self,
2486 _program: &Program,
2487 _inputs: &[Vec<u8>],
2488 _config: &DispatchConfig,
2489 ) -> Result<Vec<Vec<u8>>, BackendError> {
2490 unreachable!("test uses dispatch_borrowed_into")
2491 }
2492
2493 fn dispatch_borrowed_into(
2494 &self,
2495 program: &Program,
2496 inputs: &[&[u8]],
2497 _config: &DispatchConfig,
2498 outputs: &mut OutputBuffers,
2499 ) -> Result<(), BackendError> {
2500 let mut in_pos = None;
2503 let mut cur_in = 0usize;
2504 let mut out_pos = None;
2505 let mut cur_out = 0usize;
2506 for buffer in program.buffers() {
2507 if matches!(buffer.access(), BufferAccess::Workgroup) {
2508 continue;
2509 }
2510 let consumes = segment_buffer_consumes_input(buffer);
2511 let produces = segment_buffer_produces_output(buffer);
2512 if buffer.name() == "out" {
2513 if consumes {
2514 in_pos = Some(cur_in);
2515 }
2516 if produces {
2517 out_pos = Some(cur_out);
2518 }
2519 }
2520 if consumes {
2521 cur_in += 1;
2522 }
2523 if produces {
2524 cur_out += 1;
2525 }
2526 }
2527 let out_pos = out_pos.expect("every writing segment must produce `out`");
2528 let mut state = match in_pos {
2529 Some(i) => inputs[i].to_vec(),
2530 None => vec![0u8; 16],
2531 };
2532
2533 fn apply(nodes: &[Node], state: &mut [u8]) {
2534 for node in nodes {
2535 match node {
2536 Node::Store {
2537 buffer,
2538 index: Expr::LitU32(i),
2539 value: Expr::LitU32(v),
2540 } if buffer.as_str() == "out" => {
2541 let off = (*i as usize) * 4;
2542 state[off] = (*v & 0xff) as u8;
2543 }
2544 Node::Region { body, .. } => apply(body, state),
2545 Node::Block(body) => apply(body, state),
2546 Node::If {
2547 then, otherwise, ..
2548 } => {
2549 apply(then, state);
2550 apply(otherwise, state);
2551 }
2552 Node::Loop { body, .. } => apply(body, state),
2553 _ => {}
2554 }
2555 }
2556 }
2557 apply(entry_sequence(program), &mut state);
2558
2559 self.calls.fetch_add(1, Ordering::SeqCst);
2560 while outputs.len() <= out_pos {
2561 outputs.push(Vec::new());
2562 }
2563 outputs[out_pos].clear();
2564 outputs[out_pos].extend_from_slice(&state);
2565 Ok(())
2566 }
2567 }
2568
2569 #[test]
2570 fn split_preserves_earlier_segment_output_slots_end_to_end() {
2571 let out = BufferDecl::output("out", 0, DataType::U32).with_count(4);
2578 let program = Program::wrapped(
2579 vec![out],
2580 [1, 1, 1],
2581 vec![
2582 region("a", vec![Node::store("out", Expr::u32(0), Expr::u32(0xAA))]),
2583 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2584 region("b", vec![Node::store("out", Expr::u32(2), Expr::u32(0xBB))]),
2585 ],
2586 );
2587 let backend = SlotStoringBackend {
2588 calls: AtomicUsize::new(0),
2589 };
2590 let mut outputs = vec![Vec::new()];
2591 dispatch_with_grid_sync_split_into(
2592 &backend,
2593 &program,
2594 &[],
2595 &DispatchConfig::default(),
2596 &mut outputs,
2597 )
2598 .expect("split dispatch");
2599 assert_eq!(
2600 backend.calls.load(Ordering::SeqCst),
2601 2,
2602 "two segments, single fixpoint pass"
2603 );
2604 assert_eq!(outputs.len(), 1);
2605 assert_eq!(
2606 outputs[0].len(),
2607 16,
2608 "output buffer is 4 × u32 = 16 bytes"
2609 );
2610 assert_eq!(
2611 outputs[0][0], 0xAA,
2612 "segment 0's slot (element 0) must survive the final segment's write"
2613 );
2614 assert_eq!(
2615 outputs[0][8], 0xBB,
2616 "the final segment's slot (element 2) is also present"
2617 );
2618 }
2619
2620 struct SaturatingBackend {
2624 calls: AtomicUsize,
2625 cap: u8,
2626 }
2627
2628 impl crate::backend::private::Sealed for SaturatingBackend {}
2629
2630 impl VyreBackend for SaturatingBackend {
2631 fn id(&self) -> &'static str {
2632 "grid-sync-saturating"
2633 }
2634
2635 fn dispatch(
2636 &self,
2637 _program: &Program,
2638 _inputs: &[Vec<u8>],
2639 _config: &DispatchConfig,
2640 ) -> Result<Vec<Vec<u8>>, BackendError> {
2641 unreachable!("test uses dispatch_borrowed_into")
2642 }
2643
2644 fn dispatch_borrowed_into(
2645 &self,
2646 _program: &Program,
2647 inputs: &[&[u8]],
2648 _config: &DispatchConfig,
2649 outputs: &mut OutputBuffers,
2650 ) -> Result<(), BackendError> {
2651 self.calls.fetch_add(1, Ordering::SeqCst);
2652 if outputs.is_empty() {
2653 outputs.push(Vec::new());
2654 }
2655 outputs[0].clear();
2656 outputs[0].extend_from_slice(inputs[0]);
2657 if outputs[0][0] < self.cap {
2658 outputs[0][0] += 1;
2659 }
2660 Ok(())
2661 }
2662 }
2663
2664 #[test]
2665 fn split_outer_loop_early_exits_when_accumulator_reaches_fixpoint() {
2666 let program = Program::wrapped(
2671 vec![buffer()],
2672 [1, 1, 1],
2673 vec![
2674 region("a", vec![Node::Return]),
2675 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2676 region("b", vec![Node::Return]),
2677 ],
2678 );
2679 let backend = SaturatingBackend {
2680 calls: AtomicUsize::new(0),
2681 cap: 3,
2682 };
2683 let config = DispatchConfig {
2684 fixpoint_iterations: Some(10),
2685 ..DispatchConfig::default()
2686 };
2687 let mut outputs = vec![Vec::new()];
2688 dispatch_with_grid_sync_split_into(
2689 &backend,
2690 &program,
2691 &[[0u8, 0, 0, 0].as_slice()],
2692 &config,
2693 &mut outputs,
2694 )
2695 .expect("converging split dispatch");
2696 assert_eq!(
2699 backend.calls.load(Ordering::SeqCst),
2700 6,
2701 "outer loop must early-exit one pass after the accumulator stops changing, not run all 10 iterations"
2702 );
2703 assert_eq!(
2704 outputs,
2705 vec![vec![3, 0, 0, 0]],
2706 "early-exit must return the converged fixpoint value, identical to running every iteration"
2707 );
2708 }
2709
2710 #[test]
2711 fn split_non_converging_accumulator_runs_full_iteration_budget() {
2712 let program = Program::wrapped(
2716 vec![buffer()],
2717 [1, 1, 1],
2718 vec![
2719 region("a", vec![Node::Return]),
2720 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2721 region("b", vec![Node::Return]),
2722 ],
2723 );
2724 let backend = SaturatingBackend {
2726 calls: AtomicUsize::new(0),
2727 cap: 255,
2728 };
2729 let config = DispatchConfig {
2730 fixpoint_iterations: Some(4),
2731 ..DispatchConfig::default()
2732 };
2733 let mut outputs = vec![Vec::new()];
2734 dispatch_with_grid_sync_split_into(
2735 &backend,
2736 &program,
2737 &[[0u8, 0, 0, 0].as_slice()],
2738 &config,
2739 &mut outputs,
2740 )
2741 .expect("non-converging split dispatch");
2742 assert_eq!(
2743 backend.calls.load(Ordering::SeqCst),
2744 8,
2745 "a still-advancing accumulator must run the full 4 iterations x 2 segments"
2746 );
2747 assert_eq!(outputs, vec![vec![8, 0, 0, 0]]);
2748 }
2749
2750 struct ResidentReuseBackend {
2751 calls: AtomicUsize,
2752 }
2753
2754 impl crate::backend::private::Sealed for ResidentReuseBackend {}
2755
2756 impl VyreBackend for ResidentReuseBackend {
2757 fn id(&self) -> &'static str {
2758 "grid-sync-resident-reuse"
2759 }
2760
2761 fn dispatch(
2762 &self,
2763 _program: &Program,
2764 _inputs: &[Vec<u8>],
2765 _config: &DispatchConfig,
2766 ) -> Result<Vec<Vec<u8>>, BackendError> {
2767 unreachable!("test uses dispatch_resident_timed")
2768 }
2769
2770 fn dispatch_borrowed_into(
2771 &self,
2772 _program: &Program,
2773 _inputs: &[&[u8]],
2774 _config: &DispatchConfig,
2775 _outputs: &mut OutputBuffers,
2776 ) -> Result<(), BackendError> {
2777 unreachable!("resident grid-sync split must not refresh through host borrowed inputs")
2778 }
2779
2780 fn dispatch_resident_timed(
2781 &self,
2782 _program: &Program,
2783 resources: &[Resource],
2784 _config: &DispatchConfig,
2785 ) -> Result<TimedDispatchResult, BackendError> {
2786 assert!(
2787 matches!(resources, [Resource::Resident(11), Resource::Resident(22)]),
2788 "Fix: resident grid-sync split must keep the original device handles bound across every segment."
2789 );
2790 let call = self.calls.fetch_add(1, Ordering::SeqCst);
2791 Ok(TimedDispatchResult {
2792 outputs: vec![vec![call as u8]],
2793 wall_ns: 10,
2794 device_ns: Some(2),
2795 enqueue_ns: Some(3),
2796 wait_ns: Some(4),
2797 })
2798 }
2799 }
2800
2801 #[test]
2802 fn resident_split_reuses_same_device_resources_across_segments() {
2803 let program = Program::wrapped(
2804 vec![buffer()],
2805 [1, 1, 1],
2806 vec![
2807 region("a", vec![Node::Return]),
2808 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2809 region("b", vec![Node::Return]),
2810 Node::barrier_with_ordering(MemoryOrdering::GridSync),
2811 region("c", vec![Node::Return]),
2812 ],
2813 );
2814 let backend = ResidentReuseBackend {
2815 calls: AtomicUsize::new(0),
2816 };
2817
2818 let timed = dispatch_resident_with_grid_sync_split_timed(
2819 &backend,
2820 &program,
2821 &[Resource::Resident(11), Resource::Resident(22)],
2822 &DispatchConfig::default(),
2823 )
2824 .expect("Fix: resident grid-sync split should run each segment on the same device handles");
2825
2826 assert_eq!(backend.calls.load(Ordering::SeqCst), 3);
2827 assert_eq!(timed.outputs, vec![vec![2]]);
2828 assert_eq!(timed.device_ns, Some(6));
2829 assert_eq!(timed.enqueue_ns, Some(9));
2830 assert_eq!(timed.wait_ns, Some(12));
2831 }
2832
2833 struct ResidentDeviceBackend {
2839 next_id: std::sync::atomic::AtomicU64,
2840 buffers: std::sync::Mutex<HashMap<u64, Vec<u8>>>,
2841 freed: std::sync::Mutex<Vec<u64>>,
2842 dispatches: AtomicUsize,
2843 }
2844
2845 impl ResidentDeviceBackend {
2846 fn new() -> Self {
2847 Self {
2848 next_id: std::sync::atomic::AtomicU64::new(1),
2849 buffers: std::sync::Mutex::new(HashMap::new()),
2850 freed: std::sync::Mutex::new(Vec::new()),
2851 dispatches: AtomicUsize::new(0),
2852 }
2853 }
2854
2855 fn resident_id(resource: &Resource) -> u64 {
2856 match resource {
2857 Resource::Resident(id) => *id,
2858 Resource::Borrowed(_) => {
2859 panic!("Fix: resident grid-sync fixpoint must bind Resident handles, not Borrowed")
2860 }
2861 }
2862 }
2863 }
2864
2865 impl crate::backend::private::Sealed for ResidentDeviceBackend {}
2866
2867 impl VyreBackend for ResidentDeviceBackend {
2868 fn id(&self) -> &'static str {
2869 "grid-sync-resident-device"
2870 }
2871
2872 fn dispatch(
2873 &self,
2874 _program: &Program,
2875 _inputs: &[Vec<u8>],
2876 _config: &DispatchConfig,
2877 ) -> Result<Vec<Vec<u8>>, BackendError> {
2878 unreachable!("resident fixpoint test uses resident dispatch")
2879 }
2880
2881 fn dispatch_borrowed_into(
2882 &self,
2883 _program: &Program,
2884 _inputs: &[&[u8]],
2885 _config: &DispatchConfig,
2886 _outputs: &mut OutputBuffers,
2887 ) -> Result<(), BackendError> {
2888 unreachable!("resident fixpoint must thread device handles, never host borrowed inputs")
2889 }
2890
2891 fn allocate_resident(&self, byte_len: usize) -> Result<Resource, BackendError> {
2892 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
2893 self.buffers.lock().unwrap().insert(id, vec![0xFFu8; byte_len]);
2896 Ok(Resource::Resident(id))
2897 }
2898
2899 fn upload_resident(&self, resource: &Resource, bytes: &[u8]) -> Result<(), BackendError> {
2900 let id = Self::resident_id(resource);
2901 let mut buffers = self.buffers.lock().unwrap();
2902 let buf = buffers.get_mut(&id).expect("resident handle exists");
2903 assert!(
2904 bytes.len() <= buf.len(),
2905 "upload {} bytes into a {}-byte resident buffer",
2906 bytes.len(),
2907 buf.len()
2908 );
2909 buf[..bytes.len()].copy_from_slice(bytes);
2910 Ok(())
2911 }
2912
2913 fn download_resident_range_into(
2914 &self,
2915 resource: &Resource,
2916 byte_offset: usize,
2917 byte_len: usize,
2918 output: &mut Vec<u8>,
2919 ) -> Result<(), BackendError> {
2920 let id = Self::resident_id(resource);
2921 let buffers = self.buffers.lock().unwrap();
2922 let buf = buffers.get(&id).expect("resident handle exists");
2923 output.clear();
2924 output.extend_from_slice(&buf[byte_offset..byte_offset + byte_len]);
2925 Ok(())
2926 }
2927
2928 fn free_resident(&self, resource: Resource) -> Result<(), BackendError> {
2929 let id = Self::resident_id(&resource);
2930 self.buffers.lock().unwrap().remove(&id);
2931 self.freed.lock().unwrap().push(id);
2932 Ok(())
2933 }
2934
2935 fn dispatch_resident_timed(
2936 &self,
2937 program: &Program,
2938 resources: &[Resource],
2939 _config: &DispatchConfig,
2940 ) -> Result<TimedDispatchResult, BackendError> {
2941 self.dispatches.fetch_add(1, Ordering::SeqCst);
2942 let plan = BindingPlan::build(program)?;
2945 let mut out_slot = None;
2946 let mut pos = 0usize;
2947 for binding in &plan.bindings {
2948 if binding.role == BindingRole::Shared {
2949 continue;
2950 }
2951 if binding.name.as_ref() == "out" {
2952 out_slot = Some(pos);
2953 }
2954 pos += 1;
2955 }
2956 let out_slot = out_slot.expect("program declares `out`");
2957 let id = Self::resident_id(&resources[out_slot]);
2958 let mut buffers = self.buffers.lock().unwrap();
2959 let buf = buffers.get_mut(&id).expect("resident `out` handle exists");
2960
2961 fn apply(nodes: &[Node], state: &mut [u8]) {
2964 for node in nodes {
2965 match node {
2966 Node::Store {
2967 buffer,
2968 index: Expr::LitU32(i),
2969 value: Expr::LitU32(v),
2970 } if buffer.as_str() == "out" => {
2971 state[(*i as usize) * 4] = (*v & 0xff) as u8;
2972 }
2973 Node::Region { body, .. } => apply(body, state),
2974 Node::Block(body) => apply(body, state),
2975 Node::If { then, otherwise, .. } => {
2976 apply(then, state);
2977 apply(otherwise, state);
2978 }
2979 Node::Loop { body, .. } => apply(body, state),
2980 _ => {}
2981 }
2982 }
2983 }
2984 apply(entry_sequence(program), buf.as_mut_slice());
2985
2986 Ok(TimedDispatchResult {
2987 outputs: Vec::new(),
2988 wall_ns: 1,
2989 device_ns: Some(1),
2990 enqueue_ns: Some(1),
2991 wait_ns: Some(1),
2992 })
2993 }
2994 }
2995
2996 #[test]
2997 fn resident_fixpoint_accumulates_across_segments_zero_inits_and_frees() {
2998 let out = BufferDecl::output("out", 0, DataType::U32).with_count(4);
3004 let program = Program::wrapped(
3005 vec![out],
3006 [1, 1, 1],
3007 vec![
3008 region("a", vec![Node::store("out", Expr::u32(0), Expr::u32(0xAA))]),
3009 Node::barrier_with_ordering(MemoryOrdering::GridSync),
3010 region("b", vec![Node::store("out", Expr::u32(2), Expr::u32(0xBB))]),
3011 ],
3012 );
3013 let backend = ResidentDeviceBackend::new();
3014 let mut outputs = vec![Vec::new()];
3015 dispatch_resident_grid_sync_fixpoint_into(
3016 &backend,
3017 &program,
3018 &[],
3019 &DispatchConfig::default(),
3020 &mut outputs,
3021 )
3022 .expect("resident grid-sync fixpoint dispatch");
3023
3024 assert_eq!(
3025 backend.dispatches.load(Ordering::SeqCst),
3026 2,
3027 "two segments, single fixpoint pass under the default config"
3028 );
3029 assert_eq!(outputs.len(), 1, "one output buffer (`out`)");
3030 assert_eq!(outputs[0].len(), 16, "4 × u32 = 16 bytes");
3031 assert_eq!(
3032 outputs[0][0], 0xAA,
3033 "segment 0's slot survives - resident accumulation, no clobber"
3034 );
3035 assert_eq!(outputs[0][8], 0xBB, "the final segment's slot is present");
3036 assert_eq!(outputs[0][4], 0x00, "untouched slot 1 was zero-initialized");
3040 assert_eq!(outputs[0][12], 0x00, "untouched slot 3 was zero-initialized");
3041 assert_eq!(
3043 backend.freed.lock().unwrap().len(),
3044 1,
3045 "the single `out` resident buffer is freed"
3046 );
3047 assert!(
3048 backend.buffers.lock().unwrap().is_empty(),
3049 "no resident buffer leaks after dispatch"
3050 );
3051 }
3052
3053 #[test]
3054 fn resident_fixpoint_repeats_to_fixpoint_bound() {
3055 let out = BufferDecl::output("out", 0, DataType::U32).with_count(4);
3059 let program = Program::wrapped(
3060 vec![out],
3061 [1, 1, 1],
3062 vec![
3063 region("a", vec![Node::store("out", Expr::u32(0), Expr::u32(0xAA))]),
3064 Node::barrier_with_ordering(MemoryOrdering::GridSync),
3065 region("b", vec![Node::store("out", Expr::u32(2), Expr::u32(0xBB))]),
3066 ],
3067 );
3068 let backend = ResidentDeviceBackend::new();
3069 let mut config = DispatchConfig::default();
3070 config.fixpoint_iterations = Some(3);
3071 let mut outputs = vec![Vec::new()];
3072 dispatch_resident_grid_sync_fixpoint_into(
3073 &backend,
3074 &program,
3075 &[],
3076 &config,
3077 &mut outputs,
3078 )
3079 .expect("resident grid-sync fixpoint dispatch");
3080 assert_eq!(
3081 backend.dispatches.load(Ordering::SeqCst),
3082 6,
3083 "2 segments × 3 fixpoint passes"
3084 );
3085 assert_eq!(outputs[0][0], 0xAA);
3086 assert_eq!(outputs[0][8], 0xBB);
3087 }
3088}