sp1_prover/worker/controller/
vk_tree.rs1use std::collections::{BTreeMap, BTreeSet};
2
3use either::Either;
4use rand::{seq::SliceRandom, SeedableRng};
5use serde::{Deserialize, Serialize};
6use sp1_hypercube::DIGEST_SIZE;
7use sp1_primitives::SP1Field;
8use sp1_prover_types::{ArtifactClient, TaskType};
9
10use crate::{
11 shapes::create_all_input_shapes,
12 worker::{RawTaskRequest, SP1Controller, TaskError, TaskMetadata, WorkerClient},
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VkeyMapControllerInput {
17 pub range_or_limit: Option<Either<Vec<usize>, usize>>,
18 pub chunk_size: usize,
19 pub reduce_batch_size: usize,
20}
21
22#[derive(Debug, Clone, Serialize, serde::Deserialize)]
23pub struct VkeyMapControllerOutput {
24 pub vk_map: BTreeMap<[SP1Field; DIGEST_SIZE], usize>,
25 pub panic_indices: Vec<usize>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VkeyMapChunkInput {
30 pub reduce_batch_size: usize,
31 pub indices: Vec<usize>,
32 pub total_inputs: usize,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct VkeyMapChunkOutput {
37 pub vk_set: BTreeSet<[SP1Field; DIGEST_SIZE]>,
38 pub panic_indices: Vec<usize>,
39}
40
41impl<A: ArtifactClient, W: WorkerClient> SP1Controller<A, W> {
42 pub async fn run_sp1_util_vkey_map_controller(
43 &self,
44 request: RawTaskRequest,
45 ) -> Result<TaskMetadata, TaskError>
46where {
47 let subscriber =
48 self.worker_client.subscriber(request.context.proof_id.clone()).await?.per_task();
49 let input =
50 self.artifact_client.download::<VkeyMapControllerInput>(&request.inputs[0]).await?;
51
52 let num_shapes =
53 create_all_input_shapes(self.verifier.core.machine().shape(), input.reduce_batch_size)
54 .into_iter()
55 .collect::<BTreeSet<_>>()
56 .len();
57
58 let limit = input.range_or_limit.unwrap_or(Either::Right(num_shapes));
59
60 let mut all_indices = match limit {
61 Either::Left(range) => range,
62 Either::Right(limit) => (0..limit).collect::<Vec<_>>(),
63 };
64
65 {
67 let mut rng = rand::rngs::StdRng::seed_from_u64(0);
68 all_indices.shuffle(&mut rng);
69 }
70
71 let chunks =
72 all_indices.chunks(input.chunk_size).map(|chunk| chunk.to_vec()).collect::<Vec<_>>();
73
74 let inputs = chunks
75 .into_iter()
76 .map(|chunk| VkeyMapChunkInput {
77 reduce_batch_size: input.reduce_batch_size,
78 indices: chunk,
79 total_inputs: num_shapes,
80 })
81 .collect::<Vec<_>>();
82
83 let mut input_artifacts = Vec::new();
84 for input in &inputs {
85 let artifact = self.artifact_client.create_artifact()?;
86 self.artifact_client.upload(&artifact, &input).await?;
87 input_artifacts.push(artifact);
88 }
89
90 let mut output_artifacts = Vec::new();
91 for _ in 0..inputs.len() {
92 let artifact = self.artifact_client.create_artifact()?;
93 output_artifacts.push(artifact);
94 }
95
96 let mut tasks = Vec::new();
97
98 for (task_input, task_output) in input_artifacts.into_iter().zip(output_artifacts.iter()) {
99 let request = RawTaskRequest {
100 inputs: vec![task_input],
101 outputs: vec![task_output.clone()],
102 context: request.context.clone(),
103 };
104 let task = self.worker_client.submit_task(TaskType::UtilVkeyMapChunk, request).await?;
105 tasks.push(task);
106 }
107
108 for task in tasks {
109 subscriber.wait_task(task).await?;
110 }
111
112 let mut outputs = Vec::new();
113 for output_artifact in output_artifacts {
114 let output =
115 self.artifact_client.download::<VkeyMapChunkOutput>(&output_artifact).await?;
116 outputs.push(output);
117 }
118
119 let (vk_maps, panic_indices): (Vec<_>, Vec<_>) =
121 outputs.into_iter().map(|output| (output.vk_set, output.panic_indices)).unzip();
122 let final_vk_map = vk_maps
123 .into_iter()
124 .flatten()
125 .collect::<BTreeSet<_>>()
127 .into_iter()
128 .enumerate()
129 .map(|(i, vk)| (vk, i))
130 .collect::<BTreeMap<_, _>>();
131 let panic_indices = panic_indices.into_iter().flatten().collect::<Vec<_>>();
132
133 let output = VkeyMapControllerOutput { vk_map: final_vk_map, panic_indices };
134
135 self.artifact_client.upload(&request.outputs[0], output).await?;
136
137 Ok(TaskMetadata::default())
138 }
139}