1use std::{collections::HashSet, marker::PhantomData, sync::Arc, time::Instant};
17
18use futures::{SinkExt, StreamExt};
19use serde::{Deserialize, de::DeserializeOwned};
20use simulator_api::{
21 subscribe_config::{Compression, SubscribeConfig},
22 ws_compression::WsStreamDecompressor,
23};
24use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
25use tokio::{
26 net::TcpStream,
27 sync::{mpsc, watch},
28 task::JoinHandle,
29};
30use tokio_tungstenite::{
31 MaybeTlsStream, WebSocketStream, connect_async,
32 tungstenite::{Message, client::IntoClientRequest},
33};
34use tokio_util::sync::CancellationToken;
35use tracing::{debug, warn};
36
37use super::{
38 CONNECT_TIMEOUT, ConnectionStatus, HANDSHAKE_RESPONSE_TIMEOUT, KEEPALIVE_INTERVAL,
39 KEEPALIVE_MISS_DEADLINE, RECONNECT_UNGATED_ATTEMPTS, RECONNECT_UPTIME_RESET, ReconnectBudget,
40 ReconnectCoordinator, cancellable_sleep,
41};
42use crate::{
43 error::err_chain,
44 subscriptions::{AccountDiffNotification, ActionResultNotification},
45 urls::http_to_ws_url,
46};
47
48pub struct SubscriptionHandle {
50 pub status: watch::Receiver<ConnectionStatus>,
51 pub notifications: mpsc::Receiver<SubscriptionNotification>,
52 pub join: JoinHandle<()>,
53}
54
55#[derive(Debug)]
56pub enum SubscriptionNotification {
57 Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
58 AccountDiff(AccountDiffNotification),
59 ActionResult(ActionResultNotification),
60}
61
62trait SubKind: Send + Sync + 'static {
64 type Notification: DeserializeOwned + Send + 'static;
65 const LABEL: &'static str;
66 const SUBSCRIBE_METHOD: &'static str;
67 const NOTIFICATION_METHOD: &'static str;
68 fn subscribe_params(program_id: &str) -> serde_json::Value;
69 fn into_notification(notification: Self::Notification) -> SubscriptionNotification;
70 fn slot_of(notification: &Self::Notification) -> u64;
74}
75
76struct AccountDiff;
77impl SubKind for AccountDiff {
78 type Notification = AccountDiffNotification;
79 const LABEL: &'static str = "account-diff";
80 const SUBSCRIBE_METHOD: &'static str = "accountDiffSubscribe";
81 const NOTIFICATION_METHOD: &'static str = "accountDiffNotification";
82 fn subscribe_params(program_id: &str) -> serde_json::Value {
83 serde_json::json!([program_id, {"address_type": "program"}])
84 }
85 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
86 SubscriptionNotification::AccountDiff(notification)
87 }
88 fn slot_of(notification: &Self::Notification) -> u64 {
89 notification.context.slot
90 }
91}
92
93struct Transaction;
94impl SubKind for Transaction {
95 type Notification = EncodedConfirmedTransactionWithStatusMeta;
99 const LABEL: &'static str = "transaction";
100 const SUBSCRIBE_METHOD: &'static str = "transactionSubscribe";
101 const NOTIFICATION_METHOD: &'static str = "transactionNotification";
102 fn subscribe_params(program_id: &str) -> serde_json::Value {
103 serde_json::json!([{"mentions": [program_id]}, {"commitment": "confirmed"}])
104 }
105 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
106 SubscriptionNotification::Transaction(Box::new(notification))
107 }
108 fn slot_of(notification: &Self::Notification) -> u64 {
109 notification.slot
110 }
111}
112
113struct ActionResult;
114impl SubKind for ActionResult {
115 type Notification = ActionResultNotification;
116 const LABEL: &'static str = "action";
117 const SUBSCRIBE_METHOD: &'static str = "actionSubscribe";
118 const NOTIFICATION_METHOD: &'static str = "actionNotification";
119 fn subscribe_params(_program_id: &str) -> serde_json::Value {
120 serde_json::json!([serde_json::Value::Null, {}])
123 }
124 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
125 SubscriptionNotification::ActionResult(notification)
126 }
127 fn slot_of(notification: &Self::Notification) -> u64 {
128 notification.context.slot
129 }
130}
131
132pub fn spawn_transaction_subscription_manager(
133 rpc_endpoint: String,
134 program_ids: Vec<String>,
135 cancel: CancellationToken,
136 coordinator: Option<Arc<ReconnectCoordinator>>,
137) -> SubscriptionHandle {
138 spawn_subscription_manager::<Transaction>(rpc_endpoint, program_ids, cancel, coordinator)
139}
140
141pub fn spawn_action_subscription_manager(
143 rpc_endpoint: String,
144 cancel: CancellationToken,
145 coordinator: Option<Arc<ReconnectCoordinator>>,
146) -> SubscriptionHandle {
147 spawn_subscription_manager::<ActionResult>(
148 rpc_endpoint,
149 vec![String::new()],
150 cancel,
151 coordinator,
152 )
153}
154
155pub fn spawn_account_diff_subscription_manager(
156 rpc_endpoint: String,
157 program_ids: Vec<String>,
158 cancel: CancellationToken,
159 coordinator: Option<Arc<ReconnectCoordinator>>,
160) -> SubscriptionHandle {
161 spawn_subscription_manager::<AccountDiff>(rpc_endpoint, program_ids, cancel, coordinator)
162}
163
164fn spawn_subscription_manager<K>(
165 rpc_endpoint: String,
166 program_ids: Vec<String>,
167 cancel: CancellationToken,
168 coordinator: Option<Arc<ReconnectCoordinator>>,
169) -> SubscriptionHandle
170where
171 K: SubKind,
172{
173 let (notifications_tx, notifications_rx) = mpsc::channel(1024);
174 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
175 let task = Task::<K> {
176 rpc_endpoint,
177 program_ids,
178 notifications_tx,
179 status_tx,
180 cancel,
181 coordinator,
182 _marker: PhantomData,
183 };
184 let join = tokio::spawn(task.run());
185 SubscriptionHandle {
186 status: status_rx,
187 notifications: notifications_rx,
188 join,
189 }
190}
191
192type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
193type Subs = HashSet<u64>;
194
195struct Task<K: SubKind> {
196 rpc_endpoint: String,
197 program_ids: Vec<String>,
198 notifications_tx: mpsc::Sender<SubscriptionNotification>,
199 status_tx: watch::Sender<ConnectionStatus>,
200 cancel: CancellationToken,
203 coordinator: Option<Arc<ReconnectCoordinator>>,
208 _marker: PhantomData<fn() -> K>,
209}
210
211impl<K: SubKind> Task<K> {
212 async fn run(self) {
213 let mut budget = ReconnectBudget::new();
214 let mut replay_from_slot: Option<u64> = None;
218
219 loop {
220 if self.cancel.is_cancelled() {
221 break;
222 }
223 publish(&self.status_tx, ConnectionStatus::Down);
224
225 let reconnect_slot = match &self.coordinator {
235 Some(coord) if budget.attempt() >= RECONNECT_UNGATED_ATTEMPTS => {
236 let parked_at = Instant::now();
237 let Some(slot) = coord.reconnect_slot(&self.cancel).await else {
238 break; };
240 budget.discount_parked(parked_at.elapsed());
242 Some(slot)
243 }
244 _ => None,
245 };
246
247 let connect_result = tokio::select! {
250 biased;
251 _ = self.cancel.cancelled() => None,
252 result = async {
253 let ws = connect_ws(&self.rpc_endpoint).await?;
254 subscribe::<K>(ws, &self.program_ids, replay_from_slot).await
255 } => Some(result),
256 };
257 let Some(connect_result) = connect_result else {
258 break;
259 };
260
261 let Subscribed { ws, subs, pending } = match connect_result {
262 Ok(v) => v,
263 Err(why) => {
264 drop(reconnect_slot);
265 if retry_or_fail::<K>(
266 "connect",
267 why,
268 &mut budget,
269 &self.cancel,
270 &self.status_tx,
271 )
272 .await
273 {
274 continue;
275 }
276 break;
277 }
278 };
279
280 let streaming = self.coordinator.as_ref().map(|coord| coord.enter());
286 drop(reconnect_slot);
287 publish(&self.status_tx, ConnectionStatus::Up);
288 let connected_at = Instant::now();
289
290 let exit = message_loop::<K>(
291 ws,
292 subs,
293 pending,
294 &self.notifications_tx,
295 &self.cancel,
296 &mut replay_from_slot,
297 )
298 .await;
299
300 drop(streaming);
301
302 match exit {
303 MessageLoopExit::Cancelled | MessageLoopExit::Completed => break,
304 MessageLoopExit::ConnectionLost(why) => {
305 if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
306 budget.reset();
307 }
308 if retry_or_fail::<K>(
309 "connection lost",
310 why,
311 &mut budget,
312 &self.cancel,
313 &self.status_tx,
314 )
315 .await
316 {
317 continue;
318 }
319 break;
320 }
321 }
322 }
323 }
324}
325
326enum MessageLoopExit {
327 Cancelled,
328 ConnectionLost(String),
329 Completed,
332}
333
334async fn message_loop<K: SubKind>(
335 mut ws: Ws,
336 subs: Subs,
337 pending: Vec<Message>,
338 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
339 cancel: &CancellationToken,
340 replay_from_slot: &mut Option<u64>,
341) -> MessageLoopExit {
342 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
343 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
344 let mut last_inbound = Instant::now();
345 let mut completed: HashSet<u64> = HashSet::new();
348 let mut decompressor = match WsStreamDecompressor::new() {
354 Ok(decompressor) => decompressor,
355 Err(e) => return MessageLoopExit::ConnectionLost(format!("zstd decoder init: {e}")),
356 };
357
358 for msg in pending {
361 let outcome = process_data_frame::<K>(
362 msg,
363 &subs,
364 notifications_tx,
365 &mut completed,
366 replay_from_slot,
367 &mut decompressor,
368 )
369 .await;
370 if let Some(exit) = frame_outcome_to_exit(outcome) {
371 return exit;
372 }
373 }
374
375 loop {
376 tokio::select! {
377 biased;
378 _ = cancel.cancelled() => return MessageLoopExit::Cancelled,
379
380 _ = ping_timer.tick() => {
381 if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
382 return MessageLoopExit::ConnectionLost(format!(
383 "no traffic for {:?}", last_inbound.elapsed()
384 ));
385 }
386 if let Err(e) = ws.send(Message::Ping(vec![])).await {
387 return MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
388 }
389 }
390
391 msg = ws.next() => {
392 last_inbound = Instant::now();
393 match msg {
394 Some(Ok(frame @ (Message::Text(_) | Message::Binary(_)))) => {
395 let outcome = process_data_frame::<K>(
396 frame,
397 &subs,
398 notifications_tx,
399 &mut completed,
400 replay_from_slot,
401 &mut decompressor,
402 )
403 .await;
404 if let Some(exit) = frame_outcome_to_exit(outcome) {
405 return exit;
406 }
407 }
408 Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
409 Some(Ok(Message::Close(frame))) => {
410 return MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
411 }
412 Some(Ok(Message::Frame(_))) => {}
413 Some(Err(e)) => return MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e))),
414 None => return MessageLoopExit::ConnectionLost("ws stream ended".into()),
415 }
416 }
417 }
418 }
419}
420
421fn frame_outcome_to_exit(outcome: Result<TextOutcome, String>) -> Option<MessageLoopExit> {
423 match outcome {
424 Ok(TextOutcome::Continue) => None,
425 Ok(TextOutcome::AllComplete) => Some(MessageLoopExit::Completed),
426 Ok(TextOutcome::ChannelClosed) => Some(MessageLoopExit::Cancelled),
427 Err(why) => Some(MessageLoopExit::ConnectionLost(why)),
428 }
429}
430
431async fn process_data_frame<K: SubKind>(
436 msg: Message,
437 subs: &Subs,
438 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
439 completed: &mut HashSet<u64>,
440 replay_from_slot: &mut Option<u64>,
441 decompressor: &mut WsStreamDecompressor,
442) -> Result<TextOutcome, String> {
443 match msg {
444 Message::Text(t) => {
445 Ok(handle_text::<K>(&t, subs, notifications_tx, completed, replay_from_slot).await)
446 }
447 Message::Binary(b) => {
449 let decoded = decompressor
450 .decompress(&b)
451 .map_err(|e| format!("ws decompress: {e}"))?;
452 match std::str::from_utf8(&decoded) {
453 Ok(t) => {
454 Ok(
455 handle_text::<K>(t, subs, notifications_tx, completed, replay_from_slot)
456 .await,
457 )
458 }
459 Err(_) => Ok(TextOutcome::Continue),
460 }
461 }
462 _ => Ok(TextOutcome::Continue),
463 }
464}
465
466async fn retry_or_fail<K: SubKind>(
469 phase: &'static str,
470 reason: String,
471 budget: &mut ReconnectBudget,
472 cancel: &CancellationToken,
473 status_tx: &watch::Sender<ConnectionStatus>,
474) -> bool {
475 if let Some(delay) = budget.next_backoff() {
476 warn!(
477 kind = K::LABEL,
478 attempt = budget.attempt(),
479 reason = %reason,
480 ?delay,
481 "subscription {phase}, retrying",
482 );
483 cancellable_sleep(delay, cancel).await
484 } else {
485 publish(
486 status_tx,
487 ConnectionStatus::Failed(format!("{phase}: {reason}")),
488 );
489 false
490 }
491}
492
493fn publish(tx: &watch::Sender<ConnectionStatus>, status: ConnectionStatus) {
494 tx.send_if_modified(|current| {
495 if *current == status {
496 false
497 } else {
498 *current = status;
499 true
500 }
501 });
502}
503
504async fn connect_ws(rpc_endpoint: &str) -> Result<Ws, String> {
505 let ws_url = http_to_ws_url(rpc_endpoint).map_err(|e| err_chain(&e))?;
506 let request = ws_url
507 .into_client_request()
508 .map_err(|e| format!("build request: {}", err_chain(&e)))?;
509
510 let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
511 .await
512 .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
513 .map_err(|e| format!("connect: {}", err_chain(&e)))?;
514 Ok(connect.0)
515}
516
517struct Subscribed {
519 ws: Ws,
520 subs: Subs,
522 pending: Vec<Message>,
525}
526
527async fn subscribe<K: SubKind>(
528 mut ws: Ws,
529 program_ids: &[String],
530 replay_from_slot: Option<u64>,
531) -> Result<Subscribed, String> {
532 let mut subs = Subs::new();
533 let mut pending = Vec::new();
536 for (i, program_id) in program_ids.iter().enumerate() {
537 let id = (i + 1) as u64;
538 let mut params = K::subscribe_params(program_id);
539 SubscribeConfig {
544 replay_from_slot: replay_from_slot.map(|slot| slot as i64),
545 compression: Some(Compression::Zstd),
546 }
547 .apply_to(&mut params);
548 let req = serde_json::json!({
549 "jsonrpc": "2.0",
550 "id": id,
551 "method": K::SUBSCRIBE_METHOD,
552 "params": params,
553 });
554 ws.send(Message::Text(req.to_string()))
555 .await
556 .map_err(|e| format!("subscribe send: {}", err_chain(&e)))?;
557 subs.insert(read_sub_ack(&mut ws, id, &mut pending).await?);
558 }
559 debug!(
560 kind = K::LABEL,
561 count = subs.len(),
562 "subscriptions established"
563 );
564 Ok(Subscribed { ws, subs, pending })
565}
566
567#[derive(Deserialize)]
568struct SubAck {
569 id: u64,
570 result: Option<u64>,
571 #[serde(default)]
572 error: Option<serde_json::Value>,
573}
574
575async fn read_sub_ack(
581 ws: &mut Ws,
582 expected_id: u64,
583 pending: &mut Vec<Message>,
584) -> Result<u64, String> {
585 let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
586 loop {
587 let msg = tokio::time::timeout_at(deadline, ws.next())
588 .await
589 .map_err(|_| format!("subscribe ack timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
590
591 let Some(msg) = msg else {
592 return Err("ws ended during subscribe".into());
593 };
594 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
595
596 if let Message::Text(t) = &msg
598 && let Ok(ack) = serde_json::from_str::<SubAck>(t)
599 {
600 if ack.id != expected_id {
601 continue;
602 }
603 if let Some(err) = ack.error {
604 return Err(format!("subscribe rejected: {err}"));
605 }
606 return ack
607 .result
608 .ok_or_else(|| "subscribe ack missing result".to_string());
609 }
610 if matches!(msg, Message::Text(_) | Message::Binary(_)) {
611 pending.push(msg);
612 }
613 }
614}
615
616enum TextOutcome {
618 Continue,
620 AllComplete,
622 ChannelClosed,
624}
625
626async fn handle_text<K: SubKind>(
630 text: &str,
631 subs: &Subs,
632 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
633 completed: &mut HashSet<u64>,
634 replay_from_slot: &mut Option<u64>,
635) -> TextOutcome {
636 if let Some(n) = parse_notification::<K>(text, subs) {
639 let slot = K::slot_of(&n);
648 if notifications_tx
649 .send(K::into_notification(n))
650 .await
651 .is_err()
652 {
653 return TextOutcome::ChannelClosed;
654 }
655 *replay_from_slot = Some(replay_from_slot.map_or(slot, |prev| prev.max(slot)));
656 return TextOutcome::Continue;
657 }
658
659 if let Some(sub_id) = parse_completion(text)
661 && subs.contains(&sub_id)
662 {
663 completed.insert(sub_id);
664 if subs.iter().all(|id| completed.contains(id)) {
665 return TextOutcome::AllComplete;
666 }
667 }
668 TextOutcome::Continue
669}
670
671fn parse_completion(text: &str) -> Option<u64> {
674 #[derive(Deserialize)]
675 struct Msg {
676 method: String,
677 params: Params,
678 }
679 #[derive(Deserialize)]
680 struct Params {
681 subscription: u64,
682 }
683
684 let msg: Msg = serde_json::from_str(text).ok()?;
685 (msg.method == "subscriptionComplete").then_some(msg.params.subscription)
686}
687
688fn parse_notification<K: SubKind>(text: &str, subs: &Subs) -> Option<K::Notification> {
689 #[derive(Deserialize)]
690 #[serde(bound = "T: DeserializeOwned")]
691 struct Msg<T> {
692 method: String,
693 params: Params<T>,
694 }
695 #[derive(Deserialize)]
696 #[serde(bound = "T: DeserializeOwned")]
697 struct Params<T> {
698 subscription: u64,
699 result: T,
700 }
701
702 let msg: Msg<K::Notification> = serde_json::from_str(text).ok()?;
703 if msg.method != K::NOTIFICATION_METHOD {
704 return None;
705 }
706 if !subs.contains(&msg.params.subscription) {
707 return None;
708 }
709 Some(msg.params.result)
710}
711
712#[cfg(test)]
713mod tests {
714 use super::parse_completion;
715
716 #[test]
717 fn parse_completion_extracts_subscription_id() {
718 let text =
719 r#"{"jsonrpc":"2.0","method":"subscriptionComplete","params":{"subscription":7}}"#;
720 assert_eq!(parse_completion(text), Some(7));
721 }
722
723 #[test]
724 fn parse_completion_ignores_other_messages() {
725 let notification = r#"{"jsonrpc":"2.0","method":"transactionNotification","params":{"subscription":7,"result":{}}}"#;
726 assert_eq!(parse_completion(notification), None);
727 assert_eq!(parse_completion("not json"), None);
728 }
729}