Skip to main content

simulator_client/managed/
subscription.rs

1//! SubscriptionManager — owns subscription WebSockets and keeps them alive.
2//!
3//! Two variants:
4//! - account-diff subscription (`accountDiffSubscribe`) — account state capture
5//! - transaction subscription (`transactionSubscribe`) — full transaction
6//!   capture, delivers what `getTransaction` would return in one push so the
7//!   client can skip the per-tx fetch entirely
8//!
9//! Both follow the same reconnect + keepalive pattern. On reconnect, all
10//! configured subscriptions are re-established from scratch; we do not attempt
11//! to replay notifications missed during the gap.
12
13use std::{collections::HashSet, marker::PhantomData, time::Instant};
14
15use futures::{SinkExt, StreamExt};
16use serde::{Deserialize, de::DeserializeOwned};
17use solana_transaction_status::EncodedConfirmedTransactionWithStatusMeta;
18use tokio::{
19    net::TcpStream,
20    sync::{mpsc, watch},
21    task::JoinHandle,
22};
23use tokio_tungstenite::{
24    MaybeTlsStream, WebSocketStream, connect_async,
25    tungstenite::{Message, client::IntoClientRequest},
26};
27use tokio_util::sync::CancellationToken;
28use tracing::{debug, warn};
29
30use super::{
31    CONNECT_TIMEOUT, ConnectionStatus, HANDSHAKE_RESPONSE_TIMEOUT, KEEPALIVE_INTERVAL,
32    KEEPALIVE_MISS_DEADLINE, RECONNECT_UPTIME_RESET, ReconnectBudget, cancellable_sleep,
33};
34use crate::{error::err_chain, subscriptions::AccountDiffNotification, urls::http_to_ws_url};
35
36/// Handle to a running subscription manager task.
37pub struct SubscriptionHandle {
38    pub status: watch::Receiver<ConnectionStatus>,
39    pub notifications: mpsc::Receiver<SubscriptionNotification>,
40    pub join: JoinHandle<()>,
41}
42
43#[derive(Debug)]
44pub enum SubscriptionNotification {
45    Transaction(Box<EncodedConfirmedTransactionWithStatusMeta>),
46    AccountDiff(AccountDiffNotification),
47}
48
49/// Per-flavor differences between `accountDiffSubscribe` and `transactionSubscribe`.
50trait SubKind: Send + Sync + 'static {
51    type Notification: DeserializeOwned + Send + 'static;
52    const LABEL: &'static str;
53    const SUBSCRIBE_METHOD: &'static str;
54    const NOTIFICATION_METHOD: &'static str;
55    fn subscribe_params(program_id: &str) -> serde_json::Value;
56    fn into_notification(notification: Self::Notification) -> SubscriptionNotification;
57}
58
59struct AccountDiff;
60impl SubKind for AccountDiff {
61    type Notification = AccountDiffNotification;
62    const LABEL: &'static str = "account-diff";
63    const SUBSCRIBE_METHOD: &'static str = "accountDiffSubscribe";
64    const NOTIFICATION_METHOD: &'static str = "accountDiffNotification";
65    fn subscribe_params(program_id: &str) -> serde_json::Value {
66        serde_json::json!([program_id, {"address_type": "program"}])
67    }
68    fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
69        SubscriptionNotification::AccountDiff(notification)
70    }
71}
72
73struct Transaction;
74impl SubKind for Transaction {
75    /// Wire shape is identical to the `getTransaction` RPC response, so we can
76    /// reuse `transaction_from_encoded` to build the output record directly
77    /// from the push notification — no follow-up fetch required.
78    type Notification = EncodedConfirmedTransactionWithStatusMeta;
79    const LABEL: &'static str = "transaction";
80    const SUBSCRIBE_METHOD: &'static str = "transactionSubscribe";
81    const NOTIFICATION_METHOD: &'static str = "transactionNotification";
82    fn subscribe_params(program_id: &str) -> serde_json::Value {
83        serde_json::json!([{"mentions": [program_id]}, {"commitment": "confirmed"}])
84    }
85    fn into_notification(notification: Self::Notification) -> SubscriptionNotification {
86        SubscriptionNotification::Transaction(Box::new(notification))
87    }
88}
89
90pub fn spawn_transaction_subscription_manager(
91    rpc_endpoint: String,
92    program_ids: Vec<String>,
93    cancel: CancellationToken,
94) -> SubscriptionHandle {
95    spawn_subscription_manager::<Transaction>(rpc_endpoint, program_ids, cancel)
96}
97
98pub fn spawn_account_diff_subscription_manager(
99    rpc_endpoint: String,
100    program_ids: Vec<String>,
101    cancel: CancellationToken,
102) -> SubscriptionHandle {
103    spawn_subscription_manager::<AccountDiff>(rpc_endpoint, program_ids, cancel)
104}
105
106fn spawn_subscription_manager<K>(
107    rpc_endpoint: String,
108    program_ids: Vec<String>,
109    cancel: CancellationToken,
110) -> SubscriptionHandle
111where
112    K: SubKind,
113{
114    let (notifications_tx, notifications_rx) = mpsc::channel(1024);
115    let (status_tx, status_rx) = watch::channel(ConnectionStatus::Down);
116    let task = Task::<K> {
117        rpc_endpoint,
118        program_ids,
119        notifications_tx,
120        status_tx,
121        cancel,
122        _marker: PhantomData,
123    };
124    let join = tokio::spawn(task.run());
125    SubscriptionHandle {
126        status: status_rx,
127        notifications: notifications_rx,
128        join,
129    }
130}
131
132type Ws = WebSocketStream<MaybeTlsStream<TcpStream>>;
133type Subs = HashSet<u64>;
134
135struct Task<K: SubKind> {
136    rpc_endpoint: String,
137    program_ids: Vec<String>,
138    notifications_tx: mpsc::Sender<SubscriptionNotification>,
139    status_tx: watch::Sender<ConnectionStatus>,
140    /// Session-scoped cancel; fires both on user Ctrl-C *and* on normal
141    /// session completion. Stops the connect/message loop either way.
142    cancel: CancellationToken,
143    _marker: PhantomData<fn() -> K>,
144}
145
146impl<K: SubKind> Task<K> {
147    async fn run(self) {
148        let mut budget = ReconnectBudget::new();
149
150        loop {
151            if self.cancel.is_cancelled() {
152                break;
153            }
154            publish(&self.status_tx, ConnectionStatus::Down);
155
156            let connect_result = async {
157                let ws = connect_ws(&self.rpc_endpoint).await?;
158                subscribe::<K>(ws, &self.program_ids).await
159            }
160            .await;
161
162            let (ws, subs) = match connect_result {
163                Ok(v) => v,
164                Err(why) => {
165                    if retry_or_fail::<K>(
166                        "connect",
167                        why,
168                        &mut budget,
169                        &self.cancel,
170                        &self.status_tx,
171                    )
172                    .await
173                    {
174                        continue;
175                    }
176                    break;
177                }
178            };
179
180            publish(&self.status_tx, ConnectionStatus::Up);
181            let connected_at = Instant::now();
182
183            let exit = message_loop::<K>(ws, subs, &self.notifications_tx, &self.cancel).await;
184
185            match exit {
186                MessageLoopExit::Cancelled | MessageLoopExit::Completed => break,
187                MessageLoopExit::ConnectionLost(why) => {
188                    if connected_at.elapsed() >= RECONNECT_UPTIME_RESET {
189                        budget.reset();
190                    }
191                    if retry_or_fail::<K>(
192                        "connection lost",
193                        why,
194                        &mut budget,
195                        &self.cancel,
196                        &self.status_tx,
197                    )
198                    .await
199                    {
200                        continue;
201                    }
202                    break;
203                }
204            }
205        }
206    }
207}
208
209enum MessageLoopExit {
210    Cancelled,
211    ConnectionLost(String),
212    /// Every subscription on this connection delivered its end-of-stream
213    /// terminal. Stop cleanly without reconnecting.
214    Completed,
215}
216
217async fn message_loop<K: SubKind>(
218    mut ws: Ws,
219    subs: Subs,
220    notifications_tx: &mpsc::Sender<SubscriptionNotification>,
221    cancel: &CancellationToken,
222) -> MessageLoopExit {
223    let mut ping_timer = tokio::time::interval(KEEPALIVE_INTERVAL);
224    ping_timer.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
225    let mut last_inbound = Instant::now();
226    // Subscriptions whose terminal `subscriptionComplete` marker has arrived.
227    // Once this covers every entry in `subs`, the stream is fully drained.
228    let mut completed: HashSet<u64> = HashSet::new();
229
230    loop {
231        tokio::select! {
232            biased;
233            _ = cancel.cancelled() => return MessageLoopExit::Cancelled,
234
235            _ = ping_timer.tick() => {
236                if last_inbound.elapsed() > KEEPALIVE_MISS_DEADLINE {
237                    return MessageLoopExit::ConnectionLost(format!(
238                        "no traffic for {:?}", last_inbound.elapsed()
239                    ));
240                }
241                if let Err(e) = ws.send(Message::Ping(vec![])).await {
242                    return MessageLoopExit::ConnectionLost(format!("ping send: {}", err_chain(&e)));
243                }
244            }
245
246            msg = ws.next() => {
247                last_inbound = Instant::now();
248                match msg {
249                    Some(Ok(Message::Text(t))) => {
250                        match handle_text::<K>(&t, &subs, notifications_tx, &mut completed).await {
251                            TextOutcome::Continue => {}
252                            TextOutcome::AllComplete => return MessageLoopExit::Completed,
253                            TextOutcome::ChannelClosed => return MessageLoopExit::Cancelled,
254                        }
255                    }
256                    Some(Ok(Message::Binary(b))) => {
257                        if let Ok(t) = std::str::from_utf8(&b) {
258                            match handle_text::<K>(t, &subs, notifications_tx, &mut completed).await {
259                                TextOutcome::Continue => {}
260                                TextOutcome::AllComplete => return MessageLoopExit::Completed,
261                                TextOutcome::ChannelClosed => return MessageLoopExit::Cancelled,
262                            }
263                        }
264                    }
265                    Some(Ok(Message::Pong(_))) | Some(Ok(Message::Ping(_))) => {}
266                    Some(Ok(Message::Close(frame))) => {
267                        return MessageLoopExit::ConnectionLost(format!("remote close: {frame:?}"));
268                    }
269                    Some(Ok(Message::Frame(_))) => {}
270                    Some(Err(e)) => return MessageLoopExit::ConnectionLost(format!("ws read: {}", err_chain(&e))),
271                    None => return MessageLoopExit::ConnectionLost("ws stream ended".into()),
272                }
273            }
274        }
275    }
276}
277
278/// Sleep for the next backoff interval, or publish `Failed` and return false if
279/// the retry budget is exhausted. Returns true if the caller should retry.
280async fn retry_or_fail<K: SubKind>(
281    phase: &'static str,
282    reason: String,
283    budget: &mut ReconnectBudget,
284    cancel: &CancellationToken,
285    status_tx: &watch::Sender<ConnectionStatus>,
286) -> bool {
287    if let Some(delay) = budget.next_backoff() {
288        warn!(
289            kind = K::LABEL,
290            attempt = budget.attempt(),
291            reason = %reason,
292            ?delay,
293            "subscription {phase}, retrying",
294        );
295        cancellable_sleep(delay, cancel).await
296    } else {
297        publish(
298            status_tx,
299            ConnectionStatus::Failed(format!("{phase}: {reason}")),
300        );
301        false
302    }
303}
304
305fn publish(tx: &watch::Sender<ConnectionStatus>, status: ConnectionStatus) {
306    tx.send_if_modified(|current| {
307        if *current == status {
308            false
309        } else {
310            *current = status;
311            true
312        }
313    });
314}
315
316async fn connect_ws(rpc_endpoint: &str) -> Result<Ws, String> {
317    let ws_url = http_to_ws_url(rpc_endpoint).map_err(|e| err_chain(&e))?;
318    let request = ws_url
319        .into_client_request()
320        .map_err(|e| format!("build request: {}", err_chain(&e)))?;
321
322    let connect = tokio::time::timeout(CONNECT_TIMEOUT, connect_async(request))
323        .await
324        .map_err(|_| format!("connect timeout after {CONNECT_TIMEOUT:?}"))?
325        .map_err(|e| format!("connect: {}", err_chain(&e)))?;
326    Ok(connect.0)
327}
328
329async fn subscribe<K: SubKind>(mut ws: Ws, program_ids: &[String]) -> Result<(Ws, Subs), String> {
330    let mut subs = Subs::new();
331    for (i, program_id) in program_ids.iter().enumerate() {
332        let id = (i + 1) as u64;
333        let req = serde_json::json!({
334            "jsonrpc": "2.0",
335            "id": id,
336            "method": K::SUBSCRIBE_METHOD,
337            "params": K::subscribe_params(program_id),
338        });
339        ws.send(Message::Text(req.to_string()))
340            .await
341            .map_err(|e| format!("subscribe send: {}", err_chain(&e)))?;
342        subs.insert(read_sub_ack(&mut ws, id).await?);
343    }
344    debug!(
345        kind = K::LABEL,
346        count = subs.len(),
347        "subscriptions established"
348    );
349    Ok((ws, subs))
350}
351
352#[derive(Deserialize)]
353struct SubAck {
354    id: u64,
355    result: Option<u64>,
356    #[serde(default)]
357    error: Option<serde_json::Value>,
358}
359
360async fn read_sub_ack(ws: &mut Ws, expected_id: u64) -> Result<u64, String> {
361    let deadline = tokio::time::Instant::now() + HANDSHAKE_RESPONSE_TIMEOUT;
362    loop {
363        let msg = tokio::time::timeout_at(deadline, ws.next())
364            .await
365            .map_err(|_| format!("subscribe ack timeout after {HANDSHAKE_RESPONSE_TIMEOUT:?}"))?;
366
367        let Some(msg) = msg else {
368            return Err("ws ended during subscribe".into());
369        };
370        let msg = msg.map_err(|e| format!("ws read: {}", err_chain(&e)))?;
371
372        if let Message::Text(t) = msg
373            && let Ok(ack) = serde_json::from_str::<SubAck>(&t)
374        {
375            if ack.id != expected_id {
376                continue;
377            }
378            if let Some(err) = ack.error {
379                return Err(format!("subscribe rejected: {err}"));
380            }
381            if let Some(sub_id) = ack.result {
382                return Ok(sub_id);
383            }
384            return Err("subscribe ack missing result".into());
385        }
386    }
387}
388
389/// Result of feeding one inbound text frame to the message loop.
390enum TextOutcome {
391    /// Keep reading (notification handled, or frame ignored).
392    Continue,
393    /// Every subscription on this connection has delivered its terminal marker.
394    AllComplete,
395    /// The downstream notifications channel closed — caller is gone.
396    ChannelClosed,
397}
398
399/// Handle one inbound text frame: forward a matching notification, or record a
400/// terminal `subscriptionComplete` marker. Returns [`TextOutcome::AllComplete`]
401/// once a terminal has been seen for every subscription in `subs`.
402async fn handle_text<K: SubKind>(
403    text: &str,
404    subs: &Subs,
405    notifications_tx: &mpsc::Sender<SubscriptionNotification>,
406    completed: &mut HashSet<u64>,
407) -> TextOutcome {
408    // Try a data notification first — that's the overwhelmingly common frame,
409    // so the hot path parses the payload once.
410    if let Some(n) = parse_notification::<K>(text, subs) {
411        if notifications_tx
412            .send(K::into_notification(n))
413            .await
414            .is_err()
415        {
416            return TextOutcome::ChannelClosed;
417        }
418        return TextOutcome::Continue;
419    }
420
421    // Otherwise it may be the terminal end-of-stream marker.
422    if let Some(sub_id) = parse_completion(text)
423        && subs.contains(&sub_id)
424    {
425        completed.insert(sub_id);
426        if subs.iter().all(|id| completed.contains(id)) {
427            return TextOutcome::AllComplete;
428        }
429    }
430    TextOutcome::Continue
431}
432
433/// Parse a terminal end-of-stream marker, returning the subscription id it
434/// targets. Shape: `{"method":"subscriptionComplete","params":{"subscription":N}}`.
435fn parse_completion(text: &str) -> Option<u64> {
436    #[derive(Deserialize)]
437    struct Msg {
438        method: String,
439        params: Params,
440    }
441    #[derive(Deserialize)]
442    struct Params {
443        subscription: u64,
444    }
445
446    let msg: Msg = serde_json::from_str(text).ok()?;
447    (msg.method == "subscriptionComplete").then_some(msg.params.subscription)
448}
449
450fn parse_notification<K: SubKind>(text: &str, subs: &Subs) -> Option<K::Notification> {
451    #[derive(Deserialize)]
452    #[serde(bound = "T: DeserializeOwned")]
453    struct Msg<T> {
454        method: String,
455        params: Params<T>,
456    }
457    #[derive(Deserialize)]
458    #[serde(bound = "T: DeserializeOwned")]
459    struct Params<T> {
460        subscription: u64,
461        result: T,
462    }
463
464    let msg: Msg<K::Notification> = serde_json::from_str(text).ok()?;
465    if msg.method != K::NOTIFICATION_METHOD {
466        return None;
467    }
468    if !subs.contains(&msg.params.subscription) {
469        return None;
470    }
471    Some(msg.params.result)
472}
473
474#[cfg(test)]
475mod tests {
476    use super::parse_completion;
477
478    #[test]
479    fn parse_completion_extracts_subscription_id() {
480        let text =
481            r#"{"jsonrpc":"2.0","method":"subscriptionComplete","params":{"subscription":7}}"#;
482        assert_eq!(parse_completion(text), Some(7));
483    }
484
485    #[test]
486    fn parse_completion_ignores_other_messages() {
487        let notification = r#"{"jsonrpc":"2.0","method":"transactionNotification","params":{"subscription":7,"result":{}}}"#;
488        assert_eq!(parse_completion(notification), None);
489        assert_eq!(parse_completion("not json"), None);
490    }
491}