Skip to main content

sp1_prover/worker/
client.rs

1use core::fmt;
2use std::{
3    future::Future,
4    pin::Pin,
5    sync::{Arc, Mutex},
6    task::Poll,
7};
8
9use futures::{prelude::*, stream::FuturesOrdered};
10use hashbrown::{HashMap, HashSet};
11use mti::prelude::{MagicTypeIdExt, V7};
12use opentelemetry::Context;
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14use sp1_prover_types::{
15    Artifact, ArtifactClient, ArtifactType, ProofRequestStatus, TaskStatus, TaskType,
16};
17use thiserror::Error;
18use tokio::{
19    sync::{mpsc, watch, RwLock},
20    task::AbortHandle,
21};
22
23mod local;
24
25pub use local::*;
26
27use crate::worker::{ProveShardTaskRequest, TaskError};
28
29pub trait WorkerClient: Send + Sync + Clone + 'static {
30    fn submit_task(
31        &self,
32        kind: TaskType,
33        task: RawTaskRequest,
34    ) -> impl Future<Output = anyhow::Result<TaskId>> + Send;
35
36    fn complete_task(
37        &self,
38        proof_id: ProofId,
39        task_id: TaskId,
40        metadata: TaskMetadata,
41    ) -> impl Future<Output = anyhow::Result<()>> + Send;
42
43    fn complete_proof(
44        &self,
45        proof_id: ProofId,
46        task_id: Option<TaskId>,
47        status: ProofRequestStatus,
48        extra_data: impl Into<String> + Send,
49    ) -> impl Future<Output = anyhow::Result<()>> + Send;
50
51    fn subscriber(
52        &self,
53        proof_id: ProofId,
54    ) -> impl Future<Output = anyhow::Result<SubscriberBuilder<Self>>> + Send;
55
56    /// Subscribe to the message stream for a task. The returned receiver's stream
57    /// ends when the producer task completes or fails.
58    fn subscribe_task_messages(
59        &self,
60        task_id: &TaskId,
61    ) -> impl Future<Output = anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>>> + Send;
62
63    /// Send a payload on the message channel for this task. Lazily creates the channel entry
64    /// if it does not yet exist.
65    fn send_task_message(
66        &self,
67        task_id: &TaskId,
68        payload: Vec<u8>,
69    ) -> impl Future<Output = anyhow::Result<()>> + Send;
70
71    fn submit_tasks(
72        &self,
73        kind: TaskType,
74        tasks: impl IntoIterator<Item = RawTaskRequest> + Send,
75    ) -> impl Future<Output = anyhow::Result<Vec<TaskId>>> + Send {
76        tasks
77            .into_iter()
78            .map(move |task| self.submit_task(kind, task))
79            .collect::<FuturesOrdered<_>>()
80            .try_collect()
81    }
82
83    fn submit_all(
84        &self,
85        kind: TaskType,
86        tasks: impl Stream<Item = RawTaskRequest> + Send,
87    ) -> impl Future<Output = anyhow::Result<Vec<TaskId>>> + Send {
88        tasks.then(move |task| self.submit_task(kind, task)).try_collect()
89    }
90}
91
92/// Wrapper around an mpsc::UnboundedReceiver<Vec<u8>> that deserializes the payload as `T`.
93pub struct MessageReceiver<T: DeserializeOwned> {
94    rx: mpsc::UnboundedReceiver<Vec<u8>>,
95    _marker: std::marker::PhantomData<T>,
96}
97
98impl<T: DeserializeOwned> MessageReceiver<T> {
99    pub fn new(rx: mpsc::UnboundedReceiver<Vec<u8>>) -> Self {
100        Self { rx, _marker: std::marker::PhantomData }
101    }
102
103    /// Receive and deserialize the next message, returning `None` when the channel is closed.
104    pub async fn recv(&mut self) -> Option<T> {
105        let bytes = self.rx.recv().await?;
106        Some(bincode::deserialize(&bytes).expect("failed to deserialize message channel payload"))
107    }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
111pub struct ProofId(String);
112
113impl ProofId {
114    #[inline]
115    pub fn new(id: impl Into<String>) -> Self {
116        Self(id.into())
117    }
118}
119
120impl fmt::Display for ProofId {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(f, "{}", self.0) // TODO: nicely indicate that it is a proof id. Right now, it messes
123                                // with the coordinator communication.
124    }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
128pub struct TaskId(String);
129
130impl TaskId {
131    #[inline]
132    pub fn new(id: impl Into<String>) -> Self {
133        Self(id.into())
134    }
135}
136
137impl fmt::Display for TaskId {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        write!(f, "{}", self.0) // TODO: nicely indicate that it is a task id. Right now, it messes
140                                // with the coordinator communication.
141    }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
145pub struct RequesterId(String);
146
147impl RequesterId {
148    #[inline]
149    pub fn new(id: impl Into<String>) -> Self {
150        Self(id.into())
151    }
152}
153
154impl fmt::Display for RequesterId {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        write!(f, "{}", self.0)
157    }
158}
159
160#[derive(Clone)]
161pub struct RawTaskRequest {
162    pub inputs: Vec<Artifact>,
163    pub outputs: Vec<Artifact>,
164    pub context: TaskContext,
165}
166
167#[derive(Clone)]
168pub struct TaskContext {
169    pub proof_id: ProofId,
170    pub parent_id: Option<TaskId>,
171    pub parent_context: Option<Context>,
172    pub requester_id: RequesterId,
173}
174
175#[derive(Debug, Serialize, Deserialize, Default)]
176pub struct TaskMetadata {
177    pub gpu_ms: Option<u64>,
178}
179
180pub struct SubscriberBuilder<W> {
181    client: W,
182    subscriber_tx: mpsc::UnboundedSender<TaskId>,
183    subscriber_rx: mpsc::UnboundedReceiver<(TaskId, TaskStatus)>,
184}
185
186impl<W> SubscriberBuilder<W> {
187    pub fn new(
188        client: W,
189        subscriber_tx: mpsc::UnboundedSender<TaskId>,
190        subscriber_rx: mpsc::UnboundedReceiver<(TaskId, TaskStatus)>,
191    ) -> Self {
192        Self { client, subscriber_tx, subscriber_rx }
193    }
194
195    pub fn per_task(self) -> TaskSubscriber<W> {
196        TaskSubscriber::new(self)
197    }
198
199    pub fn stream(self) -> (StreamSubscriber<W>, EventStream) {
200        StreamSubscriber::new(self)
201    }
202}
203
204type TaskSubscriberDb =
205    Arc<RwLock<HashMap<TaskId, (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>)>>>;
206
207// TODO: maybe traitify this struct to allow more flexibility in implementations.
208#[derive(Clone)]
209#[allow(clippy::type_complexity)]
210pub struct TaskSubscriber<W> {
211    client: W,
212    request_map: TaskSubscriberDb,
213    subscriber_tx: mpsc::UnboundedSender<TaskId>,
214    abort_handle: AbortHandle,
215}
216
217impl<W> TaskSubscriber<W> {
218    /// Get a reference to the client.
219    #[inline]
220    pub const fn client(&self) -> &W {
221        &self.client
222    }
223
224    /// Create a new task subscriber.
225    pub fn new(builder: SubscriberBuilder<W>) -> Self {
226        let SubscriberBuilder { client, subscriber_tx, mut subscriber_rx, .. } = builder;
227        // Create stores to map all incoming status requests and subscribers.
228        let request_map = Arc::new(RwLock::new(HashMap::<
229            TaskId,
230            (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>),
231        >::new()));
232        // Spawn a blocking task to update the status map when new statuses are received.
233        let handle = tokio::task::spawn({
234            let request_map = request_map.clone();
235            async move {
236                while let Some((task_id, status)) = subscriber_rx.recv().await {
237                    // Send an update to the request map.
238                    let (sender, _) = request_map
239                        .read()
240                        .await
241                        .get(&task_id)
242                        .cloned()
243                        .expect("task should be in request map");
244                    // Send the status to the requester, it's ok if the receiver is dropped.
245                    sender.send(status).ok();
246                }
247            }
248        });
249        let abort_handle = handle.abort_handle();
250
251        Self { client, request_map, subscriber_tx, abort_handle }
252    }
253
254    /// Close the task subscriber.
255    ///
256    /// The subsctiber will no longer receive updates on the status of the tasks.
257    pub fn close(&self) {
258        self.abort_handle.abort();
259    }
260
261    /// Wait for a task to complete.
262    ///
263    /// This function will return a `WaitTask` that can be used to wait for the task to complete.
264    pub async fn wait_task(&self, task_id: TaskId) -> Result<TaskStatus, TaskError> {
265        self.request_map
266            .write()
267            .await
268            .entry(task_id.clone())
269            .or_insert_with(|| watch::channel(TaskStatus::UnspecifiedStatus));
270
271        let (_, mut watch) = self
272            .request_map
273            .read()
274            .await
275            .get(&task_id)
276            .cloned()
277            .ok_or(TaskError::Fatal(anyhow::anyhow!("task does not exist")))?;
278
279        // Send the task id to the inner subscriber.
280        self.subscriber_tx.send(task_id.clone()).map_err(|e| {
281            TaskError::Fatal(anyhow::anyhow!("failed to send task id to inner subscriber: {e}"))
282        })?;
283
284        watch.mark_changed();
285        while let Ok(()) = watch.changed().await {
286            let v = *watch.borrow();
287            if matches!(
288                v,
289                TaskStatus::FailedFatal | TaskStatus::FailedRetryable | TaskStatus::Succeeded
290            ) {
291                return Ok(v);
292            }
293        }
294        Err(TaskError::Fatal(anyhow::anyhow!("task status lost for task {task_id}")))
295    }
296}
297
298#[derive(Debug, Error)]
299#[error("failed to subscribe to task {0}")]
300pub struct SubscribeError(#[from] mpsc::error::SendError<TaskId>);
301
302// TODO: maybe traitify this struct to allow more flexibility in implementations.
303#[derive(Clone)]
304pub struct StreamSubscriber<W> {
305    client: W,
306    subscriber_tx: mpsc::UnboundedSender<TaskId>,
307}
308
309impl<W> StreamSubscriber<W> {
310    /// Get a reference to the client.
311    #[inline]
312    pub const fn client(&self) -> &W {
313        &self.client
314    }
315
316    /// Create a new task subscriber.
317    fn new(builder: SubscriberBuilder<W>) -> (Self, EventStream) {
318        let SubscriberBuilder { client, subscriber_tx, subscriber_rx, .. } = builder;
319        (Self { client, subscriber_tx }, EventStream { subscriber_rx })
320    }
321
322    pub fn subscribe(&self, task_id: TaskId) -> Result<(), SubscribeError> {
323        self.subscriber_tx.send(task_id)?;
324        Ok(())
325    }
326}
327
328pub struct EventStream {
329    subscriber_rx: mpsc::UnboundedReceiver<(TaskId, TaskStatus)>,
330}
331
332impl EventStream {
333    pub async fn recv(&mut self) -> Option<(TaskId, TaskStatus)> {
334        self.subscriber_rx.recv().await
335    }
336
337    pub fn blocking_recv(&mut self) -> Option<(TaskId, TaskStatus)> {
338        self.subscriber_rx.blocking_recv()
339    }
340
341    pub fn close(&mut self) {
342        self.subscriber_rx.close();
343    }
344}
345
346impl Stream for EventStream {
347    type Item = (TaskId, TaskStatus);
348
349    fn poll_next(
350        mut self: Pin<&mut Self>,
351        cx: &mut std::task::Context<'_>,
352    ) -> Poll<Option<Self::Item>> {
353        self.subscriber_rx.poll_recv(cx)
354    }
355}
356
357struct TrivialMessageChannel {
358    tx: mpsc::UnboundedSender<Vec<u8>>,
359    rx: Option<mpsc::UnboundedReceiver<Vec<u8>>>,
360}
361
362/// A trivial client that can be used for testing.
363#[derive(Clone)]
364pub struct TrivialWorkerClient {
365    inner: Arc<Mutex<HashSet<TaskId>>>,
366    task_sender: mpsc::Sender<(TaskType, RawTaskRequest)>,
367    task_channels: Arc<Mutex<HashMap<TaskId, TrivialMessageChannel>>>,
368}
369
370impl TrivialWorkerClient {
371    pub fn new<A: ArtifactClient>(task_capacity: usize, artifact_client: A) -> Self {
372        let (task_sender, mut task_receiver) =
373            mpsc::channel::<(TaskType, RawTaskRequest)>(task_capacity);
374
375        tokio::task::spawn(async move {
376            while let Some((kind, task)) = task_receiver.recv().await {
377                match kind {
378                    TaskType::ProveShard => {
379                        let request = ProveShardTaskRequest::from_raw(task).unwrap();
380                        // remove the record artifact from the client
381                        artifact_client
382                            .delete(&request.record, ArtifactType::UnspecifiedArtifactType)
383                            .await
384                            .unwrap();
385                    }
386                    TaskType::MarkerDeferredRecord => {}
387                    _ => unimplemented!("task type not supported"),
388                }
389            }
390        });
391
392        Self {
393            inner: Arc::new(Mutex::new(HashSet::new())),
394            task_sender,
395            task_channels: Arc::new(Mutex::new(HashMap::new())),
396        }
397    }
398}
399
400impl WorkerClient for TrivialWorkerClient {
401    async fn submit_task(&self, kind: TaskType, task: RawTaskRequest) -> anyhow::Result<TaskId> {
402        let task_id = TaskId::new("task".create_type_id::<V7>().to_string());
403        self.inner.lock().unwrap().insert(task_id.clone());
404        self.task_sender.send((kind, task)).await.unwrap();
405        Ok(task_id)
406    }
407
408    async fn complete_task(
409        &self,
410        _proof_id: ProofId,
411        _task_id: TaskId,
412        _metadata: TaskMetadata,
413    ) -> anyhow::Result<()> {
414        Ok(())
415    }
416
417    async fn complete_proof(
418        &self,
419        _proof_id: ProofId,
420        _task_id: Option<TaskId>,
421        _status: ProofRequestStatus,
422        _extra_data: impl Into<String> + Send,
423    ) -> anyhow::Result<()> {
424        Ok(())
425    }
426
427    async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
428        let (sub_input_tx, mut sub_input_rx) = mpsc::unbounded_channel();
429        let (sub_output_tx, sub_output_rx) = mpsc::unbounded_channel();
430
431        let task_map = self.inner.clone();
432        tokio::task::spawn(async move {
433            while let Some(task_id) = sub_input_rx.recv().await {
434                // Get the input artifacts
435
436                if task_map.lock().unwrap().contains(&task_id) {
437                    sub_output_tx.send((task_id, TaskStatus::Succeeded)).unwrap();
438                } else {
439                    sub_output_tx.send((task_id, TaskStatus::Pending)).unwrap();
440                }
441            }
442        });
443
444        Ok(SubscriberBuilder::new(self.clone(), sub_input_tx, sub_output_rx))
445    }
446
447    async fn subscribe_task_messages(
448        &self,
449        task_id: &TaskId,
450    ) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
451        let mut channels = self.task_channels.lock().unwrap();
452        if let Some(state) = channels.get_mut(task_id) {
453            let rx = state
454                .rx
455                .take()
456                .ok_or_else(|| anyhow::anyhow!("task channel already subscribed for {task_id}"))?;
457            return Ok(rx);
458        }
459        let (tx, rx) = mpsc::unbounded_channel();
460        channels.insert(task_id.clone(), TrivialMessageChannel { tx, rx: None });
461        Ok(rx)
462    }
463
464    async fn send_task_message(&self, task_id: &TaskId, payload: Vec<u8>) -> anyhow::Result<()> {
465        let mut channels = self.task_channels.lock().unwrap();
466        if let Some(state) = channels.get_mut(task_id) {
467            state.tx.send(payload).map_err(|_| anyhow::anyhow!("task channel receiver dropped"))?;
468        } else {
469            let (tx, rx) = mpsc::unbounded_channel();
470            tx.send(payload).expect("just-created channel cannot be closed");
471            channels.insert(task_id.clone(), TrivialMessageChannel { tx, rx: Some(rx) });
472        }
473        Ok(())
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use std::time::Duration;
480
481    use mti::prelude::{MagicTypeIdExt, V7};
482    use sp1_prover_types::{ArtifactClient, InMemoryArtifactClient};
483
484    use super::*;
485
486    /// A simnple test worker consisting a single thread that runs a single counter.
487    ///
488    /// This client support two tasks:
489    ///    - Increment the counter
490    //     - Read the current value
491    #[derive(Clone)]
492    #[allow(clippy::type_complexity)]
493    pub struct TestWorkerClient {
494        input_tx: mpsc::UnboundedSender<(TaskId, RawTaskRequest)>,
495        db: TaskSubscriberDb,
496    }
497
498    #[derive(Serialize, Deserialize, Clone, Copy)]
499    pub enum TestTaskKind {
500        Increment,
501        Read,
502    }
503
504    #[derive(Serialize, Deserialize)]
505    pub struct TestTask {
506        pub kind: TestTaskKind,
507    }
508
509    impl TestTask {
510        pub async fn into_raw(self, client: &impl ArtifactClient) -> RawTaskRequest {
511            let input_artifact = client.create_artifact().expect("failed to create input artifact");
512            client.upload(&input_artifact, self.kind).await.unwrap();
513            let outputs = if let TestTaskKind::Read = self.kind {
514                let artifact = client.create_artifact().expect("failed to create output artifact");
515                vec![artifact]
516            } else {
517                vec![]
518            };
519            RawTaskRequest {
520                inputs: vec![input_artifact],
521                outputs,
522                context: TaskContext {
523                    proof_id: ProofId::new("test_proof_id"),
524                    parent_id: None,
525                    parent_context: None,
526                    requester_id: RequesterId::new("test_requester_id"),
527                },
528            }
529        }
530
531        async fn from_raw(
532            raw: RawTaskRequest,
533            client: &impl ArtifactClient,
534        ) -> (Self, Option<Artifact>) {
535            let kind = client.download::<TestTaskKind>(&raw.inputs[0]).await.unwrap();
536            (Self { kind }, raw.outputs.into_iter().next())
537        }
538    }
539
540    impl TestWorkerClient {
541        fn new(artifact_client: impl ArtifactClient) -> Self {
542            let (tx, mut rx) = mpsc::unbounded_channel();
543            let db = Arc::new(RwLock::new(HashMap::<
544                TaskId,
545                (watch::Sender<TaskStatus>, watch::Receiver<TaskStatus>),
546            >::new()));
547
548            tokio::task::spawn({
549                let db = db.clone();
550                async move {
551                    let mut counter: usize = 0;
552                    while let Some((id, task)) = rx.recv().await {
553                        let (task, output) = TestTask::from_raw(task, &artifact_client).await;
554                        match task.kind {
555                            TestTaskKind::Increment => {
556                                counter += 1;
557                                let (tx, _) =
558                                    db.read().await.get(&id).cloned().expect("task does not exist");
559                                tx.send(TaskStatus::Succeeded).unwrap();
560                            }
561                            TestTaskKind::Read => {
562                                let out_artifact = output.unwrap();
563                                artifact_client.upload(&out_artifact, counter).await.unwrap();
564                                let (tx, _) =
565                                    db.read().await.get(&id).cloned().expect("task does not exist");
566                                tx.send(TaskStatus::Succeeded).unwrap();
567                            }
568                        }
569                    }
570                }
571            });
572
573            Self { input_tx: tx, db }
574        }
575    }
576
577    impl WorkerClient for TestWorkerClient {
578        async fn submit_task(
579            &self,
580            _kind: TaskType,
581            task: RawTaskRequest,
582        ) -> anyhow::Result<TaskId> {
583            let task_id = TaskId::new("task".create_type_id::<V7>().to_string());
584            // Add the task to the db.
585            let (tx, rx) = watch::channel(TaskStatus::Pending);
586            self.db.write().await.insert(task_id.clone(), (tx, rx));
587            self.input_tx.send((task_id.clone(), task)).unwrap();
588            Ok(task_id)
589        }
590
591        async fn complete_task(
592            &self,
593            _proof_id: ProofId,
594            _task_id: TaskId,
595            _metadata: TaskMetadata,
596        ) -> anyhow::Result<()> {
597            unimplemented!()
598        }
599
600        async fn complete_proof(
601            &self,
602            _proof_id: ProofId,
603            _task_id: Option<TaskId>,
604            _status: ProofRequestStatus,
605            _extra_data: impl Into<String> + Send,
606        ) -> anyhow::Result<()> {
607            unimplemented!()
608        }
609
610        async fn subscriber(&self, _proof_id: ProofId) -> anyhow::Result<SubscriberBuilder<Self>> {
611            let (subscriber_input_tx, mut subscriber_input_rx) = mpsc::unbounded_channel();
612            let (subscriber_output_tx, subscriber_output_rx) = mpsc::unbounded_channel();
613
614            tokio::task::spawn({
615                let db = self.db.clone();
616                let output_tx = subscriber_output_tx.clone();
617                async move {
618                    while let Some(id) = subscriber_input_rx.recv().await {
619                        // Spawn a task to send the status to the output channel.
620                        let db = db.clone();
621                        let output_tx = output_tx.clone();
622                        tokio::task::spawn(async move {
623                            let (_, mut rx) =
624                                db.read().await.get(&id).cloned().expect("task does not exist");
625                            rx.mark_changed();
626                            while let Ok(()) = rx.changed().await {
627                                let value = *rx.borrow();
628                                if matches!(
629                                    value,
630                                    TaskStatus::FailedFatal
631                                        | TaskStatus::FailedRetryable
632                                        | TaskStatus::Succeeded
633                                ) {
634                                    output_tx.send((id, value)).ok();
635                                    return;
636                                }
637                            }
638                        });
639                    }
640                }
641            });
642            Ok(SubscriberBuilder::new(self.clone(), subscriber_input_tx, subscriber_output_rx))
643        }
644
645        async fn subscribe_task_messages(
646            &self,
647            _task_id: &TaskId,
648        ) -> anyhow::Result<mpsc::UnboundedReceiver<Vec<u8>>> {
649            let (_tx, rx) = mpsc::unbounded_channel();
650            Ok(rx)
651        }
652
653        async fn send_task_message(
654            &self,
655            _task_id: &TaskId,
656            _payload: Vec<u8>,
657        ) -> anyhow::Result<()> {
658            Ok(())
659        }
660    }
661
662    #[tokio::test]
663    #[allow(clippy::print_stdout)]
664    async fn test_worker_client() {
665        let artifact_client = InMemoryArtifactClient::default();
666        let worker_client = TestWorkerClient::new(artifact_client.clone());
667        let increment_task = TestTask { kind: TestTaskKind::Increment };
668        let increment_task = increment_task.into_raw(&artifact_client).await;
669        let read_task = TestTask { kind: TestTaskKind::Read };
670        let read_task = read_task.into_raw(&artifact_client).await;
671
672        // Create a subscriber to receive the task status.
673        let subscriber =
674            worker_client.subscriber(ProofId::new("dummy proof id")).await.unwrap().per_task();
675
676        // Submit tasks, single threaded.
677        let mut increment_tasks = vec![];
678        for i in 0..10 {
679            let subscriber = subscriber.clone();
680            let increment_task = increment_task.clone();
681            let handle = tokio::task::spawn(async move {
682                tokio::time::sleep(Duration::from_millis(100 * i)).await;
683                subscriber
684                    .client()
685                    .submit_task(TaskType::UnspecifiedTaskType, increment_task.clone())
686                    .await
687                    .unwrap()
688            });
689            increment_tasks.push(handle);
690        }
691        tokio::time::sleep(Duration::from_millis(300)).await;
692        let read_task_id = subscriber
693            .client()
694            .submit_task(TaskType::UnspecifiedTaskType, read_task.clone())
695            .await
696            .unwrap();
697
698        // Read the value once the read task is complete.
699
700        // Get the status of the read task.
701        let status = subscriber.wait_task(read_task_id).await.unwrap();
702        // Assert that the read task is complete.
703        assert_eq!(status, TaskStatus::Succeeded);
704        // Assert that the status of the increment tasks is complete.
705        let mut increment_task_ids = vec![];
706        for handle in increment_tasks {
707            let task_id = handle.await.unwrap();
708            increment_task_ids.push(task_id);
709        }
710        for task_id in increment_task_ids {
711            let status = subscriber.wait_task(task_id).await.unwrap();
712            assert_eq!(status, TaskStatus::Succeeded);
713        }
714        // // Read the value from the artifact client.
715        let (_, output) = TestTask::from_raw(read_task, &artifact_client).await;
716        let output = output.unwrap();
717        let value: usize = artifact_client.download(&output).await.unwrap();
718        println!("value: {}", value);
719        assert!(value <= 10);
720    }
721}