tower_lsp_f/
transport.rs

1//! Generic server for multiplexing bidirectional streams through a transport.
2
3#[cfg(feature = "runtime-agnostic")]
4use async_codec_lite::{FramedRead, FramedWrite};
5#[cfg(feature = "runtime-agnostic")]
6use futures::io::{AsyncRead, AsyncWrite};
7
8#[cfg(feature = "runtime-tokio")]
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "runtime-tokio")]
11use tokio_util::codec::{FramedRead, FramedWrite};
12
13use futures::channel::mpsc;
14use futures::{
15    FutureExt, Sink, SinkExt, Stream, StreamExt, TryFutureExt, future, join, stream, stream_select,
16};
17use tower::Service;
18use tracing::error;
19
20use crate::codec::{LanguageServerCodec, ParseError};
21use crate::jsonrpc::{Error, Id, Message, Request, Response};
22use crate::service::{ClientSocket, RequestStream, ResponseSink};
23
24const DEFAULT_MAX_CONCURRENCY: usize = 4;
25const MESSAGE_QUEUE_SIZE: usize = 100;
26
27/// Trait implemented by client loopback sockets.
28///
29/// This socket handles the server-to-client half of the bidirectional communication stream.
30pub trait Loopback {
31    /// Yields a stream of pending server-to-client requests.
32    type RequestStream: Stream<Item = Request> + Unpin;
33    /// Routes client-to-server responses back to the server.
34    type ResponseSink: Sink<Response> + Unpin;
35
36    /// Splits this socket into two halves capable of operating independently.
37    ///
38    /// The two halves returned implement the [`Stream`] and [`Sink`] traits, respectively.
39    fn split(self) -> (Self::RequestStream, Self::ResponseSink);
40}
41
42impl Loopback for ClientSocket {
43    type RequestStream = RequestStream;
44    type ResponseSink = ResponseSink;
45
46    #[inline]
47    fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
48        self.split()
49    }
50}
51
52/// Server for processing requests and responses on standard I/O or TCP.
53#[derive(Debug)]
54pub struct Server<I, O, L = ClientSocket> {
55    stdin: I,
56    stdout: O,
57    loopback: L,
58    max_concurrency: usize,
59}
60
61impl<I, O, L> Server<I, O, L>
62where
63    I: AsyncRead + Unpin,
64    O: AsyncWrite,
65    L: Loopback,
66    <L::ResponseSink as Sink<Response>>::Error: std::error::Error,
67{
68    /// Creates a new `Server` with the given `stdin` and `stdout` handles.
69    pub fn new(stdin: I, stdout: O, socket: L) -> Self {
70        Server {
71            stdin,
72            stdout,
73            loopback: socket,
74            max_concurrency: DEFAULT_MAX_CONCURRENCY,
75        }
76    }
77
78    /// Sets the server concurrency limit to `max`.
79    ///
80    /// This setting specifies how many incoming requests may be processed concurrently. Setting
81    /// this value to `1` forces all requests to be processed sequentially, thereby implicitly
82    /// disabling support for the [`$/cancelRequest`] notification.
83    ///
84    /// [`$/cancelRequest`]: https://microsoft.github.io/language-server-protocol/specification#cancelRequest
85    ///
86    /// If not explicitly specified, `max` defaults to 4.
87    ///
88    /// # Preference over standard `tower` middleware
89    ///
90    /// The [`ConcurrencyLimit`] and [`Buffer`] middlewares provided by `tower` rely on
91    /// [`tokio::spawn`] in common usage, while this library aims to be executor agnostic and to
92    /// support exotic targets currently incompatible with `tokio`, such as WASM. As such, `Server`
93    /// includes its own concurrency facilities that don't require a global executor to be present.
94    ///
95    /// [`ConcurrencyLimit`]: https://docs.rs/tower/latest/tower/limit/concurrency/struct.ConcurrencyLimit.html
96    /// [`Buffer`]: https://docs.rs/tower/latest/tower/buffer/index.html
97    /// [`tokio::spawn`]: https://docs.rs/tokio/latest/tokio/fn.spawn.html
98    pub fn concurrency_level(mut self, max: usize) -> Self {
99        self.max_concurrency = max;
100        self
101    }
102
103    /// Spawns the service with messages read through `stdin` and responses written to `stdout`.
104    pub async fn serve<T>(self, mut service: T)
105    where
106        T: Service<Request, Response = Option<Response>> + Send + 'static,
107        T::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
108        T::Future: Send,
109    {
110        let (client_requests, mut client_responses) = self.loopback.split();
111        let (client_requests, client_abort) = stream::abortable(client_requests);
112        let (mut responses_tx, responses_rx) = mpsc::channel(0);
113        let (mut server_tasks_tx, server_tasks_rx) = mpsc::channel(MESSAGE_QUEUE_SIZE);
114
115        let mut framed_stdin = FramedRead::new(self.stdin, LanguageServerCodec::default());
116        let framed_stdout = FramedWrite::new(self.stdout, LanguageServerCodec::default());
117
118        let process_server_tasks = server_tasks_rx
119            .buffer_unordered(self.max_concurrency)
120            .filter_map(future::ready)
121            .map(|res| Ok(Message::Response(res)))
122            .forward(responses_tx.clone().sink_map_err(|_| unreachable!()))
123            .map(|_| ());
124
125        let print_output = stream_select!(responses_rx, client_requests.map(Message::Request))
126            .map(Ok)
127            .forward(framed_stdout.sink_map_err(|e| error!("failed to encode message: {}", e)))
128            .map(|_| ());
129
130        let read_input = async {
131            while let Some(msg) = framed_stdin.next().await {
132                match msg {
133                    Ok(Message::Request(req)) => {
134                        if let Err(err) = future::poll_fn(|cx| service.poll_ready(cx)).await {
135                            error!("{}", display_sources(err.into().as_ref()));
136                            return;
137                        }
138
139                        let fut = service.call(req).unwrap_or_else(|err| {
140                            error!("{}", display_sources(err.into().as_ref()));
141                            None
142                        });
143
144                        server_tasks_tx.send(fut).await.unwrap();
145                    }
146                    Ok(Message::Response(res)) => {
147                        if let Err(err) = client_responses.send(res).await {
148                            error!("{}", display_sources(&err));
149                            return;
150                        }
151                    }
152                    Err(err) => {
153                        error!("failed to decode message: {}", err);
154                        let res = Response::from_error(Id::Null, to_jsonrpc_error(err));
155                        responses_tx.send(Message::Response(res)).await.unwrap();
156                    }
157                }
158            }
159
160            server_tasks_tx.disconnect();
161            responses_tx.disconnect();
162            client_abort.abort();
163        };
164
165        join!(print_output, read_input, process_server_tasks);
166    }
167}
168
169fn display_sources(error: &dyn std::error::Error) -> String {
170    if let Some(source) = error.source() {
171        format!("{}: {}", error, display_sources(source))
172    } else {
173        error.to_string()
174    }
175}
176
177#[cfg(feature = "runtime-tokio")]
178fn to_jsonrpc_error(err: ParseError) -> Error {
179    match err {
180        ParseError::Body(err) if err.is_data() => Error::invalid_request(),
181        _ => Error::parse_error(),
182    }
183}
184
185#[cfg(feature = "runtime-agnostic")]
186fn to_jsonrpc_error(err: impl std::error::Error) -> Error {
187    match err.source().and_then(|e| e.downcast_ref()) {
188        Some(ParseError::Body(err)) if err.is_data() => Error::invalid_request(),
189        _ => Error::parse_error(),
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use std::task::{Context, Poll};
196
197    #[cfg(feature = "runtime-agnostic")]
198    use futures::io::Cursor;
199    #[cfg(feature = "runtime-tokio")]
200    use std::io::Cursor;
201
202    use futures::future::Ready;
203    use futures::{future, sink, stream};
204
205    use super::*;
206
207    const REQUEST: &str = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
208    const RESPONSE: &str = r#"{"jsonrpc":"2.0","result":{"capabilities":{}},"id":1}"#;
209
210    #[derive(Debug)]
211    struct MockService;
212
213    impl Service<Request> for MockService {
214        type Response = Option<Response>;
215        type Error = String;
216        type Future = Ready<Result<Self::Response, Self::Error>>;
217
218        fn poll_ready(&mut self, _: &mut Context) -> Poll<Result<(), Self::Error>> {
219            Poll::Ready(Ok(()))
220        }
221
222        fn call(&mut self, _: Request) -> Self::Future {
223            let response = serde_json::from_str(RESPONSE).unwrap();
224            future::ok(Some(response))
225        }
226    }
227
228    struct MockLoopback(Vec<Request>);
229
230    impl Loopback for MockLoopback {
231        type RequestStream = stream::Iter<std::vec::IntoIter<Request>>;
232        type ResponseSink = sink::Drain<Response>;
233
234        fn split(self) -> (Self::RequestStream, Self::ResponseSink) {
235            (stream::iter(self.0), sink::drain())
236        }
237    }
238
239    fn mock_request() -> Vec<u8> {
240        format!("Content-Length: {}\r\n\r\n{}", REQUEST.len(), REQUEST).into_bytes()
241    }
242
243    fn mock_response() -> Vec<u8> {
244        format!("Content-Length: {}\r\n\r\n{}", RESPONSE.len(), RESPONSE).into_bytes()
245    }
246
247    fn mock_stdio() -> (Cursor<Vec<u8>>, Vec<u8>) {
248        (Cursor::new(mock_request()), Vec::new())
249    }
250
251    #[tokio::test(flavor = "current_thread")]
252    async fn serves_on_stdio() {
253        let (mut stdin, mut stdout) = mock_stdio();
254        Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
255            .serve(MockService)
256            .await;
257
258        assert_eq!(stdin.position(), 80);
259        assert_eq!(stdout, mock_response());
260    }
261
262    #[tokio::test(flavor = "current_thread")]
263    async fn interleaves_messages() {
264        let socket = MockLoopback(vec![serde_json::from_str(REQUEST).unwrap()]);
265
266        let (mut stdin, mut stdout) = mock_stdio();
267        Server::new(&mut stdin, &mut stdout, socket)
268            .serve(MockService)
269            .await;
270
271        assert_eq!(stdin.position(), 80);
272        let output: Vec<_> = mock_request().into_iter().chain(mock_response()).collect();
273        assert_eq!(stdout, output);
274    }
275
276    #[tokio::test(flavor = "current_thread")]
277    async fn handles_invalid_json() {
278        let invalid = r#"{"jsonrpc":"2.0","method":"#;
279        let message = format!("Content-Length: {}\r\n\r\n{}", invalid.len(), invalid).into_bytes();
280        let (mut stdin, mut stdout) = (Cursor::new(message), Vec::new());
281
282        Server::new(&mut stdin, &mut stdout, MockLoopback(vec![]))
283            .serve(MockService)
284            .await;
285
286        assert_eq!(stdin.position(), 48);
287        let err = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":null}"#;
288        let output = format!("Content-Length: {}\r\n\r\n{}", err.len(), err).into_bytes();
289        assert_eq!(stdout, output);
290    }
291}