Skip to main content

simulator_client/
subscriptions.rs

1use std::{future::Future, time::Duration};
2
3use futures::{SinkExt, StreamExt};
4use serde::Deserialize;
5use solana_client::{
6    nonblocking::pubsub_client::PubsubClient,
7    rpc_response::{Response, RpcLogsResponse},
8};
9use solana_commitment_config::CommitmentConfig;
10use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
11use thiserror::Error;
12use tokio::{
13    sync::{oneshot, watch},
14    task::JoinHandle,
15};
16use tokio_tungstenite::tungstenite::Message;
17
18use crate::urls::{UrlError, http_to_ws_url};
19
20/// Error establishing a PubSub log subscription.
21#[derive(Debug, Error)]
22pub enum SubscriptionError {
23    #[error(transparent)]
24    InvalidUrl(#[from] UrlError),
25
26    #[error("pubsub connect to {url} failed: {source}")]
27    Connect {
28        url: String,
29        #[source]
30        source: Box<dyn std::error::Error + Send + Sync>,
31    },
32
33    #[error("logs_subscribe failed: {source}")]
34    Subscribe {
35        #[source]
36        source: Box<dyn std::error::Error + Send + Sync>,
37    },
38
39    #[error("subscription task exited unexpectedly before signaling ready")]
40    TaskDropped,
41
42    #[error("session has no rpc_endpoint (was the session created?)")]
43    NoRpcEndpoint,
44}
45
46#[derive(Debug, Error)]
47pub enum SubscriptionRuntimeError {
48    #[error("{kind} subscription for {target} closed unexpectedly")]
49    Closed { kind: &'static str, target: String },
50
51    #[error("{kind} subscription callback worker for {target} failed: {source}")]
52    CallbackWorker {
53        kind: &'static str,
54        target: String,
55        #[source]
56        source: tokio::task::JoinError,
57    },
58}
59
60const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
61const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
62
63type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
64type AccountDiffWs =
65    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
66
67/// A unified subscription handle that can represent any subscription type.
68pub struct SubscriptionHandle {
69    pub join_handle: SubscriptionTaskHandle,
70    pub stop: watch::Sender<bool>,
71}
72
73impl From<LogSubscriptionHandle> for SubscriptionHandle {
74    fn from(h: LogSubscriptionHandle) -> Self {
75        Self {
76            join_handle: h.join_handle,
77            stop: h.stop,
78        }
79    }
80}
81
82impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
83    fn from(h: AccountDiffSubscriptionHandle) -> Self {
84        Self {
85            join_handle: h.join_handle,
86            stop: h.stop,
87        }
88    }
89}
90
91/// Handle for a running log subscription background task.
92pub struct LogSubscriptionHandle {
93    /// Background task that drives the subscription and spawns per-notification callbacks.
94    ///
95    /// Resolves after `stop.send(true)` is called, remaining buffered
96    /// notifications are drained, and all spawned callback tasks complete.
97    pub join_handle: SubscriptionTaskHandle,
98
99    /// Send `true` to signal the background task to stop accepting new
100    /// notifications, drain remaining buffered ones, and exit cleanly.
101    pub stop: watch::Sender<bool>,
102}
103
104/// Subscribe to program log notifications and invoke a callback for each one.
105///
106/// Spawns a background task that:
107/// 1. Connects to the PubSub endpoint derived from `rpc_endpoint`.
108/// 2. Subscribes to logs mentioning `program_id`.
109/// 3. For each notification, spawns `on_notification(notification)` as a Tokio task.
110/// 4. When `handle.stop.send(true)` is called, drains remaining buffered
111///    notifications (up to 1s), waits for all spawned tasks, then returns.
112///
113/// Returns after the subscription is established. If setup fails, an error is
114/// returned before any background task is left running.
115///
116/// ## Example
117///
118/// ```no_run
119/// use std::sync::{Arc, Mutex};
120/// use simulator_client::subscribe_program_logs;
121/// use solana_commitment_config::CommitmentConfig;
122///
123/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
124/// let handle = subscribe_program_logs(
125///     "https://api.mainnet-beta.solana.com",
126///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
127///     CommitmentConfig::confirmed(),
128///     |notification| async move {
129///         println!("sig: {}", notification.value.signature);
130///     },
131/// )
132/// .await?;
133///
134/// // ... do other work ...
135///
136/// handle.stop.send(true).ok();
137/// handle.join_handle.await.ok();
138/// # Ok(())
139/// # }
140/// ```
141pub async fn subscribe_program_logs<F, Fut>(
142    rpc_endpoint: &str,
143    program_id: &str,
144    commitment: CommitmentConfig,
145    on_notification: F,
146) -> Result<LogSubscriptionHandle, SubscriptionError>
147where
148    F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
149    Fut: Future<Output = ()> + Send + 'static,
150{
151    let ws_url = http_to_ws_url(rpc_endpoint)?;
152    let program_id = program_id.to_string();
153
154    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
155    let (stop_tx, mut stop_rx) = watch::channel(false);
156
157    // PubsubClient::logs_subscribe borrows &self, so both the client and the
158    // stream must live inside the spawned task. We report setup success/failure
159    // back through a oneshot channel before entering the notification loop.
160    let join_handle = tokio::spawn(async move {
161        let client = match PubsubClient::new(&ws_url).await {
162            Ok(c) => c,
163            Err(e) => {
164                let _ = ready_tx.send(Err(SubscriptionError::Connect {
165                    url: ws_url,
166                    source: Box::new(e),
167                }));
168                return Ok(());
169            }
170        };
171
172        let (mut stream, _unsubscribe) = match client
173            .logs_subscribe(
174                RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
175                RpcTransactionLogsConfig {
176                    commitment: Some(commitment),
177                },
178            )
179            .await
180        {
181            Ok(s) => s,
182            Err(e) => {
183                let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
184                    source: Box::new(e),
185                }));
186                return Ok(());
187            }
188        };
189
190        let _ = ready_tx.send(Ok(()));
191
192        let mut tasks: Vec<JoinHandle<()>> = Vec::new();
193        let kind = "program logs";
194
195        loop {
196            if *stop_rx.borrow() {
197                let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
198                while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
199                    drain_deadline,
200                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
201                )
202                .await
203                {
204                    tasks.push(tokio::spawn(on_notification(notification)));
205                }
206                break;
207            }
208
209            let notification = tokio::select! {
210                n = stream.next() => n,
211                _ = stop_rx.changed() => continue,
212            };
213
214            match notification {
215                Some(n) => tasks.push(tokio::spawn(on_notification(n))),
216                None => return Err(subscription_runtime_closed(kind, &program_id)),
217            }
218        }
219
220        // Wait for all in-flight callback tasks to complete.
221        for task in tasks {
222            if let Err(source) = task.await {
223                return Err(callback_worker_failed(kind, &program_id, source));
224            }
225        }
226
227        Ok(())
228    });
229
230    match ready_rx.await {
231        Ok(Ok(())) => Ok(LogSubscriptionHandle {
232            join_handle,
233            stop: stop_tx,
234        }),
235        Ok(Err(e)) => {
236            join_handle.abort();
237            Err(e)
238        }
239        Err(_) => {
240            join_handle.abort();
241            Err(SubscriptionError::TaskDropped)
242        }
243    }
244}
245
246// ── Account diff subscription ────────────────────────────────────────────────
247
248/// Slot context included in every account diff notification.
249#[derive(Debug, Clone, Deserialize)]
250pub struct AccountDiffContext {
251    pub slot: u64,
252}
253
254/// A single account diff notification delivered by `accountDiffSubscribe`.
255#[derive(Debug, Clone, Deserialize)]
256pub struct AccountDiffNotification {
257    pub context: AccountDiffContext,
258    /// The address of the account that changed.
259    pub account: Option<String>,
260    /// Signature of the transaction that triggered this change, if known.
261    pub signature: Option<String>,
262    /// Position of the transaction within its slot, if known.
263    #[serde(default)]
264    pub tx_index: Option<u32>,
265    /// Unix-seconds block time of the slot, if known.
266    #[serde(default)]
267    pub block_time: Option<i64>,
268    /// Account state before the change (absent for newly created accounts).
269    pub pre: Option<serde_json::Value>,
270    /// Account state after the change (absent for deleted accounts).
271    pub post: Option<serde_json::Value>,
272}
273
274/// A routed account diff notification tied to the subscribed account that produced it.
275#[derive(Debug, Clone)]
276pub struct RoutedAccountDiffNotification {
277    pub account: String,
278    pub notification: AccountDiffNotification,
279}
280
281/// Handle for a running account diff subscription background task.
282///
283/// Send `true` on `stop` to request a clean shutdown, then await `join_handle`.
284pub struct AccountDiffSubscriptionHandle {
285    pub join_handle: SubscriptionTaskHandle,
286    pub stop: watch::Sender<bool>,
287}
288
289fn subscription_runtime_closed(
290    kind: &'static str,
291    target: impl Into<String>,
292) -> SubscriptionRuntimeError {
293    SubscriptionRuntimeError::Closed {
294        kind,
295        target: target.into(),
296    }
297}
298
299fn callback_worker_failed(
300    kind: &'static str,
301    target: impl Into<String>,
302    source: tokio::task::JoinError,
303) -> SubscriptionRuntimeError {
304    SubscriptionRuntimeError::CallbackWorker {
305        kind,
306        target: target.into(),
307        source,
308    }
309}
310
311/// Subscribe to account diff notifications and invoke a callback for each one.
312///
313/// Spawns a background task that:
314/// 1. Connects to the WebSocket endpoint derived from `rpc_endpoint`.
315/// 2. Subscribes to account diffs for the given filter (account or program).
316/// 3. For each notification, spawns `on_notification(notification)` as a Tokio task.
317/// 4. When `handle.stop.send(true)` is called, drains remaining buffered
318///    notifications (up to 1s), waits for all spawned tasks, then returns.
319///
320/// ## Example
321///
322/// ```no_run
323/// use simulator_client::subscribe_account_diffs;
324///
325/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
326/// let handle = subscribe_account_diffs(
327///     "http://localhost:8900/session/abc",
328///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
329///     |notification| async move {
330///         println!("slot={} sig={:?}", notification.context.slot, notification.signature);
331///     },
332/// )
333/// .await?;
334///
335/// handle.stop.send(true).ok();
336/// handle.join_handle.await.ok();
337/// # Ok(())
338/// # }
339/// ```
340pub async fn subscribe_account_diffs<F, Fut>(
341    rpc_endpoint: &str,
342    account: &str,
343    on_notification: F,
344) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
345where
346    F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
347    Fut: Future<Output = ()> + Send + 'static,
348{
349    subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
350        on_notification(notification.notification)
351    })
352    .await
353}
354
355/// Subscribe to account diff notifications for many accounts over a single websocket.
356///
357/// All requested subscriptions must be acknowledged before this returns. Once the
358/// stream is live, any websocket disconnect is treated as a fatal completeness
359/// error instead of silently reconnecting and risking dropped notifications.
360pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
361    rpc_endpoint: &str,
362    accounts: I,
363    on_notification: F,
364) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
365where
366    F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
367    Fut: Future<Output = ()> + Send + 'static,
368    I: IntoIterator<Item = S>,
369    S: Into<String>,
370{
371    let ws_url = http_to_ws_url(rpc_endpoint)?;
372    let accounts = dedup_accounts(accounts);
373    if accounts.is_empty() {
374        let (stop_tx, stop_rx) = watch::channel(false);
375        return Ok(AccountDiffSubscriptionHandle {
376            join_handle: tokio::spawn(async move {
377                let _ = stop_rx;
378                Ok(())
379            }),
380            stop: stop_tx,
381        });
382    }
383
384    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
385    let (stop_tx, mut stop_rx) = watch::channel(false);
386    let target = format!("{} accounts", accounts.len());
387
388    let join_handle = tokio::spawn(async move {
389        let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
390        let callback_handle = tokio::spawn(async move {
391            while let Some(notification) = notification_rx.recv().await {
392                on_notification(notification).await;
393            }
394        });
395
396        let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
397            Ok(connection) => connection,
398            Err(e) => {
399                let _ = ready_tx.send(Err(SubscriptionError::Connect {
400                    url: ws_url,
401                    source: Box::new(e),
402                }));
403                return Ok(());
404            }
405        };
406
407        let subscriptions =
408            match send_account_diff_subscribe_many(&mut ws, &accounts, &notification_tx).await {
409                Ok(subscriptions) => subscriptions,
410                Err(error) => {
411                    let _ = ready_tx.send(Err(error));
412                    return Ok(());
413                }
414            };
415
416        let _ = ready_tx.send(Ok(()));
417
418        if let Err(error) =
419            drive_account_diff_stream_many(&mut ws, &subscriptions, &notification_tx, &mut stop_rx)
420                .await
421        {
422            drop(notification_tx);
423            if let Err(source) = callback_handle.await {
424                return Err(callback_worker_failed("account diff", target, source));
425            }
426            return Err(error);
427        }
428
429        drop(notification_tx);
430        if let Err(source) = callback_handle.await {
431            return Err(callback_worker_failed("account diff", target, source));
432        }
433
434        Ok(())
435    });
436
437    match ready_rx.await {
438        Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
439            join_handle,
440            stop: stop_tx,
441        }),
442        Ok(Err(e)) => {
443            join_handle.abort();
444            Err(e)
445        }
446        Err(_) => {
447            join_handle.abort();
448            Err(SubscriptionError::TaskDropped)
449        }
450    }
451}
452
453#[derive(Deserialize)]
454struct AccountDiffMessage {
455    method: String,
456    params: AccountDiffParams,
457}
458
459#[derive(Deserialize)]
460struct AccountDiffParams {
461    subscription: u64,
462    result: AccountDiffNotification,
463}
464
465async fn send_account_diff_subscribe_many(
466    ws: &mut AccountDiffWs,
467    accounts: &[String],
468    notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
469) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
470    #[derive(Deserialize)]
471    struct SubscriptionConfirmation {
472        id: u64,
473        result: Option<u64>,
474    }
475
476    let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
477    let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
478
479    for (index, account) in accounts.iter().enumerate() {
480        let request_id = (index + 1) as u64;
481        let req = serde_json::json!({
482            "jsonrpc": "2.0",
483            "id": request_id,
484            "method": "accountDiffSubscribe",
485            "params": [account]
486        });
487        ws.send(Message::Text(req.to_string()))
488            .await
489            .map_err(|source| SubscriptionError::Subscribe {
490                source: Box::new(source),
491            })?;
492        pending.insert(request_id, account.clone());
493    }
494
495    while !pending.is_empty() {
496        match ws.next().await {
497            Some(Ok(Message::Text(text))) => {
498                if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
499                    let Some(account) = pending.remove(&confirmation.id) else {
500                        continue;
501                    };
502                    let Some(subscription_id) = confirmation.result else {
503                        return Err(SubscriptionError::TaskDropped);
504                    };
505                    subscriptions.insert(subscription_id, account);
506                    continue;
507                }
508
509                if let Some(notification) =
510                    parse_routed_account_diff_notification(&text, &subscriptions)
511                {
512                    let _ = notification_tx.send(notification);
513                }
514            }
515            Some(Ok(_)) => {}
516            _ => return Err(SubscriptionError::TaskDropped),
517        }
518    }
519
520    Ok(subscriptions)
521}
522
523async fn drive_account_diff_stream_many(
524    ws: &mut AccountDiffWs,
525    subscriptions: &std::collections::HashMap<u64, String>,
526    notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
527    stop_rx: &mut watch::Receiver<bool>,
528) -> Result<(), SubscriptionRuntimeError> {
529    loop {
530        if *stop_rx.borrow() {
531            let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
532            loop {
533                match tokio::time::timeout_at(
534                    drain_deadline,
535                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
536                )
537                .await
538                {
539                    Ok(Ok(Some(Ok(Message::Text(text))))) => {
540                        if let Some(notification) =
541                            parse_routed_account_diff_notification(&text, subscriptions)
542                        {
543                            let _ = notification_tx.send(notification);
544                        }
545                    }
546                    _ => return Ok(()),
547                }
548            }
549        }
550
551        let msg = tokio::select! {
552            m = ws.next() => m,
553            _ = stop_rx.changed() => continue,
554        };
555
556        match msg {
557            Some(Ok(Message::Text(text))) => {
558                if let Some(notification) =
559                    parse_routed_account_diff_notification(&text, subscriptions)
560                {
561                    let _ = notification_tx.send(notification);
562                }
563            }
564            Some(Ok(_)) => {}
565            _ => {
566                return Err(subscription_runtime_closed(
567                    "account diff",
568                    format!("{} accounts", subscriptions.len()),
569                ));
570            }
571        }
572    }
573}
574
575fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
576    let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
577    (msg.method == "accountDiffNotification").then_some(msg)
578}
579
580fn parse_routed_account_diff_notification(
581    text: &str,
582    subscriptions: &std::collections::HashMap<u64, String>,
583) -> Option<RoutedAccountDiffNotification> {
584    let msg = parse_account_diff_message(text)?;
585    let account = subscriptions.get(&msg.params.subscription)?.clone();
586    Some(RoutedAccountDiffNotification {
587        account,
588        notification: msg.params.result,
589    })
590}
591
592fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
593where
594    I: IntoIterator<Item = S>,
595    S: Into<String>,
596{
597    let mut unique = std::collections::BTreeSet::new();
598    accounts
599        .into_iter()
600        .map(Into::into)
601        .filter(|account| unique.insert(account.clone()))
602        .collect()
603}
604
605// ── Program account diff subscription ────────────────────────────────────────
606
607/// Subscribe to account diff notifications for all accounts owned by a program.
608///
609/// Uses the server-side program filter (`{"address_type": "program"}`), so no
610/// RPC prefetch of program accounts is required.  The callback receives one
611/// [`AccountDiffNotification`] per changed account.
612///
613/// A websocket disconnect is treated as a fatal error — the handle's
614/// `join_handle` resolves with a [`SubscriptionRuntimeError`].
615///
616/// ## Example
617///
618/// ```no_run
619/// use simulator_client::subscribe_program_diffs;
620///
621/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
622/// let handle = subscribe_program_diffs(
623///     "http://localhost:8900/session/abc",
624///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
625///     |notification| async move {
626///         let account = notification.account.unwrap_or_default();
627///         println!("account={account} slot={}", notification.context.slot);
628///     },
629/// )
630/// .await?;
631///
632/// handle.stop.send(true).ok();
633/// handle.join_handle.await.ok();
634/// # Ok(())
635/// # }
636/// ```
637pub async fn subscribe_program_diffs<F, Fut>(
638    rpc_endpoint: &str,
639    program_id: &str,
640    on_notification: F,
641) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
642where
643    F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
644    Fut: Future<Output = ()> + Send + 'static,
645{
646    let ws_url = http_to_ws_url(rpc_endpoint)?;
647    let program_id = program_id.to_string();
648
649    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
650    let (stop_tx, mut stop_rx) = watch::channel(false);
651
652    let join_handle = tokio::spawn(async move {
653        let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
654        let callback_handle = tokio::spawn(async move {
655            while let Some(notification) = notification_rx.recv().await {
656                on_notification(notification).await;
657            }
658        });
659
660        let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
661            Ok(connection) => connection,
662            Err(e) => {
663                let _ = ready_tx.send(Err(SubscriptionError::Connect {
664                    url: ws_url,
665                    source: Box::new(e),
666                }));
667                return Ok(());
668            }
669        };
670
671        if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
672            let _ = ready_tx.send(Err(error));
673            return Ok(());
674        }
675
676        let _ = ready_tx.send(Ok(()));
677
678        if let Err(error) =
679            drive_program_diff_stream(&mut ws, &notification_tx, &mut stop_rx, &program_id).await
680        {
681            drop(notification_tx);
682            if let Err(source) = callback_handle.await {
683                return Err(callback_worker_failed(
684                    "program account diff",
685                    &program_id,
686                    source,
687                ));
688            }
689            return Err(error);
690        }
691
692        drop(notification_tx);
693        if let Err(source) = callback_handle.await {
694            return Err(callback_worker_failed(
695                "program account diff",
696                &program_id,
697                source,
698            ));
699        }
700
701        Ok(())
702    });
703
704    match ready_rx.await {
705        Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
706            join_handle,
707            stop: stop_tx,
708        }),
709        Ok(Err(e)) => {
710            join_handle.abort();
711            Err(e)
712        }
713        Err(_) => {
714            join_handle.abort();
715            Err(SubscriptionError::TaskDropped)
716        }
717    }
718}
719
720async fn send_program_diff_subscribe(
721    ws: &mut AccountDiffWs,
722    program_id: &str,
723) -> Result<(), SubscriptionError> {
724    #[derive(Deserialize)]
725    struct SubscriptionConfirmation {
726        result: Option<u64>,
727    }
728
729    let req = serde_json::json!({
730        "jsonrpc": "2.0",
731        "id": 1,
732        "method": "accountDiffSubscribe",
733        "params": [program_id, {"address_type": "program"}]
734    });
735    ws.send(Message::Text(req.to_string()))
736        .await
737        .map_err(|source| SubscriptionError::Subscribe {
738            source: Box::new(source),
739        })?;
740
741    loop {
742        match ws.next().await {
743            Some(Ok(Message::Text(text))) => {
744                match serde_json::from_str::<SubscriptionConfirmation>(&text) {
745                    Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
746                    Ok(_) => continue,
747                    Err(source) => {
748                        return Err(SubscriptionError::Subscribe {
749                            source: Box::new(source),
750                        });
751                    }
752                }
753            }
754            Some(Ok(_)) => continue,
755            _ => return Err(SubscriptionError::TaskDropped),
756        }
757    }
758}
759
760async fn drive_program_diff_stream(
761    ws: &mut AccountDiffWs,
762    notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
763    stop_rx: &mut watch::Receiver<bool>,
764    program_id: &str,
765) -> Result<(), SubscriptionRuntimeError> {
766    loop {
767        if *stop_rx.borrow() {
768            let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
769            loop {
770                match tokio::time::timeout_at(
771                    drain_deadline,
772                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
773                )
774                .await
775                {
776                    Ok(Ok(Some(Ok(Message::Text(text))))) => {
777                        if let Some(msg) = parse_account_diff_message(&text) {
778                            let _ = notification_tx.send(msg.params.result);
779                        }
780                    }
781                    _ => return Ok(()),
782                }
783            }
784        }
785
786        let msg = tokio::select! {
787            m = ws.next() => m,
788            _ = stop_rx.changed() => continue,
789        };
790
791        match msg {
792            Some(Ok(Message::Text(text))) => {
793                if let Some(msg) = parse_account_diff_message(&text) {
794                    let _ = notification_tx.send(msg.params.result);
795                }
796            }
797            Some(Ok(_)) => {}
798            _ => {
799                return Err(subscription_runtime_closed(
800                    "program account diff",
801                    program_id,
802                ));
803            }
804        }
805    }
806}
807
808#[cfg(test)]
809mod tests {
810    use super::*;
811
812    #[test]
813    fn parse_account_diff_notification_ignores_other_messages() {
814        let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
815        assert!(parse_account_diff_message(confirmation).is_none());
816    }
817
818    #[test]
819    fn parse_account_diff_notification_extracts_payload() {
820        let text = r#"{
821            "jsonrpc":"2.0",
822            "method":"accountDiffNotification",
823            "params":{
824                "subscription":7,
825                "result":{
826                    "context":{"slot":123},
827                    "signature":"sig",
828                    "pre":{"a":1},
829                    "post":{"a":2}
830                }
831            }
832        }"#;
833
834        let notification = parse_account_diff_message(text)
835            .expect("notification")
836            .params
837            .result;
838        assert_eq!(notification.context.slot, 123);
839        assert_eq!(notification.signature.as_deref(), Some("sig"));
840        assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
841        assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
842    }
843
844    #[test]
845    fn parse_routed_account_diff_notification_extracts_subscription_account() {
846        let text = r#"{
847            "jsonrpc":"2.0",
848            "method":"accountDiffNotification",
849            "params":{
850                "subscription":42,
851                "result":{
852                    "context":{"slot":456},
853                    "signature":"sig",
854                    "pre":null,
855                    "post":{"a":2}
856                }
857            }
858        }"#;
859        let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
860
861        let notification =
862            parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
863        assert_eq!(notification.account, "acct");
864        assert_eq!(notification.notification.context.slot, 456);
865    }
866
867    #[test]
868    fn dedup_accounts_preserves_first_seen_order() {
869        let accounts = dedup_accounts([
870            "b".to_string(),
871            "a".to_string(),
872            "b".to_string(),
873            "c".to_string(),
874        ]);
875        assert_eq!(accounts, vec!["b", "a", "c"]);
876    }
877}