runpod_sdk/serverless/job.rs
1//! Job tracking and result retrieval
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use serde::de::DeserializeOwned;
9use serde_json::Value;
10
11use super::{JobStatus, JobStatusResponse, RunRequest, RunResponse, StreamChunk, StreamResponse};
12use crate::{Result, RunpodClient};
13
14#[cfg(feature = "tracing")]
15const TRACING_TARGET: &str = "runpod_sdk::serverless::job";
16
17pin_project_lite::pin_project! {
18 /// A job submitted to a serverless endpoint.
19 ///
20 /// Implements `Future` to allow awaiting the job result directly.
21 ///
22 /// # Examples
23 ///
24 /// ```no_run
25 /// # use runpod_sdk::{RunpodClient, Result};
26 /// # use runpod_sdk::serverless::ServerlessEndpoint;
27 /// # use serde_json::json;
28 /// # async fn example() -> Result<()> {
29 /// let client = RunpodClient::from_env()?;
30 /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
31 /// let job = endpoint.run(&json!({"prompt": "Hello"}))?;
32 ///
33 /// // Await the job to get the output
34 /// let output: serde_json::Value = job.await?;
35 /// println!("Output: {:?}", output);
36 /// # Ok(())
37 /// # }
38 /// ```
39 pub struct ServerlessJob {
40 endpoint_id: Arc<String>,
41 job_id: Option<String>,
42 input: Option<Value>,
43 client: RunpodClient,
44 #[pin]
45 state: JobState,
46 }
47}
48
49enum JobState {
50 NotSubmitted,
51 Submitting,
52 Polling,
53 Ready(Option<Value>),
54 Failed(crate::Error),
55}
56
57impl ServerlessJob {
58 /// Creates a new Job instance with input to be submitted
59 pub(crate) fn new(endpoint_id: Arc<String>, input: Value, client: RunpodClient) -> Self {
60 Self {
61 endpoint_id,
62 job_id: None,
63 input: Some(input),
64 client,
65 state: JobState::NotSubmitted,
66 }
67 }
68
69 /// Returns the job ID (if the job has been submitted)
70 pub fn job_id(&self) -> Option<&str> {
71 self.job_id.as_deref()
72 }
73
74 /// Returns the endpoint ID
75 pub fn endpoint_id(&self) -> &str {
76 &self.endpoint_id
77 }
78
79 /// Fetches the current job state from the specified endpoint.
80 async fn fetch_job(&self, source: &str) -> Result<serde_json::Value> {
81 let job_id = self
82 .job_id
83 .as_ref()
84 .ok_or_else(|| crate::Error::Job("Job has not been submitted yet".to_string()))?;
85 let path = format!("{}/{}/{}", self.endpoint_id, source, job_id);
86
87 let response = self.client.get_api(&path).send().await?;
88 let response = response.error_for_status()?;
89 let data: Value = response.json().await?;
90
91 Ok(data)
92 }
93
94 /// Returns the current status of the job.
95 ///
96 /// # Example
97 ///
98 /// ```no_run
99 /// # use runpod_sdk::{RunpodClient, Result};
100 /// # use runpod_sdk::serverless::{ServerlessEndpoint, JobStatus};
101 /// # use serde_json::json;
102 /// # async fn example() -> Result<()> {
103 /// let client = RunpodClient::from_env()?;
104 /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
105 /// let job = endpoint.run(&json!({"prompt": "Hello"}))?;
106 ///
107 /// let status = job.status().await?;
108 /// match status {
109 /// JobStatus::Completed => println!("Job finished"),
110 /// JobStatus::Failed => println!("Job failed"),
111 /// JobStatus::InProgress => println!("Job is running"),
112 /// _ => println!("Job status: {}", status),
113 /// }
114 /// # Ok(())
115 /// # }
116 /// ```
117 pub async fn status(&self) -> Result<JobStatus> {
118 #[cfg(feature = "tracing")]
119 tracing::debug!(
120 target: TRACING_TARGET,
121 job_id = ?self.job_id,
122 endpoint_id = %self.endpoint_id,
123 "Fetching job status"
124 );
125
126 let data = self.fetch_job("status").await?;
127 let response: JobStatusResponse = serde_json::from_value(data)?;
128
129 #[cfg(feature = "tracing")]
130 tracing::debug!(
131 target: TRACING_TARGET,
132 job_id = ?self.job_id,
133 status = %response.status,
134 "Job status retrieved"
135 );
136
137 Ok(response.status)
138 }
139
140 /// Returns the output of the job.
141 ///
142 /// # Example
143 ///
144 /// ```no_run
145 /// # use runpod_sdk::{RunpodClient, Result};
146 /// # use runpod_sdk::serverless::ServerlessEndpoint;
147 /// # use serde::{Deserialize, Serialize};
148 /// # use serde_json::json;
149 /// #
150 /// # #[derive(Deserialize)]
151 /// # struct Output {
152 /// # text: String,
153 /// # }
154 /// #
155 /// # async fn example() -> Result<()> {
156 /// let client = RunpodClient::from_env()?;
157 /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
158 /// let job = endpoint.run(&json!({"prompt": "Hello"}))?;
159 ///
160 /// let output: Output = job.output().await?;
161 /// println!("Result: {}", output.text);
162 /// # Ok(())
163 /// # }
164 /// ```
165 pub async fn output<O>(&self) -> Result<O>
166 where
167 O: DeserializeOwned,
168 {
169 let data = self.fetch_job("status").await?;
170 let response: JobStatusResponse = serde_json::from_value(data)?;
171
172 match response.output {
173 Some(output) => Ok(serde_json::from_value(output)?),
174 None => Err(crate::Error::Serialization(
175 serde_json::from_str::<Value>("\"Job has no output\"").unwrap_err(),
176 )),
177 }
178 }
179
180 /// Returns stream chunks from a streaming job.
181 ///
182 /// # Example
183 ///
184 /// ```no_run
185 /// # use runpod_sdk::{RunpodClient, Result};
186 /// # use runpod_sdk::serverless::{ServerlessEndpoint, JobStatus};
187 /// # use serde_json::json;
188 /// # async fn example() -> Result<()> {
189 /// let client = RunpodClient::from_env()?;
190 /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
191 /// let job = endpoint.run(&json!({"prompt": "Generate text"}))?;
192 ///
193 /// loop {
194 /// let (status, chunks) = job.stream().await?;
195 ///
196 /// for chunk in chunks {
197 /// println!("Chunk: {:?}", chunk.output);
198 /// }
199 ///
200 /// if status.is_final() {
201 /// break;
202 /// }
203 ///
204 /// std::thread::sleep(std::time::Duration::from_secs(1));
205 /// }
206 /// # Ok(())
207 /// # }
208 /// ```
209 pub async fn stream(&self) -> Result<(JobStatus, Vec<StreamChunk>)> {
210 let data = self.fetch_job("stream").await?;
211 let response: StreamResponse = serde_json::from_value(data)?;
212 Ok((response.status, response.stream))
213 }
214
215 /// Cancels the job.
216 ///
217 /// # Example
218 ///
219 /// ```no_run
220 /// # use runpod_sdk::{RunpodClient, Result};
221 /// # use runpod_sdk::serverless::ServerlessEndpoint;
222 /// # use serde_json::json;
223 /// # async fn example() -> Result<()> {
224 /// let client = RunpodClient::from_env()?;
225 /// let endpoint = ServerlessEndpoint::new("ENDPOINT_ID", client);
226 /// let job = endpoint.run(&json!({"prompt": "Long running task"}))?;
227 ///
228 /// job.cancel().await?;
229 /// println!("Job cancelled");
230 /// # Ok(())
231 /// # }
232 /// ```
233 pub async fn cancel(&self) -> Result<()> {
234 let job_id = self
235 .job_id
236 .as_ref()
237 .ok_or_else(|| crate::Error::Job("Job has not been submitted yet".to_string()))?;
238 let path = format!("{}/cancel/{}", self.endpoint_id, job_id);
239
240 let response = self.client.post_api(&path).send().await?;
241 response.error_for_status()?;
242 Ok(())
243 }
244}
245
246impl Future for ServerlessJob {
247 type Output = Result<Value>;
248
249 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250 let mut this = self.project();
251
252 match this.state.as_mut().get_mut() {
253 JobState::NotSubmitted => {
254 *this.state.get_mut() = JobState::Submitting;
255 cx.waker().wake_by_ref();
256 Poll::Pending
257 }
258 JobState::Submitting => {
259 // Submit the job
260 let endpoint_id = Arc::clone(this.endpoint_id);
261 let input = this.input.take().expect("Input should be present");
262 let client = this.client.clone();
263
264 let fut = async move {
265 #[cfg(feature = "tracing")]
266 tracing::debug!(
267 target: TRACING_TARGET,
268 endpoint_id = %endpoint_id,
269 "Submitting job to endpoint"
270 );
271
272 let path = format!("{}/run", endpoint_id);
273
274 let payload = RunRequest { input };
275
276 let response = client.post_api(&path).json(&payload).send().await?;
277
278 let response = response.error_for_status()?;
279 let run_response: RunResponse = response.json().await?;
280
281 #[cfg(feature = "tracing")]
282 tracing::info!(
283 target: TRACING_TARGET,
284 endpoint_id = %endpoint_id,
285 job_id = %run_response.id,
286 "Job submitted successfully"
287 );
288
289 Ok::<_, crate::Error>(run_response.id)
290 };
291
292 let mut pinned = Box::pin(fut);
293 match pinned.as_mut().poll(cx) {
294 Poll::Ready(Ok(job_id)) => {
295 *this.job_id = Some(job_id);
296 *this.state.get_mut() = JobState::Polling;
297 cx.waker().wake_by_ref();
298 Poll::Pending
299 }
300 Poll::Ready(Err(e)) => {
301 #[cfg(feature = "tracing")]
302 tracing::error!(
303 target: TRACING_TARGET,
304 endpoint_id = %this.endpoint_id,
305 error = %e,
306 "Failed to submit job"
307 );
308
309 *this.state.get_mut() = JobState::Failed(e);
310 cx.waker().wake_by_ref();
311 Poll::Pending
312 }
313 Poll::Pending => Poll::Pending,
314 }
315 }
316 JobState::Polling => {
317 // Create a future to fetch the job status
318 let endpoint_id = Arc::clone(this.endpoint_id);
319 let job_id = this.job_id.as_ref().expect("Job ID should be set").clone();
320 let client = this.client.clone();
321
322 let fut = async move {
323 let path = format!("{}/status/{}", endpoint_id, job_id);
324 let response = client.get_api(&path).send().await?;
325 let response = response.error_for_status()?;
326 let data: Value = response.json().await?;
327 let response: JobStatusResponse = serde_json::from_value(data)?;
328
329 Ok::<_, crate::Error>((response.status, response.output))
330 };
331
332 // Pin and poll the future
333 let mut pinned = Box::pin(fut);
334 match pinned.as_mut().poll(cx) {
335 Poll::Ready(Ok((status, output))) => {
336 if status.is_final() {
337 #[cfg(feature = "tracing")]
338 tracing::info!(
339 target: TRACING_TARGET,
340 job_id = ?this.job_id,
341 status = %status,
342 "Job reached final state"
343 );
344
345 *this.state.get_mut() = JobState::Ready(output);
346 cx.waker().wake_by_ref();
347 Poll::Pending
348 } else {
349 #[cfg(feature = "tracing")]
350 tracing::trace!(
351 target: TRACING_TARGET,
352 job_id = ?this.job_id,
353 status = %status,
354 "Job still in progress, continuing to poll"
355 );
356
357 // Still polling, wake up later
358 cx.waker().wake_by_ref();
359 Poll::Pending
360 }
361 }
362 Poll::Ready(Err(e)) => {
363 #[cfg(feature = "tracing")]
364 tracing::error!(
365 target: TRACING_TARGET,
366 job_id = ?this.job_id,
367 error = %e,
368 "Failed to poll job status"
369 );
370
371 *this.state.get_mut() = JobState::Failed(e);
372 cx.waker().wake_by_ref();
373 Poll::Pending
374 }
375 Poll::Pending => Poll::Pending,
376 }
377 }
378 JobState::Ready(output) => {
379 let output = output.take();
380 match output {
381 Some(val) => Poll::Ready(Ok(val)),
382 None => Poll::Ready(Err(crate::Error::Job("Job has no output".to_string()))),
383 }
384 }
385 JobState::Failed(_) => {
386 if let JobState::Failed(e) =
387 std::mem::replace(this.state.get_mut(), JobState::NotSubmitted)
388 {
389 Poll::Ready(Err(e))
390 } else {
391 unreachable!()
392 }
393 }
394 }
395 }
396}
397
398impl std::fmt::Debug for ServerlessJob {
399 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 f.debug_struct("Job")
401 .field("endpoint_id", &self.endpoint_id)
402 .field("job_id", &self.job_id)
403 .finish()
404 }
405}