1use serde::Serialize;
2use serde::de::DeserializeOwned;
3use std::future::Future;
4use tokio::io::BufReader;
5use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::UnixStream;
7
8pub const DEFAULT_MAX_LINE_BYTES: usize = 1024 * 1024;
9
10pub async fn read_json_line_with_limit<R, T>(
11 reader: &mut R,
12 max_bytes: usize,
13) -> std::io::Result<Option<T>>
14where
15 R: AsyncBufRead + Unpin,
16 T: DeserializeOwned,
17{
18 let mut buf = Vec::new();
21 loop {
22 let available = reader.fill_buf().await?;
23 if available.is_empty() {
24 if buf.is_empty() {
26 return Ok(None);
27 }
28 break;
29 }
30 if let Some(pos) = available.iter().position(|&b| b == b'\n') {
31 buf.extend_from_slice(&available[..=pos]);
32 reader.consume(pos + 1);
33 break;
34 }
35 buf.extend_from_slice(available);
36 let consumed = available.len();
37 reader.consume(consumed);
38 if buf.len() > max_bytes {
39 return Err(std::io::Error::new(
40 std::io::ErrorKind::InvalidData,
41 format!(
42 "json line exceeds max length ({} > {})",
43 buf.len(),
44 max_bytes
45 ),
46 ));
47 }
48 }
49 if buf.is_empty() {
50 return Ok(None);
51 }
52 if buf.len() > max_bytes {
53 return Err(std::io::Error::new(
54 std::io::ErrorKind::InvalidData,
55 format!(
56 "json line exceeds max length ({} > {})",
57 buf.len(),
58 max_bytes
59 ),
60 ));
61 }
62
63 let s = std::str::from_utf8(&buf)
64 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
65
66 serde_json::from_str::<T>(s)
67 .map(Some)
68 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
69}
70
71pub async fn read_json_line<R, T>(reader: &mut R) -> std::io::Result<Option<T>>
72where
73 R: AsyncBufRead + Unpin,
74 T: DeserializeOwned,
75{
76 read_json_line_with_limit(reader, DEFAULT_MAX_LINE_BYTES).await
77}
78
79pub async fn write_json_line<W, T>(writer: &mut W, value: &T) -> std::io::Result<()>
80where
81 W: AsyncWrite + Unpin,
82 T: Serialize,
83{
84 let json = serde_json::to_string(value)
85 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
86 writer.write_all(json.as_bytes()).await?;
87 writer.write_all(b"\n").await?;
88 Ok(())
89}
90
91pub async fn serve_jsonl_connection<Req, Resp, F, Fut, InvalidResp>(
92 stream: UnixStream,
93 handler: F,
94 invalid_response: InvalidResp,
95) -> std::io::Result<()>
96where
97 Req: DeserializeOwned,
98 Resp: Serialize,
99 F: Fn(Req) -> Fut,
100 Fut: Future<Output = Resp>,
101 InvalidResp: Fn(std::io::Error) -> Resp,
102{
103 let (reader, mut writer) = stream.into_split();
104 let mut reader = BufReader::new(reader);
105
106 loop {
107 let Some(req) = (match read_json_line::<_, Req>(&mut reader).await {
108 Ok(v) => v,
109 Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
110 let resp = invalid_response(e);
111 let _ = write_json_line(&mut writer, &resp).await;
112 continue;
113 }
114 Err(e) => return Err(e),
115 }) else {
116 break;
117 };
118
119 let resp = handler(req).await;
120 write_json_line(&mut writer, &resp).await?;
121 }
122
123 Ok(())
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use tokio::io::BufReader;
130
131 #[tokio::test]
132 async fn roundtrips_struct_over_jsonl() {
133 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
134 struct Msg {
135 kind: String,
136 n: u64,
137 }
138
139 let (a, b) = tokio::io::duplex(1024);
140 let (mut ar, mut aw) = tokio::io::split(a);
141 let (mut br, mut bw) = tokio::io::split(b);
142 let mut ar = BufReader::new(&mut ar);
143 let mut br = BufReader::new(&mut br);
144
145 let a_send = Msg {
146 kind: "hello".to_string(),
147 n: 42,
148 };
149 write_json_line(&mut aw, &a_send).await.unwrap();
150 let b_recv: Msg = read_json_line(&mut br).await.unwrap().unwrap();
151 assert_eq!(b_recv, a_send);
152
153 let b_send = Msg {
154 kind: "world".to_string(),
155 n: 7,
156 };
157 write_json_line(&mut bw, &b_send).await.unwrap();
158 let a_recv: Msg = read_json_line(&mut ar).await.unwrap().unwrap();
159 assert_eq!(a_recv, b_send);
160 }
161
162 #[tokio::test]
163 async fn returns_invalid_data_on_bad_json() {
164 let (a, b) = tokio::io::duplex(1024);
165 let (mut _ar, mut aw) = tokio::io::split(a);
166 let (mut br, _bw) = tokio::io::split(b);
167 let mut br = BufReader::new(&mut br);
168
169 aw.write_all(b"{not json}\n").await.unwrap();
170
171 let err = read_json_line::<_, serde_json::Value>(&mut br)
172 .await
173 .unwrap_err();
174 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
175 }
176
177 #[tokio::test]
178 async fn errors_when_line_exceeds_limit() {
179 let (a, b) = tokio::io::duplex(1024 * 1024);
180 let (mut _ar, mut aw) = tokio::io::split(a);
181 let (mut br, _bw) = tokio::io::split(b);
182 let mut br = BufReader::new(&mut br);
183
184 let big = "a".repeat(33);
186 aw.write_all(big.as_bytes()).await.unwrap();
187 aw.write_all(b"\n").await.unwrap();
188
189 let err = read_json_line_with_limit::<_, serde_json::Value>(&mut br, 32)
190 .await
191 .unwrap_err();
192 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
193 }
194
195 #[tokio::test]
196 async fn serve_jsonl_connection_handles_invalid_and_valid_requests() {
197 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
198 struct Req {
199 n: u64,
200 }
201 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
202 struct Resp {
203 ok: bool,
204 n: u64,
205 }
206
207 let (a, b) = UnixStream::pair().unwrap();
208 let h = tokio::spawn(async move {
209 serve_jsonl_connection(
210 a,
211 |req: Req| async move { Resp { ok: true, n: req.n } },
212 |_e| Resp { ok: false, n: 0 },
213 )
214 .await
215 .unwrap();
216 });
217
218 let (r, mut w) = b.into_split();
219 let mut r = BufReader::new(r);
220
221 w.write_all(b"{not json}\n").await.unwrap();
223 let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
224 assert_eq!(resp, Resp { ok: false, n: 0 });
225
226 write_json_line(&mut w, &Req { n: 7 }).await.unwrap();
228 let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
229 assert_eq!(resp, Resp { ok: true, n: 7 });
230
231 drop(w);
232 h.await.unwrap();
233 }
234}