tower_lsp_server/jsonrpc/
router.rs

1//! Lightweight JSON-RPC router service.
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::fmt::{self, Debug, Formatter};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use futures::future::{self, BoxFuture, FutureExt};
12use ls_types::LSPAny;
13use serde::{Serialize, de::DeserializeOwned};
14use tower::{Layer, Service, util::BoxService};
15
16use crate::jsonrpc::ErrorCode;
17
18use super::{Error, Id, Request, Response};
19
20/// A modular JSON-RPC 2.0 request router service.
21pub struct Router<S, E = Infallible> {
22    server: Arc<S>,
23    methods: HashMap<&'static str, BoxService<Request, Option<Response>, E>>,
24}
25
26impl<S: Send + Sync + 'static, E> Router<S, E> {
27    /// Creates a new `Router` with the given shared state.
28    pub fn new(server: S) -> Self {
29        Self {
30            server: Arc::new(server),
31            methods: HashMap::new(),
32        }
33    }
34
35    /// Returns a reference to the inner server.
36    pub fn inner(&self) -> &S {
37        self.server.as_ref()
38    }
39
40    /// Registers a new RPC method which constructs a response with the given `callback`.
41    ///
42    /// The `layer` argument can be used to inject middleware into the method handler, if desired.
43    pub fn method<P, R, F, L>(&mut self, name: &'static str, callback: F, layer: L) -> &mut Self
44    where
45        P: FromParams,
46        R: IntoResponse,
47        F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
48        L: Layer<MethodHandler<P, R, E>>,
49        L::Service: Service<Request, Response = Option<Response>, Error = E> + Send + 'static,
50        <L::Service as Service<Request>>::Future: Send + 'static,
51    {
52        let server = &self.server;
53        self.methods.entry(name).or_insert_with(|| {
54            let server = server.clone();
55            let handler = MethodHandler::new(move |params| {
56                let callback = callback.clone();
57                let server = server.clone();
58                async move { callback.invoke(&*server, params).await }
59            });
60
61            BoxService::new(layer.layer(handler))
62        });
63
64        self
65    }
66}
67
68impl<S: Debug, E> Debug for Router<S, E> {
69    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
70        f.debug_struct("Router")
71            .field("server", &self.server)
72            .field("methods", &self.methods.keys())
73            .finish()
74    }
75}
76
77impl<S, E: Send + 'static> Service<Request> for Router<S, E> {
78    type Response = Option<Response>;
79    type Error = E;
80    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
81
82    fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83        Poll::Ready(Ok(()))
84    }
85
86    fn call(&mut self, req: Request) -> Self::Future {
87        if let Some(handler) = self.methods.get_mut(req.method()) {
88            handler.call(req)
89        } else {
90            let (method, id, _) = req.into_parts();
91            future::ok(id.map(|id| {
92                let mut error = Error::method_not_found();
93                error.data = Some(LSPAny::from(method));
94                Response::from_error(id, error)
95            }))
96            .boxed()
97        }
98    }
99}
100
101/// Opaque JSON-RPC method handler.
102pub struct MethodHandler<P, R, E> {
103    f: Box<dyn Fn(P) -> BoxFuture<'static, R> + Send>,
104    _marker: PhantomData<E>,
105}
106
107impl<P: FromParams, R: IntoResponse, E> MethodHandler<P, R, E> {
108    fn new<F, Fut>(handler: F) -> Self
109    where
110        F: Fn(P) -> Fut + Send + 'static,
111        Fut: Future<Output = R> + Send + 'static,
112    {
113        Self {
114            f: Box::new(move |p| handler(p).boxed()),
115            _marker: PhantomData,
116        }
117    }
118}
119
120impl<P, R, E> Service<Request> for MethodHandler<P, R, E>
121where
122    P: FromParams,
123    R: IntoResponse,
124    E: Send + 'static,
125{
126    type Response = Option<Response>;
127    type Error = E;
128    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
129
130    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131        Poll::Ready(Ok(()))
132    }
133
134    fn call(&mut self, req: Request) -> Self::Future {
135        let (_, id, params) = req.into_parts();
136
137        match id {
138            Some(_) if R::is_notification() => return future::ok(().into_response(id)).boxed(),
139            None if !R::is_notification() => return future::ok(None).boxed(),
140            _ => {}
141        }
142
143        let params = match P::from_params(params) {
144            Ok(params) => params,
145            Err(err) => return future::ok(id.map(|id| Response::from_error(id, err))).boxed(),
146        };
147
148        (self.f)(params)
149            .map(move |r| Ok(r.into_response(id)))
150            .boxed()
151    }
152}
153
154/// A trait implemented by all valid JSON-RPC method handlers.
155///
156/// This trait abstracts over the following classes of functions and/or closures:
157///
158/// Signature                                            | Description
159/// -----------------------------------------------------|---------------------------------
160/// `async fn f(&self) -> jsonrpc::Result<R>`            | Request without parameters
161/// `async fn f(&self, params: P) -> jsonrpc::Result<R>` | Request with required parameters
162/// `async fn f(&self)`                                  | Notification without parameters
163/// `async fn f(&self, params: P)`                       | Notification with parameters
164pub trait Method<S, P, R>: private::Sealed {
165    /// The future response value.
166    type Future: Future<Output = R> + Send;
167
168    /// Invokes the method with the given `server` receiver and parameters.
169    fn invoke(&self, server: S, params: P) -> Self::Future;
170}
171
172/// Support parameter-less JSON-RPC methods.
173impl<F, S, R, Fut> Method<S, (), R> for F
174where
175    F: Fn(S) -> Fut,
176    Fut: Future<Output = R> + Send,
177{
178    type Future = Fut;
179
180    #[inline]
181    fn invoke(&self, server: S, (): ()) -> Self::Future {
182        self(server)
183    }
184}
185
186/// Support JSON-RPC methods with `params`.
187impl<F, S, P, R, Fut> Method<S, (P,), R> for F
188where
189    F: Fn(S, P) -> Fut,
190    P: DeserializeOwned,
191    Fut: Future<Output = R> + Send,
192{
193    type Future = Fut;
194
195    #[inline]
196    fn invoke(&self, server: S, params: (P,)) -> Self::Future {
197        self(server, params.0)
198    }
199}
200
201/// A trait implemented by all JSON-RPC method parameters.
202pub trait FromParams: private::Sealed + Send + Sized + 'static {
203    /// Attempts to deserialize `Self` from the `params` value extracted from [`Request`].
204    ///
205    /// # Errors
206    ///
207    /// - If the given parameters don't match the expected shape
208    fn from_params(params: Option<LSPAny>) -> super::Result<Self>;
209}
210
211/// Deserialize non-existent JSON-RPC parameters.
212impl FromParams for () {
213    fn from_params(params: Option<LSPAny>) -> super::Result<Self> {
214        match params {
215            None
216            // See #40: allow lsp clients (e.g. `lsp4j`) to not precisely
217            // respect the specification and set `params` to `null` when it
218            // should not be present at all.
219            | Some(LSPAny::Null) => Ok(()),
220            Some(p) => Err(Error::invalid_params(format!("Unexpected params: {p}"))),
221        }
222    }
223}
224
225/// Deserialize required JSON-RPC parameters.
226impl<P: DeserializeOwned + Send + 'static> FromParams for (P,) {
227    fn from_params(params: Option<LSPAny>) -> super::Result<Self> {
228        params.map_or_else(
229            || Err(Error::invalid_params("Missing params field")),
230            |p| {
231                serde_json::from_value(p)
232                    .map(|params| (params,))
233                    .map_err(|e| Error::invalid_params(e.to_string()))
234            },
235        )
236    }
237}
238
239/// A trait implemented by all JSON-RPC response types.
240pub trait IntoResponse: private::Sealed + Send + 'static {
241    /// Attempts to construct a [`Response`] using `Self` and a corresponding [`Id`].
242    fn into_response(self, id: Option<Id>) -> Option<Response>;
243
244    /// Returns `true` if this is a notification response type.
245    fn is_notification() -> bool;
246}
247
248/// Support JSON-RPC notification methods.
249impl IntoResponse for () {
250    #[expect(clippy::single_option_map, reason = "we cannot change trait signature")]
251    fn into_response(self, id: Option<Id>) -> Option<Response> {
252        id.map(|id| Response::from_error(id, Error::invalid_request()))
253    }
254
255    #[inline]
256    fn is_notification() -> bool {
257        true
258    }
259}
260
261/// Support JSON-RPC request methods.
262impl<R: Serialize + Send + 'static> IntoResponse for Result<R, Error> {
263    fn into_response(self, id: Option<Id>) -> Option<Response> {
264        debug_assert!(id.is_some(), "Requests always contain an `id` field");
265        id.map(|id| {
266            let result = self.and_then(|r| {
267                serde_json::to_value(r).map_err(|e| Error {
268                    code: ErrorCode::InternalError,
269                    message: e.to_string().into(),
270                    data: None,
271                })
272            });
273            Response::from_parts(id, result)
274        })
275    }
276
277    #[inline]
278    fn is_notification() -> bool {
279        false
280    }
281}
282
283mod private {
284    pub trait Sealed {}
285    impl<T> Sealed for T {}
286}
287
288#[cfg(test)]
289mod tests {
290    use serde::{Deserialize, Serialize};
291    use serde_json::json;
292    use tower::ServiceExt;
293    use tower::layer::layer_fn;
294
295    use super::*;
296
297    #[derive(Deserialize, Serialize)]
298    struct Params {
299        foo: i32,
300        bar: String,
301    }
302
303    struct Mock;
304
305    #[expect(clippy::unused_async)]
306    impl Mock {
307        async fn request(&self) -> Result<LSPAny, Error> {
308            Ok(LSPAny::Null)
309        }
310
311        async fn request_params(&self, params: Params) -> Result<Params, Error> {
312            Ok(params)
313        }
314
315        async fn notification(&self) {}
316
317        async fn notification_params(&self, _params: Params) {}
318    }
319
320    #[tokio::test(flavor = "current_thread")]
321    async fn routes_requests() {
322        let mut router: Router<Mock> = Router::new(Mock);
323        router
324            .method("first", Mock::request, layer_fn(|s| s))
325            .method("second", Mock::request_params, layer_fn(|s| s));
326
327        let request = Request::build("first").id(0).finish();
328        let response = router.ready().await.unwrap().call(request).await;
329        assert_eq!(
330            response,
331            Ok(Some(Response::from_ok(0.into(), LSPAny::Null)))
332        );
333
334        let params = json!({"foo": -123i32, "bar": "hello world"});
335        let with_params = Request::build("second")
336            .params(params.clone())
337            .id(1)
338            .finish();
339        let response = router.ready().await.unwrap().call(with_params).await;
340        assert_eq!(response, Ok(Some(Response::from_ok(1.into(), params))));
341    }
342
343    #[tokio::test(flavor = "current_thread")]
344    async fn routes_notifications() {
345        let mut router: Router<Mock> = Router::new(Mock);
346        router
347            .method("first", Mock::notification, layer_fn(|s| s))
348            .method("second", Mock::notification_params, layer_fn(|s| s));
349
350        let request = Request::build("first").finish();
351        let response = router.ready().await.unwrap().call(request).await;
352        assert_eq!(response, Ok(None));
353
354        let params = json!({"foo": -123i32, "bar": "hello world"});
355        let with_params = Request::build("second").params(params).finish();
356        let response = router.ready().await.unwrap().call(with_params).await;
357        assert_eq!(response, Ok(None));
358    }
359
360    #[tokio::test(flavor = "current_thread")]
361    async fn rejects_request_with_invalid_params() {
362        let mut router: Router<Mock> = Router::new(Mock);
363        router.method("request", Mock::request_params, layer_fn(|s| s));
364
365        let invalid_params = Request::build("request")
366            .params(json!("wrong"))
367            .id(0)
368            .finish();
369
370        let response = router.ready().await.unwrap().call(invalid_params).await;
371        assert_eq!(
372            response,
373            Ok(Some(Response::from_error(
374                0.into(),
375                Error::invalid_params("invalid type: string \"wrong\", expected struct Params"),
376            )))
377        );
378    }
379
380    #[tokio::test(flavor = "current_thread")]
381    async fn ignores_notification_with_invalid_params() {
382        let mut router: Router<Mock> = Router::new(Mock);
383        router.method("notification", Mock::request_params, layer_fn(|s| s));
384
385        let invalid_params = Request::build("notification")
386            .params(json!("wrong"))
387            .finish();
388
389        let response = router.ready().await.unwrap().call(invalid_params).await;
390        assert_eq!(response, Ok(None));
391    }
392
393    #[tokio::test(flavor = "current_thread")]
394    async fn handles_incorrect_request_types() {
395        let mut router: Router<Mock> = Router::new(Mock);
396        router
397            .method("first", Mock::request, layer_fn(|s| s))
398            .method("second", Mock::notification, layer_fn(|s| s));
399
400        let request = Request::build("first").finish();
401        let response = router.ready().await.unwrap().call(request).await;
402        assert_eq!(response, Ok(None));
403
404        let request = Request::build("second").id(0).finish();
405        let response = router.ready().await.unwrap().call(request).await;
406        assert_eq!(
407            response,
408            Ok(Some(Response::from_error(
409                0.into(),
410                Error::invalid_request(),
411            )))
412        );
413    }
414
415    #[tokio::test(flavor = "current_thread")]
416    async fn responds_to_nonexistent_request() {
417        let mut router: Router<Mock> = Router::new(Mock);
418
419        let request = Request::build("foo").id(0).finish();
420        let response = router.ready().await.unwrap().call(request).await;
421        let mut error = Error::method_not_found();
422        error.data = Some("foo".into());
423        assert_eq!(response, Ok(Some(Response::from_error(0.into(), error))));
424    }
425
426    #[tokio::test(flavor = "current_thread")]
427    async fn ignores_nonexistent_notification() {
428        let mut router: Router<Mock> = Router::new(Mock);
429
430        let request = Request::build("foo").finish();
431        let response = router.ready().await.unwrap().call(request).await;
432        assert_eq!(response, Ok(None));
433    }
434}