1use std::collections::HashMap;
4use std::convert::Infallible;
5use std::fmt::{self, Debug, Formatter};
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use futures::future::{self, BoxFuture, FutureExt};
12use ls_types::LSPAny;
13use serde::{Serialize, de::DeserializeOwned};
14use tower::{Layer, Service, util::BoxService};
15
16use crate::jsonrpc::ErrorCode;
17
18use super::{Error, Id, Request, Response};
19
20pub struct Router<S, E = Infallible> {
22 server: Arc<S>,
23 methods: HashMap<&'static str, BoxService<Request, Option<Response>, E>>,
24}
25
26impl<S: Send + Sync + 'static, E> Router<S, E> {
27 pub fn new(server: S) -> Self {
29 Self {
30 server: Arc::new(server),
31 methods: HashMap::new(),
32 }
33 }
34
35 pub fn inner(&self) -> &S {
37 self.server.as_ref()
38 }
39
40 pub fn method<P, R, F, L>(&mut self, name: &'static str, callback: F, layer: L) -> &mut Self
44 where
45 P: FromParams,
46 R: IntoResponse,
47 F: for<'a> Method<&'a S, P, R> + Clone + Send + Sync + 'static,
48 L: Layer<MethodHandler<P, R, E>>,
49 L::Service: Service<Request, Response = Option<Response>, Error = E> + Send + 'static,
50 <L::Service as Service<Request>>::Future: Send + 'static,
51 {
52 let server = &self.server;
53 self.methods.entry(name).or_insert_with(|| {
54 let server = server.clone();
55 let handler = MethodHandler::new(move |params| {
56 let callback = callback.clone();
57 let server = server.clone();
58 async move { callback.invoke(&*server, params).await }
59 });
60
61 BoxService::new(layer.layer(handler))
62 });
63
64 self
65 }
66}
67
68impl<S: Debug, E> Debug for Router<S, E> {
69 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
70 f.debug_struct("Router")
71 .field("server", &self.server)
72 .field("methods", &self.methods.keys())
73 .finish()
74 }
75}
76
77impl<S, E: Send + 'static> Service<Request> for Router<S, E> {
78 type Response = Option<Response>;
79 type Error = E;
80 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
81
82 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
83 Poll::Ready(Ok(()))
84 }
85
86 fn call(&mut self, req: Request) -> Self::Future {
87 if let Some(handler) = self.methods.get_mut(req.method()) {
88 handler.call(req)
89 } else {
90 let (method, id, _) = req.into_parts();
91 future::ok(id.map(|id| {
92 let mut error = Error::method_not_found();
93 error.data = Some(LSPAny::from(method));
94 Response::from_error(id, error)
95 }))
96 .boxed()
97 }
98 }
99}
100
101pub struct MethodHandler<P, R, E> {
103 f: Box<dyn Fn(P) -> BoxFuture<'static, R> + Send>,
104 _marker: PhantomData<E>,
105}
106
107impl<P: FromParams, R: IntoResponse, E> MethodHandler<P, R, E> {
108 fn new<F, Fut>(handler: F) -> Self
109 where
110 F: Fn(P) -> Fut + Send + 'static,
111 Fut: Future<Output = R> + Send + 'static,
112 {
113 Self {
114 f: Box::new(move |p| handler(p).boxed()),
115 _marker: PhantomData,
116 }
117 }
118}
119
120impl<P, R, E> Service<Request> for MethodHandler<P, R, E>
121where
122 P: FromParams,
123 R: IntoResponse,
124 E: Send + 'static,
125{
126 type Response = Option<Response>;
127 type Error = E;
128 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
129
130 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131 Poll::Ready(Ok(()))
132 }
133
134 fn call(&mut self, req: Request) -> Self::Future {
135 let (_, id, params) = req.into_parts();
136
137 match id {
138 Some(_) if R::is_notification() => return future::ok(().into_response(id)).boxed(),
139 None if !R::is_notification() => return future::ok(None).boxed(),
140 _ => {}
141 }
142
143 let params = match P::from_params(params) {
144 Ok(params) => params,
145 Err(err) => return future::ok(id.map(|id| Response::from_error(id, err))).boxed(),
146 };
147
148 (self.f)(params)
149 .map(move |r| Ok(r.into_response(id)))
150 .boxed()
151 }
152}
153
154pub trait Method<S, P, R>: private::Sealed {
165 type Future: Future<Output = R> + Send;
167
168 fn invoke(&self, server: S, params: P) -> Self::Future;
170}
171
172impl<F, S, R, Fut> Method<S, (), R> for F
174where
175 F: Fn(S) -> Fut,
176 Fut: Future<Output = R> + Send,
177{
178 type Future = Fut;
179
180 #[inline]
181 fn invoke(&self, server: S, (): ()) -> Self::Future {
182 self(server)
183 }
184}
185
186impl<F, S, P, R, Fut> Method<S, (P,), R> for F
188where
189 F: Fn(S, P) -> Fut,
190 P: DeserializeOwned,
191 Fut: Future<Output = R> + Send,
192{
193 type Future = Fut;
194
195 #[inline]
196 fn invoke(&self, server: S, params: (P,)) -> Self::Future {
197 self(server, params.0)
198 }
199}
200
201pub trait FromParams: private::Sealed + Send + Sized + 'static {
203 fn from_params(params: Option<LSPAny>) -> super::Result<Self>;
209}
210
211impl FromParams for () {
213 fn from_params(params: Option<LSPAny>) -> super::Result<Self> {
214 match params {
215 None
216 | Some(LSPAny::Null) => Ok(()),
220 Some(p) => Err(Error::invalid_params(format!("Unexpected params: {p}"))),
221 }
222 }
223}
224
225impl<P: DeserializeOwned + Send + 'static> FromParams for (P,) {
227 fn from_params(params: Option<LSPAny>) -> super::Result<Self> {
228 params.map_or_else(
229 || Err(Error::invalid_params("Missing params field")),
230 |p| {
231 serde_json::from_value(p)
232 .map(|params| (params,))
233 .map_err(|e| Error::invalid_params(e.to_string()))
234 },
235 )
236 }
237}
238
239pub trait IntoResponse: private::Sealed + Send + 'static {
241 fn into_response(self, id: Option<Id>) -> Option<Response>;
243
244 fn is_notification() -> bool;
246}
247
248impl IntoResponse for () {
250 #[expect(clippy::single_option_map, reason = "we cannot change trait signature")]
251 fn into_response(self, id: Option<Id>) -> Option<Response> {
252 id.map(|id| Response::from_error(id, Error::invalid_request()))
253 }
254
255 #[inline]
256 fn is_notification() -> bool {
257 true
258 }
259}
260
261impl<R: Serialize + Send + 'static> IntoResponse for Result<R, Error> {
263 fn into_response(self, id: Option<Id>) -> Option<Response> {
264 debug_assert!(id.is_some(), "Requests always contain an `id` field");
265 id.map(|id| {
266 let result = self.and_then(|r| {
267 serde_json::to_value(r).map_err(|e| Error {
268 code: ErrorCode::InternalError,
269 message: e.to_string().into(),
270 data: None,
271 })
272 });
273 Response::from_parts(id, result)
274 })
275 }
276
277 #[inline]
278 fn is_notification() -> bool {
279 false
280 }
281}
282
283mod private {
284 pub trait Sealed {}
285 impl<T> Sealed for T {}
286}
287
288#[cfg(test)]
289mod tests {
290 use serde::{Deserialize, Serialize};
291 use serde_json::json;
292 use tower::ServiceExt;
293 use tower::layer::layer_fn;
294
295 use super::*;
296
297 #[derive(Deserialize, Serialize)]
298 struct Params {
299 foo: i32,
300 bar: String,
301 }
302
303 struct Mock;
304
305 #[expect(clippy::unused_async)]
306 impl Mock {
307 async fn request(&self) -> Result<LSPAny, Error> {
308 Ok(LSPAny::Null)
309 }
310
311 async fn request_params(&self, params: Params) -> Result<Params, Error> {
312 Ok(params)
313 }
314
315 async fn notification(&self) {}
316
317 async fn notification_params(&self, _params: Params) {}
318 }
319
320 #[tokio::test(flavor = "current_thread")]
321 async fn routes_requests() {
322 let mut router: Router<Mock> = Router::new(Mock);
323 router
324 .method("first", Mock::request, layer_fn(|s| s))
325 .method("second", Mock::request_params, layer_fn(|s| s));
326
327 let request = Request::build("first").id(0).finish();
328 let response = router.ready().await.unwrap().call(request).await;
329 assert_eq!(
330 response,
331 Ok(Some(Response::from_ok(0.into(), LSPAny::Null)))
332 );
333
334 let params = json!({"foo": -123i32, "bar": "hello world"});
335 let with_params = Request::build("second")
336 .params(params.clone())
337 .id(1)
338 .finish();
339 let response = router.ready().await.unwrap().call(with_params).await;
340 assert_eq!(response, Ok(Some(Response::from_ok(1.into(), params))));
341 }
342
343 #[tokio::test(flavor = "current_thread")]
344 async fn routes_notifications() {
345 let mut router: Router<Mock> = Router::new(Mock);
346 router
347 .method("first", Mock::notification, layer_fn(|s| s))
348 .method("second", Mock::notification_params, layer_fn(|s| s));
349
350 let request = Request::build("first").finish();
351 let response = router.ready().await.unwrap().call(request).await;
352 assert_eq!(response, Ok(None));
353
354 let params = json!({"foo": -123i32, "bar": "hello world"});
355 let with_params = Request::build("second").params(params).finish();
356 let response = router.ready().await.unwrap().call(with_params).await;
357 assert_eq!(response, Ok(None));
358 }
359
360 #[tokio::test(flavor = "current_thread")]
361 async fn rejects_request_with_invalid_params() {
362 let mut router: Router<Mock> = Router::new(Mock);
363 router.method("request", Mock::request_params, layer_fn(|s| s));
364
365 let invalid_params = Request::build("request")
366 .params(json!("wrong"))
367 .id(0)
368 .finish();
369
370 let response = router.ready().await.unwrap().call(invalid_params).await;
371 assert_eq!(
372 response,
373 Ok(Some(Response::from_error(
374 0.into(),
375 Error::invalid_params("invalid type: string \"wrong\", expected struct Params"),
376 )))
377 );
378 }
379
380 #[tokio::test(flavor = "current_thread")]
381 async fn ignores_notification_with_invalid_params() {
382 let mut router: Router<Mock> = Router::new(Mock);
383 router.method("notification", Mock::request_params, layer_fn(|s| s));
384
385 let invalid_params = Request::build("notification")
386 .params(json!("wrong"))
387 .finish();
388
389 let response = router.ready().await.unwrap().call(invalid_params).await;
390 assert_eq!(response, Ok(None));
391 }
392
393 #[tokio::test(flavor = "current_thread")]
394 async fn handles_incorrect_request_types() {
395 let mut router: Router<Mock> = Router::new(Mock);
396 router
397 .method("first", Mock::request, layer_fn(|s| s))
398 .method("second", Mock::notification, layer_fn(|s| s));
399
400 let request = Request::build("first").finish();
401 let response = router.ready().await.unwrap().call(request).await;
402 assert_eq!(response, Ok(None));
403
404 let request = Request::build("second").id(0).finish();
405 let response = router.ready().await.unwrap().call(request).await;
406 assert_eq!(
407 response,
408 Ok(Some(Response::from_error(
409 0.into(),
410 Error::invalid_request(),
411 )))
412 );
413 }
414
415 #[tokio::test(flavor = "current_thread")]
416 async fn responds_to_nonexistent_request() {
417 let mut router: Router<Mock> = Router::new(Mock);
418
419 let request = Request::build("foo").id(0).finish();
420 let response = router.ready().await.unwrap().call(request).await;
421 let mut error = Error::method_not_found();
422 error.data = Some("foo".into());
423 assert_eq!(response, Ok(Some(Response::from_error(0.into(), error))));
424 }
425
426 #[tokio::test(flavor = "current_thread")]
427 async fn ignores_nonexistent_notification() {
428 let mut router: Router<Mock> = Router::new(Mock);
429
430 let request = Request::build("foo").finish();
431 let response = router.ready().await.unwrap().call(request).await;
432 assert_eq!(response, Ok(None));
433 }
434}