starknet_devnet_server/
server.rs

1use std::time::Duration;
2
3use axum::Router;
4use axum::body::{Body, Bytes};
5use axum::extract::{DefaultBodyLimit, Request};
6use axum::http::{HeaderValue, StatusCode};
7use axum::middleware::Next;
8use axum::response::{IntoResponse, Response};
9use axum::routing::{IntoMakeService, get, post};
10use http_body_util::BodyExt;
11use reqwest::{Method, header};
12use tokio::net::TcpListener;
13use tower_http::cors::CorsLayer;
14use tower_http::timeout::TimeoutLayer;
15use tower_http::trace::TraceLayer;
16
17use crate::api::JsonRpcHandler;
18use crate::rpc_handler::RpcHandler;
19use crate::{ServerConfig, rpc_handler};
20pub type StarknetDevnetServer = axum::serve::Serve<TcpListener, IntoMakeService<Router>, Router>;
21
22fn json_rpc_routes<TJsonRpcHandler: RpcHandler>(json_rpc_handler: TJsonRpcHandler) -> Router {
23    Router::new()
24        .route("/", post(rpc_handler::handle::<TJsonRpcHandler>))
25        .route("/rpc", post(rpc_handler::handle::<TJsonRpcHandler>))
26        .route("/ws", get(rpc_handler::handle_socket::<TJsonRpcHandler>))
27        .with_state(json_rpc_handler)
28}
29
30/// Configures an [axum::Server] that handles related JSON-RPC calls and web API calls via HTTP.
31pub async fn serve_http_json_rpc(
32    tcp_listener: TcpListener,
33    server_config: &ServerConfig,
34    json_rpc_handler: JsonRpcHandler,
35) -> StarknetDevnetServer {
36    let mut routes = Router::new()
37        .route("/is_alive", get(|| async { "Alive!!!" })) // Only REST endpoint to simplify liveness probe
38        .merge(json_rpc_routes(json_rpc_handler.clone()))
39        .layer(TraceLayer::new_for_http());
40
41    if server_config.log_response {
42        routes = routes.layer(axum::middleware::from_fn(response_logging_middleware));
43    };
44
45    routes = routes
46        .layer(TimeoutLayer::with_status_code(
47            StatusCode::REQUEST_TIMEOUT,
48            Duration::from_secs(server_config.timeout.into()),
49        ))
50        .layer(DefaultBodyLimit::disable())
51        .layer(
52            // More details: https://docs.rs/tower-http/latest/tower_http/cors/index.html
53            CorsLayer::new()
54                .allow_origin(HeaderValue::from_static("*"))
55                .allow_headers(vec![header::CONTENT_TYPE])
56                .allow_methods(vec![Method::GET, Method::POST]),
57        );
58
59    if server_config.log_request {
60        routes = routes.layer(axum::middleware::from_fn(request_logging_middleware));
61    }
62
63    axum::serve(tcp_listener, routes.into_make_service())
64}
65
66async fn log_body_and_path<T>(
67    body: T,
68    uri_option: Option<axum::http::Uri>,
69) -> Result<axum::body::Body, (StatusCode, String)>
70where
71    T: axum::body::HttpBody<Data = Bytes>,
72    T::Error: std::fmt::Display,
73{
74    let bytes = match body.collect().await {
75        Ok(collected) => collected.to_bytes(),
76        Err(err) => {
77            return Err((StatusCode::INTERNAL_SERVER_ERROR, err.to_string()));
78        }
79    };
80
81    if let Ok(body_str) = std::str::from_utf8(&bytes) {
82        if let Some(uri) = uri_option {
83            tracing::info!("{} {}", uri, body_str);
84        } else {
85            tracing::info!("{}", body_str);
86        }
87    } else {
88        tracing::error!("Failed to convert body to string");
89    }
90
91    Ok(Body::from(bytes))
92}
93
94async fn request_logging_middleware(
95    request: Request,
96    next: Next,
97) -> Result<impl IntoResponse, (StatusCode, String)> {
98    let (parts, body) = request.into_parts();
99
100    let body = log_body_and_path(body, Some(parts.uri.clone())).await?;
101    Ok(next.run(Request::from_parts(parts, body)).await)
102}
103
104async fn response_logging_middleware(
105    request: Request,
106    next: Next,
107) -> Result<impl IntoResponse, (StatusCode, String)> {
108    let response = next.run(request).await;
109
110    let (parts, body) = response.into_parts();
111
112    let body = log_body_and_path(body, None).await?;
113
114    let response = Response::from_parts(parts, body);
115    Ok(response)
116}