Skip to main content

sp1_prover/worker/client/
local.rs

1use std::{collections::BTreeMap, sync::Arc};
2
3use hashbrown::{HashMap, HashSet};
4use mti::prelude::{MagicTypeIdExt, V7};
5use sp1_prover_types::{ProofRequestStatus, TaskStatus, TaskType};
6use tokio::sync::{mpsc, watch, RwLock};
7
8use crate::worker::{
9    ProofId, RawTaskRequest, SubscriberBuilder, TaskId, TaskMetadata, WorkerClient,
10};
11
12struct MessageChannelState {
13    tx: mpsc::UnboundedSender<Vec<u8>>,
14    rx: Option<mpsc::UnboundedReceiver<Vec<u8>>>,
15}
16
17type LocalDb =
18    Arc<RwLock<HashMap<TaskId, (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>)>>>;
19
20type ProofIndex = Arc<RwLock<HashMap<ProofId, HashSet<TaskId>>>>;
21
22pub struct LocalWorkerClientChannels {
23    pub task_receivers: BTreeMap<TaskType, mpsc::Receiver<(TaskId, RawTaskRequest)>>,
24}
25
26pub struct LocalWorkerClientInner {
27    db: LocalDb,
28    proof_index: ProofIndex,
29    input_task_queues: HashMap<TaskType, mpsc::Sender<(TaskId, RawTaskRequest)>>,
30    task_channels: RwLock<HashMap<TaskId, MessageChannelState>>,
31}
32
33impl LocalWorkerClientInner {
34    fn create_id() -> TaskId {
35        TaskId::new("local_worker".create_type_id::<V7>().to_string())
36    }
37
38    fn init() -> (Self, LocalWorkerClientChannels) {
39        let mut task_outputs = BTreeMap::new();
40        let mut task_queues = HashMap::new();
41        for task_type in [
42            TaskType::UnspecifiedTaskType,
43            TaskType::Controller,
44            TaskType::ProveShard,
45            TaskType::RecursionReduce,
46            TaskType::RecursionDeferred,
47            TaskType::ShrinkWrap,
48            TaskType::SetupVkey,
49            TaskType::MarkerDeferredRecord,
50            TaskType::PlonkWrap,
51            TaskType::Groth16Wrap,
52            TaskType::ExecuteOnly,
53            TaskType::UtilVkeyMapChunk,
54            TaskType::UtilVkeyMapController,
55            TaskType::CoreExecute,
56        ] {
57            let (tx, rx) = mpsc::channel(1);
58            task_outputs.insert(task_type, rx);
59            task_queues.insert(task_type, tx);
60        }
61
62        let db = Arc::new(RwLock::new(HashMap::new()));
63        let proof_index = Arc::new(RwLock::new(HashMap::new()));
64        let task_channels = RwLock::new(HashMap::new());
65        let inner = Self { db, proof_index, input_task_queues: task_queues, task_channels };
66        (inner, LocalWorkerClientChannels { task_receivers: task_outputs })
67    }
68}
69
70pub struct LocalWorkerClient {
71    inner: Arc<LocalWorkerClientInner>,
72}
73
74impl LocalWorkerClient {
75    /// Creates a new local worker client.
76    #[must_use]
77    pub fn init() -> (Self, LocalWorkerClientChannels) {
78        let (inner, channels) = LocalWorkerClientInner::init();
79        (Self { inner: Arc::new(inner) }, channels)
80    }
81
82    pub async fn update_task_status(
83        &self,
84        task_id: TaskId,
85        status: TaskStatus,
86    ) -> anyhow::Result<()> {
87        // Get the sender for this task
88        let (status_tx, _) = self
89            .inner
90            .db
91            .read()
92            .await
93            .get(&task_id)
94            .cloned()
95            .ok_or_else(|| anyhow::anyhow!("task does not exist"))?;
96
97        status_tx.send(status).map_err(|_| anyhow::anyhow!("failed to send status to task"))?;
98
99        if matches!(
100            status,
101            TaskStatus::Succeeded | TaskStatus::FailedFatal | TaskStatus::FailedRetryable
102        ) {
103            self.inner.task_channels.write().await.remove(&task_id);
104        }
105
106        Ok(())
107    }
108}
109
110impl Clone for LocalWorkerClient {
111    fn clone(&self) -> Self {
112        Self { inner: self.inner.clone() }
113    }
114}
115
116impl WorkerClient for LocalWorkerClient {
117    async fn submit_task(&self, kind: TaskType, task: RawTaskRequest) -> anyhow::Result<TaskId> {
118        tracing::debug!("submitting task of kind {kind:?}");
119        let task_id = LocalWorkerClientInner::create_id();
120        // Add the task to the proof index.
121        self.inner
122            .proof_index
123            .write()
124            .await
125            .entry(task.context.proof_id.clone())
126            .or_insert_with(HashSet::new)
127            .insert(task_id.clone());
128        // Create a db entry for the task.
129        let (tx, rx) = watch::channel(TaskStatus::Pending);
130        self.inner.db.write().await.insert(task_id.clone(), (tx, rx));
131        // Send the task to the input queue.
132        self.inner.input_task_queues[&kind]
133            .send((task_id.clone(), task))
134            .await
135            .map_err(|e| anyhow::anyhow!("failed to send task of kind {:?} to queue: {e}", kind))?;
136        Ok(task_id)
137    }
138
139    async fn complete_task(
140        &self,
141        _proof_id: ProofId,
142        task_id: TaskId,
143        _metadata: TaskMetadata,
144    ) -> anyhow::Result<()> {
145        self.update_task_status(task_id, TaskStatus::Succeeded).await
146    }
147
148    async fn complete_proof(
149        &self,
150        proof_id: ProofId,
151        _task_id: Option<TaskId>,
152        _status: ProofRequestStatus,
153        _extra_data: impl Into<String> + Send,
154    ) -> anyhow::Result<()> {
155        // Remove the proof from the proof index.
156        let tasks = self
157            .inner
158            .proof_index
159            .write()
160            .await
161            .remove(&proof_id)
162            .ok_or_else(|| anyhow::anyhow!("proof does not exist for id {proof_id}"))?;
163        // Prune the db for all tasks that are related to this proof and clean them up.
164        for task_id in tasks {
165            self.inner.db.write().await.remove(&task_id);
166        }
167        Ok(())
168    }
169
170    async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
171        let (subscriber_input_tx, mut subscriber_input_rx) = mpsc::unbounded_channel();
172        let (subscriber_output_tx, subscriber_output_rx) = mpsc::unbounded_channel();
173
174        tokio::task::spawn({
175            let db = self.inner.db.clone();
176            let output_tx = subscriber_output_tx.clone();
177            async move {
178                while let Some(id) = subscriber_input_rx.recv().await {
179                    // Spawn a task to send the status to the output channel.
180                    let db = db.clone();
181                    let output_tx = output_tx.clone();
182                    tokio::task::spawn(async move {
183                        let (_, mut rx) =
184                            db.read().await.get(&id).cloned().expect("task does not exist");
185                        rx.mark_changed();
186                        while let Ok(()) = rx.changed().await {
187                            let value = *rx.borrow();
188                            if matches!(
189                                value,
190                                TaskStatus::FailedFatal
191                                    | TaskStatus::FailedRetryable
192                                    | TaskStatus::Succeeded
193                            ) {
194                                output_tx.send((id, value)).ok();
195                                return;
196                            }
197                        }
198                    });
199                }
200            }
201        });
202        Ok(SubscriberBuilder::new(self.clone(), subscriber_input_tx, subscriber_output_rx))
203    }
204
205    async fn subscribe_task_messages(
206        &self,
207        task_id: &TaskId,
208    ) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
209        let mut channels = self.inner.task_channels.write().await;
210        if let Some(state) = channels.get_mut(task_id) {
211            let rx = state
212                .rx
213                .take()
214                .ok_or_else(|| anyhow::anyhow!("task channel already subscribed for {task_id}"))?;
215            return Ok(rx);
216        }
217        let (tx, rx) = mpsc::unbounded_channel();
218        channels.insert(task_id.clone(), MessageChannelState { tx, rx: None });
219        Ok(rx)
220    }
221
222    async fn send_task_message(&self, task_id: &TaskId, payload: Vec<u8>) -> anyhow::Result<()> {
223        let mut channels = self.inner.task_channels.write().await;
224        if let Some(state) = channels.get_mut(task_id) {
225            state.tx.send(payload).map_err(|_| anyhow::anyhow!("task channel receiver dropped"))?;
226        } else {
227            let (tx, rx) = mpsc::unbounded_channel();
228            tx.send(payload).expect("just-created channel cannot be closed");
229            channels.insert(task_id.clone(), MessageChannelState { tx, rx: Some(rx) });
230        }
231        Ok(())
232    }
233}
234
235#[cfg(test)]
236pub mod test_utils {
237    use std::{ops::Range, time::Duration};
238
239    use rand::Rng;
240
241    use super::*;
242
243    pub fn mock_worker_client(
244        mut random_interval: HashMap<TaskType, Range<Duration>>,
245    ) -> LocalWorkerClient {
246        let (worker_client, mut channels) = LocalWorkerClient::init();
247
248        for task_type in [
249            TaskType::Controller,
250            TaskType::SetupVkey,
251            TaskType::ProveShard,
252            TaskType::MarkerDeferredRecord,
253            TaskType::RecursionReduce,
254            TaskType::RecursionDeferred,
255            TaskType::ShrinkWrap,
256            TaskType::PlonkWrap,
257            TaskType::Groth16Wrap,
258            TaskType::ExecuteOnly,
259            TaskType::CoreExecute,
260        ] {
261            let mut rx = channels.task_receivers.remove(&task_type).unwrap();
262            let interval = random_interval.remove(&task_type).unwrap();
263            let worker_client = worker_client.clone();
264            tokio::task::spawn(async move {
265                while let Some((task_id, request)) = rx.recv().await {
266                    let client = worker_client.clone();
267                    let interval = interval.clone();
268                    tokio::spawn(async move {
269                        let duration = {
270                            let mut rng = rand::thread_rng();
271                            rng.gen_range(interval)
272                        };
273                        tokio::time::sleep(duration).await;
274                        client
275                            .complete_task(
276                                request.context.proof_id,
277                                task_id,
278                                TaskMetadata { gpu_ms: None },
279                            )
280                            .await
281                            .unwrap();
282                    });
283                }
284            });
285        }
286
287        worker_client
288    }
289}