starknet_devnet_server/
rpc_handler.rs

1use std::fmt::{self};
2
3use axum::Json;
4use axum::extract::rejection::JsonRejection;
5use axum::extract::ws::WebSocket;
6use axum::extract::{State, WebSocketUpgrade};
7use axum::response::IntoResponse;
8use futures::{FutureExt, future};
9use serde::de::DeserializeOwned;
10use tracing::{trace, warn};
11
12use crate::rpc_core::error::RpcError;
13use crate::rpc_core::request::{Request, RpcCall, RpcMethodCall};
14use crate::rpc_core::response::{Response, ResponseResult, RpcResponse};
15
16/// Helper trait that is used to execute starknet rpc calls
17#[async_trait::async_trait]
18pub trait RpcHandler: Clone + Send + Sync + 'static {
19    /// The request type to expect
20    type Request: DeserializeOwned + Send + Sync + fmt::Display;
21
22    /// Invoked when the request was received
23    async fn on_request(
24        &self,
25        request: Self::Request,
26        original_call: RpcMethodCall,
27    ) -> ResponseResult;
28
29    /// Invoked for every incoming `RpcMethodCall`
30    ///
31    /// This will attempt to deserialize a `{ "method" : "<name>", "params": "<params>" }` message
32    /// into the `Request` type of this handler. If a `Request` instance was deserialized
33    /// successfully, [`Self::on_request`] will be invoked.
34    ///
35    /// **Note**: override this function if the expected `Request` deviates from `{ "method" :
36    /// "<name>", "params": "<params>" }`
37    async fn on_call(&self, call: RpcMethodCall) -> RpcResponse;
38
39    /// Handles websocket connection, from start to finish.
40    async fn on_websocket(&self, mut socket: WebSocket);
41}
42
43/// Handles incoming JSON-RPC Request
44pub async fn handle<THandler: RpcHandler>(
45    State(handler): State<THandler>,
46    request: Result<Json<Request>, JsonRejection>,
47) -> Json<Response> {
48    match request {
49        Ok(req) => handle_request(req.0, handler)
50            .await
51            .unwrap_or_else(|| Response::error(RpcError::invalid_request()))
52            .into(),
53        Err(err) => match &err {
54            JsonRejection::JsonSyntaxError(e) => {
55                let error_msg = e.to_string();
56                warn!(target: "rpc", "JSON syntax error: {}", error_msg);
57                Response::error(RpcError::parse_error(error_msg)).into()
58            }
59            JsonRejection::JsonDataError(e) => {
60                warn!(target: "rpc", "JSON data error: {}", e);
61                Response::error(RpcError::invalid_request_with_reason(format!("Data error: {}", e)))
62                    .into()
63            }
64            JsonRejection::MissingJsonContentType(e) => {
65                warn!(target: "rpc", "Missing JSON content type: {}", e);
66                Response::error(RpcError::invalid_request_with_reason("Missing content type"))
67                    .into()
68            }
69            _ => {
70                warn!(target: "rpc", "Request rejection: {}", err);
71                Response::error(RpcError::invalid_request_with_reason(err.to_string())).into()
72            }
73        },
74    }
75}
76
77pub async fn handle_socket<THandler: RpcHandler>(
78    ws_upgrade: WebSocketUpgrade,
79    State(handler): State<THandler>,
80) -> impl IntoResponse {
81    tracing::info!("New websocket connection!");
82    ws_upgrade.on_failed_upgrade(|e| tracing::error!("Failed websocket upgrade: {e:?}")).on_upgrade(
83        move |socket| async move {
84            handler.on_websocket(socket).await;
85        },
86    )
87}
88
89/// Handle the JSON-RPC [Request]
90///
91/// This will try to deserialize the payload into the request type of the handler and if successful
92/// invoke the handler.
93pub async fn handle_request<THandler: RpcHandler>(
94    req: Request,
95    handler: THandler,
96) -> Option<Response> {
97    /// processes batch calls
98    fn responses_as_batch(outs: Vec<Option<RpcResponse>>) -> Option<Response> {
99        let batch: Vec<_> = outs.into_iter().flatten().collect();
100        (!batch.is_empty()).then_some(Response::Batch(batch))
101    }
102
103    match req {
104        Request::Single(call) => handle_call(call, handler).await.map(Response::Single),
105        Request::Batch(calls) => {
106            future::join_all(calls.into_iter().map(move |call| handle_call(call, handler.clone())))
107                .map(responses_as_batch)
108                .await
109        }
110    }
111}
112
113/// handle a single RPC method call
114pub(crate) async fn handle_call<THandler: RpcHandler>(
115    call: RpcCall,
116    handler: THandler,
117) -> Option<RpcResponse> {
118    match call {
119        RpcCall::MethodCall(call) => {
120            trace!(target: "rpc", id = ?call.id, method = ?call.method, "handling call");
121            Some(handler.on_call(call).await)
122        }
123        RpcCall::Notification(notification) => {
124            trace!(target: "rpc", method = ?notification.method, "received rpc notification");
125            None
126        }
127        RpcCall::Invalid { id } => {
128            warn!(target: "rpc", ?id, "invalid rpc call");
129            Some(RpcResponse::invalid_request(id))
130        }
131    }
132}