1#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{
15 FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt, future, join, stream, stream_select,
16};
17use tower::Service;
18use tracing::error;
19
20use crate::codec::{LanguageServerCodec, ParseError};
21use crate::jsonrpc::{Error, Id, Message, Request, Response};
22use crate::service::{ClientSocket, RequestStream, ResponseSink};
23
24const DEFAULT_MAX_CONCURRENCY: usize = 4;
25const MESSAGE_QUEUE_SIZE: usize = 100;
26
27pub trait Loopback {
31 type RequestStream: Stream<Item = Request> + Unpin;
33 type ResponseSink: Sink<Response> + Unpin;
35
36 fn split(self) -> (Self::RequestStream, Self::ResponseSink);
40}
41
42impl Loopback for ClientSocket {
43 type RequestStream = RequestStream;
44 type ResponseSink = ResponseSink;
45
46 #[inline]
47 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
48 self.split()
49 }
50}
51
52#[derive(Debug)]
54pub struct Server<I, O, L = ClientSocket> {
55 stdin: I,
56 stdout: O,
57 loopback: L,
58 max_concurrency: usize,
59}
60
61impl<I, O, L> Server<I, O, L>
62where
63 I: AsyncRead + Unpin,
64 O: AsyncWrite,
65 L: Loopback,
66 <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
67{
68 pub fn new(stdin: I, stdout: O, socket: L) -> Self {
70 Server {
71 stdin,
72 stdout,
73 loopback: socket,
74 max_concurrency: DEFAULT_MAX_CONCURRENCY,
75 }
76 }
77
78 pub fn concurrency_level(mut self, max: usize) -> Self {
99 self.max_concurrency = max;
100 self
101 }
102
103 pub async fn serve<T>(self, mut service: T)
105 where
106 T: Service<Request, Response = Option<Response>> + Send + 'static,
107 T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
108 T::Future: Send,
109 {
110 let (client_requests, mut client_responses) = self.loopback.split();
111 let (client_requests, client_abort) = stream::abortable(client_requests);
112 let (mut responses_tx, responses_rx) = mpsc::channel(0);
113 let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
114
115 let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
116 let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
117
118 let process_server_tasks = server_tasks_rx
119 .buffer_unordered(self.max_concurrency)
120 .filter_map(future::ready)
121 .map(|res| Ok(Message::Response(res)))
122 .forward(responses_tx.clone().sink_map_err(|_| unreachable!()))
123 .map(|_| ());
124
125 let print_output = stream_select!(responses_rx, client_requests.map(Message::Request))
126 .map(Ok)
127 .forward(framed_stdout.sink_map_err(|e| error!("failed to encode message: {}", e)))
128 .map(|_| ());
129
130 let read_input = async {
131 while let Some(msg) = framed_stdin.next().await {
132 match msg {
133 Ok(Message::Request(req)) => {
134 if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
135 error!("{}", display_sources(err.into().as_ref()));
136 return;
137 }
138
139 let fut = service.call(req).unwrap_or_else(|err| {
140 error!("{}", display_sources(err.into().as_ref()));
141 None
142 });
143
144 server_tasks_tx.send(fut).await.unwrap();
145 }
146 Ok(Message::Response(res)) => {
147 if let Err(err) = client_responses.send(res).await {
148 error!("{}", display_sources(&err));
149 return;
150 }
151 }
152 Err(err) => {
153 error!("failed to decode message: {}", err);
154 let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
155 responses_tx.send(Message::Response(res)).await.unwrap();
156 }
157 }
158 }
159
160 server_tasks_tx.disconnect();
161 responses_tx.disconnect();
162 client_abort.abort();
163 };
164
165 join!(print_output, read_input, process_server_tasks);
166 }
167}
168
169fn display_sources(error: &dyn std::error::Error) -> String {
170 if let Some(source) = error.source() {
171 format!("{}: {}", error, display_sources(source))
172 } else {
173 error.to_string()
174 }
175}
176
177#[cfg(feature = "runtime-tokio")]
178fn to_jsonrpc_error(err: ParseError) -> Error {
179 match err {
180 ParseError::Body(err) if err.is_data() => Error::invalid_request(),
181 _ => Error::parse_error(),
182 }
183}
184
185#[cfg(feature = "runtime-agnostic")]
186fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
187 match err.source().and_then(|e| e.downcast_ref()) {
188 Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
189 _ => Error::parse_error(),
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use std::task::{Context, Poll};
196
197 #[cfg(feature = "runtime-agnostic")]
198 use futures::io::Cursor;
199 #[cfg(feature = "runtime-tokio")]
200 use std::io::Cursor;
201
202 use futures::future::Ready;
203 use futures::{future, sink, stream};
204
205 use super::*;
206
207 const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
208 const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
209
210 #[derive(Debug)]
211 struct MockService;
212
213 impl Service<Request> for MockService {
214 type Response = Option<Response>;
215 type Error = String;
216 type Future = Ready<Result<Self::Response, Self::Error>>;
217
218 fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
219 Poll::Ready(Ok(()))
220 }
221
222 fn call(&mut self, _: Request) -> Self::Future {
223 let response = serde_json::from_str(RESPONSE).unwrap();
224 future::ok(Some(response))
225 }
226 }
227
228 struct MockLoopback(Vec<Request>);
229
230 impl Loopback for MockLoopback {
231 type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
232 type ResponseSink = sink::Drain<Response>;
233
234 fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
235 (stream::iter(self.0), sink::drain())
236 }
237 }
238
239 fn mock_request() -> Vec<u8> {
240 format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
241 }
242
243 fn mock_response() -> Vec<u8> {
244 format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
245 }
246
247 fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
248 (Cursor::new(mock_request()), Vec::new())
249 }
250
251 #[tokio::test(flavor = "current_thread")]
252 async fn serves_on_stdio() {
253 let (mut stdin, mut stdout) = mock_stdio();
254 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
255 .serve(MockService)
256 .await;
257
258 assert_eq!(stdin.position(), 80);
259 assert_eq!(stdout, mock_response());
260 }
261
262 #[tokio::test(flavor = "current_thread")]
263 async fn interleaves_messages() {
264 let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
265
266 let (mut stdin, mut stdout) = mock_stdio();
267 Server::new(&mut stdin, &mut stdout, socket)
268 .serve(MockService)
269 .await;
270
271 assert_eq!(stdin.position(), 80);
272 let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
273 assert_eq!(stdout, output);
274 }
275
276 #[tokio::test(flavor = "current_thread")]
277 async fn handles_invalid_json() {
278 let invalid = r#"{"jsonrpc":"2.0","method":"#;
279 let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
280 let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
281
282 Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
283 .serve(MockService)
284 .await;
285
286 assert_eq!(stdin.position(), 48);
287 let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
288 let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
289 assert_eq!(stdout, output);
290 }
291}