tower_lsp_server/
transport.rs

1//! Generic server for multiplexing bidirectional streams through a transport.
2
3#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{
15    FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt, future, join, stream, stream_select,
16};
17use tower::Service;
18use tracing::error;
19
20use crate::codec::{LanguageServerCodec, ParseError};
21use crate::jsonrpc::{Error, Id, Message, Request, Response};
22use crate::service::{ClientSocket, RequestStream, ResponseSink};
23
24const DEFAULT_MAX_CONCURRENCY: usize = 4;
25const MESSAGE_QUEUE_SIZE: usize = 100;
26
27/// Trait implemented by client loopback sockets.
28///
29/// This socket handles the server-to-client half of the bidirectional communication stream.
30pub trait Loopback {
31    /// Yields a stream of pending server-to-client requests.
32    type RequestStream: Stream<Item = Request> + Unpin;
33    /// Routes client-to-server responses back to the server.
34    type ResponseSink: Sink<Response> + Unpin;
35
36    /// Splits this socket into two halves capable of operating independently.
37    ///
38    /// The two halves returned implement the [`Stream`] and [`Sink`] traits, respectively.
39    fn split(self) -> (Self::RequestStream, Self::ResponseSink);
40}
41
42impl Loopback for ClientSocket {
43    type RequestStream = RequestStream;
44    type ResponseSink = ResponseSink;
45
46    #[inline]
47    fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
48        self.split()
49    }
50}
51
52/// Server for processing requests and responses on standard I/O or TCP.
53#[derive(Debug)]
54pub struct Server<I, O, L = ClientSocket> {
55    stdin: I,
56    stdout: O,
57    loopback: L,
58    max_concurrency: usize,
59}
60
61impl<I, O, L> Server<I, O, L>
62where
63    I: AsyncRead + Unpin,
64    O: AsyncWrite,
65    L: Loopback,
66    <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
67{
68    /// Creates a new `Server` with the given `stdin` and `stdout` handles.
69    pub const fn new(stdin: I, stdout: O, socket: L) -> Self {
70        Self {
71            stdin,
72            stdout,
73            loopback: socket,
74            max_concurrency: DEFAULT_MAX_CONCURRENCY,
75        }
76    }
77
78    /// Sets the server concurrency limit to `max`.
79    ///
80    /// This setting specifies how many incoming requests may be processed concurrently. Setting
81    /// this value to `1` forces all requests to be processed sequentially, thereby implicitly
82    /// disabling support for the [`$/cancelRequest`] notification.
83    ///
84    /// [`$/cancelRequest`]: https://microsoft.github.io/language-server-protocol/specification#cancelRequest
85    ///
86    /// If not explicitly specified, `max` defaults to 4.
87    ///
88    /// # Preference over standard `tower` middleware
89    ///
90    /// The [`ConcurrencyLimit`] and [`Buffer`] middlewares provided by `tower` rely on
91    /// [`tokio::spawn`] in common usage, while this library aims to be executor agnostic and to
92    /// support exotic targets currently incompatible with `tokio`, such as WASM. As such, `Server`
93    /// includes its own concurrency facilities that don't require a global executor to be present.
94    ///
95    /// [`ConcurrencyLimit`]: https://docs.rs/tower/latest/tower/limit/concurrency/struct.ConcurrencyLimit.html
96    /// [`Buffer`]: https://docs.rs/tower/latest/tower/buffer/index.html
97    /// [`tokio::spawn`]: https://docs.rs/tokio/latest/tokio/fn.spawn.html
98    #[must_use]
99    pub const fn concurrency_level(mut self, max: usize) -> Self {
100        self.max_concurrency = max;
101        self
102    }
103
104    /// Spawns the service with messages read through `stdin` and responses written to `stdout`.
105    pub async fn serve<T>(self, mut service: T)
106    where
107        T: Service<Request, Response = Option<Response>> + Send + 'static,
108        T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
109        T::Future: Send,
110    {
111        let (client_requests, mut client_responses) = self.loopback.split();
112        let (client_requests, client_abort) = stream::abortable(client_requests);
113        let (mut responses_tx, responses_rx) = mpsc::channel(0);
114        let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
115
116        let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
117        let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
118
119        let process_server_tasks = server_tasks_rx
120            .buffer_unordered(self.max_concurrency)
121            .filter_map(future::ready)
122            .map(|res| Ok(Message::Response(res)))
123            .forward(responses_tx.clone());
124
125        let print_output = stream_select!(responses_rx, client_requests.map(Message::Request))
126            .map(Ok)
127            .forward(framed_stdout);
128
129        let read_input = async {
130            while let Some(msg) = framed_stdin.next().await {
131                match msg {
132                    Ok(Message::Request(req)) => {
133                        if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
134                            error!("{}", display_sources(err.into().as_ref()));
135                            return;
136                        }
137
138                        // per specifications, the server must exit immediately after the exit notification
139                        // some clients will not close stdin and thus keep framed_stdin waiting forever
140                        // we break early here so that control can be yielded back immediately
141                        let will_exit = req.method() == "exit";
142
143                        let fut = service.call(req).unwrap_or_else(|err| {
144                            error!("{}", display_sources(err.into().as_ref()));
145                            None
146                        });
147
148                        let _ = server_tasks_tx.send(fut).await;
149
150                        if will_exit {
151                            break;
152                        }
153                    }
154                    Ok(Message::Response(res)) => {
155                        if let Err(err) = client_responses.send(res).await {
156                            error!("{}", display_sources(&err));
157                            return;
158                        }
159                    }
160                    Err(err) => {
161                        error!("failed to decode message: {}", err);
162                        let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
163                        let _ = responses_tx.send(Message::Response(res)).await;
164                    }
165                }
166            }
167
168            server_tasks_tx.disconnect();
169            responses_tx.disconnect();
170            client_abort.abort();
171        };
172
173        join!(
174            process_server_tasks.map(|_| ()),
175            print_output.map(|_| ()),
176            read_input
177        );
178    }
179}
180
181fn display_sources(error: &dyn std::error::Error) -> String {
182    error.source().map_or_else(
183        || error.to_string(),
184        |source| format!("{}: {}", error, display_sources(source)),
185    )
186}
187
188#[cfg(feature = "runtime-tokio")]
189fn to_jsonrpc_error(err: ParseError) -> Error {
190    match err {
191        ParseError::Body(err) if err.is_data() => Error::invalid_request(),
192        _ => Error::parse_error(),
193    }
194}
195
196#[cfg(feature = "runtime-agnostic")]
197fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
198    match err.source().and_then(|e| e.downcast_ref()) {
199        Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
200        _ => Error::parse_error(),
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::task::{Context, Poll};
207
208    #[cfg(feature = "runtime-agnostic")]
209    use futures::io::Cursor;
210    #[cfg(feature = "runtime-tokio")]
211    use std::io::Cursor;
212
213    use futures::future::Ready;
214    use futures::{future, sink, stream};
215
216    use super::*;
217
218    const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
219    const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
220
221    #[derive(Debug)]
222    struct MockService;
223
224    impl Service<Request> for MockService {
225        type Response = Option<Response>;
226        type Error = String;
227        type Future = Ready<Result<Self::Response, Self::Error>>;
228
229        fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
230            Poll::Ready(Ok(()))
231        }
232
233        fn call(&mut self, _: Request) -> Self::Future {
234            let response = serde_json::from_str(RESPONSE).unwrap();
235            future::ok(Some(response))
236        }
237    }
238
239    struct MockLoopback(Vec<Request>);
240
241    impl Loopback for MockLoopback {
242        type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
243        type ResponseSink = sink::Drain<Response>;
244
245        fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
246            (stream::iter(self.0), sink::drain())
247        }
248    }
249
250    fn mock_request() -> Vec<u8> {
251        format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
252    }
253
254    fn mock_response() -> Vec<u8> {
255        format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
256    }
257
258    fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
259        (Cursor::new(mock_request()), Vec::new())
260    }
261
262    // Simulates a still-live stdin that the client didn't drop.
263    struct DetachedCursor(Vec<u8>);
264
265    #[cfg(feature = "runtime-tokio")]
266    impl AsyncRead for DetachedCursor {
267        fn poll_read(
268            mut self: std::pin::Pin<&mut Self>,
269            _: &mut Context<'_>,
270            buf: &mut tokio::io::ReadBuf<'_>,
271        ) -> Poll<std::io::Result<()>> {
272            if self.0.is_empty() {
273                return Poll::Pending;
274            }
275
276            buf.put_slice(&self.0);
277            self.0.clear();
278            Poll::Ready(Ok(()))
279        }
280    }
281
282    #[cfg(feature = "runtime-agnostic")]
283    impl AsyncRead for DetachedCursor {
284        fn poll_read(
285            mut self: std::pin::Pin<&mut Self>,
286            _: &mut Context<'_>,
287            buf: &mut [u8],
288        ) -> Poll<std::io::Result<usize>> {
289            if self.0.is_empty() {
290                return Poll::Pending;
291            }
292
293            let len = std::cmp::min(buf.len(), self.0.len());
294            let after = self.0.split_off(len);
295            buf[..len].copy_from_slice(&self.0);
296            self.0 = after;
297            Poll::Ready(Ok(len))
298        }
299    }
300
301    #[tokio::test(flavor = "current_thread")]
302    async fn serves_on_stdio() {
303        let (mut stdin, mut stdout) = mock_stdio();
304        Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
305            .serve(MockService)
306            .await;
307
308        assert_eq!(stdin.position(), 80);
309        assert_eq!(stdout, mock_response());
310    }
311
312    #[tokio::test(flavor = "current_thread")]
313    async fn interleaves_messages() {
314        let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
315
316        let (mut stdin, mut stdout) = mock_stdio();
317        Server::new(&mut stdin, &mut stdout, socket)
318            .serve(MockService)
319            .await;
320
321        assert_eq!(stdin.position(), 80);
322        let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
323        assert_eq!(stdout, output);
324    }
325
326    #[tokio::test(flavor = "current_thread")]
327    async fn handles_invalid_json() {
328        let invalid = r#"{"jsonrpc":"2.0","method":"#;
329        let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
330        let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
331
332        Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
333            .serve(MockService)
334            .await;
335
336        assert_eq!(stdin.position(), 48);
337        let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
338        let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
339        assert_eq!(stdout, output);
340    }
341
342    #[tokio::test(flavor = "current_thread")]
343    async fn stops_promptly_after_exit_notification() {
344        let req = r#"{"jsonrpc":"2.0","method":"exit"}"#;
345        let message = format!("Content-Length: {}\r\n\r\n{}", req.len(), req).into_bytes();
346        let (mut stdin, mut stdout) = (DetachedCursor(message), Vec::new());
347
348        assert!(
349            tokio::time::timeout(
350                std::time::Duration::from_secs(1),
351                Server::new(&mut stdin, &mut stdout, MockLoopback(vec![])).serve(MockService)
352            )
353            .await
354            .is_ok(),
355            "waited for more than 1 second for exit"
356        );
357    }
358}