1use std::sync::Arc;
2
3use axum::extract::ws::{Message, WebSocket};
4use futures::stream::SplitSink;
5use futures::{SinkExt, StreamExt};
6use starknet_core::StarknetBlock;
7use starknet_core::starknet::starknet_config::DumpOn;
8use starknet_types::emitted_event::SubscriptionEmittedEvent;
9use starknet_types::rpc::block::{BlockId, BlockTag, ReorgData};
10use starknet_types::rpc::transactions::TransactionFinalityStatus;
11use tokio::sync::Mutex;
12use tracing::{info, trace};
13
14use crate::api::models::{
15 AccountAddressInput, BlockAndClassHashInput, BlockAndContractAddressInput, BlockAndIndexInput,
16 BlockIdInput, BroadcastedDeclareTransactionEnumWrapper, BroadcastedDeclareTransactionInput,
17 BroadcastedDeployAccountTransactionEnumWrapper, BroadcastedDeployAccountTransactionInput,
18 BroadcastedInvokeTransactionEnumWrapper, BroadcastedInvokeTransactionInput, CallInput,
19 ClassHashInput, DevnetSpecRequest, EstimateFeeInput, EventsInput, GetStorageInput,
20 JsonRpcRequest, JsonRpcResponse, JsonRpcWsRequest, LoadPath, SimulateTransactionsInput,
21 StarknetSpecRequest, ToRpcResponseResult, TransactionHashInput, to_json_rpc_request,
22};
23use crate::api::origin_forwarder::OriginForwarder;
24use crate::api::{Api, ApiError, error};
25use crate::dump_util::dump_event;
26use crate::restrictive_mode::is_json_rpc_method_restricted;
27use crate::rpc_core;
28use crate::rpc_core::error::{ErrorCode, RpcError};
29use crate::rpc_core::request::RpcMethodCall;
30use crate::rpc_core::response::{ResponseResult, RpcResponse};
31use crate::rpc_handler::RpcHandler;
32use crate::subscribe::{
33 NewTransactionNotification, NewTransactionReceiptNotification, NewTransactionStatus,
34 NotificationData, SocketId,
35};
36
37#[derive(Clone)]
41pub struct JsonRpcHandler {
42 pub api: Api,
43 pub origin_caller: Option<OriginForwarder>,
44}
45
46#[async_trait::async_trait]
47impl RpcHandler for JsonRpcHandler {
48 type Request = JsonRpcRequest;
49
50 async fn on_request(
51 &self,
52 request: Self::Request,
53 original_call: RpcMethodCall,
54 ) -> ResponseResult {
55 info!(target: "rpc", "received method in on_request {}", request);
56
57 if !self.allows_method(&original_call.method) {
58 return ResponseResult::Error(RpcError::new(ErrorCode::MethodForbidden));
59 }
60
61 let is_request_forwardable = request.is_forwardable_to_origin(); let is_request_dumpable = request.is_dumpable();
63
64 let old_latest_block = if request.requires_notifying() {
66 Some(self.get_block_by_tag(BlockTag::Latest).await)
67 } else {
68 None
69 };
70
71 let old_pre_confirmed_block =
72 if request.requires_notifying() && self.api.config.uses_pre_confirmed_block() {
73 Some(self.get_block_by_tag(BlockTag::PreConfirmed).await)
74 } else {
75 None
76 };
77
78 let starknet_resp = self.execute(request).await;
79
80 if let (Err(err), Some(forwarder)) = (&starknet_resp, &self.origin_caller) {
82 if err.is_forwardable_to_origin() && is_request_forwardable {
83 return forwarder.call(&original_call).await;
86 }
87 }
88
89 if starknet_resp.is_ok() && is_request_dumpable {
90 if let Err(e) = self.update_dump(&original_call).await {
91 return ResponseResult::Error(e);
92 }
93 }
94
95 if let Err(e) = self.broadcast_changes(old_latest_block, old_pre_confirmed_block).await {
96 return ResponseResult::Error(e.api_error_to_rpc_error());
97 }
98
99 starknet_resp.to_rpc_result()
100 }
101
102 async fn on_call(&self, call: RpcMethodCall) -> RpcResponse {
103 let id = call.id.clone();
104 trace!(target: "rpc", id = ?id, method = ?call.method, "received method call");
105
106 match to_json_rpc_request(&call) {
107 Ok(req) => {
108 let result = self.on_request(req, call).await;
109 RpcResponse::new(id, result)
110 }
111 Err(e) => RpcResponse::from_rpc_error(e, id),
112 }
113 }
114
115 async fn on_websocket(&self, socket: WebSocket) {
116 let (socket_writer, mut socket_reader) = socket.split();
117 let socket_writer = Arc::new(Mutex::new(socket_writer));
118
119 let socket_id = self.api.sockets.lock().await.insert(socket_writer.clone());
120
121 let mut socket_safely_closed = false;
123 while let Some(msg) = socket_reader.next().await {
124 match msg {
125 Ok(Message::Text(text)) => {
126 self.on_websocket_call(text.as_bytes(), socket_writer.clone(), socket_id).await;
127 }
128 Ok(Message::Binary(bytes)) => {
129 self.on_websocket_call(&bytes, socket_writer.clone(), socket_id).await;
130 }
131 Ok(Message::Close(_)) => {
132 socket_safely_closed = true;
133 break;
134 }
135 other => {
136 tracing::error!("Socket handler got an unexpected message: {other:?}")
137 }
138 }
139 }
140
141 self.api.sockets.lock().await.remove(&socket_id);
142 if socket_safely_closed {
143 tracing::info!("Websocket disconnected");
144 } else {
145 tracing::error!("Failed socket read");
146 }
147 }
148}
149
150impl JsonRpcHandler {
151 pub fn new(api: Api) -> JsonRpcHandler {
152 let origin_caller = if let (Some(url), Some(block_number)) =
153 (&api.config.fork_config.url, api.config.fork_config.block_number)
154 {
155 Some(OriginForwarder::new(url.clone(), block_number))
156 } else {
157 None
158 };
159
160 JsonRpcHandler { api, origin_caller }
161 }
162
163 async fn get_block_by_tag(&self, tag: BlockTag) -> StarknetBlock {
167 let starknet = self.api.starknet.lock().await;
168 match starknet.get_block(&BlockId::Tag(tag)) {
169 Ok(block) => block.clone(),
170 _ => StarknetBlock::create_empty_accepted(),
171 }
172 }
173
174 async fn broadcast_pre_confirmed_tx_changes(
175 &self,
176 old_pre_confirmed_block: StarknetBlock,
177 ) -> Result<(), error::ApiError> {
178 let new_pre_confirmed_block = self.get_block_by_tag(BlockTag::PreConfirmed).await;
179 let old_pre_confirmed_txs = old_pre_confirmed_block.get_transactions();
180 let new_pre_confirmed_txs = new_pre_confirmed_block.get_transactions();
181
182 if new_pre_confirmed_txs.len() > old_pre_confirmed_txs.len() {
183 #[allow(clippy::expect_used)]
184 let new_tx_hash = new_pre_confirmed_txs.last().expect("has at least one element");
185
186 let mut notifications = vec![];
187 let starknet = self.api.starknet.lock().await;
188
189 let status = starknet
190 .get_transaction_execution_and_finality_status(*new_tx_hash)
191 .map_err(error::ApiError::StarknetDevnetError)?;
192 notifications.push(NotificationData::TransactionStatus(NewTransactionStatus {
193 transaction_hash: *new_tx_hash,
194 status,
195 }));
196
197 let tx = starknet
198 .get_transaction_by_hash(*new_tx_hash)
199 .map_err(error::ApiError::StarknetDevnetError)?;
200 notifications.push(NotificationData::NewTransaction(NewTransactionNotification {
201 tx: tx.clone(),
202 finality_status: TransactionFinalityStatus::PreConfirmed,
203 }));
204
205 let receipt = starknet
206 .get_transaction_receipt_by_hash(new_tx_hash)
207 .map_err(error::ApiError::StarknetDevnetError)?;
208
209 notifications.push(NotificationData::NewTransactionReceipt(
210 NewTransactionReceiptNotification {
211 tx_receipt: receipt,
212 sender_address: tx.get_sender_address(),
213 },
214 ));
215
216 let events = starknet.get_unlimited_events(
217 Some(BlockId::Tag(BlockTag::PreConfirmed)),
218 Some(BlockId::Tag(BlockTag::PreConfirmed)),
219 None,
220 None,
221 None, )?;
223
224 drop(starknet); for emitted_event in events.into_iter().filter(|e| &e.transaction_hash == new_tx_hash) {
227 notifications.push(NotificationData::Event(SubscriptionEmittedEvent {
228 emitted_event,
229 finality_status: TransactionFinalityStatus::PreConfirmed,
230 }));
231 }
232
233 self.api.sockets.lock().await.notify_subscribers(¬ifications).await;
234 }
235
236 Ok(())
237 }
238
239 async fn broadcast_latest_changes(
240 &self,
241 new_latest_block: StarknetBlock,
242 ) -> Result<(), error::ApiError> {
243 let block_header = (&new_latest_block).into();
244 let mut notifications = vec![NotificationData::NewHeads(block_header)];
245
246 let starknet = self.api.starknet.lock().await;
247
248 let finality_status = TransactionFinalityStatus::AcceptedOnL2;
249 let latest_txs = new_latest_block.get_transactions();
250 for tx_hash in latest_txs {
251 let tx = starknet
252 .get_transaction_by_hash(*tx_hash)
253 .map_err(error::ApiError::StarknetDevnetError)?;
254 notifications.push(NotificationData::NewTransaction(NewTransactionNotification {
255 tx: tx.clone(),
256 finality_status,
257 }));
258
259 let status = starknet
260 .get_transaction_execution_and_finality_status(*tx_hash)
261 .map_err(error::ApiError::StarknetDevnetError)?;
262 notifications.push(NotificationData::TransactionStatus(NewTransactionStatus {
263 transaction_hash: *tx_hash,
264 status,
265 }));
266
267 let tx_receipt = starknet
268 .get_transaction_receipt_by_hash(tx_hash)
269 .map_err(error::ApiError::StarknetDevnetError)?;
270 notifications.push(NotificationData::NewTransactionReceipt(
271 NewTransactionReceiptNotification {
272 tx_receipt,
273 sender_address: tx.get_sender_address(),
274 },
275 ));
276 }
277
278 let events = starknet.get_unlimited_events(
279 Some(BlockId::Tag(BlockTag::Latest)),
280 Some(BlockId::Tag(BlockTag::Latest)),
281 None,
282 None,
283 None, )?;
285
286 drop(starknet); for emitted_event in events {
289 notifications.push(NotificationData::Event(SubscriptionEmittedEvent {
290 emitted_event,
291 finality_status,
292 }));
293 }
294
295 self.api.sockets.lock().await.notify_subscribers(¬ifications).await;
296 Ok(())
297 }
298
299 async fn broadcast_changes(
301 &self,
302 old_latest_block: Option<StarknetBlock>,
303 old_pre_confirmed_block: Option<StarknetBlock>,
304 ) -> Result<(), error::ApiError> {
305 let Some(old_latest_block) = old_latest_block else {
306 return Ok(());
307 };
308
309 if let Some(old_pre_confirmed_block) = old_pre_confirmed_block {
310 self.broadcast_pre_confirmed_tx_changes(old_pre_confirmed_block).await?;
311 }
312
313 let new_latest_block = self.get_block_by_tag(BlockTag::Latest).await;
314
315 match new_latest_block.block_number().cmp(&old_latest_block.block_number()) {
316 std::cmp::Ordering::Less => {
317 self.broadcast_reorg(old_latest_block, new_latest_block).await?
318 }
319 std::cmp::Ordering::Equal => { }
320 std::cmp::Ordering::Greater => self.broadcast_latest_changes(new_latest_block).await?,
321 }
322
323 Ok(())
324 }
325
326 async fn broadcast_reorg(
327 &self,
328 old_latest_block: StarknetBlock,
329 new_latest_block: StarknetBlock,
330 ) -> Result<(), ApiError> {
331 let last_aborted_block_hash =
332 *self.api.starknet.lock().await.last_aborted_block_hash().ok_or(
333 ApiError::StarknetDevnetError(
334 starknet_core::error::Error::UnexpectedInternalError {
335 msg: "Aborted block hash should be defined.".into(),
336 },
337 ),
338 )?;
339
340 let notification = NotificationData::Reorg(ReorgData {
341 starting_block_hash: last_aborted_block_hash,
342 starting_block_number: new_latest_block.block_number().unchecked_next(),
343 ending_block_hash: old_latest_block.block_hash(),
344 ending_block_number: old_latest_block.block_number(),
345 });
346
347 self.api.sockets.lock().await.notify_subscribers(&[notification]).await;
348 Ok(())
349 }
350
351 async fn execute(&self, req: JsonRpcRequest) -> Result<JsonRpcResponse, error::ApiError> {
353 trace!(target: "JsonRpcHandler::execute", "executing request");
354 match req {
355 JsonRpcRequest::StarknetSpecRequest(req) => self.execute_starknet_spec(req).await,
356 JsonRpcRequest::DevnetSpecRequest(req) => self.execute_devnet_spec(req).await,
357 }
358 }
359
360 async fn execute_starknet_spec(
361 &self,
362 req: StarknetSpecRequest,
363 ) -> Result<JsonRpcResponse, error::ApiError> {
364 match req {
365 StarknetSpecRequest::SpecVersion => self.spec_version(),
366 StarknetSpecRequest::BlockWithTransactionHashes(block) => {
367 self.get_block_with_tx_hashes(block.block_id).await
368 }
369 StarknetSpecRequest::BlockWithFullTransactions(block) => {
370 self.get_block_with_txs(block.block_id).await
371 }
372 StarknetSpecRequest::BlockWithReceipts(block) => {
373 self.get_block_with_receipts(block.block_id).await
374 }
375 StarknetSpecRequest::StateUpdate(block) => self.get_state_update(block.block_id).await,
376 StarknetSpecRequest::StorageAt(GetStorageInput { contract_address, key, block_id }) => {
377 self.get_storage_at(contract_address, key, block_id).await
378 }
379 StarknetSpecRequest::TransactionStatusByHash(TransactionHashInput {
380 transaction_hash,
381 }) => self.get_transaction_status_by_hash(transaction_hash).await,
382 StarknetSpecRequest::TransactionByHash(TransactionHashInput { transaction_hash }) => {
383 self.get_transaction_by_hash(transaction_hash).await
384 }
385 StarknetSpecRequest::TransactionByBlockAndIndex(BlockAndIndexInput {
386 block_id,
387 index,
388 }) => self.get_transaction_by_block_id_and_index(block_id, index).await,
389 StarknetSpecRequest::TransactionReceiptByTransactionHash(TransactionHashInput {
390 transaction_hash,
391 }) => self.get_transaction_receipt_by_hash(transaction_hash).await,
392 StarknetSpecRequest::ClassByHash(BlockAndClassHashInput { block_id, class_hash }) => {
393 self.get_class(block_id, class_hash).await
394 }
395 StarknetSpecRequest::CompiledCasmByClassHash(ClassHashInput { class_hash }) => {
396 self.get_compiled_casm(class_hash).await
397 }
398 StarknetSpecRequest::ClassHashAtContractAddress(BlockAndContractAddressInput {
399 block_id,
400 contract_address,
401 }) => self.get_class_hash_at(block_id, contract_address).await,
402 StarknetSpecRequest::ClassAtContractAddress(BlockAndContractAddressInput {
403 block_id,
404 contract_address,
405 }) => self.get_class_at(block_id, contract_address).await,
406 StarknetSpecRequest::BlockTransactionCount(block) => {
407 self.get_block_txs_count(block.block_id).await
408 }
409 StarknetSpecRequest::Call(CallInput { request, block_id }) => {
410 self.call(block_id, request).await
411 }
412 StarknetSpecRequest::EstimateFee(EstimateFeeInput {
413 request,
414 block_id,
415 simulation_flags,
416 }) => self.estimate_fee(block_id, request, simulation_flags).await,
417 StarknetSpecRequest::BlockNumber => self.block_number().await,
418 StarknetSpecRequest::BlockHashAndNumber => self.block_hash_and_number().await,
419 StarknetSpecRequest::ChainId => self.chain_id().await,
420 StarknetSpecRequest::Syncing => self.syncing().await,
421 StarknetSpecRequest::Events(EventsInput { filter }) => self.get_events(filter).await,
422 StarknetSpecRequest::ContractNonce(BlockAndContractAddressInput {
423 block_id,
424 contract_address,
425 }) => self.get_nonce(block_id, contract_address).await,
426 StarknetSpecRequest::AddDeclareTransaction(BroadcastedDeclareTransactionInput {
427 declare_transaction,
428 }) => {
429 let BroadcastedDeclareTransactionEnumWrapper::Declare(broadcasted_transaction) =
430 declare_transaction;
431 self.add_declare_transaction(broadcasted_transaction).await
432 }
433 StarknetSpecRequest::AddDeployAccountTransaction(
434 BroadcastedDeployAccountTransactionInput { deploy_account_transaction },
435 ) => {
436 let BroadcastedDeployAccountTransactionEnumWrapper::DeployAccount(
437 broadcasted_transaction,
438 ) = deploy_account_transaction;
439 self.add_deploy_account_transaction(broadcasted_transaction).await
440 }
441 StarknetSpecRequest::AddInvokeTransaction(BroadcastedInvokeTransactionInput {
442 invoke_transaction,
443 }) => {
444 let BroadcastedInvokeTransactionEnumWrapper::Invoke(broadcasted_transaction) =
445 invoke_transaction;
446 self.add_invoke_transaction(broadcasted_transaction).await
447 }
448 StarknetSpecRequest::EstimateMessageFee(request) => {
449 self.estimate_message_fee(request.get_block_id(), request.get_raw_message().clone())
450 .await
451 }
452 StarknetSpecRequest::SimulateTransactions(SimulateTransactionsInput {
453 block_id,
454 transactions,
455 simulation_flags,
456 }) => self.simulate_transactions(block_id, transactions, simulation_flags).await,
457 StarknetSpecRequest::TraceTransaction(TransactionHashInput { transaction_hash }) => {
458 self.get_trace_transaction(transaction_hash).await
459 }
460 StarknetSpecRequest::BlockTransactionTraces(BlockIdInput { block_id }) => {
461 self.get_trace_block_transactions(block_id).await
462 }
463 StarknetSpecRequest::MessagesStatusByL1Hash(data) => {
464 self.get_messages_status(data).await
465 }
466 StarknetSpecRequest::StorageProof(data) => self.get_storage_proof(data).await,
467 }
468 }
469
470 async fn execute_devnet_spec(
471 &self,
472 req: DevnetSpecRequest,
473 ) -> Result<JsonRpcResponse, error::ApiError> {
474 match req {
475 DevnetSpecRequest::ImpersonateAccount(AccountAddressInput { account_address }) => {
476 self.impersonate_account(account_address).await
477 }
478 DevnetSpecRequest::StopImpersonateAccount(AccountAddressInput { account_address }) => {
479 self.stop_impersonating_account(account_address).await
480 }
481 DevnetSpecRequest::AutoImpersonate => self.set_auto_impersonate(true).await,
482 DevnetSpecRequest::StopAutoImpersonate => self.set_auto_impersonate(false).await,
483 DevnetSpecRequest::Dump(path) => self.dump(path).await,
484 DevnetSpecRequest::Load(LoadPath { path }) => self.load(path).await,
485 DevnetSpecRequest::PostmanLoadL1MessagingContract(data) => {
486 self.postman_load(data).await
487 }
488 DevnetSpecRequest::PostmanFlush(data) => self.postman_flush(data).await,
489 DevnetSpecRequest::PostmanSendMessageToL2(message) => {
490 self.postman_send_message_to_l2(message).await
491 }
492 DevnetSpecRequest::PostmanConsumeMessageFromL2(message) => {
493 self.postman_consume_message_from_l2(message).await
494 }
495 DevnetSpecRequest::CreateBlock => self.create_block().await,
496 DevnetSpecRequest::AbortBlocks(data) => self.abort_blocks(data).await,
497 DevnetSpecRequest::AcceptOnL1(data) => self.accept_on_l1(data).await,
498 DevnetSpecRequest::SetGasPrice(data) => self.set_gas_price(data).await,
499 DevnetSpecRequest::Restart(data) => self.restart(data).await,
500 DevnetSpecRequest::SetTime(data) => self.set_time(data).await,
501 DevnetSpecRequest::IncreaseTime(data) => self.increase_time(data).await,
502 DevnetSpecRequest::PredeployedAccounts(data) => {
503 self.get_predeployed_accounts(data).await
504 }
505 DevnetSpecRequest::AccountBalance(data) => self.get_account_balance(data).await,
506 DevnetSpecRequest::Mint(data) => self.mint(data).await,
507 DevnetSpecRequest::DevnetConfig => self.get_devnet_config().await,
508 }
509 }
510
511 async fn on_websocket_call(
513 &self,
514 bytes: &[u8],
515 ws: Arc<Mutex<SplitSink<WebSocket, Message>>>,
516 socket_id: SocketId,
517 ) {
518 let error_serialized = match serde_json::from_slice(bytes) {
519 Ok(rpc_call) => match self.on_websocket_rpc_call(&rpc_call, socket_id).await {
520 Ok(_) => return,
521 Err(e) => serde_json::to_string(&RpcResponse::from_rpc_error(e, rpc_call.id))
522 .unwrap_or_default(),
523 },
524 Err(e) => serde_json::to_string(&RpcResponse::from_rpc_error(
525 RpcError::parse_error(e.to_string()),
526 rpc_core::request::Id::Null,
527 ))
528 .unwrap_or_default(),
529 };
530
531 if let Err(e) = ws.lock().await.send(Message::Text(error_serialized.into())).await {
532 tracing::error!("Error sending websocket message: {e}");
533 }
534 }
535
536 fn allows_method(&self, method: &str) -> bool {
537 if let Some(restricted_methods) = &self.api.server_config.restricted_methods {
538 if is_json_rpc_method_restricted(method, restricted_methods) {
539 return false;
540 }
541 }
542
543 true
544 }
545
546 async fn on_websocket_rpc_call(
551 &self,
552 call: &RpcMethodCall,
553 socket_id: SocketId,
554 ) -> Result<(), RpcError> {
555 trace!(target: "rpc", id = ?call.id, method = ?call.method, "received websocket call");
556
557 let req: JsonRpcWsRequest = to_json_rpc_request(call)?;
558 match req {
559 JsonRpcWsRequest::OneTimeRequest(req) => {
560 let resp_result = self.on_request(*req, call.clone()).await;
561 let mut sockets = self.api.sockets.lock().await;
562
563 let socket_context =
564 sockets.get_mut(&socket_id).map_err(|e| e.api_error_to_rpc_error())?;
565
566 match resp_result {
567 ResponseResult::Success(result_value) => {
568 socket_context.send_rpc_response(result_value, call.id.clone()).await;
569 Ok(())
570 }
571 ResponseResult::Error(rpc_error) => Err(rpc_error),
572 }
573 }
574 JsonRpcWsRequest::SubscriptionRequest(req) => self
575 .execute_ws_subscription(req, call.id.clone(), socket_id)
576 .await
577 .map_err(|e| e.api_error_to_rpc_error()),
578 }
579 }
580
581 async fn update_dump(&self, event: &RpcMethodCall) -> Result<(), RpcError> {
582 match self.api.config.dump_on {
583 Some(DumpOn::Block) => {
584 let dump_path = self
585 .api
586 .config
587 .dump_path
588 .as_deref()
589 .ok_or(RpcError::internal_error_with("Undefined dump_path"))?;
590
591 dump_event(event, dump_path).map_err(|e| {
592 let msg = format!("Failed dumping of {}: {e}", event.method);
593 RpcError::internal_error_with(msg)
594 })?;
595 }
596 Some(DumpOn::Request | DumpOn::Exit) => {
597 self.api.dumpable_events.lock().await.push(event.clone())
598 }
599 None => (),
600 }
601
602 Ok(())
603 }
604
605 pub async fn re_execute(&self, events: &[RpcMethodCall]) -> Result<(), RpcError> {
606 for event in events {
607 if let ResponseResult::Error(e) = self.on_call(event.clone()).await.result {
608 return Err(e);
609 }
610 }
611 Ok(())
612 }
613}