tycho_rpc/endpoint/
mod.rs1use std::time::Duration;
2
3use anyhow::Result;
4use axum::RequestExt;
5use axum::extract::{DefaultBodyLimit, FromRef, Request, State};
6use axum::http::StatusCode;
7use axum::response::{IntoResponse, Response};
8use axum::routing::{get, post};
9use tokio::net::TcpListener;
10
11pub use self::jrpc::JrpcEndpointCache;
12pub use self::proto::ProtoEndpointCache;
13use crate::state::RpcState;
14use crate::util::mime::{APPLICATION_JSON, APPLICATION_PROTOBUF, get_mime_type};
15
16pub mod jrpc;
17pub mod proto;
18
19pub struct RpcEndpointBuilder<C = ()> {
20 common: RpcEndpointBuilderCommon,
21 custom_routes: C,
22}
23
24impl Default for RpcEndpointBuilder {
25 #[inline]
26 fn default() -> Self {
27 Self {
28 common: Default::default(),
29 custom_routes: (),
30 }
31 }
32}
33
34impl RpcEndpointBuilder<()> {
35 pub fn empty() -> Self {
36 Self {
37 common: RpcEndpointBuilderCommon::empty(),
38 custom_routes: (),
39 }
40 }
41
42 pub fn with_custom_routes<S>(
43 self,
44 routes: axum::Router<S>,
45 ) -> RpcEndpointBuilder<axum::Router<S>>
46 where
47 RpcState: FromRef<S>,
48 S: Send + Sync,
49 {
50 RpcEndpointBuilder {
51 common: self.common,
52 custom_routes: routes,
53 }
54 }
55
56 pub async fn bind(self, state: RpcState) -> Result<RpcEndpoint> {
57 let listener = state.bind_socket().await?;
58 Ok(RpcEndpoint::from_parts(
59 listener,
60 self.common.build(),
61 state,
62 ))
63 }
64}
65
66impl<C> RpcEndpointBuilder<C> {
67 pub fn with_healthcheck_route<T: Into<String>>(mut self, route: T) -> Self {
68 self.common.healthcheck_route = Some(route.into());
69 self
70 }
71
72 pub fn with_base_routes<I, T>(mut self, routes: I) -> Self
73 where
74 I: IntoIterator<Item = T>,
75 T: Into<String>,
76 {
77 self.common.base_routes = routes.into_iter().map(Into::into).collect();
78 self
79 }
80}
81
82impl<S> RpcEndpointBuilder<axum::Router<S>>
83where
84 RpcState: FromRef<S>,
85 S: Send + Sync + Clone + 'static,
86{
87 pub async fn bind(self, state: S) -> Result<RpcEndpoint> {
88 let listener = RpcState::from_ref(&state).bind_socket().await?;
89 Ok(RpcEndpoint::from_parts(
90 listener,
91 self.common.build::<S>().merge(self.custom_routes),
92 state,
93 ))
94 }
95}
96
97struct RpcEndpointBuilderCommon {
98 healthcheck_route: Option<String>,
99 base_routes: Vec<String>,
100}
101
102impl Default for RpcEndpointBuilderCommon {
103 fn default() -> Self {
104 Self {
105 healthcheck_route: Some("/".to_owned()),
106 base_routes: vec!["/".to_owned(), "/rpc".to_owned(), "/proto".to_owned()],
107 }
108 }
109}
110
111impl RpcEndpointBuilderCommon {
112 pub fn empty() -> Self {
113 Self {
114 healthcheck_route: None,
115 base_routes: Vec::new(),
116 }
117 }
118
119 fn build<S>(self) -> axum::Router<S>
120 where
121 RpcState: FromRef<S>,
122 S: Clone + Send + Sync + 'static,
123 {
124 let mut router = axum::Router::new();
125
126 if let Some(route) = self.healthcheck_route {
127 router = router.route(&route, get(health_check));
128 }
129 for route in self.base_routes {
130 router = router.route(&route, post(common_route));
131 }
132 router = router.merge(jrpc::stream_router::<S>());
133
134 router
135 }
136}
137
138pub struct RpcEndpoint {
139 listener: TcpListener,
140 router: axum::Router<()>,
141}
142
143impl RpcEndpoint {
144 pub fn builder() -> RpcEndpointBuilder {
145 RpcEndpointBuilder::default()
146 }
147
148 pub fn empty_builder() -> RpcEndpointBuilder {
149 RpcEndpointBuilder::empty()
150 }
151
152 pub fn from_parts<S>(listener: TcpListener, router: axum::Router<S>, state: S) -> Self
153 where
154 S: Clone + Send + Sync + 'static,
155 {
156 use tower::ServiceBuilder;
157 use tower_http::cors::CorsLayer;
158 use tower_http::timeout::TimeoutLayer;
159
160 let service = ServiceBuilder::new()
162 .layer(DefaultBodyLimit::max(MAX_REQUEST_SIZE))
163 .layer(CorsLayer::permissive())
164 .layer(TimeoutLayer::with_status_code(
165 StatusCode::REQUEST_TIMEOUT,
166 Duration::from_secs(25),
167 ));
168
169 #[cfg(feature = "compression")]
170 let service = service.layer(tower_http::compression::CompressionLayer::new().gzip(true));
171
172 let router = router.layer(service).with_state(state);
174
175 Self { listener, router }
177 }
178
179 pub async fn serve(self) -> std::io::Result<()> {
180 axum::serve(self.listener, self.router).await
181 }
182}
183
184fn health_check() -> futures_util::future::Ready<impl IntoResponse> {
185 futures_util::future::ready(
186 std::time::SystemTime::now()
187 .duration_since(std::time::UNIX_EPOCH)
188 .expect("system time before Unix epoch")
189 .as_millis()
190 .to_string(),
191 )
192}
193
194async fn common_route(state: State<RpcState>, req: Request) -> Response {
195 use axum::http::StatusCode;
196
197 match get_mime_type(&req) {
198 Some(mime) if mime.starts_with(APPLICATION_JSON) => match req.extract().await {
199 Ok(method) => jrpc::route(state, method).await,
200 Err(e) => e.into_response(),
201 },
202 Some(mime) if mime.starts_with(APPLICATION_PROTOBUF) => match req.extract().await {
203 Ok(request) => proto::route(state, request).await,
204 Err(e) => e.into_response(),
205 },
206 _ => StatusCode::UNSUPPORTED_MEDIA_TYPE.into_response(),
207 }
208}
209
210const MAX_REQUEST_SIZE: usize = 2 << 17;