1use core::fmt;
2use std::{
3 future::Future,
4 pin::Pin,
5 sync::{Arc, Mutex},
6 task::Poll,
7};
8
9use futures::{prelude::*, stream::FuturesOrdered};
10use hashbrown::{HashMap, HashSet};
11use mti::prelude::{MagicTypeIdExt, V7};
12use opentelemetry::Context;
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14use sp1_prover_types::{
15 Artifact, ArtifactClient, ArtifactType, ProofRequestStatus, TaskStatus, TaskType,
16};
17use thiserror::Error;
18use tokio::{
19 sync::{mpsc, watch, RwLock},
20 task::AbortHandle,
21};
22
23mod local;
24
25pub use local::*;
26
27use crate::worker::{ProveShardTaskRequest, TaskError};
28
29pub trait WorkerClient: Send + Sync + Clone + 'static {
30 fn submit_task(
31 &self,
32 kind: TaskType,
33 task: RawTaskRequest,
34 ) -> impl Future<Output = anyhow::Result<TaskId>> + Send;
35
36 fn complete_task(
37 &self,
38 proof_id: ProofId,
39 task_id: TaskId,
40 metadata: TaskMetadata,
41 ) -> impl Future<Output = anyhow::Result<()>> + Send;
42
43 fn complete_proof(
44 &self,
45 proof_id: ProofId,
46 task_id: Option<TaskId>,
47 status: ProofRequestStatus,
48 extra_data: impl Into<String> + Send,
49 ) -> impl Future<Output = anyhow::Result<()>> + Send;
50
51 fn subscriber(
52 &self,
53 proof_id: ProofId,
54 ) -> impl Future<Output = anyhow::Result<SubscriberBuilder<Self>>> + Send;
55
56 fn subscribe_task_messages(
59 &self,
60 task_id: &TaskId,
61 ) -> impl Future<Output = anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>>> + Send;
62
63 fn send_task_message(
66 &self,
67 task_id: &TaskId,
68 payload: Vec<u8>,
69 ) -> impl Future<Output = anyhow::Result<()>> + Send;
70
71 fn submit_tasks(
72 &self,
73 kind: TaskType,
74 tasks: impl IntoIterator<Item = RawTaskRequest> + Send,
75 ) -> impl Future<Output = anyhow::Result<Vec<TaskId>>> + Send {
76 tasks
77 .into_iter()
78 .map(move |task| self.submit_task(kind, task))
79 .collect::<FuturesOrdered<_>>()
80 .try_collect()
81 }
82
83 fn submit_all(
84 &self,
85 kind: TaskType,
86 tasks: impl Stream<Item = RawTaskRequest> + Send,
87 ) -> impl Future<Output = anyhow::Result<Vec<TaskId>>> + Send {
88 tasks.then(move |task| self.submit_task(kind, task)).try_collect()
89 }
90}
91
92pub struct MessageReceiver<T: DeserializeOwned> {
94 rx: mpsc::UnboundedReceiver<Vec<u8>>,
95 _marker: std::marker::PhantomData<T>,
96}
97
98impl<T: DeserializeOwned> MessageReceiver<T> {
99 pub fn new(rx: mpsc::UnboundedReceiver<Vec<u8>>) -> Self {
100 Self { rx, _marker: std::marker::PhantomData }
101 }
102
103 pub async fn recv(&mut self) -> Option<T> {
105 let bytes = self.rx.recv().await?;
106 Some(bincode::deserialize(&bytes).expect("failed to deserialize message channel payload"))
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
111pub struct ProofId(String);
112
113impl ProofId {
114 #[inline]
115 pub fn new(id: impl Into<String>) -> Self {
116 Self(id.into())
117 }
118}
119
120impl fmt::Display for ProofId {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 write!(f, "{}", self.0) }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
128pub struct TaskId(String);
129
130impl TaskId {
131 #[inline]
132 pub fn new(id: impl Into<String>) -> Self {
133 Self(id.into())
134 }
135}
136
137impl fmt::Display for TaskId {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 write!(f, "{}", self.0) }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
145pub struct RequesterId(String);
146
147impl RequesterId {
148 #[inline]
149 pub fn new(id: impl Into<String>) -> Self {
150 Self(id.into())
151 }
152}
153
154impl fmt::Display for RequesterId {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 write!(f, "{}", self.0)
157 }
158}
159
160#[derive(Clone)]
161pub struct RawTaskRequest {
162 pub inputs: Vec<Artifact>,
163 pub outputs: Vec<Artifact>,
164 pub context: TaskContext,
165}
166
167#[derive(Clone)]
168pub struct TaskContext {
169 pub proof_id: ProofId,
170 pub parent_id: Option<TaskId>,
171 pub parent_context: Option<Context>,
172 pub requester_id: RequesterId,
173}
174
175#[derive(Debug, Serialize, Deserialize, Default)]
176pub struct TaskMetadata {
177 pub gpu_ms: Option<u64>,
178}
179
180pub struct SubscriberBuilder<W> {
181 client: W,
182 subscriber_tx: mpsc::UnboundedSender<TaskId>,
183 subscriber_rx: mpsc::UnboundedReceiver<(TaskId, TaskStatus)>,
184}
185
186impl<W> SubscriberBuilder<W> {
187 pub fn new(
188 client: W,
189 subscriber_tx: mpsc::UnboundedSender<TaskId>,
190 subscriber_rx: mpsc::UnboundedReceiver<(TaskId, TaskStatus)>,
191 ) -> Self {
192 Self { client, subscriber_tx, subscriber_rx }
193 }
194
195 pub fn per_task(self) -> TaskSubscriber<W> {
196 TaskSubscriber::new(self)
197 }
198
199 pub fn stream(self) -> (StreamSubscriber<W>, EventStream) {
200 StreamSubscriber::new(self)
201 }
202}
203
204type TaskSubscriberDb =
205 Arc<RwLock<HashMap<TaskId, (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>)>>>;
206
207#[derive(Clone)]
209#[allow(clippy::type_complexity)]
210pub struct TaskSubscriber<W> {
211 client: W,
212 request_map: TaskSubscriberDb,
213 subscriber_tx: mpsc::UnboundedSender<TaskId>,
214 abort_handle: AbortHandle,
215}
216
217impl<W> TaskSubscriber<W> {
218 #[inline]
220 pub const fn client(&self) -> &W {
221 &self.client
222 }
223
224 pub fn new(builder: SubscriberBuilder<W>) -> Self {
226 let SubscriberBuilder { client, subscriber_tx, mut subscriber_rx, .. } = builder;
227 let request_map = Arc::new(RwLock::new(HashMap::<
229 TaskId,
230 (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>),
231 >::new()));
232 let handle = tokio::task::spawn({
234 let request_map = request_map.clone();
235 async move {
236 while let Some((task_id, status)) = subscriber_rx.recv().await {
237 let (sender, _) = request_map
239 .read()
240 .await
241 .get(&task_id)
242 .cloned()
243 .expect("task should be in request map");
244 sender.send(status).ok();
246 }
247 }
248 });
249 let abort_handle = handle.abort_handle();
250
251 Self { client, request_map, subscriber_tx, abort_handle }
252 }
253
254 pub fn close(&self) {
258 self.abort_handle.abort();
259 }
260
261 pub async fn wait_task(&self, task_id: TaskId) -> Result<TaskStatus, TaskError> {
265 self.request_map
266 .write()
267 .await
268 .entry(task_id.clone())
269 .or_insert_with(|| watch::channel(TaskStatus::UnspecifiedStatus));
270
271 let (_, mut watch) = self
272 .request_map
273 .read()
274 .await
275 .get(&task_id)
276 .cloned()
277 .ok_or(TaskError::Fatal(anyhow::anyhow!("task does not exist")))?;
278
279 self.subscriber_tx.send(task_id.clone()).map_err(|e| {
281 TaskError::Fatal(anyhow::anyhow!("failed to send task id to inner subscriber: {e}"))
282 })?;
283
284 watch.mark_changed();
285 while let Ok(()) = watch.changed().await {
286 let v = *watch.borrow();
287 if matches!(
288 v,
289 TaskStatus::FailedFatal | TaskStatus::FailedRetryable | TaskStatus::Succeeded
290 ) {
291 return Ok(v);
292 }
293 }
294 Err(TaskError::Fatal(anyhow::anyhow!("task status lost for task {task_id}")))
295 }
296}
297
298#[derive(Debug, Error)]
299#[error("failed to subscribe to task {0}")]
300pub struct SubscribeError(#[from] mpsc::error::SendError<TaskId>);
301
302#[derive(Clone)]
304pub struct StreamSubscriber<W> {
305 client: W,
306 subscriber_tx: mpsc::UnboundedSender<TaskId>,
307}
308
309impl<W> StreamSubscriber<W> {
310 #[inline]
312 pub const fn client(&self) -> &W {
313 &self.client
314 }
315
316 fn new(builder: SubscriberBuilder<W>) -> (Self, EventStream) {
318 let SubscriberBuilder { client, subscriber_tx, subscriber_rx, .. } = builder;
319 (Self { client, subscriber_tx }, EventStream { subscriber_rx })
320 }
321
322 pub fn subscribe(&self, task_id: TaskId) -> Result<(), SubscribeError> {
323 self.subscriber_tx.send(task_id)?;
324 Ok(())
325 }
326}
327
328pub struct EventStream {
329 subscriber_rx: mpsc::UnboundedReceiver<(TaskId, TaskStatus)>,
330}
331
332impl EventStream {
333 pub async fn recv(&mut self) -> Option<(TaskId, TaskStatus)> {
334 self.subscriber_rx.recv().await
335 }
336
337 pub fn blocking_recv(&mut self) -> Option<(TaskId, TaskStatus)> {
338 self.subscriber_rx.blocking_recv()
339 }
340
341 pub fn close(&mut self) {
342 self.subscriber_rx.close();
343 }
344}
345
346impl Stream for EventStream {
347 type Item = (TaskId, TaskStatus);
348
349 fn poll_next(
350 mut self: Pin<&mut Self>,
351 cx: &mut std::task::Context<'_>,
352 ) -> Poll<Option<Self::Item>> {
353 self.subscriber_rx.poll_recv(cx)
354 }
355}
356
357struct TrivialMessageChannel {
358 tx: mpsc::UnboundedSender<Vec<u8>>,
359 rx: Option<mpsc::UnboundedReceiver<Vec<u8>>>,
360}
361
362#[derive(Clone)]
364pub struct TrivialWorkerClient {
365 inner: Arc<Mutex<HashSet<TaskId>>>,
366 task_sender: mpsc::Sender<(TaskType, RawTaskRequest)>,
367 task_channels: Arc<Mutex<HashMap<TaskId, TrivialMessageChannel>>>,
368}
369
370impl TrivialWorkerClient {
371 pub fn new<A: ArtifactClient>(task_capacity: usize, artifact_client: A) -> Self {
372 let (task_sender, mut task_receiver) =
373 mpsc::channel::<(TaskType, RawTaskRequest)>(task_capacity);
374
375 tokio::task::spawn(async move {
376 while let Some((kind, task)) = task_receiver.recv().await {
377 match kind {
378 TaskType::ProveShard => {
379 let request = ProveShardTaskRequest::from_raw(task).unwrap();
380 artifact_client
382 .delete(&request.record, ArtifactType::UnspecifiedArtifactType)
383 .await
384 .unwrap();
385 }
386 TaskType::MarkerDeferredRecord => {}
387 _ => unimplemented!("task type not supported"),
388 }
389 }
390 });
391
392 Self {
393 inner: Arc::new(Mutex::new(HashSet::new())),
394 task_sender,
395 task_channels: Arc::new(Mutex::new(HashMap::new())),
396 }
397 }
398}
399
400impl WorkerClient for TrivialWorkerClient {
401 async fn submit_task(&self, kind: TaskType, task: RawTaskRequest) -> anyhow::Result<TaskId> {
402 let task_id = TaskId::new("task".create_type_id::<V7>().to_string());
403 self.inner.lock().unwrap().insert(task_id.clone());
404 self.task_sender.send((kind, task)).await.unwrap();
405 Ok(task_id)
406 }
407
408 async fn complete_task(
409 &self,
410 _proof_id: ProofId,
411 _task_id: TaskId,
412 _metadata: TaskMetadata,
413 ) -> anyhow::Result<()> {
414 Ok(())
415 }
416
417 async fn complete_proof(
418 &self,
419 _proof_id: ProofId,
420 _task_id: Option<TaskId>,
421 _status: ProofRequestStatus,
422 _extra_data: impl Into<String> + Send,
423 ) -> anyhow::Result<()> {
424 Ok(())
425 }
426
427 async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
428 let (sub_input_tx, mut sub_input_rx) = mpsc::unbounded_channel();
429 let (sub_output_tx, sub_output_rx) = mpsc::unbounded_channel();
430
431 let task_map = self.inner.clone();
432 tokio::task::spawn(async move {
433 while let Some(task_id) = sub_input_rx.recv().await {
434 if task_map.lock().unwrap().contains(&task_id) {
437 sub_output_tx.send((task_id, TaskStatus::Succeeded)).unwrap();
438 } else {
439 sub_output_tx.send((task_id, TaskStatus::Pending)).unwrap();
440 }
441 }
442 });
443
444 Ok(SubscriberBuilder::new(self.clone(), sub_input_tx, sub_output_rx))
445 }
446
447 async fn subscribe_task_messages(
448 &self,
449 task_id: &TaskId,
450 ) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
451 let mut channels = self.task_channels.lock().unwrap();
452 if let Some(state) = channels.get_mut(task_id) {
453 let rx = state
454 .rx
455 .take()
456 .ok_or_else(|| anyhow::anyhow!("task channel already subscribed for {task_id}"))?;
457 return Ok(rx);
458 }
459 let (tx, rx) = mpsc::unbounded_channel();
460 channels.insert(task_id.clone(), TrivialMessageChannel { tx, rx: None });
461 Ok(rx)
462 }
463
464 async fn send_task_message(&self, task_id: &TaskId, payload: Vec<u8>) -> anyhow::Result<()> {
465 let mut channels = self.task_channels.lock().unwrap();
466 if let Some(state) = channels.get_mut(task_id) {
467 state.tx.send(payload).map_err(|_| anyhow::anyhow!("task channel receiver dropped"))?;
468 } else {
469 let (tx, rx) = mpsc::unbounded_channel();
470 tx.send(payload).expect("just-created channel cannot be closed");
471 channels.insert(task_id.clone(), TrivialMessageChannel { tx, rx: Some(rx) });
472 }
473 Ok(())
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use std::time::Duration;
480
481 use mti::prelude::{MagicTypeIdExt, V7};
482 use sp1_prover_types::{ArtifactClient, InMemoryArtifactClient};
483
484 use super::*;
485
486 #[derive(Clone)]
492 #[allow(clippy::type_complexity)]
493 pub struct TestWorkerClient {
494 input_tx: mpsc::UnboundedSender<(TaskId, RawTaskRequest)>,
495 db: TaskSubscriberDb,
496 }
497
498 #[derive(Serialize, Deserialize, Clone, Copy)]
499 pub enum TestTaskKind {
500 Increment,
501 Read,
502 }
503
504 #[derive(Serialize, Deserialize)]
505 pub struct TestTask {
506 pub kind: TestTaskKind,
507 }
508
509 impl TestTask {
510 pub async fn into_raw(self, client: &impl ArtifactClient) -> RawTaskRequest {
511 let input_artifact = client.create_artifact().expect("failed to create input artifact");
512 client.upload(&input_artifact, self.kind).await.unwrap();
513 let outputs = if let TestTaskKind::Read = self.kind {
514 let artifact = client.create_artifact().expect("failed to create output artifact");
515 vec![artifact]
516 } else {
517 vec![]
518 };
519 RawTaskRequest {
520 inputs: vec![input_artifact],
521 outputs,
522 context: TaskContext {
523 proof_id: ProofId::new("test_proof_id"),
524 parent_id: None,
525 parent_context: None,
526 requester_id: RequesterId::new("test_requester_id"),
527 },
528 }
529 }
530
531 async fn from_raw(
532 raw: RawTaskRequest,
533 client: &impl ArtifactClient,
534 ) -> (Self, Option<Artifact>) {
535 let kind = client.download::<TestTaskKind>(&raw.inputs[0]).await.unwrap();
536 (Self { kind }, raw.outputs.into_iter().next())
537 }
538 }
539
540 impl TestWorkerClient {
541 fn new(artifact_client: impl ArtifactClient) -> Self {
542 let (tx, mut rx) = mpsc::unbounded_channel();
543 let db = Arc::new(RwLock::new(HashMap::<
544 TaskId,
545 (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>),
546 >::new()));
547
548 tokio::task::spawn({
549 let db = db.clone();
550 async move {
551 let mut counter: usize = 0;
552 while let Some((id, task)) = rx.recv().await {
553 let (task, output) = TestTask::from_raw(task, &artifact_client).await;
554 match task.kind {
555 TestTaskKind::Increment => {
556 counter += 1;
557 let (tx, _) =
558 db.read().await.get(&id).cloned().expect("task does not exist");
559 tx.send(TaskStatus::Succeeded).unwrap();
560 }
561 TestTaskKind::Read => {
562 let out_artifact = output.unwrap();
563 artifact_client.upload(&out_artifact, counter).await.unwrap();
564 let (tx, _) =
565 db.read().await.get(&id).cloned().expect("task does not exist");
566 tx.send(TaskStatus::Succeeded).unwrap();
567 }
568 }
569 }
570 }
571 });
572
573 Self { input_tx: tx, db }
574 }
575 }
576
577 impl WorkerClient for TestWorkerClient {
578 async fn submit_task(
579 &self,
580 _kind: TaskType,
581 task: RawTaskRequest,
582 ) -> anyhow::Result<TaskId> {
583 let task_id = TaskId::new("task".create_type_id::<V7>().to_string());
584 let (tx, rx) = watch::channel(TaskStatus::Pending);
586 self.db.write().await.insert(task_id.clone(), (tx, rx));
587 self.input_tx.send((task_id.clone(), task)).unwrap();
588 Ok(task_id)
589 }
590
591 async fn complete_task(
592 &self,
593 _proof_id: ProofId,
594 _task_id: TaskId,
595 _metadata: TaskMetadata,
596 ) -> anyhow::Result<()> {
597 unimplemented!()
598 }
599
600 async fn complete_proof(
601 &self,
602 _proof_id: ProofId,
603 _task_id: Option<TaskId>,
604 _status: ProofRequestStatus,
605 _extra_data: impl Into<String> + Send,
606 ) -> anyhow::Result<()> {
607 unimplemented!()
608 }
609
610 async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
611 let (subscriber_input_tx, mut subscriber_input_rx) = mpsc::unbounded_channel();
612 let (subscriber_output_tx, subscriber_output_rx) = mpsc::unbounded_channel();
613
614 tokio::task::spawn({
615 let db = self.db.clone();
616 let output_tx = subscriber_output_tx.clone();
617 async move {
618 while let Some(id) = subscriber_input_rx.recv().await {
619 let db = db.clone();
621 let output_tx = output_tx.clone();
622 tokio::task::spawn(async move {
623 let (_, mut rx) =
624 db.read().await.get(&id).cloned().expect("task does not exist");
625 rx.mark_changed();
626 while let Ok(()) = rx.changed().await {
627 let value = *rx.borrow();
628 if matches!(
629 value,
630 TaskStatus::FailedFatal
631 | TaskStatus::FailedRetryable
632 | TaskStatus::Succeeded
633 ) {
634 output_tx.send((id, value)).ok();
635 return;
636 }
637 }
638 });
639 }
640 }
641 });
642 Ok(SubscriberBuilder::new(self.clone(), subscriber_input_tx, subscriber_output_rx))
643 }
644
645 async fn subscribe_task_messages(
646 &self,
647 _task_id: &TaskId,
648 ) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
649 let (_tx, rx) = mpsc::unbounded_channel();
650 Ok(rx)
651 }
652
653 async fn send_task_message(
654 &self,
655 _task_id: &TaskId,
656 _payload: Vec<u8>,
657 ) -> anyhow::Result<()> {
658 Ok(())
659 }
660 }
661
662 #[tokio::test]
663 #[allow(clippy::print_stdout)]
664 async fn test_worker_client() {
665 let artifact_client = InMemoryArtifactClient::default();
666 let worker_client = TestWorkerClient::new(artifact_client.clone());
667 let increment_task = TestTask { kind: TestTaskKind::Increment };
668 let increment_task = increment_task.into_raw(&artifact_client).await;
669 let read_task = TestTask { kind: TestTaskKind::Read };
670 let read_task = read_task.into_raw(&artifact_client).await;
671
672 let subscriber =
674 worker_client.subscriber(ProofId::new("dummy proof id")).await.unwrap().per_task();
675
676 let mut increment_tasks = vec![];
678 for i in 0..10 {
679 let subscriber = subscriber.clone();
680 let increment_task = increment_task.clone();
681 let handle = tokio::task::spawn(async move {
682 tokio::time::sleep(Duration::from_millis(100 * i)).await;
683 subscriber
684 .client()
685 .submit_task(TaskType::UnspecifiedTaskType, increment_task.clone())
686 .await
687 .unwrap()
688 });
689 increment_tasks.push(handle);
690 }
691 tokio::time::sleep(Duration::from_millis(300)).await;
692 let read_task_id = subscriber
693 .client()
694 .submit_task(TaskType::UnspecifiedTaskType, read_task.clone())
695 .await
696 .unwrap();
697
698 let status = subscriber.wait_task(read_task_id).await.unwrap();
702 assert_eq!(status, TaskStatus::Succeeded);
704 let mut increment_task_ids = vec![];
706 for handle in increment_tasks {
707 let task_id = handle.await.unwrap();
708 increment_task_ids.push(task_id);
709 }
710 for task_id in increment_task_ids {
711 let status = subscriber.wait_task(task_id).await.unwrap();
712 assert_eq!(status, TaskStatus::Succeeded);
713 }
714 let (_, output) = TestTask::from_raw(read_task, &artifact_client).await;
716 let output = output.unwrap();
717 let value: usize = artifact_client.download(&output).await.unwrap();
718 println!("value: {}", value);
719 assert!(value <= 10);
720 }
721}