tower_lsp_server/
service.rs

1//! Service abstraction for language servers.
2
3pub use self::client::{Client, ClientSocket, RequestStream, ResponseSink, progress};
4
5pub use self::pending::Pending;
6pub use self::state::{ServerState, State};
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use futures::future::{self, BoxFuture, FutureExt};
13use lsp_types::LSPAny;
14use tower::Service;
15
16use crate::LanguageServer;
17use crate::jsonrpc::{
18    Error, ErrorCode, FromParams, IntoResponse, Method, Request, Response, Router,
19};
20
21pub mod layers;
22
23mod client;
24mod pending;
25mod state;
26
27/// Error that occurs when attempting to call the language server after it has already exited.
28#[derive(Clone, Debug, Eq, PartialEq)]
29pub struct ExitedError(());
30
31impl std::error::Error for ExitedError {}
32
33impl Display for ExitedError {
34    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
35        f.write_str("language server has exited")
36    }
37}
38
39/// Service abstraction for the Language Server Protocol.
40///
41/// This service takes an incoming JSON-RPC message as input and produces an outgoing message as
42/// output. If the incoming message is a server notification or a client response, then the
43/// corresponding response will be `None`.
44///
45/// This implements [`tower::Service`] in order to remain independent from the underlying transport
46/// and to facilitate further abstraction with middleware.
47///
48/// Pending requests can be canceled by issuing a [`$/cancelRequest`] notification.
49///
50/// [`$/cancelRequest`]: https://microsoft.github.io/language-server-protocol/specification#cancelRequest
51///
52/// The service shuts down and stops serving requests after the [`exit`] notification is received.
53///
54/// [`exit`]: https://microsoft.github.io/language-server-protocol/specification#exit
55#[derive(Debug)]
56pub struct LspService<S> {
57    inner: Router<S, ExitedError>,
58    state: Arc<ServerState>,
59}
60
61impl<S: LanguageServer> LspService<S> {
62    /// Creates a new `LspService` with the given server backend, also returning a channel for
63    /// server-to-client communication.
64    pub fn new<F>(init: F) -> (Self, ClientSocket)
65    where
66        F: FnOnce(Client) -> S,
67    {
68        Self::build(init).finish()
69    }
70
71    /// Starts building a new `LspService`.
72    ///
73    /// Returns an `LspServiceBuilder`, which allows adding custom JSON-RPC methods to the server.
74    pub fn build<F>(init: F) -> LspServiceBuilder<S>
75    where
76        F: FnOnce(Client) -> S,
77    {
78        let state = Arc::new(ServerState::new());
79
80        let (client, socket) = Client::new(state.clone());
81        let inner = Router::new(init(client.clone()));
82        let pending = Arc::new(Pending::new());
83
84        LspServiceBuilder {
85            inner: crate::server::generated::register_lsp_methods(
86                inner,
87                state.clone(),
88                pending.clone(),
89                client,
90            ),
91            state,
92            pending,
93            socket,
94        }
95    }
96
97    /// Returns a reference to the inner server.
98    #[must_use]
99    pub fn inner(&self) -> &S {
100        self.inner.inner()
101    }
102}
103
104impl<S: LanguageServer> Service<Request> for LspService<S> {
105    type Response = Option<Response>;
106    type Error = ExitedError;
107    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
108
109    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        match self.state.get() {
111            State::Initializing => Poll::Pending,
112            State::Exited => Poll::Ready(Err(ExitedError(()))),
113            _ => self.inner.poll_ready(cx),
114        }
115    }
116
117    fn call(&mut self, req: Request) -> Self::Future {
118        if self.state.get() == State::Exited {
119            return future::err(ExitedError(())).boxed();
120        }
121
122        let fut = self.inner.call(req);
123
124        Box::pin(async move {
125            let response = fut.await?;
126
127            match response.as_ref().and_then(|res| res.error()) {
128                Some(Error {
129                    code: ErrorCode::MethodNotFound,
130                    data: Some(LSPAny::String(m)),
131                    ..
132                }) if m.starts_with("$/") => Ok(None),
133                _ => Ok(response),
134            }
135        })
136    }
137}
138
139/// A builder to customize the properties of an `LspService`.
140///
141/// To construct an `LspServiceBuilder`, refer to [`LspService::build`].
142pub struct LspServiceBuilder<S> {
143    inner: Router<S, ExitedError>,
144    state: Arc<ServerState>,
145    pending: Arc<Pending>,
146    socket: ClientSocket,
147}
148
149impl<S: LanguageServer> LspServiceBuilder<S> {
150    /// Defines a custom JSON-RPC request or notification with the given method `name` and handler.
151    ///
152    /// # Handler varieties
153    ///
154    /// Fundamentally, any inherent `async fn(&self)` method defined directly on the language
155    /// server backend could be considered a valid method handler.
156    ///
157    /// Handlers may optionally include a single `params` argument. This argument may be of any
158    /// type that implements [`Serialize`](serde::Serialize).
159    ///
160    /// Handlers which return `()` are treated as **notifications**, while those which return
161    /// [`jsonrpc::Result<T>`](crate::jsonrpc::Result) are treated as **requests**.
162    ///
163    /// Similar to the `params` argument, the `T` in the `Result<T>` return values may be of any
164    /// type which implements [`DeserializeOwned`](serde::de::DeserializeOwned). Additionally, this
165    /// type _must_ be convertible into a [`serde_json::Value`] using [`serde_json::to_value`]. If
166    /// this latter constraint is not met, the client will receive a JSON-RPC error response with
167    /// code `-32603` (Internal Error) instead of the expected response.
168    ///
169    /// # Examples
170    ///
171    /// ```rust
172    /// use serde_json::{json, Value};
173    /// use tower_lsp_server::jsonrpc::Result;
174    /// use tower_lsp_server::lsp_types::*;
175    /// use tower_lsp_server::{LanguageServer, LspService};
176    ///
177    /// struct Mock;
178    ///
179    /// // Implementation of `LanguageServer` omitted...
180    /// # impl LanguageServer for Mock {
181    /// #     async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
182    /// #         Ok(InitializeResult::default())
183    /// #     }
184    /// #
185    /// #     async fn shutdown(&self) -> Result<()> {
186    /// #         Ok(())
187    /// #     }
188    /// # }
189    ///
190    /// impl Mock {
191    ///     async fn request(&self) -> Result<i32> {
192    ///         Ok(123)
193    ///     }
194    ///
195    ///     async fn request_params(&self, params: Vec<String>) -> Result<Value> {
196    ///         Ok(json!({"num_elems":params.len()}))
197    ///     }
198    ///
199    ///     async fn notification(&self) {
200    ///         // ...
201    ///     }
202    ///
203    ///     async fn notification_params(&self, params: Value) {
204    ///         // ...
205    /// #       let _ = params;
206    ///     }
207    /// }
208    ///
209    /// let (service, socket) = LspService::build(|_| Mock)
210    ///     .custom_method("custom/request", Mock::request)
211    ///     .custom_method("custom/requestParams", Mock::request_params)
212    ///     .custom_method("custom/notification", Mock::notification)
213    ///     .custom_method("custom/notificationParams", Mock::notification_params)
214    ///     .finish();
215    /// ```
216    #[must_use]
217    pub fn custom_method<P, R, F>(mut self, name: &'static str, callback: F) -> Self
218    where
219        P: FromParams,
220        R: IntoResponse,
221        F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
222    {
223        let layer = layers::Normal::new(self.state.clone(), self.pending.clone());
224        self.inner.method(name, callback, layer);
225        self
226    }
227
228    /// Constructs the `LspService` and returns it, along with a channel for server-to-client
229    /// communication.
230    #[must_use]
231    pub fn finish(self) -> (LspService<S>, ClientSocket) {
232        let Self {
233            inner,
234            state,
235            socket,
236            ..
237        } = self;
238
239        (LspService { inner, state }, socket)
240    }
241}
242
243impl<S: Debug> Debug for LspServiceBuilder<S> {
244    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
245        f.debug_struct("LspServiceBuilder")
246            .field("inner", &self.inner)
247            .finish_non_exhaustive()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use lsp_types::*;
254    use serde_json::json;
255    use tower::ServiceExt;
256
257    use super::*;
258    use crate::jsonrpc::Result;
259
260    #[derive(Debug)]
261    struct Mock;
262
263    impl LanguageServer for Mock {
264        async fn initialize(&self, _: InitializeParams) -> Result<InitializeResult> {
265            Ok(InitializeResult::default())
266        }
267
268        async fn shutdown(&self) -> Result<()> {
269            Ok(())
270        }
271
272        // This handler should never resolve...
273        async fn code_action_resolve(&self, _: CodeAction) -> Result<CodeAction> {
274            future::pending().await
275        }
276    }
277
278    #[expect(clippy::unused_async)]
279    impl Mock {
280        async fn custom_request(&self, params: i32) -> Result<i32> {
281            Ok(params)
282        }
283    }
284
285    fn initialize_request(id: i64) -> Request {
286        Request::build("initialize")
287            .params(json!({"capabilities":{}}))
288            .id(id)
289            .finish()
290    }
291
292    #[tokio::test(flavor = "current_thread")]
293    async fn initializes_only_once() {
294        let (mut service, _) = LspService::new(|_| Mock);
295
296        let request = initialize_request(1);
297
298        let response = service.ready().await.unwrap().call(request.clone()).await;
299        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
300        assert_eq!(response, Ok(Some(ok)));
301
302        let response = service.ready().await.unwrap().call(request).await;
303        let err = Response::from_error(1.into(), Error::invalid_request());
304        assert_eq!(response, Ok(Some(err)));
305    }
306
307    #[tokio::test(flavor = "current_thread")]
308    async fn refuses_requests_after_shutdown() {
309        let (mut service, _) = LspService::new(|_| Mock);
310
311        let initialize = initialize_request(1);
312        let response = service.ready().await.unwrap().call(initialize).await;
313        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
314        assert_eq!(response, Ok(Some(ok)));
315
316        let shutdown = Request::build("shutdown").id(1).finish();
317        let response = service.ready().await.unwrap().call(shutdown.clone()).await;
318        let ok = Response::from_ok(1.into(), json!(null));
319        assert_eq!(response, Ok(Some(ok)));
320
321        let response = service.ready().await.unwrap().call(shutdown).await;
322        let err = Response::from_error(1.into(), Error::invalid_request());
323        assert_eq!(response, Ok(Some(err)));
324    }
325
326    #[tokio::test(flavor = "current_thread")]
327    async fn exit_notification() {
328        let (mut service, _) = LspService::new(|_| Mock);
329
330        let exit = Request::build("exit").finish();
331        let response = service.ready().await.unwrap().call(exit.clone()).await;
332        assert_eq!(response, Ok(None));
333
334        let ready = future::poll_fn(|cx| service.poll_ready(cx)).await;
335        assert_eq!(ready, Err(ExitedError(())));
336        assert_eq!(service.call(exit).await, Err(ExitedError(())));
337    }
338
339    #[tokio::test(flavor = "current_thread")]
340    async fn cancels_pending_requests() {
341        let (mut service, _) = LspService::new(|_| Mock);
342
343        let initialize = initialize_request(1);
344        let response = service.ready().await.unwrap().call(initialize).await;
345        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
346        assert_eq!(response, Ok(Some(ok)));
347
348        let pending_request = Request::build("codeAction/resolve")
349            .params(json!({"title":""}))
350            .id(1)
351            .finish();
352
353        let cancel_request = Request::build("$/cancelRequest")
354            .params(json!({"id":1i32}))
355            .finish();
356
357        let pending_fut = service.ready().await.unwrap().call(pending_request);
358        let cancel_fut = service.ready().await.unwrap().call(cancel_request);
359        let (pending_response, cancel_response) = futures::join!(pending_fut, cancel_fut);
360
361        let canceled = Response::from_error(1.into(), Error::request_cancelled());
362        assert_eq!(pending_response, Ok(Some(canceled)));
363        assert_eq!(cancel_response, Ok(None));
364    }
365
366    #[tokio::test(flavor = "current_thread")]
367    async fn serves_custom_requests() {
368        let (mut service, _) = LspService::build(|_| Mock)
369            .custom_method("custom", Mock::custom_request)
370            .finish();
371
372        let initialize = initialize_request(1);
373        let response = service.ready().await.unwrap().call(initialize).await;
374        let ok = Response::from_ok(1.into(), json!({"capabilities":{}}));
375        assert_eq!(response, Ok(Some(ok)));
376
377        let custom = Request::build("custom").params(123i32).id(1).finish();
378        let response = service.ready().await.unwrap().call(custom).await;
379        let ok = Response::from_ok(1.into(), json!(123i32));
380        assert_eq!(response, Ok(Some(ok)));
381    }
382
383    #[tokio::test(flavor = "current_thread")]
384    async fn get_inner() {
385        let (service, _) = LspService::build(|_| Mock).finish();
386
387        service
388            .inner()
389            .initialize(InitializeParams::default())
390            .await
391            .unwrap();
392    }
393}