runpod_sdk/serverless/
job.rs

1//! Job tracking and result retrieval
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10
11use super::{JobStatus, JobStatusResponse, RunRequest, RunResponse, StreamChunk, StreamResponse};
12use crate::{Result, RunpodClient};
13
14#[cfg(feature = "tracing")]
15const TRACING_TARGET: &str = "runpod_sdk::serverless::job";
16
17pin_project_lite::pin_project! {
18    /// A job submitted to a serverless endpoint.
19    ///
20    /// Implements `Future` to allow awaiting the job result directly.
21    ///
22    /// # Examples
23    ///
24    /// ```no_run
25    /// # use runpod_sdk::{RunpodClient, Result};
26    /// # use runpod_sdk::serverless::ServerlessEndpoint;
27    /// # use serde_json::json;
28    /// # async fn example() -> Result<()> {
29    /// let client = RunpodClient::from_env()?;
30    /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
31    /// let job = endpoint.run(&json!({"prompt": "Hello"}))?;
32    ///
33    /// // Await the job to get the output
34    /// let output: serde_json::Value = job.await?;
35    /// println!("Output: {:?}", output);
36    /// # Ok(())
37    /// # }
38    /// ```
39    pub struct ServerlessJob {
40        endpoint_id: Arc<String>,
41        job_id: Option<String>,
42        input: Option<Value>,
43        client: RunpodClient,
44        #[pin]
45        state: JobState,
46    }
47}
48
49enum JobState {
50    NotSubmitted,
51    Submitting,
52    Polling,
53    Ready(Option<Value>),
54    Failed(crate::Error),
55}
56
57impl ServerlessJob {
58    /// Creates a new Job instance with input to be submitted
59    pub(crate) fn new(endpoint_id: Arc<String>, input: Value, client: RunpodClient) -> Self {
60        Self {
61            endpoint_id,
62            job_id: None,
63            input: Some(input),
64            client,
65            state: JobState::NotSubmitted,
66        }
67    }
68
69    /// Returns the job ID (if the job has been submitted)
70    pub fn job_id(&self) -> Option<&str> {
71        self.job_id.as_deref()
72    }
73
74    /// Returns the endpoint ID
75    pub fn endpoint_id(&self) -> &str {
76        &self.endpoint_id
77    }
78
79    /// Fetches the current job state from the specified endpoint.
80    async fn fetch_job(&self, source: &str) -> Result<serde_json::Value> {
81        let job_id = self
82            .job_id
83            .as_ref()
84            .ok_or_else(|| crate::Error::Job("Job has not been submitted yet".to_string()))?;
85        let path = format!("{}/{}/{}", self.endpoint_id, source, job_id);
86
87        let response = self.client.get_api(&path).send().await?;
88        let response = response.error_for_status()?;
89        let data: Value = response.json().await?;
90
91        Ok(data)
92    }
93
94    /// Returns the current status of the job.
95    ///
96    /// # Example
97    ///
98    /// ```no_run
99    /// # use runpod_sdk::{RunpodClient, Result};
100    /// # use runpod_sdk::serverless::{ServerlessEndpoint, JobStatus};
101    /// # use serde_json::json;
102    /// # async fn example() -> Result<()> {
103    /// let client = RunpodClient::from_env()?;
104    /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
105    /// let job = endpoint.run(&json!({"prompt": "Hello"}))?;
106    ///
107    /// let status = job.status().await?;
108    /// match status {
109    ///     JobStatus::Completed => println!("Job finished"),
110    ///     JobStatus::Failed => println!("Job failed"),
111    ///     JobStatus::InProgress => println!("Job is running"),
112    ///     _ => println!("Job status: {}", status),
113    /// }
114    /// # Ok(())
115    /// # }
116    /// ```
117    pub async fn status(&self) -> Result<JobStatus> {
118        #[cfg(feature = "tracing")]
119        tracing::debug!(
120            target: TRACING_TARGET,
121            job_id = ?self.job_id,
122            endpoint_id = %self.endpoint_id,
123            "Fetching job status"
124        );
125
126        let data = self.fetch_job("status").await?;
127        let response: JobStatusResponse = serde_json::from_value(data)?;
128
129        #[cfg(feature = "tracing")]
130        tracing::debug!(
131            target: TRACING_TARGET,
132            job_id = ?self.job_id,
133            status = %response.status,
134            "Job status retrieved"
135        );
136
137        Ok(response.status)
138    }
139
140    /// Returns the output of the job.
141    ///
142    /// # Example
143    ///
144    /// ```no_run
145    /// # use runpod_sdk::{RunpodClient, Result};
146    /// # use runpod_sdk::serverless::ServerlessEndpoint;
147    /// # use serde::{Deserialize, Serialize};
148    /// # use serde_json::json;
149    /// #
150    /// # #[derive(Deserialize)]
151    /// # struct Output {
152    /// #     text: String,
153    /// # }
154    /// #
155    /// # async fn example() -> Result<()> {
156    /// let client = RunpodClient::from_env()?;
157    /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
158    /// let job = endpoint.run(&json!({"prompt": "Hello"}))?;
159    ///
160    /// let output: Output = job.output().await?;
161    /// println!("Result: {}", output.text);
162    /// # Ok(())
163    /// # }
164    /// ```
165    pub async fn output<O>(&self) -> Result<O>
166    where
167        O: DeserializeOwned,
168    {
169        let data = self.fetch_job("status").await?;
170        let response: JobStatusResponse = serde_json::from_value(data)?;
171
172        match response.output {
173            Some(output) => Ok(serde_json::from_value(output)?),
174            None => Err(crate::Error::Serialization(
175                serde_json::from_str::<Value>("\"Job has no output\"").unwrap_err(),
176            )),
177        }
178    }
179
180    /// Returns stream chunks from a streaming job.
181    ///
182    /// # Example
183    ///
184    /// ```no_run
185    /// # use runpod_sdk::{RunpodClient, Result};
186    /// # use runpod_sdk::serverless::{ServerlessEndpoint, JobStatus};
187    /// # use serde_json::json;
188    /// # async fn example() -> Result<()> {
189    /// let client = RunpodClient::from_env()?;
190    /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
191    /// let job = endpoint.run(&json!({"prompt": "Generate text"}))?;
192    ///
193    /// loop {
194    ///     let (status, chunks) = job.stream().await?;
195    ///
196    ///     for chunk in chunks {
197    ///         println!("Chunk: {:?}", chunk.output);
198    ///     }
199    ///
200    ///     if status.is_final() {
201    ///         break;
202    ///     }
203    ///
204    ///     std::thread::sleep(std::time::Duration::from_secs(1));
205    /// }
206    /// # Ok(())
207    /// # }
208    /// ```
209    pub async fn stream(&self) -> Result<(JobStatus, Vec<StreamChunk>)> {
210        let data = self.fetch_job("stream").await?;
211        let response: StreamResponse = serde_json::from_value(data)?;
212        Ok((response.status, response.stream))
213    }
214
215    /// Cancels the job.
216    ///
217    /// # Example
218    ///
219    /// ```no_run
220    /// # use runpod_sdk::{RunpodClient, Result};
221    /// # use runpod_sdk::serverless::ServerlessEndpoint;
222    /// # use serde_json::json;
223    /// # async fn example() -> Result<()> {
224    /// let client = RunpodClient::from_env()?;
225    /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
226    /// let job = endpoint.run(&json!({"prompt": "Long running task"}))?;
227    ///
228    /// job.cancel().await?;
229    /// println!("Job cancelled");
230    /// # Ok(())
231    /// # }
232    /// ```
233    pub async fn cancel(&self) -> Result<()> {
234        let job_id = self
235            .job_id
236            .as_ref()
237            .ok_or_else(|| crate::Error::Job("Job has not been submitted yet".to_string()))?;
238        let path = format!("{}/cancel/{}", self.endpoint_id, job_id);
239
240        let response = self.client.post_api(&path).send().await?;
241        response.error_for_status()?;
242        Ok(())
243    }
244}
245
246impl Future for ServerlessJob {
247    type Output = Result<Value>;
248
249    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250        let mut this = self.project();
251
252        match this.state.as_mut().get_mut() {
253            JobState::NotSubmitted => {
254                *this.state.get_mut() = JobState::Submitting;
255                cx.waker().wake_by_ref();
256                Poll::Pending
257            }
258            JobState::Submitting => {
259                // Submit the job
260                let endpoint_id = Arc::clone(this.endpoint_id);
261                let input = this.input.take().expect("Input should be present");
262                let client = this.client.clone();
263
264                let fut = async move {
265                    #[cfg(feature = "tracing")]
266                    tracing::debug!(
267                        target: TRACING_TARGET,
268                        endpoint_id = %endpoint_id,
269                        "Submitting job to endpoint"
270                    );
271
272                    let path = format!("{}/run", endpoint_id);
273
274                    let payload = RunRequest { input };
275
276                    let response = client.post_api(&path).json(&payload).send().await?;
277
278                    let response = response.error_for_status()?;
279                    let run_response: RunResponse = response.json().await?;
280
281                    #[cfg(feature = "tracing")]
282                    tracing::info!(
283                        target: TRACING_TARGET,
284                        endpoint_id = %endpoint_id,
285                        job_id = %run_response.id,
286                        "Job submitted successfully"
287                    );
288
289                    Ok::<_, crate::Error>(run_response.id)
290                };
291
292                let mut pinned = Box::pin(fut);
293                match pinned.as_mut().poll(cx) {
294                    Poll::Ready(Ok(job_id)) => {
295                        *this.job_id = Some(job_id);
296                        *this.state.get_mut() = JobState::Polling;
297                        cx.waker().wake_by_ref();
298                        Poll::Pending
299                    }
300                    Poll::Ready(Err(e)) => {
301                        #[cfg(feature = "tracing")]
302                        tracing::error!(
303                            target: TRACING_TARGET,
304                            endpoint_id = %this.endpoint_id,
305                            error = %e,
306                            "Failed to submit job"
307                        );
308
309                        *this.state.get_mut() = JobState::Failed(e);
310                        cx.waker().wake_by_ref();
311                        Poll::Pending
312                    }
313                    Poll::Pending => Poll::Pending,
314                }
315            }
316            JobState::Polling => {
317                // Create a future to fetch the job status
318                let endpoint_id = Arc::clone(this.endpoint_id);
319                let job_id = this.job_id.as_ref().expect("Job ID should be set").clone();
320                let client = this.client.clone();
321
322                let fut = async move {
323                    let path = format!("{}/status/{}", endpoint_id, job_id);
324                    let response = client.get_api(&path).send().await?;
325                    let response = response.error_for_status()?;
326                    let data: Value = response.json().await?;
327                    let response: JobStatusResponse = serde_json::from_value(data)?;
328
329                    Ok::<_, crate::Error>((response.status, response.output))
330                };
331
332                // Pin and poll the future
333                let mut pinned = Box::pin(fut);
334                match pinned.as_mut().poll(cx) {
335                    Poll::Ready(Ok((status, output))) => {
336                        if status.is_final() {
337                            #[cfg(feature = "tracing")]
338                            tracing::info!(
339                                target: TRACING_TARGET,
340                                job_id = ?this.job_id,
341                                status = %status,
342                                "Job reached final state"
343                            );
344
345                            *this.state.get_mut() = JobState::Ready(output);
346                            cx.waker().wake_by_ref();
347                            Poll::Pending
348                        } else {
349                            #[cfg(feature = "tracing")]
350                            tracing::trace!(
351                                target: TRACING_TARGET,
352                                job_id = ?this.job_id,
353                                status = %status,
354                                "Job still in progress, continuing to poll"
355                            );
356
357                            // Still polling, wake up later
358                            cx.waker().wake_by_ref();
359                            Poll::Pending
360                        }
361                    }
362                    Poll::Ready(Err(e)) => {
363                        #[cfg(feature = "tracing")]
364                        tracing::error!(
365                            target: TRACING_TARGET,
366                            job_id = ?this.job_id,
367                            error = %e,
368                            "Failed to poll job status"
369                        );
370
371                        *this.state.get_mut() = JobState::Failed(e);
372                        cx.waker().wake_by_ref();
373                        Poll::Pending
374                    }
375                    Poll::Pending => Poll::Pending,
376                }
377            }
378            JobState::Ready(output) => {
379                let output = output.take();
380                match output {
381                    Some(val) => Poll::Ready(Ok(val)),
382                    None => Poll::Ready(Err(crate::Error::Job("Job has no output".to_string()))),
383                }
384            }
385            JobState::Failed(_) => {
386                if let JobState::Failed(e) =
387                    std::mem::replace(this.state.get_mut(), JobState::NotSubmitted)
388                {
389                    Poll::Ready(Err(e))
390                } else {
391                    unreachable!()
392                }
393            }
394        }
395    }
396}
397
398impl std::fmt::Debug for ServerlessJob {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        f.debug_struct("Job")
401            .field("endpoint_id", &self.endpoint_id)
402            .field("job_id", &self.job_id)
403            .finish()
404    }
405}