Skip to main content

simulator_client/
subscriptions.rs

1use std::{future::Future, time::Duration};
2
3use futures::{SinkExt, StreamExt};
4use serde::Deserialize;
5use simulator_api::EncodedBinary;
6use solana_client::{
7    nonblocking::pubsub_client::PubsubClient,
8    rpc_response::{Response, RpcLogsResponse},
9};
10use solana_commitment_config::CommitmentConfig;
11use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
12use thiserror::Error;
13use tokio::{
14    sync::{oneshot, watch},
15    task::JoinHandle,
16};
17use tokio_tungstenite::tungstenite::Message;
18
19use crate::urls::{UrlError, http_to_ws_url};
20
21/// Error establishing a PubSub log subscription.
22#[derive(Debug, Error)]
23pub enum SubscriptionError {
24    #[error(transparent)]
25    InvalidUrl(#[from] UrlError),
26
27    #[error("pubsub connect to {url} failed: {source}")]
28    Connect {
29        url: String,
30        #[source]
31        source: Box<dyn std::error::Error + Send + Sync>,
32    },
33
34    #[error("logs_subscribe failed: {source}")]
35    Subscribe {
36        #[source]
37        source: Box<dyn std::error::Error + Send + Sync>,
38    },
39
40    #[error("subscription task exited unexpectedly before signaling ready")]
41    TaskDropped,
42
43    #[error("session has no rpc_endpoint (was the session created?)")]
44    NoRpcEndpoint,
45}
46
47#[derive(Debug, Error)]
48pub enum SubscriptionRuntimeError {
49    #[error("{kind} subscription for {target} closed unexpectedly")]
50    Closed { kind: &'static str, target: String },
51
52    #[error("{kind} subscription callback worker for {target} failed: {source}")]
53    CallbackWorker {
54        kind: &'static str,
55        target: String,
56        #[source]
57        source: tokio::task::JoinError,
58    },
59}
60
61const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
62const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
63
64type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
65type AccountDiffWs =
66    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
67
68/// A unified subscription handle that can represent any subscription type.
69pub struct SubscriptionHandle {
70    pub join_handle: SubscriptionTaskHandle,
71    pub stop: watch::Sender<bool>,
72}
73
74impl From<LogSubscriptionHandle> for SubscriptionHandle {
75    fn from(h: LogSubscriptionHandle) -> Self {
76        Self {
77            join_handle: h.join_handle,
78            stop: h.stop,
79        }
80    }
81}
82
83impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
84    fn from(h: AccountDiffSubscriptionHandle) -> Self {
85        Self {
86            join_handle: h.join_handle,
87            stop: h.stop,
88        }
89    }
90}
91
92/// Handle for a running log subscription background task.
93pub struct LogSubscriptionHandle {
94    /// Background task that drives the subscription and spawns per-notification callbacks.
95    ///
96    /// Resolves after `stop.send(true)` is called, remaining buffered
97    /// notifications are drained, and all spawned callback tasks complete.
98    pub join_handle: SubscriptionTaskHandle,
99
100    /// Send `true` to signal the background task to stop accepting new
101    /// notifications, drain remaining buffered ones, and exit cleanly.
102    pub stop: watch::Sender<bool>,
103}
104
105/// Subscribe to program log notifications and invoke a callback for each one.
106///
107/// Spawns a background task that:
108/// 1. Connects to the PubSub endpoint derived from `rpc_endpoint`.
109/// 2. Subscribes to logs mentioning `program_id`.
110/// 3. For each notification, spawns `on_notification(notification)` as a Tokio task.
111/// 4. When `handle.stop.send(true)` is called, drains remaining buffered
112///    notifications (up to 1s), waits for all spawned tasks, then returns.
113///
114/// Returns after the subscription is established. If setup fails, an error is
115/// returned before any background task is left running.
116///
117/// ## Example
118///
119/// ```no_run
120/// use std::sync::{Arc, Mutex};
121/// use simulator_client::subscribe_program_logs;
122/// use solana_commitment_config::CommitmentConfig;
123///
124/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
125/// let handle = subscribe_program_logs(
126///     "https://api.mainnet-beta.solana.com",
127///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
128///     CommitmentConfig::confirmed(),
129///     |notification| async move {
130///         println!("sig: {}", notification.value.signature);
131///     },
132/// )
133/// .await?;
134///
135/// // ... do other work ...
136///
137/// handle.stop.send(true).ok();
138/// handle.join_handle.await.ok();
139/// # Ok(())
140/// # }
141/// ```
142pub async fn subscribe_program_logs<F, Fut>(
143    rpc_endpoint: &str,
144    program_id: &str,
145    commitment: CommitmentConfig,
146    on_notification: F,
147) -> Result<LogSubscriptionHandle, SubscriptionError>
148where
149    F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
150    Fut: Future<Output = ()> + Send + 'static,
151{
152    let ws_url = http_to_ws_url(rpc_endpoint)?;
153    let program_id = program_id.to_string();
154
155    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
156    let (stop_tx, mut stop_rx) = watch::channel(false);
157
158    // PubsubClient::logs_subscribe borrows &self, so both the client and the
159    // stream must live inside the spawned task. We report setup success/failure
160    // back through a oneshot channel before entering the notification loop.
161    let join_handle = tokio::spawn(async move {
162        let client = match PubsubClient::new(&ws_url).await {
163            Ok(c) => c,
164            Err(e) => {
165                let _ = ready_tx.send(Err(SubscriptionError::Connect {
166                    url: ws_url,
167                    source: Box::new(e),
168                }));
169                return Ok(());
170            }
171        };
172
173        let (mut stream, _unsubscribe) = match client
174            .logs_subscribe(
175                RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
176                RpcTransactionLogsConfig {
177                    commitment: Some(commitment),
178                },
179            )
180            .await
181        {
182            Ok(s) => s,
183            Err(e) => {
184                let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
185                    source: Box::new(e),
186                }));
187                return Ok(());
188            }
189        };
190
191        let _ = ready_tx.send(Ok(()));
192
193        let mut tasks: Vec<JoinHandle<()>> = Vec::new();
194        let kind = "program logs";
195
196        loop {
197            if *stop_rx.borrow() {
198                let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
199                while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
200                    drain_deadline,
201                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
202                )
203                .await
204                {
205                    tasks.push(tokio::spawn(on_notification(notification)));
206                }
207                break;
208            }
209
210            let notification = tokio::select! {
211                n = stream.next() => n,
212                _ = stop_rx.changed() => continue,
213            };
214
215            match notification {
216                Some(n) => tasks.push(tokio::spawn(on_notification(n))),
217                None => return Err(subscription_runtime_closed(kind, &program_id)),
218            }
219        }
220
221        // Wait for all in-flight callback tasks to complete.
222        for task in tasks {
223            if let Err(source) = task.await {
224                return Err(callback_worker_failed(kind, &program_id, source));
225            }
226        }
227
228        Ok(())
229    });
230
231    match ready_rx.await {
232        Ok(Ok(())) => Ok(LogSubscriptionHandle {
233            join_handle,
234            stop: stop_tx,
235        }),
236        Ok(Err(e)) => {
237            join_handle.abort();
238            Err(e)
239        }
240        Err(_) => {
241            join_handle.abort();
242            Err(SubscriptionError::TaskDropped)
243        }
244    }
245}
246
247// ── Account diff subscription ────────────────────────────────────────────────
248
249/// Slot context included in every account diff notification.
250#[derive(Debug, Clone, Deserialize)]
251pub struct AccountDiffContext {
252    pub slot: u64,
253}
254
255/// A single account diff notification delivered by `accountDiffSubscribe`.
256#[derive(Debug, Clone, Deserialize)]
257pub struct AccountDiffNotification {
258    pub context: AccountDiffContext,
259    /// The address of the account that changed.
260    pub account: Option<String>,
261    /// Signature of the transaction that triggered this change, if known.
262    pub signature: Option<String>,
263    /// Position of the transaction within its slot, if known.
264    #[serde(default)]
265    pub tx_index: Option<u32>,
266    /// Unix-seconds block time of the slot, if known.
267    #[serde(default)]
268    pub block_time: Option<i64>,
269    /// Account state before the change (absent for newly created accounts).
270    pub pre: Option<serde_json::Value>,
271    /// Account state after the change (absent for deleted accounts).
272    pub post: Option<serde_json::Value>,
273}
274
275#[derive(Debug, Clone, Deserialize)]
276pub struct ActionResultContext {
277    pub slot: u64,
278}
279
280/// One scheduled-action result delivered by `actionSubscribe`.
281#[derive(Debug, Clone, Deserialize)]
282#[serde(rename_all = "camelCase")]
283pub struct ActionResultNotification {
284    pub context: ActionResultContext,
285    pub slot: u64,
286    /// Batch the action fired at; `None` for slot-boundary actions.
287    #[serde(default)]
288    pub batch_index: Option<u32>,
289    /// Index of the action in the session's `actions` list.
290    pub action_index: u32,
291    #[serde(default)]
292    pub label: Option<String>,
293    pub committed: bool,
294    #[serde(default)]
295    pub err: Option<String>,
296    #[serde(default)]
297    pub logs: Vec<String>,
298    pub units_consumed: u64,
299    #[serde(default)]
300    pub fee: Option<u64>,
301    /// Program return data (`{programId, data}`), if any.
302    #[serde(default)]
303    pub return_data: Option<serde_json::Value>,
304    /// Post-execution `UiAccount` JSON per `return_accounts` address, positional.
305    #[serde(default)]
306    pub accounts: Vec<Option<serde_json::Value>>,
307    /// Encoded transaction whose discovery-filter match triggered this action;
308    /// absent for slot-boundary actions. Decode with [`EncodedBinary::decode`]
309    /// (base64 bincode of `TxWithMeta`) to inspect the matching transaction.
310    #[serde(default)]
311    pub matched: Option<EncodedBinary>,
312}
313
314/// A routed account diff notification tied to the subscribed account that produced it.
315#[derive(Debug, Clone)]
316pub struct RoutedAccountDiffNotification {
317    pub account: String,
318    pub notification: AccountDiffNotification,
319}
320
321/// Handle for a running account diff subscription background task.
322///
323/// Send `true` on `stop` to request a clean shutdown, then await `join_handle`.
324pub struct AccountDiffSubscriptionHandle {
325    pub join_handle: SubscriptionTaskHandle,
326    pub stop: watch::Sender<bool>,
327}
328
329fn subscription_runtime_closed(
330    kind: &'static str,
331    target: impl Into<String>,
332) -> SubscriptionRuntimeError {
333    SubscriptionRuntimeError::Closed {
334        kind,
335        target: target.into(),
336    }
337}
338
339fn callback_worker_failed(
340    kind: &'static str,
341    target: impl Into<String>,
342    source: tokio::task::JoinError,
343) -> SubscriptionRuntimeError {
344    SubscriptionRuntimeError::CallbackWorker {
345        kind,
346        target: target.into(),
347        source,
348    }
349}
350
351/// Subscribe to account diff notifications and invoke a callback for each one.
352///
353/// Spawns a background task that:
354/// 1. Connects to the WebSocket endpoint derived from `rpc_endpoint`.
355/// 2. Subscribes to account diffs for the given filter (account or program).
356/// 3. For each notification, spawns `on_notification(notification)` as a Tokio task.
357/// 4. When `handle.stop.send(true)` is called, drains remaining buffered
358///    notifications (up to 1s), waits for all spawned tasks, then returns.
359///
360/// ## Example
361///
362/// ```no_run
363/// use simulator_client::subscribe_account_diffs;
364///
365/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
366/// let handle = subscribe_account_diffs(
367///     "http://localhost:8900/session/abc",
368///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
369///     |notification| async move {
370///         println!("slot={} sig={:?}", notification.context.slot, notification.signature);
371///     },
372/// )
373/// .await?;
374///
375/// handle.stop.send(true).ok();
376/// handle.join_handle.await.ok();
377/// # Ok(())
378/// # }
379/// ```
380pub async fn subscribe_account_diffs<F, Fut>(
381    rpc_endpoint: &str,
382    account: &str,
383    on_notification: F,
384) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
385where
386    F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
387    Fut: Future<Output = ()> + Send + 'static,
388{
389    subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
390        on_notification(notification.notification)
391    })
392    .await
393}
394
395/// Subscribe to account diff notifications for many accounts over a single websocket.
396///
397/// All requested subscriptions must be acknowledged before this returns. Once the
398/// stream is live, any websocket disconnect is treated as a fatal completeness
399/// error instead of silently reconnecting and risking dropped notifications.
400pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
401    rpc_endpoint: &str,
402    accounts: I,
403    on_notification: F,
404) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
405where
406    F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
407    Fut: Future<Output = ()> + Send + 'static,
408    I: IntoIterator<Item = S>,
409    S: Into<String>,
410{
411    let ws_url = http_to_ws_url(rpc_endpoint)?;
412    let accounts = dedup_accounts(accounts);
413    if accounts.is_empty() {
414        let (stop_tx, stop_rx) = watch::channel(false);
415        return Ok(AccountDiffSubscriptionHandle {
416            join_handle: tokio::spawn(async move {
417                let _ = stop_rx;
418                Ok(())
419            }),
420            stop: stop_tx,
421        });
422    }
423
424    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
425    let (stop_tx, mut stop_rx) = watch::channel(false);
426    let target = format!("{} accounts", accounts.len());
427
428    let join_handle = tokio::spawn(async move {
429        let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
430        let callback_handle = tokio::spawn(async move {
431            while let Some(notification) = notification_rx.recv().await {
432                on_notification(notification).await;
433            }
434        });
435
436        let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
437            Ok(connection) => connection,
438            Err(e) => {
439                let _ = ready_tx.send(Err(SubscriptionError::Connect {
440                    url: ws_url,
441                    source: Box::new(e),
442                }));
443                return Ok(());
444            }
445        };
446
447        let subscriptions =
448            match send_account_diff_subscribe_many(&mut ws, &accounts, &notification_tx).await {
449                Ok(subscriptions) => subscriptions,
450                Err(error) => {
451                    let _ = ready_tx.send(Err(error));
452                    return Ok(());
453                }
454            };
455
456        let _ = ready_tx.send(Ok(()));
457
458        if let Err(error) =
459            drive_account_diff_stream_many(&mut ws, &subscriptions, &notification_tx, &mut stop_rx)
460                .await
461        {
462            drop(notification_tx);
463            if let Err(source) = callback_handle.await {
464                return Err(callback_worker_failed("account diff", target, source));
465            }
466            return Err(error);
467        }
468
469        drop(notification_tx);
470        if let Err(source) = callback_handle.await {
471            return Err(callback_worker_failed("account diff", target, source));
472        }
473
474        Ok(())
475    });
476
477    match ready_rx.await {
478        Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
479            join_handle,
480            stop: stop_tx,
481        }),
482        Ok(Err(e)) => {
483            join_handle.abort();
484            Err(e)
485        }
486        Err(_) => {
487            join_handle.abort();
488            Err(SubscriptionError::TaskDropped)
489        }
490    }
491}
492
493#[derive(Deserialize)]
494struct AccountDiffMessage {
495    method: String,
496    params: AccountDiffParams,
497}
498
499#[derive(Deserialize)]
500struct AccountDiffParams {
501    subscription: u64,
502    result: AccountDiffNotification,
503}
504
505async fn send_account_diff_subscribe_many(
506    ws: &mut AccountDiffWs,
507    accounts: &[String],
508    notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
509) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
510    #[derive(Deserialize)]
511    struct SubscriptionConfirmation {
512        id: u64,
513        result: Option<u64>,
514    }
515
516    let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
517    let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
518
519    for (index, account) in accounts.iter().enumerate() {
520        let request_id = (index + 1) as u64;
521        let req = serde_json::json!({
522            "jsonrpc": "2.0",
523            "id": request_id,
524            "method": "accountDiffSubscribe",
525            "params": [account]
526        });
527        ws.send(Message::Text(req.to_string()))
528            .await
529            .map_err(|source| SubscriptionError::Subscribe {
530                source: Box::new(source),
531            })?;
532        pending.insert(request_id, account.clone());
533    }
534
535    while !pending.is_empty() {
536        match ws.next().await {
537            Some(Ok(Message::Text(text))) => {
538                if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
539                    let Some(account) = pending.remove(&confirmation.id) else {
540                        continue;
541                    };
542                    let Some(subscription_id) = confirmation.result else {
543                        return Err(SubscriptionError::TaskDropped);
544                    };
545                    subscriptions.insert(subscription_id, account);
546                    continue;
547                }
548
549                if let Some(notification) =
550                    parse_routed_account_diff_notification(&text, &subscriptions)
551                {
552                    let _ = notification_tx.send(notification);
553                }
554            }
555            Some(Ok(_)) => {}
556            _ => return Err(SubscriptionError::TaskDropped),
557        }
558    }
559
560    Ok(subscriptions)
561}
562
563async fn drive_account_diff_stream_many(
564    ws: &mut AccountDiffWs,
565    subscriptions: &std::collections::HashMap<u64, String>,
566    notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
567    stop_rx: &mut watch::Receiver<bool>,
568) -> Result<(), SubscriptionRuntimeError> {
569    loop {
570        if *stop_rx.borrow() {
571            let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
572            loop {
573                match tokio::time::timeout_at(
574                    drain_deadline,
575                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
576                )
577                .await
578                {
579                    Ok(Ok(Some(Ok(Message::Text(text))))) => {
580                        if let Some(notification) =
581                            parse_routed_account_diff_notification(&text, subscriptions)
582                        {
583                            let _ = notification_tx.send(notification);
584                        }
585                    }
586                    _ => return Ok(()),
587                }
588            }
589        }
590
591        let msg = tokio::select! {
592            m = ws.next() => m,
593            _ = stop_rx.changed() => continue,
594        };
595
596        match msg {
597            Some(Ok(Message::Text(text))) => {
598                if let Some(notification) =
599                    parse_routed_account_diff_notification(&text, subscriptions)
600                {
601                    let _ = notification_tx.send(notification);
602                }
603            }
604            Some(Ok(_)) => {}
605            _ => {
606                return Err(subscription_runtime_closed(
607                    "account diff",
608                    format!("{} accounts", subscriptions.len()),
609                ));
610            }
611        }
612    }
613}
614
615fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
616    let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
617    (msg.method == "accountDiffNotification").then_some(msg)
618}
619
620fn parse_routed_account_diff_notification(
621    text: &str,
622    subscriptions: &std::collections::HashMap<u64, String>,
623) -> Option<RoutedAccountDiffNotification> {
624    let msg = parse_account_diff_message(text)?;
625    let account = subscriptions.get(&msg.params.subscription)?.clone();
626    Some(RoutedAccountDiffNotification {
627        account,
628        notification: msg.params.result,
629    })
630}
631
632fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
633where
634    I: IntoIterator<Item = S>,
635    S: Into<String>,
636{
637    let mut unique = std::collections::BTreeSet::new();
638    accounts
639        .into_iter()
640        .map(Into::into)
641        .filter(|account| unique.insert(account.clone()))
642        .collect()
643}
644
645// ── Program account diff subscription ────────────────────────────────────────
646
647/// Subscribe to account diff notifications for all accounts owned by a program.
648///
649/// Uses the server-side program filter (`{"address_type": "program"}`), so no
650/// RPC prefetch of program accounts is required.  The callback receives one
651/// [`AccountDiffNotification`] per changed account.
652///
653/// A websocket disconnect is treated as a fatal error — the handle's
654/// `join_handle` resolves with a [`SubscriptionRuntimeError`].
655///
656/// ## Example
657///
658/// ```no_run
659/// use simulator_client::subscribe_program_diffs;
660///
661/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
662/// let handle = subscribe_program_diffs(
663///     "http://localhost:8900/session/abc",
664///     "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA",
665///     |notification| async move {
666///         let account = notification.account.unwrap_or_default();
667///         println!("account={account} slot={}", notification.context.slot);
668///     },
669/// )
670/// .await?;
671///
672/// handle.stop.send(true).ok();
673/// handle.join_handle.await.ok();
674/// # Ok(())
675/// # }
676/// ```
677pub async fn subscribe_program_diffs<F, Fut>(
678    rpc_endpoint: &str,
679    program_id: &str,
680    on_notification: F,
681) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
682where
683    F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
684    Fut: Future<Output = ()> + Send + 'static,
685{
686    let ws_url = http_to_ws_url(rpc_endpoint)?;
687    let program_id = program_id.to_string();
688
689    let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
690    let (stop_tx, mut stop_rx) = watch::channel(false);
691
692    let join_handle = tokio::spawn(async move {
693        let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
694        let callback_handle = tokio::spawn(async move {
695            while let Some(notification) = notification_rx.recv().await {
696                on_notification(notification).await;
697            }
698        });
699
700        let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
701            Ok(connection) => connection,
702            Err(e) => {
703                let _ = ready_tx.send(Err(SubscriptionError::Connect {
704                    url: ws_url,
705                    source: Box::new(e),
706                }));
707                return Ok(());
708            }
709        };
710
711        if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
712            let _ = ready_tx.send(Err(error));
713            return Ok(());
714        }
715
716        let _ = ready_tx.send(Ok(()));
717
718        if let Err(error) =
719            drive_program_diff_stream(&mut ws, &notification_tx, &mut stop_rx, &program_id).await
720        {
721            drop(notification_tx);
722            if let Err(source) = callback_handle.await {
723                return Err(callback_worker_failed(
724                    "program account diff",
725                    &program_id,
726                    source,
727                ));
728            }
729            return Err(error);
730        }
731
732        drop(notification_tx);
733        if let Err(source) = callback_handle.await {
734            return Err(callback_worker_failed(
735                "program account diff",
736                &program_id,
737                source,
738            ));
739        }
740
741        Ok(())
742    });
743
744    match ready_rx.await {
745        Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
746            join_handle,
747            stop: stop_tx,
748        }),
749        Ok(Err(e)) => {
750            join_handle.abort();
751            Err(e)
752        }
753        Err(_) => {
754            join_handle.abort();
755            Err(SubscriptionError::TaskDropped)
756        }
757    }
758}
759
760async fn send_program_diff_subscribe(
761    ws: &mut AccountDiffWs,
762    program_id: &str,
763) -> Result<(), SubscriptionError> {
764    #[derive(Deserialize)]
765    struct SubscriptionConfirmation {
766        result: Option<u64>,
767    }
768
769    let req = serde_json::json!({
770        "jsonrpc": "2.0",
771        "id": 1,
772        "method": "accountDiffSubscribe",
773        "params": [program_id, {"address_type": "program"}]
774    });
775    ws.send(Message::Text(req.to_string()))
776        .await
777        .map_err(|source| SubscriptionError::Subscribe {
778            source: Box::new(source),
779        })?;
780
781    loop {
782        match ws.next().await {
783            Some(Ok(Message::Text(text))) => {
784                match serde_json::from_str::<SubscriptionConfirmation>(&text) {
785                    Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
786                    Ok(_) => continue,
787                    Err(source) => {
788                        return Err(SubscriptionError::Subscribe {
789                            source: Box::new(source),
790                        });
791                    }
792                }
793            }
794            Some(Ok(_)) => continue,
795            _ => return Err(SubscriptionError::TaskDropped),
796        }
797    }
798}
799
800async fn drive_program_diff_stream(
801    ws: &mut AccountDiffWs,
802    notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
803    stop_rx: &mut watch::Receiver<bool>,
804    program_id: &str,
805) -> Result<(), SubscriptionRuntimeError> {
806    loop {
807        if *stop_rx.borrow() {
808            let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
809            loop {
810                match tokio::time::timeout_at(
811                    drain_deadline,
812                    tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
813                )
814                .await
815                {
816                    Ok(Ok(Some(Ok(Message::Text(text))))) => {
817                        if let Some(msg) = parse_account_diff_message(&text) {
818                            let _ = notification_tx.send(msg.params.result);
819                        }
820                    }
821                    _ => return Ok(()),
822                }
823            }
824        }
825
826        let msg = tokio::select! {
827            m = ws.next() => m,
828            _ = stop_rx.changed() => continue,
829        };
830
831        match msg {
832            Some(Ok(Message::Text(text))) => {
833                if let Some(msg) = parse_account_diff_message(&text) {
834                    let _ = notification_tx.send(msg.params.result);
835                }
836            }
837            Some(Ok(_)) => {}
838            _ => {
839                return Err(subscription_runtime_closed(
840                    "program account diff",
841                    program_id,
842                ));
843            }
844        }
845    }
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851
852    #[test]
853    fn parse_account_diff_notification_ignores_other_messages() {
854        let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
855        assert!(parse_account_diff_message(confirmation).is_none());
856    }
857
858    #[test]
859    fn parse_account_diff_notification_extracts_payload() {
860        let text = r#"{
861            "jsonrpc":"2.0",
862            "method":"accountDiffNotification",
863            "params":{
864                "subscription":7,
865                "result":{
866                    "context":{"slot":123},
867                    "signature":"sig",
868                    "pre":{"a":1},
869                    "post":{"a":2}
870                }
871            }
872        }"#;
873
874        let notification = parse_account_diff_message(text)
875            .expect("notification")
876            .params
877            .result;
878        assert_eq!(notification.context.slot, 123);
879        assert_eq!(notification.signature.as_deref(), Some("sig"));
880        assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
881        assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
882    }
883
884    #[test]
885    fn parse_routed_account_diff_notification_extracts_subscription_account() {
886        let text = r#"{
887            "jsonrpc":"2.0",
888            "method":"accountDiffNotification",
889            "params":{
890                "subscription":42,
891                "result":{
892                    "context":{"slot":456},
893                    "signature":"sig",
894                    "pre":null,
895                    "post":{"a":2}
896                }
897            }
898        }"#;
899        let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
900
901        let notification =
902            parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
903        assert_eq!(notification.account, "acct");
904        assert_eq!(notification.notification.context.slot, 456);
905    }
906
907    #[test]
908    fn dedup_accounts_preserves_first_seen_order() {
909        let accounts = dedup_accounts([
910            "b".to_string(),
911            "a".to_string(),
912            "b".to_string(),
913            "c".to_string(),
914        ]);
915        assert_eq!(accounts, vec!["b", "a", "c"]);
916    }
917}