Skip to main content

vtcode_commons/
async_utils.rs

1//! Async utility functions
2
3use anyhow::{Context, Result};
4use std::future::Future;
5use std::time::Duration;
6use tokio::io::AsyncReadExt;
7use tokio::time::timeout;
8
9pub const DEFAULT_ASYNC_TIMEOUT: Duration = Duration::from_secs(30);
10pub const SHORT_ASYNC_TIMEOUT: Duration = Duration::from_secs(5);
11pub const LONG_ASYNC_TIMEOUT: Duration = Duration::from_secs(300);
12
13/// Execute a future with a timeout and context
14pub async fn with_timeout<F, T>(fut: F, duration: Duration, context: &str) -> Result<T>
15where
16    F: Future<Output = T>,
17{
18    match timeout(duration, fut).await {
19        Ok(result) => Ok(result),
20        Err(_) => anyhow::bail!("Operation timed out after {duration:?}: {context}"),
21    }
22}
23
24/// Execute a future with the default timeout
25pub async fn with_default_timeout<F, T>(fut: F, context: &str) -> Result<T>
26where
27    F: Future<Output = T>,
28{
29    with_timeout(fut, DEFAULT_ASYNC_TIMEOUT, context).await
30}
31
32/// Execute a future with a short timeout
33pub async fn with_short_timeout<F, T>(fut: F, context: &str) -> Result<T>
34where
35    F: Future<Output = T>,
36{
37    with_timeout(fut, SHORT_ASYNC_TIMEOUT, context).await
38}
39
40/// Execute a future with a long timeout
41pub async fn with_long_timeout<F, T>(fut: F, context: &str) -> Result<T>
42where
43    F: Future<Output = T>,
44{
45    with_timeout(fut, LONG_ASYNC_TIMEOUT, context).await
46}
47
48/// Retry an operation with exponential backoff
49pub async fn retry_with_backoff<F, Fut, T>(
50    mut op: F,
51    max_retries: usize,
52    initial_delay: Duration,
53    context: &str,
54) -> Result<T>
55where
56    F: FnMut() -> Fut,
57    Fut: Future<Output = Result<T>>,
58{
59    let mut delay = initial_delay;
60    let mut last_error = None;
61
62    for i in 0..=max_retries {
63        match op().await {
64            Ok(result) => return Ok(result),
65            Err(e) => {
66                last_error = Some(e);
67                if i < max_retries {
68                    tokio::time::sleep(delay).await;
69                    delay *= 2;
70                }
71            }
72        }
73    }
74
75    let err = last_error.unwrap_or_else(|| anyhow::anyhow!("Retry failed without error"));
76    Err(err).with_context(|| format!("Operation failed after {max_retries} retries: {context}"))
77}
78
79/// Sleep with context
80pub async fn sleep_with_context(duration: Duration, _context: &str) {
81    tokio::time::sleep(duration).await;
82}
83
84/// Run multiple futures and wait for all with a timeout
85pub async fn join_all_with_timeout<F, T>(
86    futs: Vec<F>,
87    duration: Duration,
88    context: &str,
89) -> Result<Vec<T>>
90where
91    F: Future<Output = T>,
92{
93    with_timeout(futures::future::join_all(futs), duration, context).await
94}
95
96/// Read exactly `len` bytes from an async reader without zero-initializing
97/// the buffer first.
98///
99/// This avoids the double-write overhead of `vec![0u8; len]` followed by
100/// `read_exact` — `read_buf` appends directly into the `Vec`'s spare capacity,
101/// so the zeroing pass is skipped. For large payloads this can yield
102/// measurable performance gains.
103///
104/// The returned `Vec` has exactly `len` initialized bytes.
105///
106/// This used to be an `unsafe` function that handed `read_exact` a
107/// `&mut [u8]` over uninitialized memory via `from_raw_parts_mut`. That was
108/// unsound: `tokio::io::ReadBuf::new(&mut [u8])` asserts the whole buffer is
109/// initialized, so a reader conforming to the `ReadBuf` contract would be
110/// entitled to read the (uninitialized) `[filled, initialized)` region. The
111/// `read_buf`-based implementation below is fully safe and keeps the same
112/// zero-overhead property, miri-clean by construction.
113///
114/// # Errors
115///
116/// Returns `io::ErrorKind::UnexpectedEof` if the reader reaches EOF before
117/// `len` bytes have been read.
118pub async fn read_exact_uninit<R>(reader: &mut R, len: usize) -> std::io::Result<Vec<u8>>
119where
120    R: tokio::io::AsyncRead + Unpin,
121{
122    let mut buf = Vec::with_capacity(len);
123    // `read_buf` fills the Vec's spare capacity without zero-initializing it,
124    // and is sound with respect to uninitialized memory by construction.
125    while buf.len() < len {
126        let n = reader.read_buf(&mut buf).await?;
127        if n == 0 {
128            return Err(std::io::Error::new(
129                std::io::ErrorKind::UnexpectedEof,
130                format!(
131                    "unexpected EOF before reading {len} bytes (got {})",
132                    buf.len()
133                ),
134            ));
135        }
136    }
137    Ok(buf)
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[tokio::test]
145    async fn read_exact_uninit_round_trips_known_payload() {
146        let payload: Vec<u8> = (0..64u8).collect();
147        let mut reader = std::io::Cursor::new(payload.clone());
148        let got = read_exact_uninit(&mut reader, payload.len())
149            .await
150            .expect("read full payload");
151        assert_eq!(got, payload);
152    }
153
154    #[tokio::test]
155    async fn read_exact_uninit_reads_across_multiple_poll_reads() {
156        // A payload larger than a single `read_buf` is likely to require
157        // several poll_read calls; the loop must still accumulate correctly.
158        let payload: Vec<u8> = (0..2000u32).map(|i| (i % 256) as u8).collect();
159        let mut reader = std::io::Cursor::new(payload.clone());
160        let got = read_exact_uninit(&mut reader, payload.len())
161            .await
162            .expect("read full payload");
163        assert_eq!(got, payload);
164    }
165
166    #[tokio::test]
167    async fn read_exact_uninit_returns_unexpected_eof_on_short_read() {
168        let payload = b"only ten!".to_vec();
169        let mut reader = std::io::Cursor::new(payload);
170        let err = read_exact_uninit(&mut reader, 64)
171            .await
172            .expect_err("short read must error");
173        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
174    }
175
176    #[tokio::test]
177    async fn read_exact_uninit_returns_unexpected_eof_on_empty_reader() {
178        let mut reader = std::io::Cursor::new(Vec::<u8>::new());
179        let err = read_exact_uninit(&mut reader, 1)
180            .await
181            .expect_err("empty reader must error");
182        assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
183    }
184
185    #[tokio::test]
186    async fn read_exact_uninit_zero_len_returns_empty_vec() {
187        let mut reader = std::io::Cursor::new(Vec::<u8>::new());
188        let got = read_exact_uninit(&mut reader, 0)
189            .await
190            .expect("zero-length read must succeed");
191        assert!(got.is_empty());
192    }
193}