1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::extract::ws::{Message, WebSocket};
5use futures::SinkExt;
6use futures::stream::SplitSink;
7use serde::{self, Deserialize, Serialize};
8use starknet_core::starknet::events::check_if_filter_applies_for_event;
9use starknet_rs_core::types::Felt;
10use starknet_types::contract_address::ContractAddress;
11use starknet_types::emitted_event::SubscriptionEmittedEvent;
12use starknet_types::felt::TransactionHash;
13use starknet_types::rpc::block::{BlockHeader, ReorgData};
14use starknet_types::rpc::transaction_receipt::TransactionReceipt;
15use starknet_types::rpc::transactions::{
16 TransactionFinalityStatus, TransactionStatus, TransactionWithHash,
17};
18use tokio::sync::Mutex;
19
20use crate::api::error::ApiError;
21use crate::api::models::SubscriptionId;
22use crate::rpc_core::request::Id;
23
24pub type SocketId = u64;
25
26#[derive(Default)]
27pub struct SocketCollection {
28 sockets: HashMap<SocketId, SocketContext>,
29}
30
31impl SocketCollection {
32 pub fn get_mut(&mut self, socket_id: &SocketId) -> Result<&mut SocketContext, ApiError> {
33 self.sockets.get_mut(socket_id).ok_or(ApiError::StarknetDevnetError(
34 starknet_core::error::Error::UnexpectedInternalError {
35 msg: format!("Unregistered socket ID: {socket_id}"),
36 },
37 ))
38 }
39
40 pub fn insert(&mut self, socket_writer: Arc<Mutex<SplitSink<WebSocket, Message>>>) -> SocketId {
42 let socket_id = rand::random();
43 self.sockets.insert(socket_id, SocketContext::from_sender(socket_writer));
44 socket_id
45 }
46
47 pub fn remove(&mut self, socket_id: &SocketId) {
48 self.sockets.remove(socket_id);
49 }
50
51 pub async fn notify_subscribers(&self, notifications: &[NotificationData]) {
52 for (_, socket_context) in self.sockets.iter() {
53 for notification in notifications {
54 socket_context.notify_subscribers(notification).await;
55 }
56 }
57 }
58
59 pub fn clear(&mut self) {
60 self.sockets
61 .iter_mut()
62 .for_each(|(_, socket_context)| socket_context.subscriptions.clear());
63 tracing::info!("Websocket memory cleared. No subscribers.");
64 }
65}
66
67#[derive(Debug)]
68pub struct AddressFilter {
69 address_container: Vec<ContractAddress>,
70}
71
72impl AddressFilter {
73 pub(crate) fn new(address_container: Vec<ContractAddress>) -> Self {
74 Self { address_container }
75 }
76 pub(crate) fn passes(&self, address: &ContractAddress) -> bool {
77 self.address_container.is_empty() || self.address_container.contains(address)
78 }
79}
80
81#[derive(Debug, Clone)]
82pub struct StatusFilter {
83 status_container: Vec<TransactionFinalityStatus>,
84}
85
86impl StatusFilter {
87 pub(crate) fn new(status_container: Vec<TransactionFinalityStatus>) -> Self {
88 Self { status_container }
89 }
90
91 pub(crate) fn passes(&self, status: &TransactionFinalityStatus) -> bool {
92 self.status_container.is_empty() || self.status_container.contains(status)
93 }
94}
95
96#[derive(Debug)]
97pub enum Subscription {
98 NewHeads,
99 TransactionStatus {
100 transaction_hash: TransactionHash,
101 },
102 NewTransactions {
103 address_filter: AddressFilter,
104 status_filter: StatusFilter,
105 },
106 NewTransactionReceipts {
107 address_filter: AddressFilter,
108 status_filter: StatusFilter,
109 },
110 Events {
111 address: Option<ContractAddress>,
112 keys_filter: Option<Vec<Vec<Felt>>>,
113 status_filter: StatusFilter,
114 },
115}
116
117impl Subscription {
118 fn confirm(&self, id: SubscriptionId) -> SubscriptionConfirmation {
119 match self {
120 Subscription::NewHeads => SubscriptionConfirmation::NewSubscription(id),
121 Subscription::TransactionStatus { .. } => SubscriptionConfirmation::NewSubscription(id),
122 Subscription::NewTransactions { .. } => SubscriptionConfirmation::NewSubscription(id),
123 Subscription::NewTransactionReceipts { .. } => {
124 SubscriptionConfirmation::NewSubscription(id)
125 }
126 Subscription::Events { .. } => SubscriptionConfirmation::NewSubscription(id),
127 }
128 }
129
130 pub fn matches(&self, notification: &NotificationData) -> bool {
131 match (self, notification) {
132 (Subscription::NewHeads, NotificationData::NewHeads(_)) => true,
133 (
134 Subscription::TransactionStatus { transaction_hash: subscription_hash },
135 NotificationData::TransactionStatus(notification),
136 ) => subscription_hash == ¬ification.transaction_hash,
137 (
138 Subscription::NewTransactions { address_filter, status_filter },
139 NotificationData::NewTransaction(NewTransactionNotification {
140 tx,
141 finality_status,
142 }),
143 ) => match tx.get_sender_address() {
144 Some(address) => {
145 address_filter.passes(&address) && status_filter.passes(finality_status)
146 }
147 None => true,
148 },
149 (
150 Subscription::NewTransactionReceipts { address_filter, status_filter },
151 NotificationData::NewTransactionReceipt(NewTransactionReceiptNotification {
152 tx_receipt,
153 sender_address,
154 }),
155 ) => {
156 status_filter.passes(tx_receipt.finality_status())
157 && match sender_address {
158 Some(address) => address_filter.passes(address),
159 None => true,
160 }
161 }
162 (
163 Subscription::Events { address, keys_filter, status_filter },
164 NotificationData::Event(event_with_finality_status),
165 ) => {
166 let event = (&event_with_finality_status.emitted_event).into();
167 check_if_filter_applies_for_event(address, keys_filter, &event)
168 && status_filter.passes(&event_with_finality_status.finality_status)
169 }
170 (
171 Subscription::NewHeads
172 | Subscription::TransactionStatus { .. }
173 | Subscription::Events { .. }
174 | Subscription::NewTransactions { .. }
175 | Subscription::NewTransactionReceipts { .. },
176 NotificationData::Reorg(_),
177 ) => true, _ => false,
179 }
180 }
181}
182
183#[derive(Debug, Serialize)]
184#[serde(untagged)]
185#[cfg_attr(test, derive(Deserialize))]
186pub(crate) enum SubscriptionConfirmation {
187 NewSubscription(SubscriptionId),
188 Unsubscription(bool),
189}
190
191#[derive(Debug, Clone, Serialize)]
192#[cfg_attr(test, derive(Deserialize))]
193pub struct NewTransactionStatus {
194 pub transaction_hash: TransactionHash,
195 pub status: TransactionStatus,
196}
197
198#[derive(Debug, Clone)]
199pub struct TransactionHashWrapper {
200 pub hash: TransactionHash,
201 pub sender_address: Option<ContractAddress>,
202}
203
204impl Serialize for TransactionHashWrapper {
205 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
206 where
207 S: serde::Serializer,
208 {
209 self.hash.serialize(serializer)
210 }
211}
212
213impl<'de> Deserialize<'de> for TransactionHashWrapper {
214 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
215 where
216 D: serde::Deserializer<'de>,
217 {
218 let hash = Felt::deserialize(deserializer)?;
219
220 Ok(TransactionHashWrapper { hash, sender_address: None })
221 }
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
225#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
226pub enum TransactionFinalityStatusWithoutL1 {
227 PreConfirmed,
228 AcceptedOnL2,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
232#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
233pub enum TransactionStatusWithoutL1 {
234 Received,
235 Candidate,
236 PreConfirmed,
237 AcceptedOnL2,
238}
239
240impl From<TransactionFinalityStatusWithoutL1> for TransactionFinalityStatus {
241 fn from(status: TransactionFinalityStatusWithoutL1) -> Self {
242 match status {
243 TransactionFinalityStatusWithoutL1::PreConfirmed => Self::PreConfirmed,
244 TransactionFinalityStatusWithoutL1::AcceptedOnL2 => Self::AcceptedOnL2,
245 }
246 }
247}
248
249impl From<TransactionStatusWithoutL1> for TransactionFinalityStatus {
250 fn from(status: TransactionStatusWithoutL1) -> Self {
251 match status {
252 TransactionStatusWithoutL1::Received => Self::Received,
253 TransactionStatusWithoutL1::Candidate => Self::Candidate,
254 TransactionStatusWithoutL1::PreConfirmed => Self::PreConfirmed,
255 TransactionStatusWithoutL1::AcceptedOnL2 => Self::AcceptedOnL2,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize)]
261#[cfg_attr(test, derive(Deserialize))]
262pub struct NewTransactionNotification {
263 #[serde(flatten)]
264 pub tx: TransactionWithHash,
265 pub finality_status: TransactionFinalityStatus,
266}
267
268#[derive(Debug, Clone)]
269pub struct NewTransactionReceiptNotification {
270 pub tx_receipt: TransactionReceipt,
271 pub sender_address: Option<ContractAddress>,
272}
273
274#[derive(Debug, Clone)]
275pub enum NotificationData {
276 NewHeads(BlockHeader),
277 TransactionStatus(NewTransactionStatus),
278 NewTransaction(NewTransactionNotification),
279 NewTransactionReceipt(NewTransactionReceiptNotification),
280 Event(SubscriptionEmittedEvent),
281 Reorg(ReorgData),
282}
283
284#[derive(Debug, Serialize)]
285#[serde(untagged)]
286#[cfg_attr(test, derive(Deserialize))]
287pub(crate) enum SubscriptionResponse {
288 Confirmation {
289 #[serde(rename = "id")]
290 rpc_request_id: Id,
291 result: SubscriptionConfirmation,
292 },
293 Notification(Box<SubscriptionNotification>),
294}
295
296#[derive(Serialize, Debug)]
297#[cfg_attr(test, derive(Deserialize))]
298#[serde(tag = "method", content = "params")]
299pub(crate) enum SubscriptionNotification {
300 #[serde(rename = "starknet_subscriptionNewHeads")]
301 NewHeads { subscription_id: SubscriptionId, result: BlockHeader },
302 #[serde(rename = "starknet_subscriptionTransactionStatus")]
303 TransactionStatus { subscription_id: SubscriptionId, result: NewTransactionStatus },
304 #[serde(rename = "starknet_subscriptionNewTransaction")]
305 NewTransaction { subscription_id: SubscriptionId, result: NewTransactionNotification },
306 #[serde(rename = "starknet_subscriptionNewTransactionReceipts")]
307 NewTransactionReceipt { subscription_id: SubscriptionId, result: TransactionReceipt },
308 #[serde(rename = "starknet_subscriptionEvents")]
309 Event { subscription_id: SubscriptionId, result: SubscriptionEmittedEvent },
310 #[serde(rename = "starknet_subscriptionReorg")]
311 Reorg { subscription_id: SubscriptionId, result: ReorgData },
312}
313
314impl SubscriptionResponse {
315 fn to_serialized_rpc_response(&self) -> serde_json::Value {
316 let mut resp = serde_json::json!(self);
317
318 resp["jsonrpc"] = "2.0".into();
319 resp
320 }
321}
322
323pub struct SocketContext {
324 sender: Arc<Mutex<SplitSink<WebSocket, Message>>>,
326 subscriptions: HashMap<SubscriptionId, Subscription>,
327}
328
329impl SocketContext {
330 pub fn from_sender(sender: Arc<Mutex<SplitSink<WebSocket, Message>>>) -> Self {
331 Self { sender, subscriptions: HashMap::new() }
332 }
333
334 async fn send_serialized(&self, resp: String) {
335 if let Err(e) = self.sender.lock().await.send(Message::Text(resp.into())).await {
336 tracing::error!("Failed writing to socket: {}", e.to_string());
337 }
338 }
339
340 pub async fn send_rpc_response(&self, result: serde_json::Value, id: Id) {
341 let resp_serialized = serde_json::json!({
342 "jsonrpc": "2.0",
343 "id": id,
344 "result": result,
345 })
346 .to_string();
347
348 tracing::trace!(target: "ws.json-rpc-api", response = %resp_serialized, "JSON-RPC response");
349 self.send_serialized(resp_serialized).await;
350 }
351
352 async fn send_subscription_response(&self, resp: SubscriptionResponse) {
353 let resp_serialized = resp.to_serialized_rpc_response().to_string();
354
355 tracing::trace!(target: "ws.subscriptions", response = %resp_serialized, "subscription response");
356 self.send_serialized(resp_serialized).await;
357 }
358
359 pub async fn subscribe(
360 &mut self,
361 rpc_request_id: Id,
362 subscription: Subscription,
363 ) -> SubscriptionId {
364 loop {
365 let subscription_id: SubscriptionId = rand::random::<u64>().into();
366 if self.subscriptions.contains_key(&subscription_id) {
367 continue;
368 }
369
370 let confirmation = subscription.confirm(subscription_id);
371 self.subscriptions.insert(subscription_id, subscription);
372
373 self.send_subscription_response(SubscriptionResponse::Confirmation {
374 rpc_request_id,
375 result: confirmation,
376 })
377 .await;
378
379 return subscription_id;
380 }
381 }
382
383 pub async fn unsubscribe(
384 &mut self,
385 rpc_request_id: Id,
386 subscription_id: SubscriptionId,
387 ) -> Result<(), ApiError> {
388 self.subscriptions.remove(&subscription_id).ok_or(ApiError::InvalidSubscriptionId)?;
389 self.send_subscription_response(SubscriptionResponse::Confirmation {
390 rpc_request_id,
391 result: SubscriptionConfirmation::Unsubscription(true),
392 })
393 .await;
394 Ok(())
395 }
396
397 pub async fn notify(&self, subscription_id: SubscriptionId, data: NotificationData) {
398 let notification_data = match data {
399 NotificationData::NewHeads(block_header) => {
400 SubscriptionNotification::NewHeads { subscription_id, result: block_header }
401 }
402
403 NotificationData::TransactionStatus(new_transaction_status) => {
404 SubscriptionNotification::TransactionStatus {
405 subscription_id,
406 result: new_transaction_status,
407 }
408 }
409
410 NotificationData::NewTransaction(tx_notification) => {
411 SubscriptionNotification::NewTransaction {
412 subscription_id,
413 result: tx_notification,
414 }
415 }
416
417 NotificationData::NewTransactionReceipt(tx_receipt_notification) => {
418 SubscriptionNotification::NewTransactionReceipt {
419 subscription_id,
420 result: tx_receipt_notification.tx_receipt,
421 }
422 }
423
424 NotificationData::Event(emitted_event) => {
425 SubscriptionNotification::Event { subscription_id, result: emitted_event }
426 }
427
428 NotificationData::Reorg(reorg_data) => {
429 SubscriptionNotification::Reorg { subscription_id, result: reorg_data }
430 }
431 };
432
433 self.send_subscription_response(SubscriptionResponse::Notification(Box::new(
434 notification_data,
435 )))
436 .await;
437 }
438
439 pub async fn notify_subscribers(&self, notification: &NotificationData) {
440 for (subscription_id, subscription) in self.subscriptions.iter() {
441 if subscription.matches(notification) {
442 self.notify(*subscription_id, notification.clone()).await;
443 }
444 }
445 }
446}