1use std::time::{Duration, Instant};
4
5use crate::client::Client;
6use crate::error::{Error, Result};
7use crate::types::{Task, TaskId};
8
9pub type ProgressCallback = Box<dyn Fn(&Task) + Send + Sync>;
11
12pub struct WaitOptions {
14 pub timeout: Option<Duration>,
16 pub max_interval: Duration,
18 pub initial_interval: Duration,
20 pub on_progress: Option<ProgressCallback>,
22}
23
24impl Default for WaitOptions {
25 fn default() -> Self {
26 Self {
27 timeout: None,
28 max_interval: Duration::from_secs(30),
29 initial_interval: Duration::from_secs(2),
30 on_progress: None,
31 }
32 }
33}
34
35impl std::fmt::Debug for WaitOptions {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 f.debug_struct("WaitOptions")
38 .field("timeout", &self.timeout)
39 .field("max_interval", &self.max_interval)
40 .field("initial_interval", &self.initial_interval)
41 .field("on_progress", &self.on_progress.as_ref().map(|_| "<fn>"))
42 .finish()
43 }
44}
45
46pub(crate) fn next_interval(
51 task: &Task,
52 previous: Duration,
53 initial: Duration,
54 max_interval: Duration,
55) -> Duration {
56 if let Some(eta) = task.running_left_time {
57 let eta_secs = u64::try_from(eta.max(0)).unwrap_or(0);
58 let half = Duration::from_secs(eta_secs) / 2;
59 half.max(initial).min(max_interval)
60 } else {
61 (previous * 2).min(max_interval)
62 }
63}
64
65impl Client {
66 #[tracing::instrument(skip(self, opts), fields(task_id = %id))]
71 pub async fn wait_for_task(&self, id: &TaskId, opts: WaitOptions) -> Result<Task> {
72 let started = Instant::now();
73 let mut interval = opts.initial_interval;
74 loop {
75 let task = self.get_task(id).await?;
76 if let Some(cb) = &opts.on_progress {
77 cb(&task);
78 }
79 if task.status.is_terminal() {
80 return Ok(task);
81 }
82 interval = next_interval(&task, interval, opts.initial_interval, opts.max_interval);
83 if let Some(deadline) = opts.timeout {
84 let elapsed = started.elapsed();
85 let Some(remaining) = deadline.checked_sub(elapsed) else {
86 return Err(Error::WaitTimeout(id.clone()));
87 };
88 if remaining.is_zero() {
89 return Err(Error::WaitTimeout(id.clone()));
90 }
91 let to_sleep = interval.min(remaining);
92 tokio::time::sleep(to_sleep).await;
93 } else {
94 tokio::time::sleep(interval).await;
95 }
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use std::collections::BTreeMap;
103
104 use super::*;
105 use crate::types::{TaskOutput, TaskStatus};
106
107 fn task_with_eta(eta: Option<i64>) -> Task {
108 Task {
109 task_id: "x".into(),
110 task_type: "text_to_model".into(),
111 status: TaskStatus::Running,
112 input: BTreeMap::new(),
113 output: TaskOutput::default(),
114 progress: 0,
115 create_time: 0,
116 running_left_time: eta,
117 queuing_num: None,
118 error_code: None,
119 error_msg: None,
120 }
121 }
122
123 #[test]
124 fn eta_drives_half_of_remaining() {
125 let t = task_with_eta(Some(40));
126 let d = next_interval(
127 &t,
128 Duration::from_secs(2),
129 Duration::from_secs(2),
130 Duration::from_secs(30),
131 );
132 assert_eq!(d, Duration::from_secs(20));
133 }
134
135 #[test]
136 fn eta_capped_by_max() {
137 let t = task_with_eta(Some(600));
138 let d = next_interval(
139 &t,
140 Duration::from_secs(2),
141 Duration::from_secs(2),
142 Duration::from_secs(30),
143 );
144 assert_eq!(d, Duration::from_secs(30));
145 }
146
147 #[test]
148 fn eta_floor_is_initial() {
149 let t = task_with_eta(Some(1));
150 let d = next_interval(
151 &t,
152 Duration::from_secs(2),
153 Duration::from_secs(2),
154 Duration::from_secs(30),
155 );
156 assert_eq!(d, Duration::from_secs(2));
157 }
158
159 #[test]
160 fn without_eta_exponential() {
161 let t = task_with_eta(None);
162 let d = next_interval(
163 &t,
164 Duration::from_secs(2),
165 Duration::from_secs(2),
166 Duration::from_secs(30),
167 );
168 assert_eq!(d, Duration::from_secs(4));
169 let d2 = next_interval(
170 &t,
171 Duration::from_secs(20),
172 Duration::from_secs(2),
173 Duration::from_secs(30),
174 );
175 assert_eq!(d2, Duration::from_secs(30));
176 }
177}