1use starknet_core::error::Error;
2use starknet_rs_core::types::{
3 BlockId as ImportedBlockId, Felt, L1DataAvailabilityMode as ImportedL1DataAvailabilityMode,
4 MaybePreConfirmedBlockWithTxHashes,
5};
6use starknet_rs_providers::{Provider, ProviderError};
7use starknet_types::contract_address::ContractAddress;
8use starknet_types::emitted_event::{EmittedEvent, SubscriptionEmittedEvent};
9use starknet_types::felt::TransactionHash;
10use starknet_types::rpc::block::{BlockHeader, BlockId, BlockStatus, BlockTag};
11use starknet_types::rpc::transactions::TransactionFinalityStatus;
12use starknet_types::starknet_api::block::{BlockNumber, BlockTimestamp};
13use starknet_types::starknet_api::core::{
14 EventCommitment, ReceiptCommitment, StateDiffCommitment, TransactionCommitment,
15};
16use starknet_types::starknet_api::data_availability::L1DataAvailabilityMode;
17use starknet_types::starknet_api::hash::PoseidonHash;
18
19use super::JsonRpcHandler;
20use super::error::ApiError;
21use super::models::{
22 EventsSubscriptionInput, SubscriptionBlockIdInput, SubscriptionIdInput, TransactionHashInput,
23 TransactionReceiptSubscriptionInput, TransactionSubscriptionInput,
24};
25use crate::api::models::JsonRpcSubscriptionRequest;
26use crate::rpc_core::request::Id;
27use crate::subscribe::{
28 AddressFilter, NewTransactionStatus, NotificationData, SocketId, StatusFilter, Subscription,
29};
30
31impl JsonRpcHandler {
33 pub async fn execute_ws_subscription(
34 &self,
35 request: JsonRpcSubscriptionRequest,
36 rpc_request_id: Id,
37 socket_id: SocketId,
38 ) -> Result<(), ApiError> {
39 match request {
40 JsonRpcSubscriptionRequest::NewHeads(data) => {
41 self.subscribe_new_heads(data, rpc_request_id, socket_id).await
42 }
43 JsonRpcSubscriptionRequest::TransactionStatus(TransactionHashInput {
44 transaction_hash,
45 }) => self.subscribe_tx_status(transaction_hash, rpc_request_id, socket_id).await,
46 JsonRpcSubscriptionRequest::NewTransactions(data) => {
47 self.subscribe_new_txs(data, rpc_request_id, socket_id).await
48 }
49 JsonRpcSubscriptionRequest::NewTransactionReceipts(data) => {
50 self.subscribe_new_tx_receipts(data, rpc_request_id, socket_id).await
51 }
52 JsonRpcSubscriptionRequest::Events(data) => {
53 self.subscribe_events(data, rpc_request_id, socket_id).await
54 }
55 JsonRpcSubscriptionRequest::Unsubscribe(SubscriptionIdInput { subscription_id }) => {
56 let mut sockets = self.api.sockets.lock().await;
57 let socket_context = sockets.get_mut(&socket_id)?;
58 socket_context.unsubscribe(rpc_request_id, subscription_id).await
59 }
60 }
61 }
62
63 async fn get_origin_block_header_by_id(&self, id: BlockId) -> Result<BlockHeader, ApiError> {
64 let origin_caller = self.origin_caller.as_ref().ok_or_else(|| {
65 ApiError::StarknetDevnetError(Error::UnexpectedInternalError {
66 msg: "No origin caller available".into(),
67 })
68 })?;
69 match origin_caller
70 .starknet_client
71 .get_block_with_tx_hashes(ImportedBlockId::from(id))
72 .await
73 {
74 Ok(MaybePreConfirmedBlockWithTxHashes::Block(origin_block)) => {
75 let origin_header = BlockHeader {
76 block_hash: origin_block.block_hash,
77 parent_hash: origin_block.parent_hash,
78 block_number: BlockNumber(origin_block.block_number),
79 l1_gas_price: origin_block.l1_gas_price.into(),
80 l2_gas_price: origin_block.l2_gas_price.into(),
81 new_root: origin_block.new_root,
82 sequencer_address: ContractAddress::new_unchecked(
83 origin_block.sequencer_address,
84 ),
85 timestamp: BlockTimestamp(origin_block.timestamp),
86 starknet_version: origin_block.starknet_version,
87 l1_data_gas_price: origin_block.l1_data_gas_price.into(),
88 l1_da_mode: match origin_block.l1_da_mode {
89 ImportedL1DataAvailabilityMode::Calldata => {
90 L1DataAvailabilityMode::Calldata
91 }
92 ImportedL1DataAvailabilityMode::Blob => L1DataAvailabilityMode::Blob,
93 },
94 n_transactions: origin_block.transaction_count,
95 n_events: origin_block.event_count,
96 state_diff_length: origin_block.state_diff_length,
97 state_diff_commitment: StateDiffCommitment(PoseidonHash(
98 origin_block.state_diff_commitment,
99 )),
100 transaction_commitment: TransactionCommitment(
101 origin_block.transaction_commitment,
102 ),
103 event_commitment: EventCommitment(origin_block.event_commitment),
104 receipt_commitment: ReceiptCommitment(origin_block.receipt_commitment),
105 };
106 Ok(origin_header)
107 }
108 Err(ProviderError::StarknetError(
109 starknet_rs_core::types::StarknetError::BlockNotFound,
110 )) => Err(ApiError::BlockNotFound),
111 other => Err(ApiError::StarknetDevnetError(
112 starknet_core::error::Error::UnexpectedInternalError {
113 msg: format!("Failed retrieval of block from forking origin. Got: {other:?}"),
114 },
115 )),
116 }
117 }
118
119 async fn get_local_block_header_by_id(&self, id: &BlockId) -> Result<BlockHeader, ApiError> {
120 let starknet = self.api.starknet.lock().await;
121
122 let block = match starknet.get_block(id) {
123 Ok(block) => match block.status() {
124 BlockStatus::Rejected => return Err(ApiError::BlockNotFound),
125 _ => Ok::<_, ApiError>(block),
126 },
127 Err(Error::NoBlock) => Err(ApiError::BlockNotFound),
128 Err(other) => Err(ApiError::StarknetDevnetError(other)),
129 }?;
130
131 Ok(block.into())
132 }
133
134 async fn get_validated_block_number_range(
137 &self,
138 mut starting_block_id: BlockId,
139 ) -> Result<(u64, u64, Option<(u64, u64)>), ApiError> {
140 starting_block_id = match starting_block_id {
142 BlockId::Tag(BlockTag::PreConfirmed) => BlockId::Tag(BlockTag::Latest),
143 other => other,
144 };
145
146 let query_block_number = match starting_block_id {
147 BlockId::Number(n) => n,
148 block_id => match self.get_local_block_header_by_id(&block_id).await {
149 Ok(block) => block.block_number.0,
150 Err(ApiError::BlockNotFound) if self.origin_caller.is_some() => {
151 self.get_origin_block_header_by_id(block_id).await?.block_number.0
152 }
153 Err(other) => return Err(other),
154 },
155 };
156
157 let starknet = self.api.starknet.lock().await;
158 let latest_block_number =
159 starknet.get_block(&BlockId::Tag(BlockTag::Latest))?.block_number().0;
160 drop(starknet);
161
162 let (fork_url, fork_block_number) =
163 (self.api.config.fork_config.url.clone(), self.api.config.fork_config.block_number);
164
165 if query_block_number > latest_block_number {
166 return Err(ApiError::BlockNotFound);
167 }
168 if latest_block_number - query_block_number > 1024 {
169 return Err(ApiError::TooManyBlocksBack);
170 }
171
172 let origin_block_range = match (fork_url, fork_block_number) {
174 (Some(_url), Some(fork_block_number)) => {
175 if query_block_number <= fork_block_number {
178 Some((query_block_number, fork_block_number))
179 } else {
180 None
181 }
182 }
183 _ => None, };
185
186 let validated_start_block_number =
187 if let Some(origin) = origin_block_range { origin.1 + 1 } else { query_block_number };
188
189 Ok((validated_start_block_number, latest_block_number, origin_block_range))
190 }
191
192 async fn fetch_origin_heads(
193 &self,
194 start_block: u64,
195 end_block: u64,
196 ) -> Result<Vec<BlockHeader>, ApiError> {
197 let mut headers = Vec::new();
198 for block_n in start_block..=end_block {
199 let block_id = BlockId::Number(block_n);
200 headers.push(self.get_origin_block_header_by_id(block_id).await?);
201 }
202 Ok(headers)
203 }
204
205 async fn subscribe_new_heads(
211 &self,
212 block_input: Option<SubscriptionBlockIdInput>,
213 rpc_request_id: Id,
214 socket_id: SocketId,
215 ) -> Result<(), ApiError> {
216 let block_id = if let Some(SubscriptionBlockIdInput { block_id }) = block_input {
217 block_id.into()
218 } else {
219 BlockId::Tag(BlockTag::Latest)
221 };
222
223 let (query_block_number, latest_block_number, origin_range) =
224 self.get_validated_block_number_range(block_id).await?;
225
226 let mut sockets = self.api.sockets.lock().await;
228 let socket_context = sockets.get_mut(&socket_id)?;
229 let subscription_id =
230 socket_context.subscribe(rpc_request_id, Subscription::NewHeads).await;
231
232 if let BlockId::Tag(_) = block_id {
233 return Ok(());
235 }
236
237 let mut headers = Vec::new();
238 if let Some((origin_start, origin_end)) = origin_range {
239 let origin_headers = self.fetch_origin_heads(origin_start, origin_end).await?;
242 headers.extend(origin_headers.iter().cloned());
243 }
244
245 let starknet = self.api.starknet.lock().await;
248 for block_n in query_block_number..=latest_block_number {
249 let old_block = starknet
250 .get_block(&BlockId::Number(block_n))
251 .map_err(ApiError::StarknetDevnetError)?;
252
253 headers.push(old_block.into());
254 }
255
256 for header in headers {
257 let notification = NotificationData::NewHeads(header);
258 socket_context.notify(subscription_id, notification).await;
259 }
260
261 Ok(())
262 }
263
264 pub async fn subscribe_new_txs(
266 &self,
267 maybe_subscription_input: Option<TransactionSubscriptionInput>,
268 rpc_request_id: Id,
269 socket_id: SocketId,
270 ) -> Result<(), ApiError> {
271 let status_filter = StatusFilter::new(
272 maybe_subscription_input
273 .as_ref()
274 .and_then(|input| input.finality_status.as_ref())
275 .map_or_else(
276 || vec![TransactionFinalityStatus::AcceptedOnL2],
277 |statuses| {
278 statuses.iter().cloned().map(TransactionFinalityStatus::from).collect()
279 },
280 ),
281 );
282
283 let address_filter = AddressFilter::new(
284 maybe_subscription_input
285 .and_then(|subscription_input| subscription_input.sender_address)
286 .unwrap_or_default(),
287 );
288
289 let mut sockets = self.api.sockets.lock().await;
290 let socket_context = sockets.get_mut(&socket_id)?;
291
292 let subscription = Subscription::NewTransactions { address_filter, status_filter };
293 socket_context.subscribe(rpc_request_id, subscription).await;
294
295 Ok(())
296 }
297
298 pub async fn subscribe_new_tx_receipts(
300 &self,
301 maybe_subscription_input: Option<TransactionReceiptSubscriptionInput>,
302 rpc_request_id: Id,
303 socket_id: SocketId,
304 ) -> Result<(), ApiError> {
305 let status_filter = StatusFilter::new(
306 maybe_subscription_input
307 .as_ref()
308 .and_then(|input| input.finality_status.as_ref())
309 .map_or_else(
310 || vec![TransactionFinalityStatus::AcceptedOnL2],
311 |statuses| {
312 statuses.iter().cloned().map(TransactionFinalityStatus::from).collect()
313 },
314 ),
315 );
316
317 let address_filter = AddressFilter::new(
318 maybe_subscription_input
319 .and_then(|subscription_input| subscription_input.sender_address)
320 .unwrap_or_default(),
321 );
322
323 let mut sockets = self.api.sockets.lock().await;
324 let socket_context = sockets.get_mut(&socket_id)?;
325
326 let subscription = Subscription::NewTransactionReceipts { address_filter, status_filter };
327 socket_context.subscribe(rpc_request_id, subscription).await;
328
329 Ok(())
330 }
331
332 async fn subscribe_tx_status(
333 &self,
334 transaction_hash: TransactionHash,
335 rpc_request_id: Id,
336 socket_id: SocketId,
337 ) -> Result<(), ApiError> {
338 let mut sockets = self.api.sockets.lock().await;
340 let socket_context = sockets.get_mut(&socket_id)?;
341
342 let subscription = Subscription::TransactionStatus { transaction_hash };
343 let subscription_id = socket_context.subscribe(rpc_request_id, subscription).await;
344
345 let starknet = self.api.starknet.lock().await;
346
347 if let Some(tx) = starknet.transactions.get(&transaction_hash) {
348 let notification = NotificationData::TransactionStatus(NewTransactionStatus {
349 transaction_hash,
350 status: tx.get_status(),
351 });
352 socket_context.notify(subscription_id, notification).await;
353 } else {
354 tracing::debug!("Tx status subscription: tx not yet received")
355 }
356
357 Ok(())
358 }
359
360 async fn fetch_origin_events(
361 &self,
362 from_block: u64,
363 to_block: u64,
364 address: Option<ContractAddress>,
365 keys_filter: Option<Vec<Vec<Felt>>>,
366 ) -> Result<Vec<EmittedEvent>, ApiError> {
367 const DEFAULT_CHUNK_SIZE: u64 = 1000;
368 let mut continuation_token: Option<String> = None;
369 let mut all_events = Vec::new();
370
371 loop {
373 let events_chunk = self
374 .fetch_origin_events_chunk(
375 from_block,
376 to_block,
377 continuation_token,
378 address,
379 keys_filter.clone(),
380 DEFAULT_CHUNK_SIZE,
381 )
382 .await?;
383
384 all_events.extend(events_chunk.events);
386
387 match events_chunk.continuation_token {
389 Some(token) if token == "0" => break,
390 Some(token) => continuation_token = Some(token),
391 None => break,
392 }
393 }
394
395 Ok(all_events)
396 }
397
398 async fn subscribe_events(
399 &self,
400 maybe_subscription_input: Option<EventsSubscriptionInput>,
401 rpc_request_id: Id,
402 socket_id: SocketId,
403 ) -> Result<(), ApiError> {
404 let address = maybe_subscription_input
405 .as_ref()
406 .and_then(|subscription_input| subscription_input.from_address);
407
408 let starting_block_id = maybe_subscription_input
409 .as_ref()
410 .and_then(|subscription_input| subscription_input.block_id.as_ref().map(BlockId::from))
411 .unwrap_or(BlockId::Tag(BlockTag::Latest));
412
413 let (validated_start_block_number, _, origin_range) =
414 self.get_validated_block_number_range(starting_block_id).await?;
415
416 let keys_filter = maybe_subscription_input
417 .as_ref()
418 .and_then(|subscription_input| subscription_input.keys.clone());
419
420 let finality_status = maybe_subscription_input
421 .and_then(|subscription_input| subscription_input.finality_status)
422 .unwrap_or(TransactionFinalityStatus::AcceptedOnL2);
423
424 let mut sockets = self.api.sockets.lock().await;
425 let socket_context = sockets.get_mut(&socket_id)?;
426 let subscription = Subscription::Events {
427 address,
428 keys_filter: keys_filter.clone(),
429 status_filter: StatusFilter::new(vec![finality_status]),
430 };
431 let subscription_id = socket_context.subscribe(rpc_request_id, subscription).await;
432
433 let origin_events = if let Some((origin_start, origin_end)) = origin_range {
435 Some(
436 self.fetch_origin_events(origin_start, origin_end, address, keys_filter.clone())
437 .await?,
438 )
439 } else {
440 None
441 };
442
443 let local_events = self.api.starknet.lock().await.get_unlimited_events(
445 Some(BlockId::Number(validated_start_block_number)),
446 Some(BlockId::Tag(BlockTag::PreConfirmed)), address,
448 keys_filter,
449 Some(finality_status),
450 )?;
451
452 if let Some(origin_events) = origin_events {
454 for event in origin_events {
455 let notification_data = NotificationData::Event(SubscriptionEmittedEvent {
456 emitted_event: event,
457 finality_status,
458 });
459 socket_context.notify(subscription_id, notification_data).await;
460 }
461 }
462
463 for event in local_events {
465 let notification_data = NotificationData::Event(SubscriptionEmittedEvent {
466 emitted_event: event,
467 finality_status,
468 });
469 socket_context.notify(subscription_id, notification_data).await;
470 }
471
472 Ok(())
473 }
474}