1use std::{collections::HashSet, marker::PhantomData, time::Instant};
14
15use futures::{SinkExt, StreamExt};
16use serde::{Deserialize, de::DeserializeOwned};
17use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
18use tokio::{
19 net::TcpStream,
20 sync::{mpsc, watch},
21 task::JoinHandle,
22};
23use tokio_tungstenite::{
24 MaybeTlsStream, WebSocketStream, connect_async,
25 tungstenite::{Message, client::IntoClientRequest},
26};
27use tokio_util::sync::CancellationToken;
28use tracing::{debug, warn};
29
30use super::{
31 CONNECT_TIMEOUT, ConnectionStatus, HANDSHAKE_RESPONSE_TIMEOUT, KEEPALIVE_INTERVAL,
32 KEEPALIVE_MISS_DEADLINE, RECONNECT_UPTIME_RESET, ReconnectBudget, cancellable_sleep,
33};
34use crate::{error::err_chain, subscriptions::AccountDiffNotification, urls::http_to_ws_url};
35
36pub struct SubscriptionHandle {
38 pub status: watch::Receiver<ConnectionStatus>,
39 pub notifications: mpsc::Receiver<SubscriptionNotification>,
40 pub join: JoinHandle<()>,
41}
42
43#[derive(Debug)]
44pub enum SubscriptionNotification {
45 Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
46 AccountDiff(AccountDiffNotification),
47}
48
49trait SubKind: Send + Sync + 'static {
51 type Notification: DeserializeOwned + Send + 'static;
52 const LABEL: &'static str;
53 const SUBSCRIBE_METHOD: &'static str;
54 const NOTIFICATION_METHOD: &'static str;
55 fn subscribe_params(program_id: &str) -> serde_json::Value;
56 fn into_notification(notification: Self::Notification) -> SubscriptionNotification;
57}
58
59struct AccountDiff;
60impl SubKind for AccountDiff {
61 type Notification = AccountDiffNotification;
62 const LABEL: &'static str = "account-diff";
63 const SUBSCRIBE_METHOD: &'static str = "accountDiffSubscribe";
64 const NOTIFICATION_METHOD: &'static str = "accountDiffNotification";
65 fn subscribe_params(program_id: &str) -> serde_json::Value {
66 serde_json::json!([program_id, {"address_type": "program"}])
67 }
68 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
69 SubscriptionNotification::AccountDiff(notification)
70 }
71}
72
73struct Transaction;
74impl SubKind for Transaction {
75 type Notification = EncodedConfirmedTransactionWithStatusMeta;
79 const LABEL: &'static str = "transaction";
80 const SUBSCRIBE_METHOD: &'static str = "transactionSubscribe";
81 const NOTIFICATION_METHOD: &'static str = "transactionNotification";
82 fn subscribe_params(program_id: &str) -> serde_json::Value {
83 serde_json::json!([{"mentions": [program_id]}, {"commitment": "confirmed"}])
84 }
85 fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
86 SubscriptionNotification::Transaction(Box::new(notification))
87 }
88}
89
90pub fn spawn_transaction_subscription_manager(
91 rpc_endpoint: String,
92 program_ids: Vec<String>,
93 cancel: CancellationToken,
94) -> SubscriptionHandle {
95 spawn_subscription_manager::<Transaction>(rpc_endpoint, program_ids, cancel)
96}
97
98pub fn spawn_account_diff_subscription_manager(
99 rpc_endpoint: String,
100 program_ids: Vec<String>,
101 cancel: CancellationToken,
102) -> SubscriptionHandle {
103 spawn_subscription_manager::<AccountDiff>(rpc_endpoint, program_ids, cancel)
104}
105
106fn spawn_subscription_manager<K>(
107 rpc_endpoint: String,
108 program_ids: Vec<String>,
109 cancel: CancellationToken,
110) -> SubscriptionHandle
111where
112 K: SubKind,
113{
114 let (notifications_tx, notifications_rx) = mpsc::channel(1024);
115 let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
116 let task = Task::<K> {
117 rpc_endpoint,
118 program_ids,
119 notifications_tx,
120 status_tx,
121 cancel,
122 _marker: PhantomData,
123 };
124 let join = tokio::spawn(task.run());
125 SubscriptionHandle {
126 status: status_rx,
127 notifications: notifications_rx,
128 join,
129 }
130}
131
132type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
133type Subs = HashSet<u64>;
134
135struct Task<K: SubKind> {
136 rpc_endpoint: String,
137 program_ids: Vec<String>,
138 notifications_tx: mpsc::Sender<SubscriptionNotification>,
139 status_tx: watch::Sender<ConnectionStatus>,
140 cancel: CancellationToken,
143 _marker: PhantomData<fn() -> K>,
144}
145
146impl<K: SubKind> Task<K> {
147 async fn run(self) {
148 let mut budget = ReconnectBudget::new();
149
150 loop {
151 if self.cancel.is_cancelled() {
152 break;
153 }
154 publish(&self.status_tx, ConnectionStatus::Down);
155
156 let connect_result = async {
157 let ws = connect_ws(&self.rpc_endpoint).await?;
158 subscribe::<K>(ws, &self.program_ids).await
159 }
160 .await;
161
162 let (ws, subs) = match connect_result {
163 Ok(v) => v,
164 Err(why) => {
165 if retry_or_fail::<K>(
166 "connect",
167 why,
168 &mut budget,
169 &self.cancel,
170 &self.status_tx,
171 )
172 .await
173 {
174 continue;
175 }
176 break;
177 }
178 };
179
180 publish(&self.status_tx, ConnectionStatus::Up);
181 let connected_at = Instant::now();
182
183 let exit = message_loop::<K>(ws, subs, &self.notifications_tx, &self.cancel).await;
184
185 match exit {
186 MessageLoopExit::Cancelled | MessageLoopExit::Completed => break,
187 MessageLoopExit::ConnectionLost(why) => {
188 if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
189 budget.reset();
190 }
191 if retry_or_fail::<K>(
192 "connection lost",
193 why,
194 &mut budget,
195 &self.cancel,
196 &self.status_tx,
197 )
198 .await
199 {
200 continue;
201 }
202 break;
203 }
204 }
205 }
206 }
207}
208
209enum MessageLoopExit {
210 Cancelled,
211 ConnectionLost(String),
212 Completed,
215}
216
217async fn message_loop<K: SubKind>(
218 mut ws: Ws,
219 subs: Subs,
220 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
221 cancel: &CancellationToken,
222) -> MessageLoopExit {
223 let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
224 ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
225 let mut last_inbound = Instant::now();
226 let mut completed: HashSet<u64> = HashSet::new();
229
230 loop {
231 tokio::select! {
232 biased;
233 _ = cancel.cancelled() => return MessageLoopExit::Cancelled,
234
235 _ = ping_timer.tick() => {
236 if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
237 return MessageLoopExit::ConnectionLost(format!(
238 "no traffic for {:?}", last_inbound.elapsed()
239 ));
240 }
241 if let Err(e) = ws.send(Message::Ping(vec![])).await {
242 return MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
243 }
244 }
245
246 msg = ws.next() => {
247 last_inbound = Instant::now();
248 match msg {
249 Some(Ok(Message::Text(t))) => {
250 match handle_text::<K>(&t, &subs, notifications_tx, &mut completed).await {
251 TextOutcome::Continue => {}
252 TextOutcome::AllComplete => return MessageLoopExit::Completed,
253 TextOutcome::ChannelClosed => return MessageLoopExit::Cancelled,
254 }
255 }
256 Some(Ok(Message::Binary(b))) => {
257 if let Ok(t) = std::str::from_utf8(&b) {
258 match handle_text::<K>(t, &subs, notifications_tx, &mut completed).await {
259 TextOutcome::Continue => {}
260 TextOutcome::AllComplete => return MessageLoopExit::Completed,
261 TextOutcome::ChannelClosed => return MessageLoopExit::Cancelled,
262 }
263 }
264 }
265 Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
266 Some(Ok(Message::Close(frame))) => {
267 return MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
268 }
269 Some(Ok(Message::Frame(_))) => {}
270 Some(Err(e)) => return MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e))),
271 None => return MessageLoopExit::ConnectionLost("ws stream ended".into()),
272 }
273 }
274 }
275 }
276}
277
278async fn retry_or_fail<K: SubKind>(
281 phase: &'static str,
282 reason: String,
283 budget: &mut ReconnectBudget,
284 cancel: &CancellationToken,
285 status_tx: &watch::Sender<ConnectionStatus>,
286) -> bool {
287 if let Some(delay) = budget.next_backoff() {
288 warn!(
289 kind = K::LABEL,
290 attempt = budget.attempt(),
291 reason = %reason,
292 ?delay,
293 "subscription {phase}, retrying",
294 );
295 cancellable_sleep(delay, cancel).await
296 } else {
297 publish(
298 status_tx,
299 ConnectionStatus::Failed(format!("{phase}: {reason}")),
300 );
301 false
302 }
303}
304
305fn publish(tx: &watch::Sender<ConnectionStatus>, status: ConnectionStatus) {
306 tx.send_if_modified(|current| {
307 if *current == status {
308 false
309 } else {
310 *current = status;
311 true
312 }
313 });
314}
315
316async fn connect_ws(rpc_endpoint: &str) -> Result<Ws, String> {
317 let ws_url = http_to_ws_url(rpc_endpoint).map_err(|e| err_chain(&e))?;
318 let request = ws_url
319 .into_client_request()
320 .map_err(|e| format!("build request: {}", err_chain(&e)))?;
321
322 let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
323 .await
324 .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
325 .map_err(|e| format!("connect: {}", err_chain(&e)))?;
326 Ok(connect.0)
327}
328
329async fn subscribe<K: SubKind>(mut ws: Ws, program_ids: &[String]) -> Result<(Ws, Subs), String> {
330 let mut subs = Subs::new();
331 for (i, program_id) in program_ids.iter().enumerate() {
332 let id = (i + 1) as u64;
333 let req = serde_json::json!({
334 "jsonrpc": "2.0",
335 "id": id,
336 "method": K::SUBSCRIBE_METHOD,
337 "params": K::subscribe_params(program_id),
338 });
339 ws.send(Message::Text(req.to_string()))
340 .await
341 .map_err(|e| format!("subscribe send: {}", err_chain(&e)))?;
342 subs.insert(read_sub_ack(&mut ws, id).await?);
343 }
344 debug!(
345 kind = K::LABEL,
346 count = subs.len(),
347 "subscriptions established"
348 );
349 Ok((ws, subs))
350}
351
352#[derive(Deserialize)]
353struct SubAck {
354 id: u64,
355 result: Option<u64>,
356 #[serde(default)]
357 error: Option<serde_json::Value>,
358}
359
360async fn read_sub_ack(ws: &mut Ws, expected_id: u64) -> Result<u64, String> {
361 let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
362 loop {
363 let msg = tokio::time::timeout_at(deadline, ws.next())
364 .await
365 .map_err(|_| format!("subscribe ack timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
366
367 let Some(msg) = msg else {
368 return Err("ws ended during subscribe".into());
369 };
370 let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
371
372 if let Message::Text(t) = msg
373 && let Ok(ack) = serde_json::from_str::<SubAck>(&t)
374 {
375 if ack.id != expected_id {
376 continue;
377 }
378 if let Some(err) = ack.error {
379 return Err(format!("subscribe rejected: {err}"));
380 }
381 if let Some(sub_id) = ack.result {
382 return Ok(sub_id);
383 }
384 return Err("subscribe ack missing result".into());
385 }
386 }
387}
388
389enum TextOutcome {
391 Continue,
393 AllComplete,
395 ChannelClosed,
397}
398
399async fn handle_text<K: SubKind>(
403 text: &str,
404 subs: &Subs,
405 notifications_tx: &mpsc::Sender<SubscriptionNotification>,
406 completed: &mut HashSet<u64>,
407) -> TextOutcome {
408 if let Some(n) = parse_notification::<K>(text, subs) {
411 if notifications_tx
412 .send(K::into_notification(n))
413 .await
414 .is_err()
415 {
416 return TextOutcome::ChannelClosed;
417 }
418 return TextOutcome::Continue;
419 }
420
421 if let Some(sub_id) = parse_completion(text)
423 && subs.contains(&sub_id)
424 {
425 completed.insert(sub_id);
426 if subs.iter().all(|id| completed.contains(id)) {
427 return TextOutcome::AllComplete;
428 }
429 }
430 TextOutcome::Continue
431}
432
433fn parse_completion(text: &str) -> Option<u64> {
436 #[derive(Deserialize)]
437 struct Msg {
438 method: String,
439 params: Params,
440 }
441 #[derive(Deserialize)]
442 struct Params {
443 subscription: u64,
444 }
445
446 let msg: Msg = serde_json::from_str(text).ok()?;
447 (msg.method == "subscriptionComplete").then_some(msg.params.subscription)
448}
449
450fn parse_notification<K: SubKind>(text: &str, subs: &Subs) -> Option<K::Notification> {
451 #[derive(Deserialize)]
452 #[serde(bound = "T: DeserializeOwned")]
453 struct Msg<T> {
454 method: String,
455 params: Params<T>,
456 }
457 #[derive(Deserialize)]
458 #[serde(bound = "T: DeserializeOwned")]
459 struct Params<T> {
460 subscription: u64,
461 result: T,
462 }
463
464 let msg: Msg<K::Notification> = serde_json::from_str(text).ok()?;
465 if msg.method != K::NOTIFICATION_METHOD {
466 return None;
467 }
468 if !subs.contains(&msg.params.subscription) {
469 return None;
470 }
471 Some(msg.params.result)
472}
473
474#[cfg(test)]
475mod tests {
476 use super::parse_completion;
477
478 #[test]
479 fn parse_completion_extracts_subscription_id() {
480 let text =
481 r#"{"jsonrpc":"2.0","method":"subscriptionComplete","params":{"subscription":7}}"#;
482 assert_eq!(parse_completion(text), Some(7));
483 }
484
485 #[test]
486 fn parse_completion_ignores_other_messages() {
487 let notification = r#"{"jsonrpc":"2.0","method":"transactionNotification","params":{"subscription":7,"result":{}}}"#;
488 assert_eq!(parse_completion(notification), None);
489 assert_eq!(parse_completion("not json"), None);
490 }
491}