Skip to main content

simulator_client/
subscriptions.rs

1use std::{future::Future, time::Duration};
2
3use futures::{SinkExt, StreamExt};
4use serde::Deserialize;
5use solana_client::{
6    nonblocking::pubsub_client::PubsubClient,
7    rpc_response::{Response, RpcLogsResponse},
8};
9use solana_commitment_config::CommitmentConfig;
10use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
11use thiserror::Error;
12use tokio::{
13    sync::{oneshot, watch},
14    task::JoinHandle,
15};
16use tokio_tungstenite::tungstenite::Message;
17
18use crate::urls::{UrlError, http_to_ws_url};
19
20/// Error establishing a PubSub log subscription.
21#[derive(Debug, Error)]
22pub enum SubscriptionError {
23    #[error(transparent)]
24    InvalidUrl(#[from] UrlError),
25
26    #[error("pubsub connect to {url} failed: {source}")]
27    Connect {
28        url: String,
29        #[source]
30        source: Box<dyn std::error::Error + Send + Sync>,
31    },
32
33    #[error("logs_subscribe failed: {source}")]
34    Subscribe {
35        #[source]
36        source: Box<dyn std::error::Error + Send + Sync>,
37    },
38
39    #[error("subscription task exited unexpectedly before signaling ready")]
40    TaskDropped,
41
42    #[error("session has no rpc_endpoint (was the session created?)")]
43    NoRpcEndpoint,
44}
45
46#[derive(Debug, Error)]
47pub enum SubscriptionRuntimeError {
48    #[error("{kind} subscription for {target} closed unexpectedly")]
49    Closed { kind: &'static str, target: String },
50
51    #[error("{kind} subscription callback worker for {target} failed: {source}")]
52    CallbackWorker {
53        kind: &'static str,
54        target: String,
55        #[source]
56        source: tokio::task::JoinError,
57    },
58}
59
60const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
61const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
62
63type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
64type AccountDiffWs =
65    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
66
67/// A unified subscription handle that can represent any subscription type.
68pub struct SubscriptionHandle {
69    pub join_handle: SubscriptionTaskHandle,
70    pub stop: watch::Sender<bool>,
71}
72
73impl From<LogSubscriptionHandle> for SubscriptionHandle {
74    fn from(h: LogSubscriptionHandle) -> Self {
75        Self {
76            join_handle: h.join_handle,
77            stop: h.stop,
78        }
79    }
80}
81
82impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
83    fn from(h: AccountDiffSubscriptionHandle) -> Self {
84        Self {
85            join_handle: h.join_handle,
86            stop: h.stop,
87        }
88    }
89}
90
91/// Handle for a running log subscription background task.
92pub struct LogSubscriptionHandle {
93    /// Background task that drives the subscription and spawns per-notification callbacks.
94    ///
95    /// Resolves after `stop.send(true)` is called, remaining buffered
96    /// notifications are drained, and all spawned callback tasks complete.
97    pub join_handle: SubscriptionTaskHandle,
98
99    /// Send `true` to signal the background task to stop accepting new
100    /// notifications, drain remaining buffered ones, and exit cleanly.
101    pub stop: watch::Sender<bool>,
102}
103
104/// Subscribe to program log notifications and invoke a callback for each one.
105///
106/// Spawns a background task that:
107/// 1. Connects to the PubSub endpoint derived from `rpc_endpoint`.
108/// 2. Subscribes to logs mentioning `program_id`.
109/// 3. For each notification, spawns `on_notification(notification)` as a Tokio task.
110/// 4. When `handle.stop.send(true)` is called, drains remaining buffered
111///    notifications (up to 1s), waits for all spawned tasks, then returns.
112///
113/// Returns after the subscription is established. If setup fails, an error is
114/// returned before any background task is left running.
115///
116/// ## Example
117///
118/// ```no_run
119/// use std::sync::{Arc, Mutex};
120/// use simulator_client::subscribe_program_logs;
121/// use solana_commitment_config::CommitmentConfig;
122///
123/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
124/// let handle = subscribe_program_logs(
125///     "https://api.mainnet-beta.solana.com",
126///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
127///     CommitmentConfig::confirmed(),
128///     |notification| async move {
129///         println!("sig: {}", notification.value.signature);
130///     },
131/// )
132/// .await?;
133///
134/// // ... do other work ...
135///
136/// handle.stop.send(true).ok();
137/// handle.join_handle.await.ok();
138/// # Ok(())
139/// # }
140/// ```
141pub async fn subscribe_program_logs<F, Fut>(
142    rpc_endpoint: &str,
143    program_id: &str,
144    commitment: CommitmentConfig,
145    on_notification: F,
146) -> Result<LogSubscriptionHandle, SubscriptionError>
147where
148    F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
149    Fut: Future<Output = ()> + Send + 'static,
150{
151    let ws_url = http_to_ws_url(rpc_endpoint)?;
152    let program_id = program_id.to_string();
153
154    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
155    let (stop_tx, mut stop_rx) = watch::channel(false);
156
157    // PubsubClient::logs_subscribe borrows &self, so both the client and the
158    // stream must live inside the spawned task. We report setup success/failure
159    // back through a oneshot channel before entering the notification loop.
160    let join_handle = tokio::spawn(async move {
161        let client = match PubsubClient::new(&ws_url).await {
162            Ok(c) => c,
163            Err(e) => {
164                let _ = ready_tx.send(Err(SubscriptionError::Connect {
165                    url: ws_url,
166                    source: Box::new(e),
167                }));
168                return Ok(());
169            }
170        };
171
172        let (mut stream, _unsubscribe) = match client
173            .logs_subscribe(
174                RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
175                RpcTransactionLogsConfig {
176                    commitment: Some(commitment),
177                },
178            )
179            .await
180        {
181            Ok(s) => s,
182            Err(e) => {
183                let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
184                    source: Box::new(e),
185                }));
186                return Ok(());
187            }
188        };
189
190        let _ = ready_tx.send(Ok(()));
191
192        let mut tasks: Vec<JoinHandle<()>> = Vec::new();
193        let kind = "program logs";
194
195        loop {
196            if *stop_rx.borrow() {
197                let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
198                while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
199                    drain_deadline,
200                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
201                )
202                .await
203                {
204                    tasks.push(tokio::spawn(on_notification(notification)));
205                }
206                break;
207            }
208
209            let notification = tokio::select! {
210                n = stream.next() => n,
211                _ = stop_rx.changed() => continue,
212            };
213
214            match notification {
215                Some(n) => tasks.push(tokio::spawn(on_notification(n))),
216                None => return Err(subscription_runtime_closed(kind, &program_id)),
217            }
218        }
219
220        // Wait for all in-flight callback tasks to complete.
221        for task in tasks {
222            if let Err(source) = task.await {
223                return Err(callback_worker_failed(kind, &program_id, source));
224            }
225        }
226
227        Ok(())
228    });
229
230    match ready_rx.await {
231        Ok(Ok(())) => Ok(LogSubscriptionHandle {
232            join_handle,
233            stop: stop_tx,
234        }),
235        Ok(Err(e)) => {
236            join_handle.abort();
237            Err(e)
238        }
239        Err(_) => {
240            join_handle.abort();
241            Err(SubscriptionError::TaskDropped)
242        }
243    }
244}
245
246// ── Account diff subscription ────────────────────────────────────────────────
247
248/// Slot context included in every account diff notification.
249#[derive(Debug, Clone, Deserialize)]
250pub struct AccountDiffContext {
251    pub slot: u64,
252}
253
254/// A single account diff notification delivered by `accountDiffSubscribe`.
255#[derive(Debug, Clone, Deserialize)]
256pub struct AccountDiffNotification {
257    pub context: AccountDiffContext,
258    /// The address of the account that changed.
259    pub account: Option<String>,
260    /// Signature of the transaction that triggered this change, if known.
261    pub signature: Option<String>,
262    /// Account state before the change (absent for newly created accounts).
263    pub pre: Option<serde_json::Value>,
264    /// Account state after the change (absent for deleted accounts).
265    pub post: Option<serde_json::Value>,
266}
267
268/// A routed account diff notification tied to the subscribed account that produced it.
269#[derive(Debug, Clone)]
270pub struct RoutedAccountDiffNotification {
271    pub account: String,
272    pub notification: AccountDiffNotification,
273}
274
275/// Handle for a running account diff subscription background task.
276///
277/// Send `true` on `stop` to request a clean shutdown, then await `join_handle`.
278pub struct AccountDiffSubscriptionHandle {
279    pub join_handle: SubscriptionTaskHandle,
280    pub stop: watch::Sender<bool>,
281}
282
283fn subscription_runtime_closed(
284    kind: &'static str,
285    target: impl Into<String>,
286) -> SubscriptionRuntimeError {
287    SubscriptionRuntimeError::Closed {
288        kind,
289        target: target.into(),
290    }
291}
292
293fn callback_worker_failed(
294    kind: &'static str,
295    target: impl Into<String>,
296    source: tokio::task::JoinError,
297) -> SubscriptionRuntimeError {
298    SubscriptionRuntimeError::CallbackWorker {
299        kind,
300        target: target.into(),
301        source,
302    }
303}
304
305/// Subscribe to account diff notifications and invoke a callback for each one.
306///
307/// Spawns a background task that:
308/// 1. Connects to the WebSocket endpoint derived from `rpc_endpoint`.
309/// 2. Subscribes to account diffs for the given filter (account or program).
310/// 3. For each notification, spawns `on_notification(notification)` as a Tokio task.
311/// 4. When `handle.stop.send(true)` is called, drains remaining buffered
312///    notifications (up to 1s), waits for all spawned tasks, then returns.
313///
314/// ## Example
315///
316/// ```no_run
317/// use simulator_client::subscribe_account_diffs;
318///
319/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
320/// let handle = subscribe_account_diffs(
321///     "http://localhost:8900/session/abc",
322///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
323///     |notification| async move {
324///         println!("slot={} sig={:?}", notification.context.slot, notification.signature);
325///     },
326/// )
327/// .await?;
328///
329/// handle.stop.send(true).ok();
330/// handle.join_handle.await.ok();
331/// # Ok(())
332/// # }
333/// ```
334pub async fn subscribe_account_diffs<F, Fut>(
335    rpc_endpoint: &str,
336    account: &str,
337    on_notification: F,
338) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
339where
340    F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
341    Fut: Future<Output = ()> + Send + 'static,
342{
343    subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
344        on_notification(notification.notification)
345    })
346    .await
347}
348
349/// Subscribe to account diff notifications for many accounts over a single websocket.
350///
351/// All requested subscriptions must be acknowledged before this returns. Once the
352/// stream is live, any websocket disconnect is treated as a fatal completeness
353/// error instead of silently reconnecting and risking dropped notifications.
354pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
355    rpc_endpoint: &str,
356    accounts: I,
357    on_notification: F,
358) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
359where
360    F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
361    Fut: Future<Output = ()> + Send + 'static,
362    I: IntoIterator<Item = S>,
363    S: Into<String>,
364{
365    let ws_url = http_to_ws_url(rpc_endpoint)?;
366    let accounts = dedup_accounts(accounts);
367    if accounts.is_empty() {
368        let (stop_tx, stop_rx) = watch::channel(false);
369        return Ok(AccountDiffSubscriptionHandle {
370            join_handle: tokio::spawn(async move {
371                let _ = stop_rx;
372                Ok(())
373            }),
374            stop: stop_tx,
375        });
376    }
377
378    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
379    let (stop_tx, mut stop_rx) = watch::channel(false);
380    let target = format!("{} accounts", accounts.len());
381
382    let join_handle = tokio::spawn(async move {
383        let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
384        let callback_handle = tokio::spawn(async move {
385            while let Some(notification) = notification_rx.recv().await {
386                on_notification(notification).await;
387            }
388        });
389
390        let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
391            Ok(connection) => connection,
392            Err(e) => {
393                let _ = ready_tx.send(Err(SubscriptionError::Connect {
394                    url: ws_url,
395                    source: Box::new(e),
396                }));
397                return Ok(());
398            }
399        };
400
401        let subscriptions =
402            match send_account_diff_subscribe_many(&mut ws, &accounts, &notification_tx).await {
403                Ok(subscriptions) => subscriptions,
404                Err(error) => {
405                    let _ = ready_tx.send(Err(error));
406                    return Ok(());
407                }
408            };
409
410        let _ = ready_tx.send(Ok(()));
411
412        if let Err(error) =
413            drive_account_diff_stream_many(&mut ws, &subscriptions, &notification_tx, &mut stop_rx)
414                .await
415        {
416            drop(notification_tx);
417            if let Err(source) = callback_handle.await {
418                return Err(callback_worker_failed("account diff", target, source));
419            }
420            return Err(error);
421        }
422
423        drop(notification_tx);
424        if let Err(source) = callback_handle.await {
425            return Err(callback_worker_failed("account diff", target, source));
426        }
427
428        Ok(())
429    });
430
431    match ready_rx.await {
432        Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
433            join_handle,
434            stop: stop_tx,
435        }),
436        Ok(Err(e)) => {
437            join_handle.abort();
438            Err(e)
439        }
440        Err(_) => {
441            join_handle.abort();
442            Err(SubscriptionError::TaskDropped)
443        }
444    }
445}
446
447#[derive(Deserialize)]
448struct AccountDiffMessage {
449    method: String,
450    params: AccountDiffParams,
451}
452
453#[derive(Deserialize)]
454struct AccountDiffParams {
455    subscription: u64,
456    result: AccountDiffNotification,
457}
458
459async fn send_account_diff_subscribe_many(
460    ws: &mut AccountDiffWs,
461    accounts: &[String],
462    notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
463) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
464    #[derive(Deserialize)]
465    struct SubscriptionConfirmation {
466        id: u64,
467        result: Option<u64>,
468    }
469
470    let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
471    let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
472
473    for (index, account) in accounts.iter().enumerate() {
474        let request_id = (index + 1) as u64;
475        let req = serde_json::json!({
476            "jsonrpc": "2.0",
477            "id": request_id,
478            "method": "accountDiffSubscribe",
479            "params": [account]
480        });
481        ws.send(Message::Text(req.to_string()))
482            .await
483            .map_err(|source| SubscriptionError::Subscribe {
484                source: Box::new(source),
485            })?;
486        pending.insert(request_id, account.clone());
487    }
488
489    while !pending.is_empty() {
490        match ws.next().await {
491            Some(Ok(Message::Text(text))) => {
492                if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
493                    let Some(account) = pending.remove(&confirmation.id) else {
494                        continue;
495                    };
496                    let Some(subscription_id) = confirmation.result else {
497                        return Err(SubscriptionError::TaskDropped);
498                    };
499                    subscriptions.insert(subscription_id, account);
500                    continue;
501                }
502
503                if let Some(notification) =
504                    parse_routed_account_diff_notification(&text, &subscriptions)
505                {
506                    let _ = notification_tx.send(notification);
507                }
508            }
509            Some(Ok(_)) => {}
510            _ => return Err(SubscriptionError::TaskDropped),
511        }
512    }
513
514    Ok(subscriptions)
515}
516
517async fn drive_account_diff_stream_many(
518    ws: &mut AccountDiffWs,
519    subscriptions: &std::collections::HashMap<u64, String>,
520    notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
521    stop_rx: &mut watch::Receiver<bool>,
522) -> Result<(), SubscriptionRuntimeError> {
523    loop {
524        if *stop_rx.borrow() {
525            let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
526            loop {
527                match tokio::time::timeout_at(
528                    drain_deadline,
529                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
530                )
531                .await
532                {
533                    Ok(Ok(Some(Ok(Message::Text(text))))) => {
534                        if let Some(notification) =
535                            parse_routed_account_diff_notification(&text, subscriptions)
536                        {
537                            let _ = notification_tx.send(notification);
538                        }
539                    }
540                    _ => return Ok(()),
541                }
542            }
543        }
544
545        let msg = tokio::select! {
546            m = ws.next() => m,
547            _ = stop_rx.changed() => continue,
548        };
549
550        match msg {
551            Some(Ok(Message::Text(text))) => {
552                if let Some(notification) =
553                    parse_routed_account_diff_notification(&text, subscriptions)
554                {
555                    let _ = notification_tx.send(notification);
556                }
557            }
558            Some(Ok(_)) => {}
559            _ => {
560                return Err(subscription_runtime_closed(
561                    "account diff",
562                    format!("{} accounts", subscriptions.len()),
563                ));
564            }
565        }
566    }
567}
568
569fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
570    let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
571    (msg.method == "accountDiffNotification").then_some(msg)
572}
573
574fn parse_routed_account_diff_notification(
575    text: &str,
576    subscriptions: &std::collections::HashMap<u64, String>,
577) -> Option<RoutedAccountDiffNotification> {
578    let msg = parse_account_diff_message(text)?;
579    let account = subscriptions.get(&msg.params.subscription)?.clone();
580    Some(RoutedAccountDiffNotification {
581        account,
582        notification: msg.params.result,
583    })
584}
585
586fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
587where
588    I: IntoIterator<Item = S>,
589    S: Into<String>,
590{
591    let mut unique = std::collections::BTreeSet::new();
592    accounts
593        .into_iter()
594        .map(Into::into)
595        .filter(|account| unique.insert(account.clone()))
596        .collect()
597}
598
599// ── Program account diff subscription ────────────────────────────────────────
600
601/// Subscribe to account diff notifications for all accounts owned by a program.
602///
603/// Uses the server-side program filter (`{"address_type": "program"}`), so no
604/// RPC prefetch of program accounts is required.  The callback receives one
605/// [`AccountDiffNotification`] per changed account.
606///
607/// A websocket disconnect is treated as a fatal error — the handle's
608/// `join_handle` resolves with a [`SubscriptionRuntimeError`].
609///
610/// ## Example
611///
612/// ```no_run
613/// use simulator_client::subscribe_program_diffs;
614///
615/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
616/// let handle = subscribe_program_diffs(
617///     "http://localhost:8900/session/abc",
618///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
619///     |notification| async move {
620///         let account = notification.account.unwrap_or_default();
621///         println!("account={account} slot={}", notification.context.slot);
622///     },
623/// )
624/// .await?;
625///
626/// handle.stop.send(true).ok();
627/// handle.join_handle.await.ok();
628/// # Ok(())
629/// # }
630/// ```
631pub async fn subscribe_program_diffs<F, Fut>(
632    rpc_endpoint: &str,
633    program_id: &str,
634    on_notification: F,
635) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
636where
637    F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
638    Fut: Future<Output = ()> + Send + 'static,
639{
640    let ws_url = http_to_ws_url(rpc_endpoint)?;
641    let program_id = program_id.to_string();
642
643    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
644    let (stop_tx, mut stop_rx) = watch::channel(false);
645
646    let join_handle = tokio::spawn(async move {
647        let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
648        let callback_handle = tokio::spawn(async move {
649            while let Some(notification) = notification_rx.recv().await {
650                on_notification(notification).await;
651            }
652        });
653
654        let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
655            Ok(connection) => connection,
656            Err(e) => {
657                let _ = ready_tx.send(Err(SubscriptionError::Connect {
658                    url: ws_url,
659                    source: Box::new(e),
660                }));
661                return Ok(());
662            }
663        };
664
665        if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
666            let _ = ready_tx.send(Err(error));
667            return Ok(());
668        }
669
670        let _ = ready_tx.send(Ok(()));
671
672        if let Err(error) =
673            drive_program_diff_stream(&mut ws, &notification_tx, &mut stop_rx, &program_id).await
674        {
675            drop(notification_tx);
676            if let Err(source) = callback_handle.await {
677                return Err(callback_worker_failed(
678                    "program account diff",
679                    &program_id,
680                    source,
681                ));
682            }
683            return Err(error);
684        }
685
686        drop(notification_tx);
687        if let Err(source) = callback_handle.await {
688            return Err(callback_worker_failed(
689                "program account diff",
690                &program_id,
691                source,
692            ));
693        }
694
695        Ok(())
696    });
697
698    match ready_rx.await {
699        Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
700            join_handle,
701            stop: stop_tx,
702        }),
703        Ok(Err(e)) => {
704            join_handle.abort();
705            Err(e)
706        }
707        Err(_) => {
708            join_handle.abort();
709            Err(SubscriptionError::TaskDropped)
710        }
711    }
712}
713
714async fn send_program_diff_subscribe(
715    ws: &mut AccountDiffWs,
716    program_id: &str,
717) -> Result<(), SubscriptionError> {
718    #[derive(Deserialize)]
719    struct SubscriptionConfirmation {
720        result: Option<u64>,
721    }
722
723    let req = serde_json::json!({
724        "jsonrpc": "2.0",
725        "id": 1,
726        "method": "accountDiffSubscribe",
727        "params": [program_id, {"address_type": "program"}]
728    });
729    ws.send(Message::Text(req.to_string()))
730        .await
731        .map_err(|source| SubscriptionError::Subscribe {
732            source: Box::new(source),
733        })?;
734
735    loop {
736        match ws.next().await {
737            Some(Ok(Message::Text(text))) => {
738                match serde_json::from_str::<SubscriptionConfirmation>(&text) {
739                    Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
740                    Ok(_) => continue,
741                    Err(source) => {
742                        return Err(SubscriptionError::Subscribe {
743                            source: Box::new(source),
744                        });
745                    }
746                }
747            }
748            Some(Ok(_)) => continue,
749            _ => return Err(SubscriptionError::TaskDropped),
750        }
751    }
752}
753
754async fn drive_program_diff_stream(
755    ws: &mut AccountDiffWs,
756    notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
757    stop_rx: &mut watch::Receiver<bool>,
758    program_id: &str,
759) -> Result<(), SubscriptionRuntimeError> {
760    loop {
761        if *stop_rx.borrow() {
762            let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
763            loop {
764                match tokio::time::timeout_at(
765                    drain_deadline,
766                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
767                )
768                .await
769                {
770                    Ok(Ok(Some(Ok(Message::Text(text))))) => {
771                        if let Some(msg) = parse_account_diff_message(&text) {
772                            let _ = notification_tx.send(msg.params.result);
773                        }
774                    }
775                    _ => return Ok(()),
776                }
777            }
778        }
779
780        let msg = tokio::select! {
781            m = ws.next() => m,
782            _ = stop_rx.changed() => continue,
783        };
784
785        match msg {
786            Some(Ok(Message::Text(text))) => {
787                if let Some(msg) = parse_account_diff_message(&text) {
788                    let _ = notification_tx.send(msg.params.result);
789                }
790            }
791            Some(Ok(_)) => {}
792            _ => {
793                return Err(subscription_runtime_closed(
794                    "program account diff",
795                    program_id,
796                ));
797            }
798        }
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805
806    #[test]
807    fn parse_account_diff_notification_ignores_other_messages() {
808        let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
809        assert!(parse_account_diff_message(confirmation).is_none());
810    }
811
812    #[test]
813    fn parse_account_diff_notification_extracts_payload() {
814        let text = r#"{
815            "jsonrpc":"2.0",
816            "method":"accountDiffNotification",
817            "params":{
818                "subscription":7,
819                "result":{
820                    "context":{"slot":123},
821                    "signature":"sig",
822                    "pre":{"a":1},
823                    "post":{"a":2}
824                }
825            }
826        }"#;
827
828        let notification = parse_account_diff_message(text)
829            .expect("notification")
830            .params
831            .result;
832        assert_eq!(notification.context.slot, 123);
833        assert_eq!(notification.signature.as_deref(), Some("sig"));
834        assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
835        assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
836    }
837
838    #[test]
839    fn parse_routed_account_diff_notification_extracts_subscription_account() {
840        let text = r#"{
841            "jsonrpc":"2.0",
842            "method":"accountDiffNotification",
843            "params":{
844                "subscription":42,
845                "result":{
846                    "context":{"slot":456},
847                    "signature":"sig",
848                    "pre":null,
849                    "post":{"a":2}
850                }
851            }
852        }"#;
853        let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
854
855        let notification =
856            parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
857        assert_eq!(notification.account, "acct");
858        assert_eq!(notification.notification.context.slot, 456);
859    }
860
861    #[test]
862    fn dedup_accounts_preserves_first_seen_order() {
863        let accounts = dedup_accounts([
864            "b".to_string(),
865            "a".to_string(),
866            "b".to_string(),
867            "c".to_string(),
868        ]);
869        assert_eq!(accounts, vec!["b", "a", "c"]);
870    }
871}