1#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{
15 FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt, future, join, stream, stream_select,
16};
17use tower::Service;
18use tracing::error;
19
20use crate::codec::{LanguageServerCodec, ParseError};
21use crate::jsonrpc::{Error, Id, Message, Request, Response};
22use crate::service::{ClientSocket, RequestStream, ResponseSink};
23
24const DEFAULT_MAX_CONCURRENCY: usize = 4;
25const MESSAGE_QUEUE_SIZE: usize = 100;
26
27pub trait Loopback {
31 type RequestStream: Stream<Item = Request> + Unpin;
33 type ResponseSink: Sink<Response> + Unpin;
35
36 fn split(self) -> (Self::RequestStream, Self::ResponseSink);
40}
41
42impl Loopback for ClientSocket {
43 type RequestStream = RequestStream;
44 type ResponseSink = ResponseSink;
45
46 #[inline]
47 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
48 self.split()
49 }
50}
51
52#[derive(Debug)]
54pub struct Server<I, O, L = ClientSocket> {
55 stdin: I,
56 stdout: O,
57 loopback: L,
58 max_concurrency: usize,
59}
60
61impl<I, O, L> Server<I, O, L>
62where
63 I: AsyncRead + Unpin,
64 O: AsyncWrite,
65 L: Loopback,
66 <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
67{
68 pub const fn new(stdin: I, stdout: O, socket: L) -> Self {
70 Self {
71 stdin,
72 stdout,
73 loopback: socket,
74 max_concurrency: DEFAULT_MAX_CONCURRENCY,
75 }
76 }
77
78 #[must_use]
99 pub const fn concurrency_level(mut self, max: usize) -> Self {
100 self.max_concurrency = max;
101 self
102 }
103
104 pub async fn serve<T>(self, mut service: T)
106 where
107 T: Service<Request, Response = Option<Response>> + Send + 'static,
108 T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
109 T::Future: Send,
110 {
111 let (client_requests, mut client_responses) = self.loopback.split();
112 let (client_requests, client_abort) = stream::abortable(client_requests);
113 let (mut responses_tx, responses_rx) = mpsc::channel(0);
114 let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
115
116 let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
117 let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
118
119 let process_server_tasks = server_tasks_rx
120 .buffer_unordered(self.max_concurrency)
121 .filter_map(future::ready)
122 .map(|res| Ok(Message::Response(res)))
123 .forward(responses_tx.clone());
124
125 let print_output = stream_select!(responses_rx, client_requests.map(Message::Request))
126 .map(Ok)
127 .forward(framed_stdout);
128
129 let read_input = async {
130 while let Some(msg) = framed_stdin.next().await {
131 match msg {
132 Ok(Message::Request(req)) => {
133 if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
134 error!("{}", display_sources(err.into().as_ref()));
135 return;
136 }
137
138 let will_exit = req.method() == "exit";
142
143 let fut = service.call(req).unwrap_or_else(|err| {
144 error!("{}", display_sources(err.into().as_ref()));
145 None
146 });
147
148 let _ = server_tasks_tx.send(fut).await;
149
150 if will_exit {
151 break;
152 }
153 }
154 Ok(Message::Response(res)) => {
155 if let Err(err) = client_responses.send(res).await {
156 error!("{}", display_sources(&err));
157 return;
158 }
159 }
160 Err(err) => {
161 error!("failed to decode message: {}", err);
162 let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
163 let _ = responses_tx.send(Message::Response(res)).await;
164 }
165 }
166 }
167
168 server_tasks_tx.disconnect();
169 responses_tx.disconnect();
170 client_abort.abort();
171 };
172
173 join!(
174 process_server_tasks.map(|_| ()),
175 print_output.map(|_| ()),
176 read_input
177 );
178 }
179}
180
181fn display_sources(error: &dyn std::error::Error) -> String {
182 error.source().map_or_else(
183 || error.to_string(),
184 |source| format!("{}: {}", error, display_sources(source)),
185 )
186}
187
188#[cfg(feature = "runtime-tokio")]
189fn to_jsonrpc_error(err: ParseError) -> Error {
190 match err {
191 ParseError::Body(err) if err.is_data() => Error::invalid_request(),
192 _ => Error::parse_error(),
193 }
194}
195
196#[cfg(feature = "runtime-agnostic")]
197fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
198 match err.source().and_then(|e| e.downcast_ref()) {
199 Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
200 _ => Error::parse_error(),
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::task::{Context, Poll};
207
208 #[cfg(feature = "runtime-agnostic")]
209 use futures::io::Cursor;
210 #[cfg(feature = "runtime-tokio")]
211 use std::io::Cursor;
212
213 use futures::future::Ready;
214 use futures::{future, sink, stream};
215
216 use super::*;
217
218 const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
219 const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
220
221 #[derive(Debug)]
222 struct MockService;
223
224 impl Service<Request> for MockService {
225 type Response = Option<Response>;
226 type Error = String;
227 type Future = Ready<Result<Self::Response, Self::Error>>;
228
229 fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
230 Poll::Ready(Ok(()))
231 }
232
233 fn call(&mut self, _: Request) -> Self::Future {
234 let response = serde_json::from_str(RESPONSE).unwrap();
235 future::ok(Some(response))
236 }
237 }
238
239 struct MockLoopback(Vec<Request>);
240
241 impl Loopback for MockLoopback {
242 type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
243 type ResponseSink = sink::Drain<Response>;
244
245 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
246 (stream::iter(self.0), sink::drain())
247 }
248 }
249
250 fn mock_request() -> Vec<u8> {
251 format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
252 }
253
254 fn mock_response() -> Vec<u8> {
255 format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
256 }
257
258 fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
259 (Cursor::new(mock_request()), Vec::new())
260 }
261
262 struct DetachedCursor(Vec<u8>);
264
265 #[cfg(feature = "runtime-tokio")]
266 impl AsyncRead for DetachedCursor {
267 fn poll_read(
268 mut self: std::pin::Pin<&mut Self>,
269 _: &mut Context<'_>,
270 buf: &mut tokio::io::ReadBuf<'_>,
271 ) -> Poll<std::io::Result<()>> {
272 if self.0.is_empty() {
273 return Poll::Pending;
274 }
275
276 buf.put_slice(&self.0);
277 self.0.clear();
278 Poll::Ready(Ok(()))
279 }
280 }
281
282 #[cfg(feature = "runtime-agnostic")]
283 impl AsyncRead for DetachedCursor {
284 fn poll_read(
285 mut self: std::pin::Pin<&mut Self>,
286 _: &mut Context<'_>,
287 buf: &mut [u8],
288 ) -> Poll<std::io::Result<usize>> {
289 if self.0.is_empty() {
290 return Poll::Pending;
291 }
292
293 let len = std::cmp::min(buf.len(), self.0.len());
294 let after = self.0.split_off(len);
295 buf[..len].copy_from_slice(&self.0);
296 self.0 = after;
297 Poll::Ready(Ok(len))
298 }
299 }
300
301 #[tokio::test(flavor = "current_thread")]
302 async fn serves_on_stdio() {
303 let (mut stdin, mut stdout) = mock_stdio();
304 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
305 .serve(MockService)
306 .await;
307
308 assert_eq!(stdin.position(), 80);
309 assert_eq!(stdout, mock_response());
310 }
311
312 #[tokio::test(flavor = "current_thread")]
313 async fn interleaves_messages() {
314 let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
315
316 let (mut stdin, mut stdout) = mock_stdio();
317 Server::new(&mut stdin, &mut stdout, socket)
318 .serve(MockService)
319 .await;
320
321 assert_eq!(stdin.position(), 80);
322 let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
323 assert_eq!(stdout, output);
324 }
325
326 #[tokio::test(flavor = "current_thread")]
327 async fn handles_invalid_json() {
328 let invalid = r#"{"jsonrpc":"2.0","method":"#;
329 let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
330 let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
331
332 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
333 .serve(MockService)
334 .await;
335
336 assert_eq!(stdin.position(), 48);
337 let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
338 let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
339 assert_eq!(stdout, output);
340 }
341
342 #[tokio::test(flavor = "current_thread")]
343 async fn stops_promptly_after_exit_notification() {
344 let req = r#"{"jsonrpc":"2.0","method":"exit"}"#;
345 let message = format!("Content-Length: {}\r\n\r\n{}", req.len(), req).into_bytes();
346 let (mut stdin, mut stdout) = (DetachedCursor(message), Vec::new());
347
348 assert!(
349 tokio::time::timeout(
350 std::time::Duration::from_secs(1),
351 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![])).serve(MockService)
352 )
353 .await
354 .is_ok(),
355 "waited for more than 1 second for exit"
356 );
357 }
358}