1use std::hash::Hash;
9
10use rustc_hash::{FxHashMap, FxHashSet};
11
12use crate::ordering::sort_unstable_by_key_if_needed;
13use crate::reservation_policy::{
14 reserve_typed_hash_map_to_capacity, reserve_typed_hash_set_to_capacity,
15 reserve_typed_vec_to_capacity, reserved_typed_vec, ReservationPolicy,
16};
17use crate::ResidentGraphReuseTelemetry;
18
19const MULTI_QUERY_RESERVATION: ReservationPolicy = ReservationPolicy::new(
20 "multi-query execution",
21 "shard the query batch before planning",
22);
23
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
26pub struct MultiQuery {
27 pub query: u32,
29 pub graph_layout_hash: u64,
31 pub traversal_key: u64,
33 pub graph_upload_bytes: u64,
35 pub frontier_bytes: u64,
37 pub scratch_bytes: u64,
39 pub output_bytes: u64,
41}
42
43#[derive(Clone, Debug, Eq, PartialEq)]
45pub struct MultiQueryGroup {
46 pub graph_layout_hash: u64,
48 pub traversal_key: u64,
50 pub queries: Vec<u32>,
52 pub graph_upload_bytes: u64,
54 pub frontier_bytes: u64,
56 pub peak_scratch_bytes: u64,
58 pub output_bytes: u64,
60 pub resident_bytes: u64,
62 pub avoided_launches: u32,
64 pub avoided_host_fences: u32,
66 pub avoided_graph_upload_bytes: u64,
68 pub graph_reuse: ResidentGraphReuseTelemetry,
70}
71
72#[derive(Clone, Debug, Eq, PartialEq)]
74pub struct MultiQueryExecutionPlan {
75 pub groups: Vec<MultiQueryGroup>,
77 pub launch_count: u32,
79 pub avoided_launches: u32,
81 pub avoided_host_fences: u32,
83 pub avoided_graph_upload_bytes: u64,
85 pub graph_reuse: ResidentGraphReuseTelemetry,
87 pub peak_resident_bytes: u64,
89 pub final_only_host_fence_per_group: bool,
91}
92
93#[derive(Debug, Default)]
95pub struct MultiQueryExecutionScratch {
96 group_indices: FxHashMap<(u64, u64), usize>,
97 group_query_counts: FxHashMap<(u64, u64), usize>,
98 resident_graphs: FxHashSet<u64>,
99 resident_graph_bytes: FxHashMap<u64, u64>,
100 grouped_queries: Vec<((u64, u64), Vec<MultiQuery>)>,
101 free_query_buckets: Vec<Vec<MultiQuery>>,
102 seen_queries: FxHashSet<u32>,
103}
104
105impl MultiQueryExecutionScratch {
106 #[must_use]
108 pub fn new() -> Self {
109 Self {
110 group_indices: FxHashMap::default(),
111 group_query_counts: FxHashMap::default(),
112 resident_graphs: FxHashSet::default(),
113 resident_graph_bytes: FxHashMap::default(),
114 grouped_queries: Vec::new(),
115 free_query_buckets: Vec::new(),
116 seen_queries: FxHashSet::default(),
117 }
118 }
119
120 pub fn try_with_capacity(query_count: usize) -> Result<Self, MultiQueryExecutionError> {
122 let mut scratch = Self::new();
123 scratch.try_reserve_query_shape(query_count)?;
124 Ok(scratch)
125 }
126
127 fn try_reserve_query_shape(
128 &mut self,
129 query_count: usize,
130 ) -> Result<(), MultiQueryExecutionError> {
131 reserve_map(
132 &mut self.group_indices,
133 query_count,
134 "multi-query group index table",
135 )?;
136 reserve_map(
137 &mut self.group_query_counts,
138 query_count,
139 "multi-query group size table",
140 )?;
141 reserve_set(
142 &mut self.resident_graphs,
143 query_count,
144 "multi-query resident graph set",
145 )?;
146 reserve_map(
147 &mut self.resident_graph_bytes,
148 query_count,
149 "multi-query resident graph byte table",
150 )?;
151 reserve_vec(
152 &mut self.grouped_queries,
153 query_count,
154 "multi-query grouped-query buckets",
155 )?;
156 reserve_set(
157 &mut self.seen_queries,
158 query_count,
159 "multi-query seen query ids",
160 )
161 }
162
163 #[must_use]
165 pub fn group_index_capacity(&self) -> usize {
166 self.group_indices.capacity()
167 }
168
169 #[must_use]
171 pub fn grouped_query_capacity(&self) -> usize {
172 self.grouped_queries.capacity()
173 }
174
175 #[must_use]
177 pub fn resident_graph_capacity(&self) -> usize {
178 self.resident_graphs.capacity()
179 }
180
181 #[must_use]
183 pub fn retained_query_bucket_capacity(&self) -> usize {
184 self.free_query_buckets
185 .iter()
186 .map(Vec::capacity)
187 .sum::<usize>()
188 + self
189 .grouped_queries
190 .iter()
191 .map(|(_, queries)| queries.capacity())
192 .sum::<usize>()
193 }
194
195 fn clear(&mut self) -> Result<(), MultiQueryExecutionError> {
196 self.group_indices.clear();
197 self.group_query_counts.clear();
198 self.resident_graphs.clear();
199 self.resident_graph_bytes.clear();
200 let retained_bucket_count = self
201 .free_query_buckets
202 .len()
203 .checked_add(self.grouped_queries.len())
204 .ok_or(MultiQueryExecutionError::ByteCountOverflow {
205 field: "retained multi-query bucket count",
206 })?;
207 reserve_vec(
208 &mut self.free_query_buckets,
209 retained_bucket_count,
210 "multi-query retained bucket pool",
211 )?;
212 for (_, mut queries) in self.grouped_queries.drain(..) {
213 queries.clear();
214 self.free_query_buckets.push(queries);
215 }
216 self.seen_queries.clear();
217 Ok(())
218 }
219}
220
221fn take_reserved_query_bucket(
222 free_query_buckets: &mut Vec<Vec<MultiQuery>>,
223 query_count: usize,
224) -> Result<Vec<MultiQuery>, MultiQueryExecutionError> {
225 let mut queries = free_query_buckets.pop().unwrap_or_default();
226 if let Err(error) = reserve_vec(
227 &mut queries,
228 query_count,
229 "multi-query grouped query bucket",
230 ) {
231 free_query_buckets.push(queries);
232 return Err(error);
233 }
234 queries.clear();
235 Ok(queries)
236}
237
238#[derive(Clone, Debug, Eq, PartialEq)]
240pub enum MultiQueryExecutionError {
241 DuplicateQuery {
243 query: u32,
245 },
246 ZeroGraphHash {
248 query: u32,
250 },
251 ZeroTraversalKey {
253 query: u32,
255 },
256 ZeroGraphUploadBytes {
258 query: u32,
260 },
261 GraphUploadBytesMismatch {
263 graph_layout_hash: u64,
265 expected_bytes: u64,
267 actual_bytes: u64,
269 query: u32,
271 },
272 ZeroBudget,
274 ByteCountOverflow {
276 field: &'static str,
278 },
279 OverBudget {
281 graph_layout_hash: u64,
283 traversal_key: u64,
285 required_bytes: u64,
287 budget_bytes: u64,
289 },
290 StorageReserveFailed {
292 field: &'static str,
294 requested: usize,
296 message: String,
298 },
299 InternalInvariant {
301 message: &'static str,
303 },
304}
305
306fn storage_reserve_failed(
307 field: &'static str,
308 requested: usize,
309 message: String,
310) -> MultiQueryExecutionError {
311 MultiQueryExecutionError::StorageReserveFailed {
312 field,
313 requested,
314 message,
315 }
316}
317
318fn reserve_vec<T>(
319 vec: &mut Vec<T>,
320 target_capacity: usize,
321 field: &'static str,
322) -> Result<(), MultiQueryExecutionError> {
323 reserve_typed_vec_to_capacity(
324 MULTI_QUERY_RESERVATION,
325 vec,
326 target_capacity,
327 field,
328 storage_reserve_failed,
329 )
330}
331
332fn reserved_vec<T>(
333 target_capacity: usize,
334 field: &'static str,
335) -> Result<Vec<T>, MultiQueryExecutionError> {
336 reserved_typed_vec(
337 MULTI_QUERY_RESERVATION,
338 target_capacity,
339 field,
340 storage_reserve_failed,
341 )
342}
343
344fn reserve_set<T>(
345 set: &mut FxHashSet<T>,
346 target_capacity: usize,
347 field: &'static str,
348) -> Result<(), MultiQueryExecutionError>
349where
350 T: Eq + Hash,
351{
352 reserve_typed_hash_set_to_capacity(
353 MULTI_QUERY_RESERVATION,
354 set,
355 target_capacity,
356 field,
357 storage_reserve_failed,
358 )
359}
360
361fn reserve_map<K, V>(
362 map: &mut FxHashMap<K, V>,
363 target_capacity: usize,
364 field: &'static str,
365) -> Result<(), MultiQueryExecutionError>
366where
367 K: Eq + Hash,
368{
369 reserve_typed_hash_map_to_capacity(
370 MULTI_QUERY_RESERVATION,
371 map,
372 target_capacity,
373 field,
374 storage_reserve_failed,
375 )
376}
377
378fn checked_add(lhs: u64, rhs: u64, field: &'static str) -> Result<u64, MultiQueryExecutionError> {
379 lhs.checked_add(rhs)
380 .ok_or(MultiQueryExecutionError::ByteCountOverflow { field })
381}
382
383fn checked_add_u32(
384 lhs: u32,
385 rhs: u32,
386 field: &'static str,
387) -> Result<u32, MultiQueryExecutionError> {
388 lhs.checked_add(rhs)
389 .ok_or(MultiQueryExecutionError::ByteCountOverflow { field })
390}
391
392impl std::fmt::Display for MultiQueryExecutionError {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 match self {
395 Self::DuplicateQuery { query } => write!(
396 f,
397 "multi-query execution received duplicate query id {query}. Fix: assign unique ids before batch planning."
398 ),
399 Self::ZeroGraphHash { query } => write!(
400 f,
401 "multi-query {query} has graph_layout_hash=0. Fix: normalize and hash the resident graph before query batching."
402 ),
403 Self::ZeroTraversalKey { query } => write!(
404 f,
405 "multi-query {query} has traversal_key=0. Fix: emit a concrete traversal compatibility key before multi-query batching."
406 ),
407 Self::ZeroGraphUploadBytes { query } => write!(
408 f,
409 "multi-query {query} has graph_upload_bytes=0. Fix: pass the concrete resident graph topology byte count before multi-query batching."
410 ),
411 Self::GraphUploadBytesMismatch {
412 graph_layout_hash,
413 expected_bytes,
414 actual_bytes,
415 query,
416 } => write!(
417 f,
418 "multi-query graph hash {graph_layout_hash} reported conflicting resident byte widths: expected {expected_bytes}, query {query} reported {actual_bytes}. Fix: canonicalize graph layout hashing and byte accounting before multi-query batching."
419 ),
420 Self::ZeroBudget => write!(
421 f,
422 "multi-query execution received a zero device budget. Fix: pass an explicit resident memory budget before planning."
423 ),
424 Self::ByteCountOverflow { field } => write!(
425 f,
426 "multi-query execution overflowed while computing {field}. Fix: shard the query batch before planning."
427 ),
428 Self::OverBudget {
429 graph_layout_hash,
430 traversal_key,
431 required_bytes,
432 budget_bytes,
433 } => write!(
434 f,
435 "multi-query group graph={graph_layout_hash} traversal={traversal_key} requires {required_bytes} bytes but budget allows {budget_bytes}. Fix: split the group or raise the explicit multi-query budget."
436 ),
437 Self::StorageReserveFailed {
438 field,
439 requested,
440 message,
441 } => write!(
442 f,
443 "multi-query execution could not reserve {requested} {field} entries: {message}. Fix: shard the query batch before planning."
444 ),
445 Self::InternalInvariant { message } => write!(
446 f,
447 "multi-query execution violated an internal planner invariant: {message}. Fix: keep group counting and bucket indexing in one validated planning pass."
448 ),
449 }
450 }
451}
452
453impl std::error::Error for MultiQueryExecutionError {}
454
455pub fn plan_multi_query_execution(
457 queries: &[MultiQuery],
458 budget_bytes: u64,
459) -> Result<MultiQueryExecutionPlan, MultiQueryExecutionError> {
460 let mut scratch = MultiQueryExecutionScratch::try_with_capacity(queries.len())?;
461 plan_multi_query_execution_with_scratch(queries, budget_bytes, &mut scratch)
462}
463
464pub fn plan_multi_query_execution_with_scratch(
466 queries: &[MultiQuery],
467 budget_bytes: u64,
468 scratch: &mut MultiQueryExecutionScratch,
469) -> Result<MultiQueryExecutionPlan, MultiQueryExecutionError> {
470 if budget_bytes == 0 {
471 return Err(MultiQueryExecutionError::ZeroBudget);
472 }
473 if queries.is_empty() {
474 return Ok(MultiQueryExecutionPlan {
475 launch_count: 0,
476 groups: Vec::new(),
477 avoided_launches: 0,
478 avoided_host_fences: 0,
479 avoided_graph_upload_bytes: 0,
480 graph_reuse: ResidentGraphReuseTelemetry::default(),
481 peak_resident_bytes: 0,
482 final_only_host_fence_per_group: true,
483 });
484 }
485 if queries.len() == 1 {
486 let query = queries[0];
487 if query.graph_layout_hash == 0 {
488 return Err(MultiQueryExecutionError::ZeroGraphHash { query: query.query });
489 }
490 if query.traversal_key == 0 {
491 return Err(MultiQueryExecutionError::ZeroTraversalKey { query: query.query });
492 }
493 if query.graph_upload_bytes == 0 {
494 return Err(MultiQueryExecutionError::ZeroGraphUploadBytes { query: query.query });
495 }
496 let resident_bytes = group_resident_bytes(
497 query.graph_upload_bytes,
498 query.frontier_bytes,
499 query.scratch_bytes,
500 query.output_bytes,
501 )?;
502 if resident_bytes > budget_bytes {
503 return Err(MultiQueryExecutionError::OverBudget {
504 graph_layout_hash: query.graph_layout_hash,
505 traversal_key: query.traversal_key,
506 required_bytes: resident_bytes,
507 budget_bytes,
508 });
509 }
510 let mut query_ids = reserved_vec(1, "multi-query singleton query ids")?;
511 query_ids.push(query.query);
512 let mut groups = reserved_vec(1, "multi-query output groups")?;
513 groups.push(MultiQueryGroup {
514 graph_layout_hash: query.graph_layout_hash,
515 traversal_key: query.traversal_key,
516 queries: query_ids,
517 graph_upload_bytes: query.graph_upload_bytes,
518 frontier_bytes: query.frontier_bytes,
519 peak_scratch_bytes: query.scratch_bytes,
520 output_bytes: query.output_bytes,
521 resident_bytes,
522 avoided_launches: 0,
523 avoided_host_fences: 0,
524 avoided_graph_upload_bytes: 0,
525 graph_reuse: ResidentGraphReuseTelemetry::cold_upload(query.graph_upload_bytes),
526 });
527 return Ok(MultiQueryExecutionPlan {
528 launch_count: 1,
529 groups,
530 avoided_launches: 0,
531 avoided_host_fences: 0,
532 avoided_graph_upload_bytes: 0,
533 graph_reuse: ResidentGraphReuseTelemetry::cold_upload(query.graph_upload_bytes),
534 peak_resident_bytes: resident_bytes,
535 final_only_host_fence_per_group: true,
536 });
537 }
538
539 scratch.clear()?;
540 scratch.try_reserve_query_shape(queries.len())?;
541 for query in queries {
542 if !scratch.seen_queries.insert(query.query) {
543 return Err(MultiQueryExecutionError::DuplicateQuery { query: query.query });
544 }
545 if query.graph_layout_hash == 0 {
546 return Err(MultiQueryExecutionError::ZeroGraphHash { query: query.query });
547 }
548 if query.traversal_key == 0 {
549 return Err(MultiQueryExecutionError::ZeroTraversalKey { query: query.query });
550 }
551 if query.graph_upload_bytes == 0 {
552 return Err(MultiQueryExecutionError::ZeroGraphUploadBytes { query: query.query });
553 }
554 match scratch
555 .resident_graph_bytes
556 .get(&query.graph_layout_hash)
557 .copied()
558 {
559 Some(expected_bytes) if expected_bytes != query.graph_upload_bytes => {
560 return Err(MultiQueryExecutionError::GraphUploadBytesMismatch {
561 graph_layout_hash: query.graph_layout_hash,
562 expected_bytes,
563 actual_bytes: query.graph_upload_bytes,
564 query: query.query,
565 });
566 }
567 Some(_) => {}
568 None => {
569 scratch
570 .resident_graph_bytes
571 .insert(query.graph_layout_hash, query.graph_upload_bytes);
572 }
573 }
574 let key = (query.graph_layout_hash, query.traversal_key);
575 let count = scratch.group_query_counts.entry(key).or_insert(0);
576 *count = count
577 .checked_add(1)
578 .ok_or(MultiQueryExecutionError::ByteCountOverflow {
579 field: "multi-query grouped query count",
580 })?;
581 }
582
583 reserve_vec(
584 &mut scratch.grouped_queries,
585 scratch.group_query_counts.len(),
586 "multi-query grouped-query buckets",
587 )?;
588 for (&key, &query_count) in &scratch.group_query_counts {
589 let index = scratch.grouped_queries.len();
590 let queries = take_reserved_query_bucket(&mut scratch.free_query_buckets, query_count)?;
591 scratch.grouped_queries.push((key, queries));
592 scratch.group_indices.insert(key, index);
593 }
594
595 for query in queries {
596 let key = (query.graph_layout_hash, query.traversal_key);
597 let index = scratch.group_indices.get(&key).copied().ok_or(
598 MultiQueryExecutionError::InternalInvariant {
599 message: "validated multi-query group key missing from exact-capacity bucket index",
600 },
601 )?;
602 scratch.grouped_queries[index].1.push(*query);
603 }
604
605 let mut groups = reserved_vec(scratch.grouped_queries.len(), "multi-query output groups")?;
606 let mut avoided_launches = 0_u32;
607 let mut avoided_host_fences = 0_u32;
608 let mut avoided_graph_upload_bytes = 0_u64;
609 let mut graph_reuse = ResidentGraphReuseTelemetry::default();
610 let mut peak_resident_bytes = 0_u64;
611
612 sort_unstable_by_key_if_needed(&mut scratch.grouped_queries, |(key, _)| *key);
613 for ((graph_layout_hash, traversal_key), group_queries) in &mut scratch.grouped_queries {
614 sort_unstable_by_key_if_needed(group_queries, |query| query.query);
615 let first_new_group = groups.len();
616 let graph_already_resident = !scratch.resident_graphs.insert(*graph_layout_hash);
617 append_memory_fit_groups(
618 *graph_layout_hash,
619 *traversal_key,
620 group_queries,
621 budget_bytes,
622 graph_already_resident,
623 &mut groups,
624 )?;
625 for group in &groups[first_new_group..] {
626 avoided_launches =
627 checked_add_u32(avoided_launches, group.avoided_launches, "avoided launches")?;
628 avoided_host_fences = checked_add_u32(
629 avoided_host_fences,
630 group.avoided_host_fences,
631 "avoided host fences",
632 )?;
633 avoided_graph_upload_bytes = checked_add(
634 avoided_graph_upload_bytes,
635 group.avoided_graph_upload_bytes,
636 "avoided graph upload bytes",
637 )?;
638 graph_reuse = graph_reuse.checked_add(group.graph_reuse).map_err(|_| {
639 MultiQueryExecutionError::ByteCountOverflow {
640 field: "graph reuse telemetry",
641 }
642 })?;
643 peak_resident_bytes = peak_resident_bytes.max(group.resident_bytes);
644 }
645 }
646 let launch_count =
647 u32::try_from(groups.len()).map_err(|_| MultiQueryExecutionError::ByteCountOverflow {
648 field: "launch count",
649 })?;
650
651 Ok(MultiQueryExecutionPlan {
652 launch_count,
653 groups,
654 avoided_launches,
655 avoided_host_fences,
656 avoided_graph_upload_bytes,
657 graph_reuse,
658 peak_resident_bytes,
659 final_only_host_fence_per_group: true,
660 })
661}
662
663fn append_memory_fit_groups(
664 graph_layout_hash: u64,
665 traversal_key: u64,
666 queries: &[MultiQuery],
667 budget_bytes: u64,
668 graph_already_resident: bool,
669 groups: &mut Vec<MultiQueryGroup>,
670) -> Result<(), MultiQueryExecutionError> {
671 let mut start = 0usize;
672 let resident_graph_bytes = queries[0].graph_upload_bytes;
673 while start < queries.len() {
674 let graph_upload_bytes = if start == 0 && !graph_already_resident {
675 resident_graph_bytes
676 } else {
677 0
678 };
679 let mut avoided_graph_upload_bytes = if graph_upload_bytes == 0 {
680 queries[start].graph_upload_bytes
681 } else {
682 0
683 };
684 let mut warm_reuses = if graph_upload_bytes == 0 { 1 } else { 0 };
685 let mut frontier_bytes = 0_u64;
686 let mut peak_scratch_bytes = 0_u64;
687 let mut output_bytes = 0_u64;
688 let mut resident_bytes = graph_upload_bytes;
689 let mut cursor = start;
690
691 while cursor < queries.len() {
692 let query = queries[cursor];
693 let candidate_frontier =
694 checked_add(frontier_bytes, query.frontier_bytes, "frontier bytes")?;
695 let candidate_scratch = peak_scratch_bytes.max(query.scratch_bytes);
696 let candidate_output = checked_add(output_bytes, query.output_bytes, "output bytes")?;
697 let candidate_resident = group_resident_bytes(
698 resident_graph_bytes,
699 candidate_frontier,
700 candidate_scratch,
701 candidate_output,
702 )?;
703
704 if candidate_resident > budget_bytes {
705 if cursor == start {
706 return Err(MultiQueryExecutionError::OverBudget {
707 graph_layout_hash,
708 traversal_key,
709 required_bytes: candidate_resident,
710 budget_bytes,
711 });
712 }
713 break;
714 }
715
716 if cursor != start {
717 avoided_graph_upload_bytes = checked_add(
718 avoided_graph_upload_bytes,
719 query.graph_upload_bytes,
720 "avoided graph upload bytes",
721 )?;
722 warm_reuses = checked_add(warm_reuses, 1, "warm resident graph reuse count")?;
723 }
724 frontier_bytes = candidate_frontier;
725 peak_scratch_bytes = candidate_scratch;
726 output_bytes = candidate_output;
727 resident_bytes = candidate_resident;
728 cursor += 1;
729 }
730
731 let chunk_len =
732 cursor
733 .checked_sub(start)
734 .ok_or(MultiQueryExecutionError::InternalInvariant {
735 message: "multi-query chunk cursor moved before chunk start",
736 })?;
737 let mut query_ids = reserved_vec(chunk_len, "multi-query chunk query ids")?;
738 for query in &queries[start..cursor] {
739 query_ids.push(query.query);
740 }
741
742 let avoided = u32::try_from(chunk_len - 1).map_err(|_| {
743 MultiQueryExecutionError::ByteCountOverflow {
744 field: "avoided launches",
745 }
746 })?;
747 groups.push(MultiQueryGroup {
748 graph_layout_hash,
749 traversal_key,
750 queries: query_ids,
751 graph_upload_bytes,
752 frontier_bytes,
753 peak_scratch_bytes,
754 output_bytes,
755 resident_bytes,
756 avoided_launches: avoided,
757 avoided_host_fences: avoided,
758 avoided_graph_upload_bytes,
759 graph_reuse: ResidentGraphReuseTelemetry::from_counters(
760 u64::from(graph_upload_bytes != 0),
761 warm_reuses,
762 graph_upload_bytes,
763 avoided_graph_upload_bytes,
764 ),
765 });
766 start = cursor;
767 }
768 Ok(())
769}
770
771fn group_resident_bytes(
772 graph_upload_bytes: u64,
773 frontier_bytes: u64,
774 peak_scratch_bytes: u64,
775 output_bytes: u64,
776) -> Result<u64, MultiQueryExecutionError> {
777 let graph_plus_frontier = checked_add(
778 graph_upload_bytes,
779 frontier_bytes,
780 "graph plus frontier resident bytes",
781 )?;
782 let with_scratch = checked_add(
783 graph_plus_frontier,
784 peak_scratch_bytes,
785 "resident bytes with scratch",
786 )?;
787 checked_add(with_scratch, output_bytes, "resident bytes with outputs")
788}
789
790#[cfg(test)]
791mod tests {
792 use super::*;
793
794 #[test]
795 fn multi_query_batches_compatible_queries_over_one_resident_graph() {
796 let plan = plan_multi_query_execution(
797 &[
798 query(3, 0xabc, 0x10, 4_096, 64, 128, 32),
799 query(1, 0xabc, 0x10, 4_096, 32, 64, 16),
800 query(2, 0xabc, 0x10, 4_096, 48, 96, 24),
801 ],
802 8_192,
803 )
804 .expect("Fix: compatible queries should batch");
805
806 assert_eq!(plan.launch_count, 1);
807 assert_eq!(plan.avoided_launches, 2);
808 assert_eq!(plan.avoided_host_fences, 2);
809 assert_eq!(plan.avoided_graph_upload_bytes, 8_192);
810 assert_eq!(
811 plan.graph_reuse,
812 ResidentGraphReuseTelemetry::from_counters(1, 2, 4_096, 8_192)
813 );
814 assert_eq!(plan.groups[0].queries, vec![1, 2, 3]);
815 assert_eq!(
816 plan.groups[0].graph_reuse,
817 ResidentGraphReuseTelemetry::from_counters(1, 2, 4_096, 8_192)
818 );
819 assert_eq!(plan.groups[0].frontier_bytes, 144);
820 assert_eq!(plan.groups[0].peak_scratch_bytes, 128);
821 assert_eq!(plan.groups[0].output_bytes, 72);
822 assert!(plan.final_only_host_fence_per_group);
823 }
824
825 #[test]
826 fn multi_query_splits_compatible_group_to_fit_cuda_budget_without_reuploading_graph() {
827 let plan = plan_multi_query_execution(
828 &[
829 query(1, 0xabc, 0x10, 100, 100, 10, 10),
830 query(2, 0xabc, 0x10, 100, 100, 10, 10),
831 query(3, 0xabc, 0x10, 100, 100, 10, 10),
832 ],
833 350,
834 )
835 .expect("Fix: compatible multi-query queries should split into budget-fit resident chunks");
836
837 assert_eq!(plan.launch_count, 2);
838 assert_eq!(plan.avoided_launches, 1);
839 assert_eq!(plan.avoided_host_fences, 1);
840 assert_eq!(plan.avoided_graph_upload_bytes, 200);
841 assert_eq!(
842 plan.graph_reuse,
843 ResidentGraphReuseTelemetry::from_counters(1, 2, 100, 200)
844 );
845 assert_eq!(plan.peak_resident_bytes, 330);
846 assert_eq!(plan.groups[0].queries, vec![1, 2]);
847 assert_eq!(plan.groups[0].graph_upload_bytes, 100);
848 assert_eq!(plan.groups[0].resident_bytes, 330);
849 assert_eq!(plan.groups[1].queries, vec![3]);
850 assert_eq!(plan.groups[1].graph_upload_bytes, 0);
851 assert_eq!(plan.groups[1].resident_bytes, 220);
852 assert!(plan.final_only_host_fence_per_group);
853 }
854
855 #[test]
856 fn multi_query_later_chunks_still_count_resident_graph_memory() {
857 assert_eq!(
858 plan_multi_query_execution(
859 &[
860 query(1, 0xabc, 0x10, 100, 100, 10, 10),
861 query(2, 0xabc, 0x10, 100, 100, 10, 10),
862 ],
863 150,
864 )
865 .expect_err("later resident chunk still needs graph memory and should exceed budget"),
866 MultiQueryExecutionError::OverBudget {
867 graph_layout_hash: 0xabc,
868 traversal_key: 0x10,
869 required_bytes: 220,
870 budget_bytes: 150,
871 }
872 );
873 }
874
875 #[test]
876 fn multi_query_split_chunks_reserve_only_actual_chunk_ids() {
877 let plan = plan_multi_query_execution(
878 &[
879 query(1, 0xabc, 0x10, 100, 100, 10, 10),
880 query(2, 0xabc, 0x10, 100, 100, 10, 10),
881 query(3, 0xabc, 0x10, 100, 100, 10, 10),
882 query(4, 0xabc, 0x10, 100, 100, 10, 10),
883 ],
884 220,
885 )
886 .expect("Fix: multi-query planner should split into single-query chunks");
887
888 assert_eq!(plan.launch_count, 4);
889 assert!(plan.groups.iter().all(|group| group.queries.len() == 1));
890 assert_eq!(plan.avoided_launches, 0);
891 assert_eq!(plan.avoided_host_fences, 0);
892 assert_eq!(plan.avoided_graph_upload_bytes, 300);
893
894 let src = include_str!("multi_query_execution.rs");
895 assert!(
896 src.contains("let chunk_len =")
897 && src.contains("reserved_vec(chunk_len, \"multi-query chunk query ids\")")
898 && !src.contains(concat!("reserved_vec(queries.len()", " - start")),
899 "Fix: split multi-query chunks must reserve only the actual chunk size, not the whole remaining tail."
900 );
901 }
902
903 #[test]
904 fn multi_query_splits_incompatible_graph_or_traversal_keys() {
905 let plan = plan_multi_query_execution(
906 &[
907 query(1, 0xdef, 0x10, 1_024, 32, 64, 16),
908 query(2, 0xabc, 0x20, 1_024, 32, 64, 16),
909 query(3, 0xabc, 0x10, 1_024, 32, 64, 16),
910 ],
911 4_096,
912 )
913 .expect("Fix: incompatible queries should become separate groups");
914
915 assert_eq!(plan.launch_count, 3);
916 assert_eq!(plan.avoided_launches, 0);
917 assert_eq!(plan.avoided_graph_upload_bytes, 1_024);
918 assert_eq!(
919 plan.graph_reuse,
920 ResidentGraphReuseTelemetry::from_counters(2, 1, 2_048, 1_024)
921 );
922 assert_eq!(plan.groups[0].graph_upload_bytes, 1_024);
923 assert_eq!(plan.groups[1].graph_upload_bytes, 0);
924 assert_eq!(plan.groups[2].graph_upload_bytes, 1_024);
925 assert_eq!(
926 plan.groups
927 .iter()
928 .map(|group| (group.graph_layout_hash, group.traversal_key))
929 .collect::<Vec<_>>(),
930 vec![(0xabc, 0x10), (0xabc, 0x20), (0xdef, 0x10)]
931 );
932 }
933
934 #[test]
935 fn multi_query_grouping_avoids_tree_lookup_per_query() {
936 let src = include_str!("multi_query_execution.rs");
937 assert!(
938 !src.contains(concat!("BTree", "Map")),
939 "Fix: multi-query grouping should hash query ids and group indices, then sort final groups once for deterministic output."
940 );
941 }
942
943 #[test]
944 fn multi_query_planner_reuses_caller_owned_grouping_scratch() {
945 let mut scratch = MultiQueryExecutionScratch::try_with_capacity(128)
946 .expect("Fix: multi-query scratch should reserve");
947 let wide = (0..128)
948 .map(|index| query(index, 0xabc, 0x10, 4_096, 4, 8, 4))
949 .collect::<Vec<_>>();
950 let first = plan_multi_query_execution_with_scratch(&wide, 16_384, &mut scratch)
951 .expect("Fix: wide compatible query batch should plan");
952 let group_index_capacity = scratch.group_index_capacity();
953 let grouped_query_capacity = scratch.grouped_query_capacity();
954 let resident_graph_capacity = scratch.resident_graph_capacity();
955 let query_bucket_capacity = scratch.retained_query_bucket_capacity();
956
957 assert_eq!(first.launch_count, 1);
958 assert_eq!(first.groups[0].queries.len(), 128);
959 assert!(
960 query_bucket_capacity >= 128,
961 "Fix: multi-query scratch must retain inner grouped-query bucket capacity across planning calls"
962 );
963
964 let second = plan_multi_query_execution_with_scratch(
965 &[
966 query(9, 0xdef, 0x20, 1_024, 16, 32, 8),
967 query(7, 0xabc, 0x10, 1_024, 16, 32, 8),
968 ],
969 4_096,
970 &mut scratch,
971 )
972 .expect("Fix: smaller incompatible query batch should reuse previous scratch");
973
974 assert_eq!(second.launch_count, 2);
975 assert!(scratch.group_index_capacity() >= group_index_capacity);
976 assert!(scratch.grouped_query_capacity() >= grouped_query_capacity);
977 assert!(scratch.resident_graph_capacity() >= resident_graph_capacity);
978 assert!(scratch.retained_query_bucket_capacity() >= query_bucket_capacity);
979
980 let src = include_str!("multi_query_execution.rs");
981 assert!(
982 src.contains("pub fn plan_multi_query_execution_with_scratch"),
983 "Fix: release callers need a scratch-aware multi-query planning path"
984 );
985 assert!(
986 src.contains("scratch.grouped_queries.sort_unstable_by_key"),
987 "Fix: deterministic multi-query output should sort retained scratch buckets in place"
988 );
989 }
990
991 #[test]
992 fn reused_query_bucket_returns_to_pool_when_reservation_fails() {
993 let retained = vec![query(42, 0xabc, 0x10, 4_096, 8, 16, 4)];
994 let mut free_query_buckets = vec![retained.clone()];
995
996 let err = take_reserved_query_bucket(&mut free_query_buckets, usize::MAX)
997 .expect_err("impossible query bucket reservation must fail");
998
999 assert!(
1000 matches!(
1001 err,
1002 MultiQueryExecutionError::StorageReserveFailed {
1003 field: "multi-query grouped query bucket",
1004 ..
1005 }
1006 ),
1007 "Fix: query bucket reservation failure must surface the grouped-bucket field"
1008 );
1009 assert_eq!(
1010 free_query_buckets,
1011 vec![retained],
1012 "failed reservation must return the reusable multi-query query bucket to scratch"
1013 );
1014 }
1015
1016 #[test]
1017 fn multi_query_planner_staging_reserves_fallibly() {
1018 let production = include_str!("multi_query_execution.rs")
1019 .split("#[cfg(test)]")
1020 .next()
1021 .expect("Fix: multi-query production source must precede tests");
1022
1023 assert!(
1024 production.contains("MultiQueryExecutionScratch::try_with_capacity(queries.len())?")
1025 && production.contains("scratch.try_reserve_query_shape(queries.len())?")
1026 && production.contains("use crate::reservation_policy::{")
1027 && production.contains("reserve_typed_vec_to_capacity")
1028 && production.contains("reserve_typed_hash_map_to_capacity")
1029 && production.contains("reserve_typed_hash_set_to_capacity")
1030 && production.contains("StorageReserveFailed")
1031 && production.contains("const MULTI_QUERY_RESERVATION"),
1032 "Fix: multi-query execution planning must reserve scratch and output staging fallibly."
1033 );
1034 assert!(
1035 !production.contains(concat!("FxHashMap::with_capacity", "_and_hasher"))
1036 && !production.contains(concat!("FxHashSet::with_capacity", "_and_hasher"))
1037 && !production.contains(concat!("Vec::with_capacity", "(query_count)"))
1038 && !production.contains(concat!(
1039 "Vec::with_capacity",
1040 "(scratch.grouped_queries.len())"
1041 ))
1042 && !production.contains(concat!("Vec::with_capacity", "(queries.len() - start)"))
1043 && !production.contains(concat!("groups: vec![", "MultiQueryGroup"))
1044 && !production.contains(concat!("queries: vec![", "query.query]"))
1045 && !production
1046 .contains(concat!("scratch.group_indices", ".reserve(queries.len())"))
1047 && !production.contains(concat!(
1048 "scratch.grouped_queries",
1049 ".reserve(queries.len())"
1050 ))
1051 && !production.contains(concat!("scratch.seen_queries", ".reserve(queries.len())")),
1052 "Fix: multi-query release planning must not use infallible staging allocation."
1053 );
1054 }
1055
1056 #[test]
1057 fn multi_query_planner_uses_shared_monotonic_sort_fast_path() {
1058 let production = include_str!("multi_query_execution.rs")
1059 .split("#[cfg(test)]")
1060 .next()
1061 .expect("Fix: multi-query production source must precede tests");
1062
1063 assert!(
1064 production.contains("use crate::ordering::sort_unstable_by_key_if_needed;")
1065 && production.contains("sort_unstable_by_key_if_needed(&mut scratch.grouped_queries")
1066 && production.contains("sort_unstable_by_key_if_needed(group_queries"),
1067 "Fix: multi-query planning must reuse the shared monotonic sort fast path for release-order batches."
1068 );
1069 assert!(
1070 !production.contains(".sort_unstable_by_key("),
1071 "Fix: multi-query planning must not sort already monotonic release batches unconditionally."
1072 );
1073 }
1074
1075 #[test]
1076 fn generated_multi_query_plans_preserve_grouping_budget_and_identity_contracts() {
1077 let mut state = 0x6a09_e667_f3bc_c909_u64;
1078 for case_index in 0..768usize {
1079 let query_count = 1 + (next_u64(&mut state) as usize % 64);
1080 let mut graph_bytes_by_hash = [0_u64; 8];
1081 let mut queries = Vec::new();
1082 for index in 0..query_count {
1083 let graph_slot = (next_u64(&mut state) as usize % graph_bytes_by_hash.len()) + 1;
1084 let graph_upload_bytes = if graph_bytes_by_hash[graph_slot - 1] == 0 {
1085 128 + next_u64(&mut state) % 16_384
1086 } else {
1087 graph_bytes_by_hash[graph_slot - 1]
1088 };
1089 graph_bytes_by_hash[graph_slot - 1] = graph_upload_bytes;
1090 queries.push(query(
1091 index as u32,
1092 graph_slot as u64,
1093 1 + next_u64(&mut state) % 5,
1094 graph_upload_bytes,
1095 next_u64(&mut state) % 512,
1096 next_u64(&mut state) % 1_024,
1097 next_u64(&mut state) % 256,
1098 ));
1099 }
1100
1101 let budget = graph_bytes_by_hash.iter().copied().sum::<u64>()
1102 + (query_count as u64 * 2_048)
1103 + 16_384;
1104 let plan = plan_multi_query_execution(&queries, budget)
1105 .expect("Fix: generated multi-query plan should fit generous budget");
1106 assert_eq!(
1107 plan.launch_count as usize,
1108 plan.groups.len(),
1109 "case {case_index}"
1110 );
1111 assert!(plan.final_only_host_fence_per_group, "case {case_index}");
1112 assert!(
1113 plan.groups.windows(2).all(|pair| (
1114 pair[0].graph_layout_hash,
1115 pair[0].traversal_key
1116 ) <= (
1117 pair[1].graph_layout_hash,
1118 pair[1].traversal_key
1119 )),
1120 "case {case_index}"
1121 );
1122 let mut seen = vec![false; query_count];
1123 let mut avoided_launches = 0_u32;
1124 let mut avoided_host_fences = 0_u32;
1125 let mut peak_resident_bytes = 0_u64;
1126 for group in &plan.groups {
1127 assert!(group.resident_bytes <= budget, "case {case_index}");
1128 assert!(
1129 group.queries.windows(2).all(|pair| pair[0] <= pair[1]),
1130 "case {case_index}"
1131 );
1132 avoided_launches = avoided_launches
1133 .checked_add(group.avoided_launches)
1134 .expect("Fix: generated avoided launch sum should fit u32");
1135 avoided_host_fences = avoided_host_fences
1136 .checked_add(group.avoided_host_fences)
1137 .expect("Fix: generated avoided fence sum should fit u32");
1138 peak_resident_bytes = peak_resident_bytes.max(group.resident_bytes);
1139 for query in &group.queries {
1140 let slot = *query as usize;
1141 assert!(slot < query_count, "case {case_index}");
1142 assert!(!seen[slot], "case {case_index}");
1143 seen[slot] = true;
1144 }
1145 }
1146 assert!(seen.into_iter().all(|value| value), "case {case_index}");
1147 assert_eq!(plan.avoided_launches, avoided_launches, "case {case_index}");
1148 assert_eq!(
1149 plan.avoided_host_fences, avoided_host_fences,
1150 "case {case_index}"
1151 );
1152 assert_eq!(
1153 plan.peak_resident_bytes, peak_resident_bytes,
1154 "case {case_index}"
1155 );
1156 }
1157 }
1158
1159 #[test]
1160 fn multi_query_rejects_invalid_inputs_and_budget_overflow() {
1161 assert_eq!(
1162 plan_multi_query_execution(&[query(1, 0, 1, 8, 1, 1, 1)], 128)
1163 .expect_err("missing graph hash should fail"),
1164 MultiQueryExecutionError::ZeroGraphHash { query: 1 }
1165 );
1166 assert_eq!(
1167 plan_multi_query_execution(&[query(1, 1, 1, 0, 1, 1, 1)], 128)
1168 .expect_err("zero graph bytes should fail"),
1169 MultiQueryExecutionError::ZeroGraphUploadBytes { query: 1 }
1170 );
1171 assert_eq!(
1172 plan_multi_query_execution(
1173 &[query(1, 1, 1, 8, 1, 1, 1), query(2, 1, 2, 16, 1, 1, 1)],
1174 128,
1175 )
1176 .expect_err("same graph hash with conflicting bytes should fail"),
1177 MultiQueryExecutionError::GraphUploadBytesMismatch {
1178 graph_layout_hash: 1,
1179 expected_bytes: 8,
1180 actual_bytes: 16,
1181 query: 2,
1182 }
1183 );
1184 assert_eq!(
1185 plan_multi_query_execution(
1186 &[query(1, 1, 1, 8, 1, 1, 1), query(1, 1, 1, 8, 1, 1, 1)],
1187 128,
1188 )
1189 .expect_err("duplicate query should fail"),
1190 MultiQueryExecutionError::DuplicateQuery { query: 1 }
1191 );
1192 assert_eq!(
1193 plan_multi_query_execution(&[query(2, 1, 1, 128, 16, 16, 16)], 127)
1194 .expect_err("over-budget group should fail"),
1195 MultiQueryExecutionError::OverBudget {
1196 graph_layout_hash: 1,
1197 traversal_key: 1,
1198 required_bytes: 176,
1199 budget_bytes: 127,
1200 }
1201 );
1202 }
1203
1204 fn query(
1205 query: u32,
1206 graph_layout_hash: u64,
1207 traversal_key: u64,
1208 graph_upload_bytes: u64,
1209 frontier_bytes: u64,
1210 scratch_bytes: u64,
1211 output_bytes: u64,
1212 ) -> MultiQuery {
1213 MultiQuery {
1214 query,
1215 graph_layout_hash,
1216 traversal_key,
1217 graph_upload_bytes,
1218 frontier_bytes,
1219 scratch_bytes,
1220 output_bytes,
1221 }
1222 }
1223
1224 fn next_u64(state: &mut u64) -> u64 {
1225 *state = state
1226 .wrapping_mul(6_364_136_223_846_793_005)
1227 .wrapping_add(1_442_695_040_888_963_407);
1228 *state
1229 }
1230}