1use std::{
2 collections::{BTreeMap, VecDeque},
3 sync::{Arc, Mutex},
4};
5
6use futures::future::try_join_all;
7use hashbrown::HashMap;
8use sp1_hypercube::{
9 air::{ShardBoundary, ShardRange},
10 SP1PcsProofInner, SP1RecursionProof,
11};
12use sp1_primitives::SP1GlobalContext;
13use sp1_prover_types::{Artifact, ArtifactClient, ArtifactId, ArtifactType, TaskStatus, TaskType};
14use sp1_recursion_circuit::machine::SP1ShapedWitnessValues;
15use tokio::{sync::mpsc, task::JoinSet};
16use tracing::Instrument;
17
18use crate::{
19 worker::{
20 MessageReceiver, ProofData, RecursionProverData, ReduceTaskRequest, TaskContext, TaskError,
21 TaskId, WorkerClient,
22 },
23 SP1CircuitWitness, SP1CompressWitness, SP1ProverComponents,
24};
25
26pub struct CompressTask {
27 pub witness: SP1CompressWitness,
28}
29
30#[derive(Debug, Clone)]
36pub struct RecursionProof {
37 pub shard_range: ShardRange,
38 pub proof: Artifact,
39}
40
41#[derive(Clone, Debug)]
47pub struct RangeProofs {
48 pub shard_range: ShardRange,
49 pub proofs: VecDeque<RecursionProof>,
50}
51
52impl RangeProofs {
53 pub fn new(shard_range: ShardRange, proofs: VecDeque<RecursionProof>) -> Self {
54 Self { shard_range, proofs }
55 }
56
57 pub fn as_artifacts(self) -> impl Iterator<Item = Artifact> + Send + Sync {
58 let range_artifact = Artifact::from(
59 serde_json::to_string(&self.shard_range).expect("Failed to serialize shard range"),
60 );
61 std::iter::once(range_artifact).chain(self.proofs.into_iter().flat_map(|proof| {
62 let range_str =
63 serde_json::to_string(&proof.shard_range).expect("Failed to serialize shard range");
64 let range_artifact = Artifact::from(range_str);
65 let proof_artifact = proof.proof;
66 [range_artifact, proof_artifact]
67 }))
68 }
69
70 pub fn from_artifacts(artifacts: &[Artifact]) -> Result<Self, TaskError> {
71 if artifacts.len() % 2 != 1 || artifacts.len() <= 1 {
72 return Err(TaskError::Fatal(anyhow::anyhow!(
73 "Invalid number of artifacts: {:?}",
74 artifacts.len()
75 )));
76 }
77 let shard_range =
78 serde_json::from_str(artifacts[0].id()).map_err(|e| TaskError::Fatal(e.into()))?;
79 let proofs = artifacts[1..]
80 .chunks_exact(2)
81 .map(|chunk| -> Result<RecursionProof, TaskError> {
82 let shard_range =
83 serde_json::from_str(chunk[0].id()).map_err(|e| TaskError::Fatal(e.into()))?;
84 let proof = chunk[1].clone();
85 Ok(RecursionProof { shard_range, proof })
86 })
87 .collect::<Result<VecDeque<RecursionProof>, TaskError>>()?;
88 Ok(RangeProofs { shard_range, proofs })
89 }
90
91 pub fn len(&self) -> usize {
92 self.proofs.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
96 self.proofs.is_empty()
97 }
98
99 pub fn push_right(&mut self, proof: RecursionProof) {
100 assert_eq!(proof.shard_range.end(), self.shard_range.start());
101 self.shard_range = (proof.shard_range.start()..self.shard_range.end()).into();
102 self.proofs.push_front(proof);
103 }
104
105 pub fn push_left(&mut self, proof: RecursionProof) {
106 assert_eq!(proof.shard_range.start(), self.shard_range.end());
107 self.shard_range = (self.shard_range.start()..proof.shard_range.end()).into();
108 self.proofs.push_back(proof);
109 }
110
111 pub fn split_off(&mut self, at: usize) -> Option<Self> {
112 if at >= self.proofs.len() {
113 return None;
114 }
115 let proofs = self.proofs.split_off(at);
117 let range = {
119 let at_start_range = proofs.front().unwrap().shard_range.start();
120 let at_end_range = proofs.iter().last().unwrap().shard_range.end();
121 at_start_range..at_end_range
122 }
123 .into();
124 let new_self_range = {
126 let at_start_range = self.proofs.front().unwrap().shard_range.start();
127 let at_end_range = self.proofs.iter().last().unwrap().shard_range.end();
128 at_start_range..at_end_range
129 };
130 self.shard_range = new_self_range.into();
132 Some(Self { shard_range: range, proofs })
134 }
135
136 pub fn push_both(&mut self, middle: RecursionProof, right: Self) {
137 assert_eq!(middle.shard_range.start(), self.shard_range.end());
138 assert_eq!(right.shard_range.start(), middle.shard_range.end());
139 self.proofs.push_back(middle);
141 for proof in right.proofs {
143 self.proofs.push_back(proof);
144 }
145 self.shard_range = (self.shard_range.start()..right.shard_range.end()).into();
147 }
148
149 pub fn range(&self) -> ShardRange {
150 self.shard_range
151 }
152
153 pub async fn download_witness<C: SP1ProverComponents>(
154 &self,
155 is_complete: bool,
156 artifact_client: &impl ArtifactClient,
157 recursion_data: &RecursionProverData<C>,
158 ) -> Result<SP1CircuitWitness, TaskError> {
159 let proofs = try_join_all(self.proofs.iter().map(|proof| async {
161 let downloaded_proof = artifact_client
162 .download::<SP1RecursionProof<SP1GlobalContext, SP1PcsProofInner>>(&proof.proof)
163 .await?;
164
165 Ok::<_, TaskError>(downloaded_proof)
166 }))
167 .await?;
168
169 let (vks_and_proofs, merkle_proofs): (Vec<_>, Vec<_>) = proofs
172 .into_iter()
173 .map(|proof| ((proof.vk, proof.proof), proof.vk_merkle_proof))
174 .unzip();
175
176 let witness = SP1ShapedWitnessValues { vks_and_proofs, is_complete };
177
178 let witness = recursion_data.append_merkle_proofs_to_witness(witness, merkle_proofs)?;
179
180 let witness = SP1CircuitWitness::Compress(witness);
181 Ok(witness)
182 }
183
184 pub async fn try_delete_proofs(
185 &self,
186 artifact_client: &impl ArtifactClient,
187 ) -> Result<(), TaskError> {
188 try_join_all(self.proofs.iter().map(|proof| async {
189 artifact_client.try_delete(&proof.proof, ArtifactType::UnspecifiedArtifactType).await?;
191 Ok::<_, TaskError>(())
192 }))
193 .await?;
194 Ok(())
195 }
196}
197
198#[derive(Debug)]
200enum Sibling {
201 Left(RangeProofs),
202 Right(RangeProofs),
203 Both(RangeProofs, RangeProofs),
204}
205
206pub(super) struct CompressTree {
234 map: BTreeMap<ShardBoundary, RangeProofs>,
235 batch_size: usize,
236}
237
238impl CompressTree {
239 pub fn new(batch_size: usize) -> Self {
241 Self { map: BTreeMap::new(), batch_size }
242 }
243
244 fn insert(&mut self, proofs: RangeProofs) {
246 self.map.insert(proofs.shard_range.start(), proofs);
247 }
248
249 fn sibling(&mut self, proof: &RecursionProof) -> Option<Sibling> {
255 if let Some(previous) =
257 self.map.range(ShardBoundary::initial()..=proof.shard_range.start()).next_back()
258 {
259 let (start, proofs) = previous;
260 let start = *start;
261 let proofs = proofs.clone();
262
263 if proofs.shard_range.end() == proof.shard_range.start() {
264 let left = self.map.remove(&start).unwrap();
265 if let Some(right) = self.map.remove(&proof.shard_range.end()) {
267 return Some(Sibling::Both(left, right));
268 } else {
269 return Some(Sibling::Left(left));
270 }
271 }
272 }
273 if let Some(right) = self.map.remove(&proof.shard_range.end()) {
275 return Some(Sibling::Right(right));
276 }
277
278 None
280 }
281
282 fn is_complete(
283 &self,
284 range: &ShardRange,
285 pending_tasks: usize,
286 full_range: &Option<ShardRange>,
287 ) -> bool {
288 let is_range_equal = full_range.as_ref().is_some_and(|full| range == full);
289 tracing::debug!(
290 "Checking if complete: Pending tasks: {:?}, map is empty: {:?}, full range is some: {:?}, is_range_equal: {:?}",
291 pending_tasks,
292 self.map.is_empty(),
293 full_range.is_some(),
294 is_range_equal,
295 );
296 (pending_tasks == 0) && self.map.is_empty() && is_range_equal
297 }
298
299 pub async fn reduce_proofs(
323 &mut self,
324 context: TaskContext,
325 output: Artifact,
326 mut core_proofs_rx: MessageReceiver<ProofData>,
327 artifact_client: &impl ArtifactClient,
328 worker_client: &impl WorkerClient,
329 ) -> Result<(), TaskError> {
330 let (core_proofs_subscriber, mut core_proofs_event_stream) =
334 worker_client.subscriber(context.proof_id.clone()).await?.stream();
335 let core_proof_map = Arc::new(Mutex::new(HashMap::<TaskId, RecursionProof>::new()));
336 let mut full_range: Option<ShardRange> = None;
338 let mut max_range = ShardBoundary::initial()..ShardBoundary::initial();
340 let mut pending_tasks = 0;
342 let (proof_tx, mut proof_rx) = mpsc::unbounded_channel::<RecursionProof>();
344 let (subscriber, mut event_stream) =
346 worker_client.subscriber(context.proof_id.clone()).await?.stream();
347 let mut proof_map = HashMap::<TaskId, RecursionProof>::new();
348
349 let mut join_set = JoinSet::<Result<(), TaskError>>::new();
350
351 let (num_core_proofs_tx, mut num_core_proofs_rx) = mpsc::channel(1);
352 join_set.spawn({
354 let core_proof_map = core_proof_map.clone();
355 async move {
356 let mut num_core_proofs = 0;
357 while let Some(proof_data) = core_proofs_rx.recv().await {
358 core_proofs_subscriber
359 .subscribe(proof_data.task_id.clone())
360 .map_err(|e| TaskError::Fatal(e.into()))?;
361 let proof =
362 RecursionProof { shard_range: proof_data.range, proof: proof_data.proof };
363 core_proof_map.lock().unwrap().insert(proof_data.task_id, proof);
364 num_core_proofs += 1;
365 }
366 tracing::debug!(
367 "All core proofs received: number of core proofs: {:?}",
368 num_core_proofs
369 );
370 num_core_proofs_tx.send(num_core_proofs).await.ok();
371 Ok(())
372 }
373 .instrument(tracing::debug_span!("Core proof processing"))
374 });
375
376 let mut num_core_proofs_completed = 0;
377 let mut num_core_proofs: Option<usize> = None;
378 let mut last_core_proof = None;
379 loop {
380 tokio::select! {
381 Some(num_proofs) = num_core_proofs_rx.recv() => {
382 tracing::debug!("Number of core proofs completed: {:?}", num_proofs);
383 num_core_proofs = Some(num_proofs);
384 if num_core_proofs_completed == num_proofs {
387 tracing::debug!("All core proofs completed: {:?}", num_proofs);
388 full_range = Some(max_range.clone().into());
389 tracing::debug!("Setting full range to: {:?}", full_range);
390 if let Some(proof) = last_core_proof.take() {
393 proof_tx.send(proof).map_err(|_| TaskError::Fatal(anyhow::anyhow!("Compress tree panicked")))?;
394 }
395 }
396 }
397 Some(proof) = proof_rx.recv() => {
398 pending_tasks -= 1;
400 if self.is_complete(&proof.shard_range, pending_tasks, &full_range) {
401 return Ok(());
402 }
403 if let Some(sibling) = self.sibling(&proof) {
405 tracing::debug!("Found sibling");
406 let mut proofs = match sibling {
407 Sibling::Left(mut proofs) => {
408 proofs.push_left(proof);
409 proofs
410 }
411 Sibling::Right(mut proofs) => {
412 proofs.push_right(proof);
413 proofs
414 }
415 Sibling::Both(mut proofs, right) => {
416 proofs.push_both(proof, right);
417 proofs
418 }
419 };
420
421 let split = proofs.split_off(self.batch_size);
423 if let Some(split) = split {
424 self.insert(split);
425 }
426
427 if proofs.len() > self.batch_size {
428 tracing::error!("Proofs are larger than the batch size: {:?}", proofs.len());
429 panic!("Proofs are larger than the batch size: {:?}", proofs.len());
430 }
431
432 let is_complete = self.is_complete(&proofs.shard_range, pending_tasks, &full_range);
433 if proofs.len() == self.batch_size || is_complete {
434 let shard_range = proofs.shard_range;
435 let output_artifact = if is_complete { output.clone() } else { artifact_client.create_artifact()? };
437 let task_request = ReduceTaskRequest {
438 range_proofs: proofs,
439 is_complete,
440 output: output_artifact.clone(),
441 context: context.clone(),
442 };
443 let raw_task_request = task_request.into_raw()?;
444 let task_id = worker_client.submit_task(TaskType::RecursionReduce, raw_task_request).await?;
445 proof_map.insert(task_id.clone(), RecursionProof { shard_range, proof: output_artifact });
447 subscriber.subscribe(task_id).map_err(|_| TaskError::Fatal(anyhow::anyhow!("Subscriver closed")))?;
449 pending_tasks += 1;
451 } else {
452 self.insert(proofs);
453 }
454 } else {
455 tracing::debug!("No neighboring range found, adding proof to tree");
456 let mut queue = VecDeque::with_capacity(self.batch_size);
458 let range = proof.shard_range;
459 queue.push_back(proof);
460 let proofs = RangeProofs::new(range, queue);
461 self.insert(proofs);
462 }
463 }
464 Some((task_id, status)) = event_stream.recv() => {
465 if status != TaskStatus::Succeeded {
466 return Err(
467 TaskError::Fatal
468 (anyhow::anyhow!("Reduction task {} failed", task_id))
469 );
470 }
471 let proof = proof_map.remove(&task_id);
472 if let Some(proof) = proof {
473 proof_tx.send(proof).map_err(|_| TaskError::Fatal(anyhow::anyhow!("Compress tree panicked")))?;
475 }
476 else {
477 tracing::debug!("Proof not found for task id: {}", task_id);
478 }
479 }
480
481 Some((task_id, status)) = core_proofs_event_stream.recv() => {
482 if status != TaskStatus::Succeeded {
483 return Err(
484 TaskError::Fatal
485 (anyhow::anyhow!("Core proof task {} failed", task_id))
486 );
487 }
488 let normalize_proof = core_proof_map.lock().unwrap().remove(&task_id);
490 if let Some(normalize_proof) = normalize_proof {
491 let shard_range = &normalize_proof.shard_range;
492 let (start, end) = (shard_range.start(), shard_range.end());
493 if start < max_range.start {
494 max_range.start = start;
495 }
496 if end > max_range.end {
497 max_range.end = end;
498 }
499 let previous_core_proof = last_core_proof.take();
501 last_core_proof = Some(normalize_proof);
502 if let Some(proof) = previous_core_proof {
505 proof_tx.send(proof).map_err(|_| TaskError::Fatal(anyhow::anyhow!("Compress tree panicked")))?;
507 }
508
509 pending_tasks += 1;
511 num_core_proofs_completed += 1;
513 if let Some(num_core_proofs) = num_core_proofs {
516 if num_core_proofs_completed == num_core_proofs {
517 full_range = Some(max_range.clone().into());
518 tracing::debug!("Setting full range to: {:?}", full_range);
519 tracing::debug!("Sending last core proof to proof queue: {:?}", last_core_proof);
521 let last_core_proof = last_core_proof.take().unwrap();
522 proof_tx.send(last_core_proof).map_err(|_| TaskError::Fatal(anyhow::anyhow!("Compress tree panicked")))?;
523 core_proofs_event_stream.close();
525 }
526 }
527 } else {
528 tracing::debug!("Core proof not found for task id: {}", task_id);
529 }
530 }
531 else => {
532 break;
533 }
534 }
535 }
536
537 Err(TaskError::Fatal(anyhow::anyhow!("todo explain this")))
538 }
539}
540
541#[cfg(test)]
542mod test_utils {
543 use std::time::Duration;
544
545 use sp1_core_machine::utils::setup_logger;
546 use sp1_prover_types::InMemoryArtifactClient;
547
548 use crate::{
549 shapes::DEFAULT_ARITY,
550 worker::{test_utils::mock_worker_client, ProofId, ProveShardTaskRequest, RequesterId},
551 };
552
553 use super::*;
554
555 async fn create_dummy_prove_shard_task(
556 range: ShardRange,
557 elf_artifact: Artifact,
558 common_input_artifact: Artifact,
559 context: TaskContext,
560 core_proofs_tx: &mpsc::UnboundedSender<Vec<u8>>,
561 worker_client: &impl WorkerClient,
562 artifact_client: &impl ArtifactClient,
563 ) {
564 let record_artifact = artifact_client.create_artifact().unwrap();
565 let proof_artifact = artifact_client.create_artifact().unwrap();
566
567 let request = ProveShardTaskRequest {
568 elf: elf_artifact.clone(),
569 common_input: common_input_artifact.clone(),
570 record: record_artifact,
571 output: proof_artifact.clone(),
572 deferred_marker_task: Artifact::from("dummy marker task".to_string()),
573 deferred_output: Artifact::from("dummy output artifact".to_string()),
574 context: context.clone(),
575 };
576
577 let task = request.into_raw().unwrap();
578
579 let task_id = worker_client.submit_task(TaskType::ProveShard, task).await.unwrap();
580 let proof_data = ProofData { task_id, range, proof: proof_artifact };
581 let payload = bincode::serialize(&proof_data).unwrap();
582 core_proofs_tx.send(payload).unwrap();
583 }
584
585 #[tokio::test]
586 async fn test_compress_tree() {
587 setup_logger();
588 let num_core_shards = 200;
589 let core_start_delay = Duration::from_millis(10);
590 let num_memory_shards = 40;
591 let memory_start_delay = Duration::from_millis(500);
592 let num_precompile_shards = 20;
593 let precompile_start_delay = Duration::from_millis(500);
594 let num_deferred_shards = 100;
595 let deferred_start_delay = Duration::from_millis(1);
596 let num_iterations = 1;
597 let random_intervals = HashMap::from([
598 (TaskType::Controller, Duration::from_millis(20)..Duration::from_millis(100)),
599 (TaskType::SetupVkey, Duration::from_millis(20)..Duration::from_millis(100)),
600 (TaskType::RecursionReduce, Duration::from_millis(100)..Duration::from_millis(200)),
601 (TaskType::ProveShard, Duration::from_millis(200)..Duration::from_millis(500)),
602 (TaskType::MarkerDeferredRecord, Duration::from_millis(20)..Duration::from_millis(100)),
603 (TaskType::RecursionDeferred, Duration::from_millis(20)..Duration::from_millis(100)),
604 (TaskType::ShrinkWrap, Duration::from_millis(20)..Duration::from_millis(100)),
605 (TaskType::PlonkWrap, Duration::from_millis(20)..Duration::from_millis(100)),
606 (TaskType::Groth16Wrap, Duration::from_millis(20)..Duration::from_millis(100)),
607 (TaskType::ExecuteOnly, Duration::from_millis(20)..Duration::from_millis(100)),
608 (TaskType::CoreExecute, Duration::from_millis(20)..Duration::from_millis(100)),
609 ]);
610
611 for _ in 0..num_iterations {
612 let worker_client = mock_worker_client(random_intervals.clone());
613
614 let artifact_client = InMemoryArtifactClient::new();
615
616 let mut compress_tree = CompressTree::new(DEFAULT_ARITY);
617
618 let context = TaskContext {
619 proof_id: ProofId::new("test_compress_tree"),
620 parent_id: None,
621 parent_context: None,
622 requester_id: RequesterId::new("test_compress_tree"),
623 };
624
625 let (core_proofs_tx, core_proofs_rx_inner) = mpsc::unbounded_channel::<Vec<u8>>();
626 let core_proofs_rx = MessageReceiver::<ProofData>::new(core_proofs_rx_inner);
627
628 let elf_artifact = artifact_client.create_artifact().unwrap();
629 let common_input_artifact = artifact_client.create_artifact().unwrap();
630
631 tokio::task::spawn({
632 let worker_client = worker_client.clone();
633 let artifact_client = artifact_client.clone();
634 let elf_artifact = elf_artifact.clone();
635 let common_input_artifact = common_input_artifact.clone();
636 let context = context.clone();
637 let core_proofs_tx = core_proofs_tx.clone();
638 async move {
639 tokio::time::sleep(core_start_delay).await;
640 for i in 1..=num_core_shards {
641 let range = ShardRange {
642 timestamp_range: (i, i + 1),
643 initialized_address_range: (0, 0),
644 finalized_address_range: (0, 0),
645 initialized_page_index_range: (0, 0),
646 finalized_page_index_range: (0, 0),
647 deferred_proof_range: (num_deferred_shards, num_deferred_shards),
648 };
649 create_dummy_prove_shard_task(
650 range,
651 elf_artifact.clone(),
652 common_input_artifact.clone(),
653 context.clone(),
654 &core_proofs_tx,
655 &worker_client,
656 &artifact_client,
657 )
658 .await;
659 }
660 }
661 });
662
663 tokio::task::spawn({
664 let worker_client = worker_client.clone();
665 let artifact_client = artifact_client.clone();
666 let elf_artifact = elf_artifact.clone();
667 let common_input_artifact = common_input_artifact.clone();
668 let context = context.clone();
669 let core_proofs_tx = core_proofs_tx.clone();
670 async move {
671 tokio::time::sleep(memory_start_delay).await;
672 for i in 0..num_memory_shards {
673 let range = ShardRange {
674 timestamp_range: (num_core_shards + 1, num_core_shards + 1),
675 initialized_address_range: (i, i + 1),
676 finalized_address_range: (i, i + 1),
677 initialized_page_index_range: (0, 0),
678 finalized_page_index_range: (0, 0),
679 deferred_proof_range: (num_deferred_shards, num_deferred_shards),
680 };
681 create_dummy_prove_shard_task(
682 range,
683 elf_artifact.clone(),
684 common_input_artifact.clone(),
685 context.clone(),
686 &core_proofs_tx,
687 &worker_client,
688 &artifact_client,
689 )
690 .await;
691 }
692 }
693 });
694
695 tokio::task::spawn({
696 let worker_client = worker_client.clone();
697 let artifact_client = artifact_client.clone();
698 let elf_artifact = elf_artifact.clone();
699 let common_input_artifact = common_input_artifact.clone();
700 let context = context.clone();
701 let core_proofs_tx = core_proofs_tx.clone();
702 async move {
703 tokio::time::sleep(precompile_start_delay).await;
704 for _ in 1..=num_precompile_shards {
705 let range = ShardRange::precompile();
706 create_dummy_prove_shard_task(
707 range,
708 elf_artifact.clone(),
709 common_input_artifact.clone(),
710 context.clone(),
711 &core_proofs_tx,
712 &worker_client,
713 &artifact_client,
714 )
715 .await;
716 }
717 }
718 });
719
720 tokio::task::spawn({
721 let worker_client = worker_client.clone();
722 let artifact_client = artifact_client.clone();
723 let elf_artifact = elf_artifact.clone();
724 let common_input_artifact = common_input_artifact.clone();
725 let context = context.clone();
726 async move {
727 tokio::time::sleep(deferred_start_delay).await;
728 for i in 0..num_deferred_shards {
729 let range = ShardRange::deferred(i, i + 1);
730 create_dummy_prove_shard_task(
731 range,
732 elf_artifact.clone(),
733 common_input_artifact.clone(),
734 context.clone(),
735 &core_proofs_tx,
736 &worker_client,
737 &artifact_client,
738 )
739 .await;
740 }
741 }
742 });
743
744 let output = artifact_client.create_artifact().unwrap();
745
746 let worker_client = worker_client.clone();
747
748 compress_tree
749 .reduce_proofs(context, output, core_proofs_rx, &artifact_client, &worker_client)
750 .await
751 .unwrap();
752 }
753 }
754}