sp1_prover/worker/client/
local.rs1use std::{collections::BTreeMap, sync::Arc};
2
3use hashbrown::{HashMap, HashSet};
4use mti::prelude::{MagicTypeIdExt, V7};
5use sp1_prover_types::{ProofRequestStatus, TaskStatus, TaskType};
6use tokio::sync::{mpsc, watch, RwLock};
7
8use crate::worker::{
9 ProofId, RawTaskRequest, SubscriberBuilder, TaskId, TaskMetadata, WorkerClient,
10};
11
12struct MessageChannelState {
13 tx: mpsc::UnboundedSender<Vec<u8>>,
14 rx: Option<mpsc::UnboundedReceiver<Vec<u8>>>,
15}
16
17type LocalDb =
18 Arc<RwLock<HashMap<TaskId, (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>)>>>;
19
20type ProofIndex = Arc<RwLock<HashMap<ProofId, HashSet<TaskId>>>>;
21
22pub struct LocalWorkerClientChannels {
23 pub task_receivers: BTreeMap<TaskType, mpsc::Receiver<(TaskId, RawTaskRequest)>>,
24}
25
26pub struct LocalWorkerClientInner {
27 db: LocalDb,
28 proof_index: ProofIndex,
29 input_task_queues: HashMap<TaskType, mpsc::Sender<(TaskId, RawTaskRequest)>>,
30 task_channels: RwLock<HashMap<TaskId, MessageChannelState>>,
31}
32
33impl LocalWorkerClientInner {
34 fn create_id() -> TaskId {
35 TaskId::new("local_worker".create_type_id::<V7>().to_string())
36 }
37
38 fn init() -> (Self, LocalWorkerClientChannels) {
39 let mut task_outputs = BTreeMap::new();
40 let mut task_queues = HashMap::new();
41 for task_type in [
42 TaskType::UnspecifiedTaskType,
43 TaskType::Controller,
44 TaskType::ProveShard,
45 TaskType::RecursionReduce,
46 TaskType::RecursionDeferred,
47 TaskType::ShrinkWrap,
48 TaskType::SetupVkey,
49 TaskType::MarkerDeferredRecord,
50 TaskType::PlonkWrap,
51 TaskType::Groth16Wrap,
52 TaskType::ExecuteOnly,
53 TaskType::UtilVkeyMapChunk,
54 TaskType::UtilVkeyMapController,
55 TaskType::CoreExecute,
56 ] {
57 let (tx, rx) = mpsc::channel(1);
58 task_outputs.insert(task_type, rx);
59 task_queues.insert(task_type, tx);
60 }
61
62 let db = Arc::new(RwLock::new(HashMap::new()));
63 let proof_index = Arc::new(RwLock::new(HashMap::new()));
64 let task_channels = RwLock::new(HashMap::new());
65 let inner = Self { db, proof_index, input_task_queues: task_queues, task_channels };
66 (inner, LocalWorkerClientChannels { task_receivers: task_outputs })
67 }
68}
69
70pub struct LocalWorkerClient {
71 inner: Arc<LocalWorkerClientInner>,
72}
73
74impl LocalWorkerClient {
75 #[must_use]
77 pub fn init() -> (Self, LocalWorkerClientChannels) {
78 let (inner, channels) = LocalWorkerClientInner::init();
79 (Self { inner: Arc::new(inner) }, channels)
80 }
81
82 pub async fn update_task_status(
83 &self,
84 task_id: TaskId,
85 status: TaskStatus,
86 ) -> anyhow::Result<()> {
87 let (status_tx, _) = self
89 .inner
90 .db
91 .read()
92 .await
93 .get(&task_id)
94 .cloned()
95 .ok_or_else(|| anyhow::anyhow!("task does not exist"))?;
96
97 status_tx.send(status).map_err(|_| anyhow::anyhow!("failed to send status to task"))?;
98
99 if matches!(
100 status,
101 TaskStatus::Succeeded | TaskStatus::FailedFatal | TaskStatus::FailedRetryable
102 ) {
103 self.inner.task_channels.write().await.remove(&task_id);
104 }
105
106 Ok(())
107 }
108}
109
110impl Clone for LocalWorkerClient {
111 fn clone(&self) -> Self {
112 Self { inner: self.inner.clone() }
113 }
114}
115
116impl WorkerClient for LocalWorkerClient {
117 async fn submit_task(&self, kind: TaskType, task: RawTaskRequest) -> anyhow::Result<TaskId> {
118 tracing::debug!("submitting task of kind {kind:?}");
119 let task_id = LocalWorkerClientInner::create_id();
120 self.inner
122 .proof_index
123 .write()
124 .await
125 .entry(task.context.proof_id.clone())
126 .or_insert_with(HashSet::new)
127 .insert(task_id.clone());
128 let (tx, rx) = watch::channel(TaskStatus::Pending);
130 self.inner.db.write().await.insert(task_id.clone(), (tx, rx));
131 self.inner.input_task_queues[&kind]
133 .send((task_id.clone(), task))
134 .await
135 .map_err(|e| anyhow::anyhow!("failed to send task of kind {:?} to queue: {e}", kind))?;
136 Ok(task_id)
137 }
138
139 async fn complete_task(
140 &self,
141 _proof_id: ProofId,
142 task_id: TaskId,
143 _metadata: TaskMetadata,
144 ) -> anyhow::Result<()> {
145 self.update_task_status(task_id, TaskStatus::Succeeded).await
146 }
147
148 async fn complete_proof(
149 &self,
150 proof_id: ProofId,
151 _task_id: Option<TaskId>,
152 _status: ProofRequestStatus,
153 _extra_data: impl Into<String> + Send,
154 ) -> anyhow::Result<()> {
155 let tasks = self
157 .inner
158 .proof_index
159 .write()
160 .await
161 .remove(&proof_id)
162 .ok_or_else(|| anyhow::anyhow!("proof does not exist for id {proof_id}"))?;
163 for task_id in tasks {
165 self.inner.db.write().await.remove(&task_id);
166 }
167 Ok(())
168 }
169
170 async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
171 let (subscriber_input_tx, mut subscriber_input_rx) = mpsc::unbounded_channel();
172 let (subscriber_output_tx, subscriber_output_rx) = mpsc::unbounded_channel();
173
174 tokio::task::spawn({
175 let db = self.inner.db.clone();
176 let output_tx = subscriber_output_tx.clone();
177 async move {
178 while let Some(id) = subscriber_input_rx.recv().await {
179 let db = db.clone();
181 let output_tx = output_tx.clone();
182 tokio::task::spawn(async move {
183 let (_, mut rx) =
184 db.read().await.get(&id).cloned().expect("task does not exist");
185 rx.mark_changed();
186 while let Ok(()) = rx.changed().await {
187 let value = *rx.borrow();
188 if matches!(
189 value,
190 TaskStatus::FailedFatal
191 | TaskStatus::FailedRetryable
192 | TaskStatus::Succeeded
193 ) {
194 output_tx.send((id, value)).ok();
195 return;
196 }
197 }
198 });
199 }
200 }
201 });
202 Ok(SubscriberBuilder::new(self.clone(), subscriber_input_tx, subscriber_output_rx))
203 }
204
205 async fn subscribe_task_messages(
206 &self,
207 task_id: &TaskId,
208 ) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
209 let mut channels = self.inner.task_channels.write().await;
210 if let Some(state) = channels.get_mut(task_id) {
211 let rx = state
212 .rx
213 .take()
214 .ok_or_else(|| anyhow::anyhow!("task channel already subscribed for {task_id}"))?;
215 return Ok(rx);
216 }
217 let (tx, rx) = mpsc::unbounded_channel();
218 channels.insert(task_id.clone(), MessageChannelState { tx, rx: None });
219 Ok(rx)
220 }
221
222 async fn send_task_message(&self, task_id: &TaskId, payload: Vec<u8>) -> anyhow::Result<()> {
223 let mut channels = self.inner.task_channels.write().await;
224 if let Some(state) = channels.get_mut(task_id) {
225 state.tx.send(payload).map_err(|_| anyhow::anyhow!("task channel receiver dropped"))?;
226 } else {
227 let (tx, rx) = mpsc::unbounded_channel();
228 tx.send(payload).expect("just-created channel cannot be closed");
229 channels.insert(task_id.clone(), MessageChannelState { tx, rx: Some(rx) });
230 }
231 Ok(())
232 }
233}
234
235#[cfg(test)]
236pub mod test_utils {
237 use std::{ops::Range, time::Duration};
238
239 use rand::Rng;
240
241 use super::*;
242
243 pub fn mock_worker_client(
244 mut random_interval: HashMap<TaskType, Range<Duration>>,
245 ) -> LocalWorkerClient {
246 let (worker_client, mut channels) = LocalWorkerClient::init();
247
248 for task_type in [
249 TaskType::Controller,
250 TaskType::SetupVkey,
251 TaskType::ProveShard,
252 TaskType::MarkerDeferredRecord,
253 TaskType::RecursionReduce,
254 TaskType::RecursionDeferred,
255 TaskType::ShrinkWrap,
256 TaskType::PlonkWrap,
257 TaskType::Groth16Wrap,
258 TaskType::ExecuteOnly,
259 TaskType::CoreExecute,
260 ] {
261 let mut rx = channels.task_receivers.remove(&task_type).unwrap();
262 let interval = random_interval.remove(&task_type).unwrap();
263 let worker_client = worker_client.clone();
264 tokio::task::spawn(async move {
265 while let Some((task_id, request)) = rx.recv().await {
266 let client = worker_client.clone();
267 let interval = interval.clone();
268 tokio::spawn(async move {
269 let duration = {
270 let mut rng = rand::thread_rng();
271 rng.gen_range(interval)
272 };
273 tokio::time::sleep(duration).await;
274 client
275 .complete_task(
276 request.context.proof_id,
277 task_id,
278 TaskMetadata { gpu_ms: None },
279 )
280 .await
281 .unwrap();
282 });
283 }
284 });
285 }
286
287 worker_client
288 }
289}