Skip to main content

tripo_api/
wait.rs

1//! `wait_for_task`: poll until terminal status with ETA-driven backoff.
2
3use std::time::{Duration, Instant};
4
5use crate::client::Client;
6use crate::error::{Error, Result};
7use crate::types::{Task, TaskId};
8
9/// Callback invoked after each successful poll.
10pub type ProgressCallback = Box<dyn Fn(&Task) + Send + Sync>;
11
12/// Options for [`Client::wait_for_task`].
13pub struct WaitOptions {
14    /// Overall timeout. `None` → wait forever.
15    pub timeout: Option<Duration>,
16    /// Cap on the polling interval.
17    pub max_interval: Duration,
18    /// Initial polling interval when no ETA is available.
19    pub initial_interval: Duration,
20    /// Called after every poll.
21    pub on_progress: Option<ProgressCallback>,
22}
23
24impl Default for WaitOptions {
25    fn default() -> Self {
26        Self {
27            timeout: None,
28            max_interval: Duration::from_secs(30),
29            initial_interval: Duration::from_secs(2),
30            on_progress: None,
31        }
32    }
33}
34
35impl std::fmt::Debug for WaitOptions {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        f.debug_struct("WaitOptions")
38            .field("timeout", &self.timeout)
39            .field("max_interval", &self.max_interval)
40            .field("initial_interval", &self.initial_interval)
41            .field("on_progress", &self.on_progress.as_ref().map(|_| "<fn>"))
42            .finish()
43    }
44}
45
46/// Compute the next polling delay given a task.
47///
48/// Mirrors the Python SDK: if `running_left_time` is present, sleep `max(2s, eta/2)`;
49/// otherwise double the previous interval, capped at `max_interval`.
50pub(crate) fn next_interval(
51    task: &Task,
52    previous: Duration,
53    initial: Duration,
54    max_interval: Duration,
55) -> Duration {
56    if let Some(eta) = task.running_left_time {
57        let eta_secs = u64::try_from(eta.max(0)).unwrap_or(0);
58        let half = Duration::from_secs(eta_secs) / 2;
59        half.max(initial).min(max_interval)
60    } else {
61        (previous * 2).min(max_interval)
62    }
63}
64
65impl Client {
66    /// Poll `GET /task/{id}` until the status is terminal or `opts.timeout` is reached.
67    ///
68    /// Returns the final `Task` even for non-success terminal statuses; callers can check
69    /// `task.status`. Use `Error::WaitTimeout` if you want an error on timeout (returned here).
70    #[tracing::instrument(skip(self, opts), fields(task_id = %id))]
71    pub async fn wait_for_task(&self, id: &TaskId, opts: WaitOptions) -> Result<Task> {
72        let started = Instant::now();
73        let mut interval = opts.initial_interval;
74        loop {
75            let task = self.get_task(id).await?;
76            if let Some(cb) = &opts.on_progress {
77                cb(&task);
78            }
79            if task.status.is_terminal() {
80                return Ok(task);
81            }
82            interval = next_interval(&task, interval, opts.initial_interval, opts.max_interval);
83            if let Some(deadline) = opts.timeout {
84                let elapsed = started.elapsed();
85                let Some(remaining) = deadline.checked_sub(elapsed) else {
86                    return Err(Error::WaitTimeout(id.clone()));
87                };
88                if remaining.is_zero() {
89                    return Err(Error::WaitTimeout(id.clone()));
90                }
91                let to_sleep = interval.min(remaining);
92                tokio::time::sleep(to_sleep).await;
93            } else {
94                tokio::time::sleep(interval).await;
95            }
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use std::collections::BTreeMap;
103
104    use super::*;
105    use crate::types::{TaskOutput, TaskStatus};
106
107    fn task_with_eta(eta: Option<i64>) -> Task {
108        Task {
109            task_id: "x".into(),
110            task_type: "text_to_model".into(),
111            status: TaskStatus::Running,
112            input: BTreeMap::new(),
113            output: TaskOutput::default(),
114            progress: 0,
115            create_time: 0,
116            running_left_time: eta,
117            queuing_num: None,
118            error_code: None,
119            error_msg: None,
120        }
121    }
122
123    #[test]
124    fn eta_drives_half_of_remaining() {
125        let t = task_with_eta(Some(40));
126        let d = next_interval(
127            &t,
128            Duration::from_secs(2),
129            Duration::from_secs(2),
130            Duration::from_secs(30),
131        );
132        assert_eq!(d, Duration::from_secs(20));
133    }
134
135    #[test]
136    fn eta_capped_by_max() {
137        let t = task_with_eta(Some(600));
138        let d = next_interval(
139            &t,
140            Duration::from_secs(2),
141            Duration::from_secs(2),
142            Duration::from_secs(30),
143        );
144        assert_eq!(d, Duration::from_secs(30));
145    }
146
147    #[test]
148    fn eta_floor_is_initial() {
149        let t = task_with_eta(Some(1));
150        let d = next_interval(
151            &t,
152            Duration::from_secs(2),
153            Duration::from_secs(2),
154            Duration::from_secs(30),
155        );
156        assert_eq!(d, Duration::from_secs(2));
157    }
158
159    #[test]
160    fn without_eta_exponential() {
161        let t = task_with_eta(None);
162        let d = next_interval(
163            &t,
164            Duration::from_secs(2),
165            Duration::from_secs(2),
166            Duration::from_secs(30),
167        );
168        assert_eq!(d, Duration::from_secs(4));
169        let d2 = next_interval(
170            &t,
171            Duration::from_secs(20),
172            Duration::from_secs(2),
173            Duration::from_secs(30),
174        );
175        assert_eq!(d2, Duration::from_secs(30));
176    }
177}