Skip to main content

tako_socket/
lib.rs

1use serde::Serialize;
2use serde::de::DeserializeOwned;
3use std::future::Future;
4use tokio::io::BufReader;
5use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::UnixStream;
7
8pub const DEFAULT_MAX_LINE_BYTES: usize = 1024 * 1024;
9
10pub async fn read_json_line_with_limit<R, T>(
11    reader: &mut R,
12    max_bytes: usize,
13) -> std::io::Result<Option<T>>
14where
15    R: AsyncBufRead + Unpin,
16    T: DeserializeOwned,
17{
18    // Read incrementally, checking the limit at each chunk boundary so a
19    // malicious sender without a newline cannot force unbounded allocation.
20    let mut buf = Vec::new();
21    loop {
22        let available = reader.fill_buf().await?;
23        if available.is_empty() {
24            // EOF
25            if buf.is_empty() {
26                return Ok(None);
27            }
28            break;
29        }
30        if let Some(pos) = available.iter().position(|&b| b == b'\n') {
31            buf.extend_from_slice(&available[..=pos]);
32            reader.consume(pos + 1);
33            break;
34        }
35        buf.extend_from_slice(available);
36        let consumed = available.len();
37        reader.consume(consumed);
38        if buf.len() > max_bytes {
39            return Err(std::io::Error::new(
40                std::io::ErrorKind::InvalidData,
41                format!(
42                    "json line exceeds max length ({} > {})",
43                    buf.len(),
44                    max_bytes
45                ),
46            ));
47        }
48    }
49    if buf.is_empty() {
50        return Ok(None);
51    }
52    if buf.len() > max_bytes {
53        return Err(std::io::Error::new(
54            std::io::ErrorKind::InvalidData,
55            format!(
56                "json line exceeds max length ({} > {})",
57                buf.len(),
58                max_bytes
59            ),
60        ));
61    }
62
63    let s = std::str::from_utf8(&buf)
64        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
65
66    serde_json::from_str::<T>(s)
67        .map(Some)
68        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
69}
70
71pub async fn read_json_line<R, T>(reader: &mut R) -> std::io::Result<Option<T>>
72where
73    R: AsyncBufRead + Unpin,
74    T: DeserializeOwned,
75{
76    read_json_line_with_limit(reader, DEFAULT_MAX_LINE_BYTES).await
77}
78
79pub async fn write_json_line<W, T>(writer: &mut W, value: &T) -> std::io::Result<()>
80where
81    W: AsyncWrite + Unpin,
82    T: Serialize,
83{
84    let json = serde_json::to_string(value)
85        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
86    writer.write_all(json.as_bytes()).await?;
87    writer.write_all(b"\n").await?;
88    Ok(())
89}
90
91pub async fn serve_jsonl_connection<Req, Resp, F, Fut, InvalidResp>(
92    stream: UnixStream,
93    handler: F,
94    invalid_response: InvalidResp,
95) -> std::io::Result<()>
96where
97    Req: DeserializeOwned,
98    Resp: Serialize,
99    F: Fn(Req) -> Fut,
100    Fut: Future<Output = Resp>,
101    InvalidResp: Fn(std::io::Error) -> Resp,
102{
103    let (reader, mut writer) = stream.into_split();
104    let mut reader = BufReader::new(reader);
105
106    loop {
107        let Some(req) = (match read_json_line::<_, Req>(&mut reader).await {
108            Ok(v) => v,
109            Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
110                let resp = invalid_response(e);
111                let _ = write_json_line(&mut writer, &resp).await;
112                continue;
113            }
114            Err(e) => return Err(e),
115        }) else {
116            break;
117        };
118
119        let resp = handler(req).await;
120        write_json_line(&mut writer, &resp).await?;
121    }
122
123    Ok(())
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use tokio::io::BufReader;
130
131    #[tokio::test]
132    async fn roundtrips_struct_over_jsonl() {
133        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
134        struct Msg {
135            kind: String,
136            n: u64,
137        }
138
139        let (a, b) = tokio::io::duplex(1024);
140        let (mut ar, mut aw) = tokio::io::split(a);
141        let (mut br, mut bw) = tokio::io::split(b);
142        let mut ar = BufReader::new(&mut ar);
143        let mut br = BufReader::new(&mut br);
144
145        let a_send = Msg {
146            kind: "hello".to_string(),
147            n: 42,
148        };
149        write_json_line(&mut aw, &a_send).await.unwrap();
150        let b_recv: Msg = read_json_line(&mut br).await.unwrap().unwrap();
151        assert_eq!(b_recv, a_send);
152
153        let b_send = Msg {
154            kind: "world".to_string(),
155            n: 7,
156        };
157        write_json_line(&mut bw, &b_send).await.unwrap();
158        let a_recv: Msg = read_json_line(&mut ar).await.unwrap().unwrap();
159        assert_eq!(a_recv, b_send);
160    }
161
162    #[tokio::test]
163    async fn returns_invalid_data_on_bad_json() {
164        let (a, b) = tokio::io::duplex(1024);
165        let (mut _ar, mut aw) = tokio::io::split(a);
166        let (mut br, _bw) = tokio::io::split(b);
167        let mut br = BufReader::new(&mut br);
168
169        aw.write_all(b"{not json}\n").await.unwrap();
170
171        let err = read_json_line::<_, serde_json::Value>(&mut br)
172            .await
173            .unwrap_err();
174        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
175    }
176
177    #[tokio::test]
178    async fn errors_when_line_exceeds_limit() {
179        let (a, b) = tokio::io::duplex(1024 * 1024);
180        let (mut _ar, mut aw) = tokio::io::split(a);
181        let (mut br, _bw) = tokio::io::split(b);
182        let mut br = BufReader::new(&mut br);
183
184        // Write a line bigger than our limit.
185        let big = "a".repeat(33);
186        aw.write_all(big.as_bytes()).await.unwrap();
187        aw.write_all(b"\n").await.unwrap();
188
189        let err = read_json_line_with_limit::<_, serde_json::Value>(&mut br, 32)
190            .await
191            .unwrap_err();
192        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
193    }
194
195    #[tokio::test]
196    async fn serve_jsonl_connection_handles_invalid_and_valid_requests() {
197        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
198        struct Req {
199            n: u64,
200        }
201        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
202        struct Resp {
203            ok: bool,
204            n: u64,
205        }
206
207        let (a, b) = UnixStream::pair().unwrap();
208        let h = tokio::spawn(async move {
209            serve_jsonl_connection(
210                a,
211                |req: Req| async move { Resp { ok: true, n: req.n } },
212                |_e| Resp { ok: false, n: 0 },
213            )
214            .await
215            .unwrap();
216        });
217
218        let (r, mut w) = b.into_split();
219        let mut r = BufReader::new(r);
220
221        // Invalid JSON should yield an error response.
222        w.write_all(b"{not json}\n").await.unwrap();
223        let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
224        assert_eq!(resp, Resp { ok: false, n: 0 });
225
226        // Valid JSON should roundtrip.
227        write_json_line(&mut w, &Req { n: 7 }).await.unwrap();
228        let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
229        assert_eq!(resp, Resp { ok: true, n: 7 });
230
231        drop(w);
232        h.await.unwrap();
233    }
234}