1use std::time::{Duration, Instant};
2
3#[derive(Debug)]
5pub enum PollError<E> {
6 Timeout,
8 ConditionError(E),
10}
11
12impl<E> std::fmt::Display for PollError<E>
13where
14 E: std::fmt::Display,
15{
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 match self {
18 PollError::Timeout => write!(f, "Operation timed out"),
19 PollError::ConditionError(e) => write!(f, "Condition error: {}", e),
20 }
21 }
22}
23
24impl<E> std::error::Error for PollError<E>
25where
26 E: std::error::Error + 'static,
27{
28 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
29 match self {
30 PollError::Timeout => None,
31 PollError::ConditionError(e) => Some(e),
32 }
33 }
34}
35
36pub fn poll_until<F, E>(
66 mut condition: F,
67 timeout: Duration,
68 poll_interval: Duration,
69) -> Result<(), PollError<E>>
70where
71 F: FnMut() -> Result<bool, E>,
72{
73 let start = Instant::now();
74
75 loop {
76 if start.elapsed() >= timeout {
77 return Err(PollError::Timeout);
78 }
79
80 match condition() {
81 Ok(true) => return Ok(()),
82 Ok(false) => {
83 std::thread::sleep(poll_interval);
84 }
85 Err(e) => return Err(PollError::ConditionError(e)),
86 }
87 }
88}
89
90pub fn poll_with_timeout<F, T, E>(
130 mut operation: F,
131 timeout: Duration,
132 poll_interval: Duration,
133) -> Result<Option<T>, PollError<E>>
134where
135 F: FnMut() -> Result<Option<T>, E>,
136{
137 let start = Instant::now();
138
139 loop {
140 if start.elapsed() >= timeout {
141 return Ok(None);
142 }
143
144 match operation() {
145 Ok(Some(result)) => return Ok(Some(result)),
146 Ok(None) => {
147 std::thread::sleep(poll_interval);
148 }
149 Err(e) => return Err(PollError::ConditionError(e)),
150 }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::sync::{Arc, Mutex};
158
159 #[test]
160 fn test_poll_until_success() {
161 let counter = Arc::new(Mutex::new(0));
162 let counter_clone = counter.clone();
163
164 let result = poll_until(
165 || {
166 let mut count = counter_clone.lock().unwrap();
167 *count += 1;
168 Ok::<bool, &str>(*count >= 3)
169 },
170 Duration::from_millis(500),
171 Duration::from_millis(10),
172 );
173
174 assert!(result.is_ok());
175 assert!(*counter.lock().unwrap() >= 3);
176 }
177
178 #[test]
179 fn test_poll_until_timeout() {
180 let result = poll_until(
181 || Ok::<bool, &str>(false), Duration::from_millis(50),
183 Duration::from_millis(10),
184 );
185
186 assert!(matches!(result, Err(PollError::Timeout)));
187 }
188
189 #[test]
190 fn test_poll_until_error() {
191 let result = poll_until(
192 || Err::<bool, &str>("test error"),
193 Duration::from_millis(100),
194 Duration::from_millis(10),
195 );
196
197 assert!(matches!(result, Err(PollError::ConditionError("test error"))));
198 }
199
200 #[test]
201 fn test_poll_with_timeout_success() {
202 let counter = Arc::new(Mutex::new(0));
203 let counter_clone = counter.clone();
204
205 let result = poll_with_timeout(
206 || {
207 let mut count = counter_clone.lock().unwrap();
208 *count += 1;
209
210 if *count >= 3 {
211 Ok::<Option<i32>, &str>(Some(*count))
212 } else {
213 Ok(None)
214 }
215 },
216 Duration::from_millis(500),
217 Duration::from_millis(10),
218 );
219
220 assert!(result.is_ok());
221 assert_eq!(result.unwrap(), Some(3));
222 }
223
224 #[test]
225 fn test_poll_with_timeout_timeout() {
226 let result = poll_with_timeout(
227 || Ok::<Option<()>, &str>(None), Duration::from_millis(50),
229 Duration::from_millis(10),
230 );
231
232 assert!(result.is_ok());
233 assert_eq!(result.unwrap(), None);
234 }
235
236 #[test]
237 fn test_poll_with_timeout_error() {
238 let result = poll_with_timeout(
239 || Err::<Option<()>, &str>("test error"),
240 Duration::from_millis(100),
241 Duration::from_millis(10),
242 );
243
244 assert!(matches!(result, Err(PollError::ConditionError("test error"))));
245 }
246}