starknet_devnet_server/
rpc_handler.rs1use std::fmt::{self};
2
3use axum::Json;
4use axum::extract::rejection::JsonRejection;
5use axum::extract::ws::WebSocket;
6use axum::extract::{State, WebSocketUpgrade};
7use axum::response::IntoResponse;
8use futures::{FutureExt, future};
9use serde::de::DeserializeOwned;
10use tracing::{trace, warn};
11
12use crate::rpc_core::error::RpcError;
13use crate::rpc_core::request::{Request, RpcCall, RpcMethodCall};
14use crate::rpc_core::response::{Response, ResponseResult, RpcResponse};
15
16#[async_trait::async_trait]
18pub trait RpcHandler: Clone + Send + Sync + 'static {
19 type Request: DeserializeOwned + Send + Sync + fmt::Display;
21
22 async fn on_request(
24 &self,
25 request: Self::Request,
26 original_call: RpcMethodCall,
27 ) -> ResponseResult;
28
29 async fn on_call(&self, call: RpcMethodCall) -> RpcResponse;
38
39 async fn on_websocket(&self, mut socket: WebSocket);
41}
42
43pub async fn handle<THandler: RpcHandler>(
45 State(handler): State<THandler>,
46 request: Result<Json<Request>, JsonRejection>,
47) -> Json<Response> {
48 match request {
49 Ok(req) => handle_request(req.0, handler)
50 .await
51 .unwrap_or_else(|| Response::error(RpcError::invalid_request()))
52 .into(),
53 Err(err) => match &err {
54 JsonRejection::JsonSyntaxError(e) => {
55 let error_msg = e.to_string();
56 warn!(target: "rpc", "JSON syntax error: {}", error_msg);
57 Response::error(RpcError::parse_error(error_msg)).into()
58 }
59 JsonRejection::JsonDataError(e) => {
60 warn!(target: "rpc", "JSON data error: {}", e);
61 Response::error(RpcError::invalid_request_with_reason(format!("Data error: {}", e)))
62 .into()
63 }
64 JsonRejection::MissingJsonContentType(e) => {
65 warn!(target: "rpc", "Missing JSON content type: {}", e);
66 Response::error(RpcError::invalid_request_with_reason("Missing content type"))
67 .into()
68 }
69 _ => {
70 warn!(target: "rpc", "Request rejection: {}", err);
71 Response::error(RpcError::invalid_request_with_reason(err.to_string())).into()
72 }
73 },
74 }
75}
76
77pub async fn handle_socket<THandler: RpcHandler>(
78 ws_upgrade: WebSocketUpgrade,
79 State(handler): State<THandler>,
80) -> impl IntoResponse {
81 tracing::info!("New websocket connection!");
82 ws_upgrade.on_failed_upgrade(|e| tracing::error!("Failed websocket upgrade: {e:?}")).on_upgrade(
83 move |socket| async move {
84 handler.on_websocket(socket).await;
85 },
86 )
87}
88
89pub async fn handle_request<THandler: RpcHandler>(
94 req: Request,
95 handler: THandler,
96) -> Option<Response> {
97 fn responses_as_batch(outs: Vec<Option<RpcResponse>>) -> Option<Response> {
99 let batch: Vec<_> = outs.into_iter().flatten().collect();
100 (!batch.is_empty()).then_some(Response::Batch(batch))
101 }
102
103 match req {
104 Request::Single(call) => handle_call(call, handler).await.map(Response::Single),
105 Request::Batch(calls) => {
106 future::join_all(calls.into_iter().map(move |call| handle_call(call, handler.clone())))
107 .map(responses_as_batch)
108 .await
109 }
110 }
111}
112
113pub(crate) async fn handle_call<THandler: RpcHandler>(
115 call: RpcCall,
116 handler: THandler,
117) -> Option<RpcResponse> {
118 match call {
119 RpcCall::MethodCall(call) => {
120 trace!(target: "rpc", id = ?call.id, method = ?call.method, "handling call");
121 Some(handler.on_call(call).await)
122 }
123 RpcCall::Notification(notification) => {
124 trace!(target: "rpc", method = ?notification.method, "received rpc notification");
125 None
126 }
127 RpcCall::Invalid { id } => {
128 warn!(target: "rpc", ?id, "invalid rpc call");
129 Some(RpcResponse::invalid_request(id))
130 }
131 }
132}