Skip to main content

sp1_prover/worker/controller/
vk_tree.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use either::Either;
4use rand::{seq::SliceRandom, SeedableRng};
5use serde::{Deserialize, Serialize};
6use sp1_hypercube::DIGEST_SIZE;
7use sp1_primitives::SP1Field;
8use sp1_prover_types::{ArtifactClient, TaskType};
9
10use crate::{
11    shapes::create_all_input_shapes,
12    worker::{RawTaskRequest, SP1Controller, TaskError, TaskMetadata, WorkerClient},
13};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VkeyMapControllerInput {
17    pub range_or_limit: Option<Either<Vec<usize>, usize>>,
18    pub chunk_size: usize,
19    pub reduce_batch_size: usize,
20}
21
22#[derive(Debug, Clone, Serialize, serde::Deserialize)]
23pub struct VkeyMapControllerOutput {
24    pub vk_map: BTreeMap<[SP1Field; DIGEST_SIZE], usize>,
25    pub panic_indices: Vec<usize>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VkeyMapChunkInput {
30    pub reduce_batch_size: usize,
31    pub indices: Vec<usize>,
32    pub total_inputs: usize,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct VkeyMapChunkOutput {
37    pub vk_set: BTreeSet<[SP1Field; DIGEST_SIZE]>,
38    pub panic_indices: Vec<usize>,
39}
40
41impl<A: ArtifactClient, W: WorkerClient> SP1Controller<A, W> {
42    pub async fn run_sp1_util_vkey_map_controller(
43        &self,
44        request: RawTaskRequest,
45    ) -> Result<TaskMetadata, TaskError>
46where {
47        let subscriber =
48            self.worker_client.subscriber(request.context.proof_id.clone()).await?.per_task();
49        let input =
50            self.artifact_client.download::<VkeyMapControllerInput>(&request.inputs[0]).await?;
51
52        let num_shapes =
53            create_all_input_shapes(self.verifier.core.machine().shape(), input.reduce_batch_size)
54                .into_iter()
55                .collect::<BTreeSet<_>>()
56                .len();
57
58        let limit = input.range_or_limit.unwrap_or(Either::Right(num_shapes));
59
60        let mut all_indices = match limit {
61            Either::Left(range) => range,
62            Either::Right(limit) => (0..limit).collect::<Vec<_>>(),
63        };
64
65        // Randomize the order of the indices
66        {
67            let mut rng = rand::rngs::StdRng::seed_from_u64(0);
68            all_indices.shuffle(&mut rng);
69        }
70
71        let chunks =
72            all_indices.chunks(input.chunk_size).map(|chunk| chunk.to_vec()).collect::<Vec<_>>();
73
74        let inputs = chunks
75            .into_iter()
76            .map(|chunk| VkeyMapChunkInput {
77                reduce_batch_size: input.reduce_batch_size,
78                indices: chunk,
79                total_inputs: num_shapes,
80            })
81            .collect::<Vec<_>>();
82
83        let mut input_artifacts = Vec::new();
84        for input in &inputs {
85            let artifact = self.artifact_client.create_artifact()?;
86            self.artifact_client.upload(&artifact, &input).await?;
87            input_artifacts.push(artifact);
88        }
89
90        let mut output_artifacts = Vec::new();
91        for _ in 0..inputs.len() {
92            let artifact = self.artifact_client.create_artifact()?;
93            output_artifacts.push(artifact);
94        }
95
96        let mut tasks = Vec::new();
97
98        for (task_input, task_output) in input_artifacts.into_iter().zip(output_artifacts.iter()) {
99            let request = RawTaskRequest {
100                inputs: vec![task_input],
101                outputs: vec![task_output.clone()],
102                context: request.context.clone(),
103            };
104            let task = self.worker_client.submit_task(TaskType::UtilVkeyMapChunk, request).await?;
105            tasks.push(task);
106        }
107
108        for task in tasks {
109            subscriber.wait_task(task).await?;
110        }
111
112        let mut outputs = Vec::new();
113        for output_artifact in output_artifacts {
114            let output =
115                self.artifact_client.download::<VkeyMapChunkOutput>(&output_artifact).await?;
116            outputs.push(output);
117        }
118
119        // Merge outputs into a single map and reassign indexes
120        let (vk_maps, panic_indices): (Vec<_>, Vec<_>) =
121            outputs.into_iter().map(|output| (output.vk_set, output.panic_indices)).unzip();
122        let final_vk_map = vk_maps
123            .into_iter()
124            .flatten()
125            // It's important to order the VKeys to ensure consistent indexing.
126            .collect::<BTreeSet<_>>()
127            .into_iter()
128            .enumerate()
129            .map(|(i, vk)| (vk, i))
130            .collect::<BTreeMap<_, _>>();
131        let panic_indices = panic_indices.into_iter().flatten().collect::<Vec<_>>();
132
133        let output = VkeyMapControllerOutput { vk_map: final_vk_map, panic_indices };
134
135        self.artifact_client.upload(&request.outputs[0], output).await?;
136
137        Ok(TaskMetadata::default())
138    }
139}