1use std::{collections::HashSet, marker::PhantomData, sync::Arc, time::Instant};
17
18use futures::{SinkExt, StreamExt};
19use serde::{Deserialize, de::DeserializeOwned};
20use simulator_api::{
21 subscribe_config::{Compression, SubscribeConfig},
22 ws_compression::WsStreamDecompressor,
23};
24use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
25use tokio::{
26 net::TcpStream,
27 sync::{mpsc, watch},
28 task::JoinHandle,
29};
30use tokio_tungstenite::{
31 MaybeTlsStream, WebSocketStream, connect_async,
32 tungstenite::{Message, client::IntoClientRequest},
33};
34use tokio_util::sync::CancellationToken;
35use tracing::{debug, warn};
36
37use super::{
38 CONNECT_TIMEOUT, ConnectionStatus, HANDSHAKE_RESPONSE_TIMEOUT, KEEPALIVE_INTERVAL,
39 KEEPALIVE_MISS_DEADLINE, RECONNECT_UNGATED_ATTEMPTS, RECONNECT_UPTIME_RESET, ReconnectBudget,
40 ReconnectCoordinator, cancellable_sleep,
41};
42use crate::{error::err_chain, subscriptions::AccountDiffNotification, urls::http_to_ws_url};
43
44pub struct SubscriptionHandle {
46 pub status: watch::Receiver<ConnectionStatus>,
47 pub notifications: mpsc::Receiver<SubscriptionNotification>,
48 pub join: JoinHandle<()>,
49}
50
51#[derive(Debug)]
52pub enum SubscriptionNotification {
53 Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
54 AccountDiff(AccountDiffNotification),
55}
56
57trait SubKind: Send + Sync + 'static {
59 type Notification: DeserializeOwned + Send + 'static;
60 const LABEL: &'static str;
61 const SUBSCRIBE_METHOD: &'static str;
62 const NOTIFICATION_METHOD: &'static str;
63 fn subscribe_params(program_id: &str) -> serde_json::Value;
64 fn into_notification(notification: Self::Notification) -> SubscriptionNotification;
65 fn slot_of(notification: &Self::Notification) -> u64;
69}
70
71struct AccountDiff;
72impl SubKind for AccountDiff {
73 type Notification = AccountDiffNotification;
74 const LABEL: &'static str = "account-diff";
75 const SUBSCRIBE_METHOD: &'static str = "accountDiffSubscribe";
76 const NOTIFICATION_METHOD: &'static str = "accountDiffNotification";
77 fn subscribe_params(program_id: &str) -> serde_json::Value {
78 serde_json::json!([program_id, {"address_type": "program"}])
79 }
80 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
81 SubscriptionNotification::AccountDiff(notification)
82 }
83 fn slot_of(notification: &Self::Notification) -> u64 {
84 notification.context.slot
85 }
86}
87
88struct Transaction;
89impl SubKind for Transaction {
90 type Notification = EncodedConfirmedTransactionWithStatusMeta;
94 const LABEL: &'static str = "transaction";
95 const SUBSCRIBE_METHOD: &'static str = "transactionSubscribe";
96 const NOTIFICATION_METHOD: &'static str = "transactionNotification";
97 fn subscribe_params(program_id: &str) -> serde_json::Value {
98 serde_json::json!([{"mentions": [program_id]}, {"commitment": "confirmed"}])
99 }
100 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
101 SubscriptionNotification::Transaction(Box::new(notification))
102 }
103 fn slot_of(notification: &Self::Notification) -> u64 {
104 notification.slot
105 }
106}
107
108pub fn spawn_transaction_subscription_manager(
109 rpc_endpoint: String,
110 program_ids: Vec<String>,
111 cancel: CancellationToken,
112 coordinator: Option<Arc<ReconnectCoordinator>>,
113) -> SubscriptionHandle {
114 spawn_subscription_manager::<Transaction>(rpc_endpoint, program_ids, cancel, coordinator)
115}
116
117pub fn spawn_account_diff_subscription_manager(
118 rpc_endpoint: String,
119 program_ids: Vec<String>,
120 cancel: CancellationToken,
121 coordinator: Option<Arc<ReconnectCoordinator>>,
122) -> SubscriptionHandle {
123 spawn_subscription_manager::<AccountDiff>(rpc_endpoint, program_ids, cancel, coordinator)
124}
125
126fn spawn_subscription_manager<K>(
127 rpc_endpoint: String,
128 program_ids: Vec<String>,
129 cancel: CancellationToken,
130 coordinator: Option<Arc<ReconnectCoordinator>>,
131) -> SubscriptionHandle
132where
133 K: SubKind,
134{
135 let (notifications_tx, notifications_rx) = mpsc::channel(1024);
136 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
137 let task = Task::<K> {
138 rpc_endpoint,
139 program_ids,
140 notifications_tx,
141 status_tx,
142 cancel,
143 coordinator,
144 _marker: PhantomData,
145 };
146 let join = tokio::spawn(task.run());
147 SubscriptionHandle {
148 status: status_rx,
149 notifications: notifications_rx,
150 join,
151 }
152}
153
154type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
155type Subs = HashSet<u64>;
156
157struct Task<K: SubKind> {
158 rpc_endpoint: String,
159 program_ids: Vec<String>,
160 notifications_tx: mpsc::Sender<SubscriptionNotification>,
161 status_tx: watch::Sender<ConnectionStatus>,
162 cancel: CancellationToken,
165 coordinator: Option<Arc<ReconnectCoordinator>>,
170 _marker: PhantomData<fn() -> K>,
171}
172
173impl<K: SubKind> Task<K> {
174 async fn run(self) {
175 let mut budget = ReconnectBudget::new();
176 let mut replay_from_slot: Option<u64> = None;
180
181 loop {
182 if self.cancel.is_cancelled() {
183 break;
184 }
185 publish(&self.status_tx, ConnectionStatus::Down);
186
187 let reconnect_slot = match &self.coordinator {
197 Some(coord) if budget.attempt() >= RECONNECT_UNGATED_ATTEMPTS => {
198 let parked_at = Instant::now();
199 let Some(slot) = coord.reconnect_slot(&self.cancel).await else {
200 break; };
202 budget.discount_parked(parked_at.elapsed());
204 Some(slot)
205 }
206 _ => None,
207 };
208
209 let connect_result = tokio::select! {
212 biased;
213 _ = self.cancel.cancelled() => None,
214 result = async {
215 let ws = connect_ws(&self.rpc_endpoint).await?;
216 subscribe::<K>(ws, &self.program_ids, replay_from_slot).await
217 } => Some(result),
218 };
219 let Some(connect_result) = connect_result else {
220 break;
221 };
222
223 let Subscribed { ws, subs, pending } = match connect_result {
224 Ok(v) => v,
225 Err(why) => {
226 drop(reconnect_slot);
227 if retry_or_fail::<K>(
228 "connect",
229 why,
230 &mut budget,
231 &self.cancel,
232 &self.status_tx,
233 )
234 .await
235 {
236 continue;
237 }
238 break;
239 }
240 };
241
242 let streaming = self.coordinator.as_ref().map(|coord| coord.enter());
248 drop(reconnect_slot);
249 publish(&self.status_tx, ConnectionStatus::Up);
250 let connected_at = Instant::now();
251
252 let exit = message_loop::<K>(
253 ws,
254 subs,
255 pending,
256 &self.notifications_tx,
257 &self.cancel,
258 &mut replay_from_slot,
259 )
260 .await;
261
262 drop(streaming);
263
264 match exit {
265 MessageLoopExit::Cancelled | MessageLoopExit::Completed => break,
266 MessageLoopExit::ConnectionLost(why) => {
267 if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
268 budget.reset();
269 }
270 if retry_or_fail::<K>(
271 "connection lost",
272 why,
273 &mut budget,
274 &self.cancel,
275 &self.status_tx,
276 )
277 .await
278 {
279 continue;
280 }
281 break;
282 }
283 }
284 }
285 }
286}
287
288enum MessageLoopExit {
289 Cancelled,
290 ConnectionLost(String),
291 Completed,
294}
295
296async fn message_loop<K: SubKind>(
297 mut ws: Ws,
298 subs: Subs,
299 pending: Vec<Message>,
300 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
301 cancel: &CancellationToken,
302 replay_from_slot: &mut Option<u64>,
303) -> MessageLoopExit {
304 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
305 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
306 let mut last_inbound = Instant::now();
307 let mut completed: HashSet<u64> = HashSet::new();
310 let mut decompressor = match WsStreamDecompressor::new() {
316 Ok(decompressor) => decompressor,
317 Err(e) => return MessageLoopExit::ConnectionLost(format!("zstd decoder init: {e}")),
318 };
319
320 for msg in pending {
323 let outcome = process_data_frame::<K>(
324 msg,
325 &subs,
326 notifications_tx,
327 &mut completed,
328 replay_from_slot,
329 &mut decompressor,
330 )
331 .await;
332 if let Some(exit) = frame_outcome_to_exit(outcome) {
333 return exit;
334 }
335 }
336
337 loop {
338 tokio::select! {
339 biased;
340 _ = cancel.cancelled() => return MessageLoopExit::Cancelled,
341
342 _ = ping_timer.tick() => {
343 if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
344 return MessageLoopExit::ConnectionLost(format!(
345 "no traffic for {:?}", last_inbound.elapsed()
346 ));
347 }
348 if let Err(e) = ws.send(Message::Ping(vec![])).await {
349 return MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
350 }
351 }
352
353 msg = ws.next() => {
354 last_inbound = Instant::now();
355 match msg {
356 Some(Ok(frame @ (Message::Text(_) | Message::Binary(_)))) => {
357 let outcome = process_data_frame::<K>(
358 frame,
359 &subs,
360 notifications_tx,
361 &mut completed,
362 replay_from_slot,
363 &mut decompressor,
364 )
365 .await;
366 if let Some(exit) = frame_outcome_to_exit(outcome) {
367 return exit;
368 }
369 }
370 Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
371 Some(Ok(Message::Close(frame))) => {
372 return MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
373 }
374 Some(Ok(Message::Frame(_))) => {}
375 Some(Err(e)) => return MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e))),
376 None => return MessageLoopExit::ConnectionLost("ws stream ended".into()),
377 }
378 }
379 }
380 }
381}
382
383fn frame_outcome_to_exit(outcome: Result<TextOutcome, String>) -> Option<MessageLoopExit> {
385 match outcome {
386 Ok(TextOutcome::Continue) => None,
387 Ok(TextOutcome::AllComplete) => Some(MessageLoopExit::Completed),
388 Ok(TextOutcome::ChannelClosed) => Some(MessageLoopExit::Cancelled),
389 Err(why) => Some(MessageLoopExit::ConnectionLost(why)),
390 }
391}
392
393async fn process_data_frame<K: SubKind>(
398 msg: Message,
399 subs: &Subs,
400 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
401 completed: &mut HashSet<u64>,
402 replay_from_slot: &mut Option<u64>,
403 decompressor: &mut WsStreamDecompressor,
404) -> Result<TextOutcome, String> {
405 match msg {
406 Message::Text(t) => {
407 Ok(handle_text::<K>(&t, subs, notifications_tx, completed, replay_from_slot).await)
408 }
409 Message::Binary(b) => {
411 let decoded = decompressor
412 .decompress(&b)
413 .map_err(|e| format!("ws decompress: {e}"))?;
414 match std::str::from_utf8(&decoded) {
415 Ok(t) => {
416 Ok(
417 handle_text::<K>(t, subs, notifications_tx, completed, replay_from_slot)
418 .await,
419 )
420 }
421 Err(_) => Ok(TextOutcome::Continue),
422 }
423 }
424 _ => Ok(TextOutcome::Continue),
425 }
426}
427
428async fn retry_or_fail<K: SubKind>(
431 phase: &'static str,
432 reason: String,
433 budget: &mut ReconnectBudget,
434 cancel: &CancellationToken,
435 status_tx: &watch::Sender<ConnectionStatus>,
436) -> bool {
437 if let Some(delay) = budget.next_backoff() {
438 warn!(
439 kind = K::LABEL,
440 attempt = budget.attempt(),
441 reason = %reason,
442 ?delay,
443 "subscription {phase}, retrying",
444 );
445 cancellable_sleep(delay, cancel).await
446 } else {
447 publish(
448 status_tx,
449 ConnectionStatus::Failed(format!("{phase}: {reason}")),
450 );
451 false
452 }
453}
454
455fn publish(tx: &watch::Sender<ConnectionStatus>, status: ConnectionStatus) {
456 tx.send_if_modified(|current| {
457 if *current == status {
458 false
459 } else {
460 *current = status;
461 true
462 }
463 });
464}
465
466async fn connect_ws(rpc_endpoint: &str) -> Result<Ws, String> {
467 let ws_url = http_to_ws_url(rpc_endpoint).map_err(|e| err_chain(&e))?;
468 let request = ws_url
469 .into_client_request()
470 .map_err(|e| format!("build request: {}", err_chain(&e)))?;
471
472 let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
473 .await
474 .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
475 .map_err(|e| format!("connect: {}", err_chain(&e)))?;
476 Ok(connect.0)
477}
478
479struct Subscribed {
481 ws: Ws,
482 subs: Subs,
484 pending: Vec<Message>,
487}
488
489async fn subscribe<K: SubKind>(
490 mut ws: Ws,
491 program_ids: &[String],
492 replay_from_slot: Option<u64>,
493) -> Result<Subscribed, String> {
494 let mut subs = Subs::new();
495 let mut pending = Vec::new();
498 for (i, program_id) in program_ids.iter().enumerate() {
499 let id = (i + 1) as u64;
500 let mut params = K::subscribe_params(program_id);
501 SubscribeConfig {
506 replay_from_slot: replay_from_slot.map(|slot| slot as i64),
507 compression: Some(Compression::Zstd),
508 }
509 .apply_to(&mut params);
510 let req = serde_json::json!({
511 "jsonrpc": "2.0",
512 "id": id,
513 "method": K::SUBSCRIBE_METHOD,
514 "params": params,
515 });
516 ws.send(Message::Text(req.to_string()))
517 .await
518 .map_err(|e| format!("subscribe send: {}", err_chain(&e)))?;
519 subs.insert(read_sub_ack(&mut ws, id, &mut pending).await?);
520 }
521 debug!(
522 kind = K::LABEL,
523 count = subs.len(),
524 "subscriptions established"
525 );
526 Ok(Subscribed { ws, subs, pending })
527}
528
529#[derive(Deserialize)]
530struct SubAck {
531 id: u64,
532 result: Option<u64>,
533 #[serde(default)]
534 error: Option<serde_json::Value>,
535}
536
537async fn read_sub_ack(
543 ws: &mut Ws,
544 expected_id: u64,
545 pending: &mut Vec<Message>,
546) -> Result<u64, String> {
547 let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
548 loop {
549 let msg = tokio::time::timeout_at(deadline, ws.next())
550 .await
551 .map_err(|_| format!("subscribe ack timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
552
553 let Some(msg) = msg else {
554 return Err("ws ended during subscribe".into());
555 };
556 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
557
558 if let Message::Text(t) = &msg
560 && let Ok(ack) = serde_json::from_str::<SubAck>(t)
561 {
562 if ack.id != expected_id {
563 continue;
564 }
565 if let Some(err) = ack.error {
566 return Err(format!("subscribe rejected: {err}"));
567 }
568 return ack
569 .result
570 .ok_or_else(|| "subscribe ack missing result".to_string());
571 }
572 if matches!(msg, Message::Text(_) | Message::Binary(_)) {
573 pending.push(msg);
574 }
575 }
576}
577
578enum TextOutcome {
580 Continue,
582 AllComplete,
584 ChannelClosed,
586}
587
588async fn handle_text<K: SubKind>(
592 text: &str,
593 subs: &Subs,
594 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
595 completed: &mut HashSet<u64>,
596 replay_from_slot: &mut Option<u64>,
597) -> TextOutcome {
598 if let Some(n) = parse_notification::<K>(text, subs) {
601 let slot = K::slot_of(&n);
610 if notifications_tx
611 .send(K::into_notification(n))
612 .await
613 .is_err()
614 {
615 return TextOutcome::ChannelClosed;
616 }
617 *replay_from_slot = Some(replay_from_slot.map_or(slot, |prev| prev.max(slot)));
618 return TextOutcome::Continue;
619 }
620
621 if let Some(sub_id) = parse_completion(text)
623 && subs.contains(&sub_id)
624 {
625 completed.insert(sub_id);
626 if subs.iter().all(|id| completed.contains(id)) {
627 return TextOutcome::AllComplete;
628 }
629 }
630 TextOutcome::Continue
631}
632
633fn parse_completion(text: &str) -> Option<u64> {
636 #[derive(Deserialize)]
637 struct Msg {
638 method: String,
639 params: Params,
640 }
641 #[derive(Deserialize)]
642 struct Params {
643 subscription: u64,
644 }
645
646 let msg: Msg = serde_json::from_str(text).ok()?;
647 (msg.method == "subscriptionComplete").then_some(msg.params.subscription)
648}
649
650fn parse_notification<K: SubKind>(text: &str, subs: &Subs) -> Option<K::Notification> {
651 #[derive(Deserialize)]
652 #[serde(bound = "T: DeserializeOwned")]
653 struct Msg<T> {
654 method: String,
655 params: Params<T>,
656 }
657 #[derive(Deserialize)]
658 #[serde(bound = "T: DeserializeOwned")]
659 struct Params<T> {
660 subscription: u64,
661 result: T,
662 }
663
664 let msg: Msg<K::Notification> = serde_json::from_str(text).ok()?;
665 if msg.method != K::NOTIFICATION_METHOD {
666 return None;
667 }
668 if !subs.contains(&msg.params.subscription) {
669 return None;
670 }
671 Some(msg.params.result)
672}
673
674#[cfg(test)]
675mod tests {
676 use super::parse_completion;
677
678 #[test]
679 fn parse_completion_extracts_subscription_id() {
680 let text =
681 r#"{"jsonrpc":"2.0","method":"subscriptionComplete","params":{"subscription":7}}"#;
682 assert_eq!(parse_completion(text), Some(7));
683 }
684
685 #[test]
686 fn parse_completion_ignores_other_messages() {
687 let notification = r#"{"jsonrpc":"2.0","method":"transactionNotification","params":{"subscription":7,"result":{}}}"#;
688 assert_eq!(parse_completion(notification), None);
689 assert_eq!(parse_completion("not json"), None);
690 }
691}