1use std::{future::Future, time::Duration};
2
3use futures::{SinkExt, StreamExt};
4use serde::Deserialize;
5use simulator_api::EncodedBinary;
6use solana_client::{
7 nonblocking::pubsub_client::PubsubClient,
8 rpc_response::{Response, RpcLogsResponse},
9};
10use solana_commitment_config::CommitmentConfig;
11use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
12use thiserror::Error;
13use tokio::{
14 sync::{oneshot, watch},
15 task::JoinHandle,
16};
17use tokio_tungstenite::tungstenite::Message;
18
19use crate::urls::{UrlError, http_to_ws_url};
20
21#[derive(Debug, Error)]
23pub enum SubscriptionError {
24 #[error(transparent)]
25 InvalidUrl(#[from] UrlError),
26
27 #[error("pubsub connect to {url} failed: {source}")]
28 Connect {
29 url: String,
30 #[source]
31 source: Box<dyn std::error::Error + Send + Sync>,
32 },
33
34 #[error("logs_subscribe failed: {source}")]
35 Subscribe {
36 #[source]
37 source: Box<dyn std::error::Error + Send + Sync>,
38 },
39
40 #[error("subscription task exited unexpectedly before signaling ready")]
41 TaskDropped,
42
43 #[error("session has no rpc_endpoint (was the session created?)")]
44 NoRpcEndpoint,
45}
46
47#[derive(Debug, Error)]
48pub enum SubscriptionRuntimeError {
49 #[error("{kind} subscription for {target} closed unexpectedly")]
50 Closed { kind: &'static str, target: String },
51
52 #[error("{kind} subscription callback worker for {target} failed: {source}")]
53 CallbackWorker {
54 kind: &'static str,
55 target: String,
56 #[source]
57 source: tokio::task::JoinError,
58 },
59}
60
61const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
62const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
63
64type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
65type AccountDiffWs =
66 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
67
68pub struct SubscriptionHandle {
70 pub join_handle: SubscriptionTaskHandle,
71 pub stop: watch::Sender<bool>,
72}
73
74impl From<LogSubscriptionHandle> for SubscriptionHandle {
75 fn from(h: LogSubscriptionHandle) -> Self {
76 Self {
77 join_handle: h.join_handle,
78 stop: h.stop,
79 }
80 }
81}
82
83impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
84 fn from(h: AccountDiffSubscriptionHandle) -> Self {
85 Self {
86 join_handle: h.join_handle,
87 stop: h.stop,
88 }
89 }
90}
91
92pub struct LogSubscriptionHandle {
94 pub join_handle: SubscriptionTaskHandle,
99
100 pub stop: watch::Sender<bool>,
103}
104
105pub async fn subscribe_program_logs<F, Fut>(
143 rpc_endpoint: &str,
144 program_id: &str,
145 commitment: CommitmentConfig,
146 on_notification: F,
147) -> Result<LogSubscriptionHandle, SubscriptionError>
148where
149 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
150 Fut: Future<Output = ()> + Send + 'static,
151{
152 let ws_url = http_to_ws_url(rpc_endpoint)?;
153 let program_id = program_id.to_string();
154
155 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
156 let (stop_tx, mut stop_rx) = watch::channel(false);
157
158 let join_handle = tokio::spawn(async move {
162 let client = match PubsubClient::new(&ws_url).await {
163 Ok(c) => c,
164 Err(e) => {
165 let _ = ready_tx.send(Err(SubscriptionError::Connect {
166 url: ws_url,
167 source: Box::new(e),
168 }));
169 return Ok(());
170 }
171 };
172
173 let (mut stream, _unsubscribe) = match client
174 .logs_subscribe(
175 RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
176 RpcTransactionLogsConfig {
177 commitment: Some(commitment),
178 },
179 )
180 .await
181 {
182 Ok(s) => s,
183 Err(e) => {
184 let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
185 source: Box::new(e),
186 }));
187 return Ok(());
188 }
189 };
190
191 let _ = ready_tx.send(Ok(()));
192
193 let mut tasks: Vec<JoinHandle<()>> = Vec::new();
194 let kind = "program logs";
195
196 loop {
197 if *stop_rx.borrow() {
198 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
199 while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
200 drain_deadline,
201 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
202 )
203 .await
204 {
205 tasks.push(tokio::spawn(on_notification(notification)));
206 }
207 break;
208 }
209
210 let notification = tokio::select! {
211 n = stream.next() => n,
212 _ = stop_rx.changed() => continue,
213 };
214
215 match notification {
216 Some(n) => tasks.push(tokio::spawn(on_notification(n))),
217 None => return Err(subscription_runtime_closed(kind, &program_id)),
218 }
219 }
220
221 for task in tasks {
223 if let Err(source) = task.await {
224 return Err(callback_worker_failed(kind, &program_id, source));
225 }
226 }
227
228 Ok(())
229 });
230
231 match ready_rx.await {
232 Ok(Ok(())) => Ok(LogSubscriptionHandle {
233 join_handle,
234 stop: stop_tx,
235 }),
236 Ok(Err(e)) => {
237 join_handle.abort();
238 Err(e)
239 }
240 Err(_) => {
241 join_handle.abort();
242 Err(SubscriptionError::TaskDropped)
243 }
244 }
245}
246
247#[derive(Debug, Clone, Deserialize)]
251pub struct AccountDiffContext {
252 pub slot: u64,
253}
254
255#[derive(Debug, Clone, Deserialize)]
257pub struct AccountDiffNotification {
258 pub context: AccountDiffContext,
259 pub account: Option<String>,
261 pub signature: Option<String>,
263 #[serde(default)]
265 pub tx_index: Option<u32>,
266 #[serde(default)]
268 pub block_time: Option<i64>,
269 pub pre: Option<serde_json::Value>,
271 pub post: Option<serde_json::Value>,
273}
274
275#[derive(Debug, Clone, Deserialize)]
276pub struct ActionResultContext {
277 pub slot: u64,
278}
279
280#[derive(Debug, Clone, Deserialize)]
282#[serde(rename_all = "camelCase")]
283pub struct ActionResultNotification {
284 pub context: ActionResultContext,
285 pub slot: u64,
286 #[serde(default)]
288 pub batch_index: Option<u32>,
289 pub action_index: u32,
291 #[serde(default)]
292 pub label: Option<String>,
293 pub committed: bool,
294 #[serde(default)]
295 pub err: Option<String>,
296 #[serde(default)]
297 pub logs: Vec<String>,
298 pub units_consumed: u64,
299 #[serde(default)]
300 pub fee: Option<u64>,
301 #[serde(default)]
303 pub return_data: Option<serde_json::Value>,
304 #[serde(default)]
306 pub accounts: Vec<Option<serde_json::Value>>,
307 #[serde(default)]
311 pub matched: Option<EncodedBinary>,
312}
313
314#[derive(Debug, Clone)]
316pub struct RoutedAccountDiffNotification {
317 pub account: String,
318 pub notification: AccountDiffNotification,
319}
320
321pub struct AccountDiffSubscriptionHandle {
325 pub join_handle: SubscriptionTaskHandle,
326 pub stop: watch::Sender<bool>,
327}
328
329fn subscription_runtime_closed(
330 kind: &'static str,
331 target: impl Into<String>,
332) -> SubscriptionRuntimeError {
333 SubscriptionRuntimeError::Closed {
334 kind,
335 target: target.into(),
336 }
337}
338
339fn callback_worker_failed(
340 kind: &'static str,
341 target: impl Into<String>,
342 source: tokio::task::JoinError,
343) -> SubscriptionRuntimeError {
344 SubscriptionRuntimeError::CallbackWorker {
345 kind,
346 target: target.into(),
347 source,
348 }
349}
350
351pub async fn subscribe_account_diffs<F, Fut>(
381 rpc_endpoint: &str,
382 account: &str,
383 on_notification: F,
384) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
385where
386 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
387 Fut: Future<Output = ()> + Send + 'static,
388{
389 subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
390 on_notification(notification.notification)
391 })
392 .await
393}
394
395pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
401 rpc_endpoint: &str,
402 accounts: I,
403 on_notification: F,
404) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
405where
406 F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
407 Fut: Future<Output = ()> + Send + 'static,
408 I: IntoIterator<Item = S>,
409 S: Into<String>,
410{
411 let ws_url = http_to_ws_url(rpc_endpoint)?;
412 let accounts = dedup_accounts(accounts);
413 if accounts.is_empty() {
414 let (stop_tx, stop_rx) = watch::channel(false);
415 return Ok(AccountDiffSubscriptionHandle {
416 join_handle: tokio::spawn(async move {
417 let _ = stop_rx;
418 Ok(())
419 }),
420 stop: stop_tx,
421 });
422 }
423
424 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
425 let (stop_tx, mut stop_rx) = watch::channel(false);
426 let target = format!("{} accounts", accounts.len());
427
428 let join_handle = tokio::spawn(async move {
429 let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
430 let callback_handle = tokio::spawn(async move {
431 while let Some(notification) = notification_rx.recv().await {
432 on_notification(notification).await;
433 }
434 });
435
436 let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
437 Ok(connection) => connection,
438 Err(e) => {
439 let _ = ready_tx.send(Err(SubscriptionError::Connect {
440 url: ws_url,
441 source: Box::new(e),
442 }));
443 return Ok(());
444 }
445 };
446
447 let subscriptions =
448 match send_account_diff_subscribe_many(&mut ws, &accounts, ¬ification_tx).await {
449 Ok(subscriptions) => subscriptions,
450 Err(error) => {
451 let _ = ready_tx.send(Err(error));
452 return Ok(());
453 }
454 };
455
456 let _ = ready_tx.send(Ok(()));
457
458 if let Err(error) =
459 drive_account_diff_stream_many(&mut ws, &subscriptions, ¬ification_tx, &mut stop_rx)
460 .await
461 {
462 drop(notification_tx);
463 if let Err(source) = callback_handle.await {
464 return Err(callback_worker_failed("account diff", target, source));
465 }
466 return Err(error);
467 }
468
469 drop(notification_tx);
470 if let Err(source) = callback_handle.await {
471 return Err(callback_worker_failed("account diff", target, source));
472 }
473
474 Ok(())
475 });
476
477 match ready_rx.await {
478 Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
479 join_handle,
480 stop: stop_tx,
481 }),
482 Ok(Err(e)) => {
483 join_handle.abort();
484 Err(e)
485 }
486 Err(_) => {
487 join_handle.abort();
488 Err(SubscriptionError::TaskDropped)
489 }
490 }
491}
492
493#[derive(Deserialize)]
494struct AccountDiffMessage {
495 method: String,
496 params: AccountDiffParams,
497}
498
499#[derive(Deserialize)]
500struct AccountDiffParams {
501 subscription: u64,
502 result: AccountDiffNotification,
503}
504
505async fn send_account_diff_subscribe_many(
506 ws: &mut AccountDiffWs,
507 accounts: &[String],
508 notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
509) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
510 #[derive(Deserialize)]
511 struct SubscriptionConfirmation {
512 id: u64,
513 result: Option<u64>,
514 }
515
516 let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
517 let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
518
519 for (index, account) in accounts.iter().enumerate() {
520 let request_id = (index + 1) as u64;
521 let req = serde_json::json!({
522 "jsonrpc": "2.0",
523 "id": request_id,
524 "method": "accountDiffSubscribe",
525 "params": [account]
526 });
527 ws.send(Message::Text(req.to_string()))
528 .await
529 .map_err(|source| SubscriptionError::Subscribe {
530 source: Box::new(source),
531 })?;
532 pending.insert(request_id, account.clone());
533 }
534
535 while !pending.is_empty() {
536 match ws.next().await {
537 Some(Ok(Message::Text(text))) => {
538 if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
539 let Some(account) = pending.remove(&confirmation.id) else {
540 continue;
541 };
542 let Some(subscription_id) = confirmation.result else {
543 return Err(SubscriptionError::TaskDropped);
544 };
545 subscriptions.insert(subscription_id, account);
546 continue;
547 }
548
549 if let Some(notification) =
550 parse_routed_account_diff_notification(&text, &subscriptions)
551 {
552 let _ = notification_tx.send(notification);
553 }
554 }
555 Some(Ok(_)) => {}
556 _ => return Err(SubscriptionError::TaskDropped),
557 }
558 }
559
560 Ok(subscriptions)
561}
562
563async fn drive_account_diff_stream_many(
564 ws: &mut AccountDiffWs,
565 subscriptions: &std::collections::HashMap<u64, String>,
566 notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
567 stop_rx: &mut watch::Receiver<bool>,
568) -> Result<(), SubscriptionRuntimeError> {
569 loop {
570 if *stop_rx.borrow() {
571 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
572 loop {
573 match tokio::time::timeout_at(
574 drain_deadline,
575 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
576 )
577 .await
578 {
579 Ok(Ok(Some(Ok(Message::Text(text))))) => {
580 if let Some(notification) =
581 parse_routed_account_diff_notification(&text, subscriptions)
582 {
583 let _ = notification_tx.send(notification);
584 }
585 }
586 _ => return Ok(()),
587 }
588 }
589 }
590
591 let msg = tokio::select! {
592 m = ws.next() => m,
593 _ = stop_rx.changed() => continue,
594 };
595
596 match msg {
597 Some(Ok(Message::Text(text))) => {
598 if let Some(notification) =
599 parse_routed_account_diff_notification(&text, subscriptions)
600 {
601 let _ = notification_tx.send(notification);
602 }
603 }
604 Some(Ok(_)) => {}
605 _ => {
606 return Err(subscription_runtime_closed(
607 "account diff",
608 format!("{} accounts", subscriptions.len()),
609 ));
610 }
611 }
612 }
613}
614
615fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
616 let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
617 (msg.method == "accountDiffNotification").then_some(msg)
618}
619
620fn parse_routed_account_diff_notification(
621 text: &str,
622 subscriptions: &std::collections::HashMap<u64, String>,
623) -> Option<RoutedAccountDiffNotification> {
624 let msg = parse_account_diff_message(text)?;
625 let account = subscriptions.get(&msg.params.subscription)?.clone();
626 Some(RoutedAccountDiffNotification {
627 account,
628 notification: msg.params.result,
629 })
630}
631
632fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
633where
634 I: IntoIterator<Item = S>,
635 S: Into<String>,
636{
637 let mut unique = std::collections::BTreeSet::new();
638 accounts
639 .into_iter()
640 .map(Into::into)
641 .filter(|account| unique.insert(account.clone()))
642 .collect()
643}
644
645pub async fn subscribe_program_diffs<F, Fut>(
678 rpc_endpoint: &str,
679 program_id: &str,
680 on_notification: F,
681) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
682where
683 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
684 Fut: Future<Output = ()> + Send + 'static,
685{
686 let ws_url = http_to_ws_url(rpc_endpoint)?;
687 let program_id = program_id.to_string();
688
689 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
690 let (stop_tx, mut stop_rx) = watch::channel(false);
691
692 let join_handle = tokio::spawn(async move {
693 let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
694 let callback_handle = tokio::spawn(async move {
695 while let Some(notification) = notification_rx.recv().await {
696 on_notification(notification).await;
697 }
698 });
699
700 let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
701 Ok(connection) => connection,
702 Err(e) => {
703 let _ = ready_tx.send(Err(SubscriptionError::Connect {
704 url: ws_url,
705 source: Box::new(e),
706 }));
707 return Ok(());
708 }
709 };
710
711 if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
712 let _ = ready_tx.send(Err(error));
713 return Ok(());
714 }
715
716 let _ = ready_tx.send(Ok(()));
717
718 if let Err(error) =
719 drive_program_diff_stream(&mut ws, ¬ification_tx, &mut stop_rx, &program_id).await
720 {
721 drop(notification_tx);
722 if let Err(source) = callback_handle.await {
723 return Err(callback_worker_failed(
724 "program account diff",
725 &program_id,
726 source,
727 ));
728 }
729 return Err(error);
730 }
731
732 drop(notification_tx);
733 if let Err(source) = callback_handle.await {
734 return Err(callback_worker_failed(
735 "program account diff",
736 &program_id,
737 source,
738 ));
739 }
740
741 Ok(())
742 });
743
744 match ready_rx.await {
745 Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
746 join_handle,
747 stop: stop_tx,
748 }),
749 Ok(Err(e)) => {
750 join_handle.abort();
751 Err(e)
752 }
753 Err(_) => {
754 join_handle.abort();
755 Err(SubscriptionError::TaskDropped)
756 }
757 }
758}
759
760async fn send_program_diff_subscribe(
761 ws: &mut AccountDiffWs,
762 program_id: &str,
763) -> Result<(), SubscriptionError> {
764 #[derive(Deserialize)]
765 struct SubscriptionConfirmation {
766 result: Option<u64>,
767 }
768
769 let req = serde_json::json!({
770 "jsonrpc": "2.0",
771 "id": 1,
772 "method": "accountDiffSubscribe",
773 "params": [program_id, {"address_type": "program"}]
774 });
775 ws.send(Message::Text(req.to_string()))
776 .await
777 .map_err(|source| SubscriptionError::Subscribe {
778 source: Box::new(source),
779 })?;
780
781 loop {
782 match ws.next().await {
783 Some(Ok(Message::Text(text))) => {
784 match serde_json::from_str::<SubscriptionConfirmation>(&text) {
785 Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
786 Ok(_) => continue,
787 Err(source) => {
788 return Err(SubscriptionError::Subscribe {
789 source: Box::new(source),
790 });
791 }
792 }
793 }
794 Some(Ok(_)) => continue,
795 _ => return Err(SubscriptionError::TaskDropped),
796 }
797 }
798}
799
800async fn drive_program_diff_stream(
801 ws: &mut AccountDiffWs,
802 notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
803 stop_rx: &mut watch::Receiver<bool>,
804 program_id: &str,
805) -> Result<(), SubscriptionRuntimeError> {
806 loop {
807 if *stop_rx.borrow() {
808 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
809 loop {
810 match tokio::time::timeout_at(
811 drain_deadline,
812 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
813 )
814 .await
815 {
816 Ok(Ok(Some(Ok(Message::Text(text))))) => {
817 if let Some(msg) = parse_account_diff_message(&text) {
818 let _ = notification_tx.send(msg.params.result);
819 }
820 }
821 _ => return Ok(()),
822 }
823 }
824 }
825
826 let msg = tokio::select! {
827 m = ws.next() => m,
828 _ = stop_rx.changed() => continue,
829 };
830
831 match msg {
832 Some(Ok(Message::Text(text))) => {
833 if let Some(msg) = parse_account_diff_message(&text) {
834 let _ = notification_tx.send(msg.params.result);
835 }
836 }
837 Some(Ok(_)) => {}
838 _ => {
839 return Err(subscription_runtime_closed(
840 "program account diff",
841 program_id,
842 ));
843 }
844 }
845 }
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851
852 #[test]
853 fn parse_account_diff_notification_ignores_other_messages() {
854 let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
855 assert!(parse_account_diff_message(confirmation).is_none());
856 }
857
858 #[test]
859 fn parse_account_diff_notification_extracts_payload() {
860 let text = r#"{
861 "jsonrpc":"2.0",
862 "method":"accountDiffNotification",
863 "params":{
864 "subscription":7,
865 "result":{
866 "context":{"slot":123},
867 "signature":"sig",
868 "pre":{"a":1},
869 "post":{"a":2}
870 }
871 }
872 }"#;
873
874 let notification = parse_account_diff_message(text)
875 .expect("notification")
876 .params
877 .result;
878 assert_eq!(notification.context.slot, 123);
879 assert_eq!(notification.signature.as_deref(), Some("sig"));
880 assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
881 assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
882 }
883
884 #[test]
885 fn parse_routed_account_diff_notification_extracts_subscription_account() {
886 let text = r#"{
887 "jsonrpc":"2.0",
888 "method":"accountDiffNotification",
889 "params":{
890 "subscription":42,
891 "result":{
892 "context":{"slot":456},
893 "signature":"sig",
894 "pre":null,
895 "post":{"a":2}
896 }
897 }
898 }"#;
899 let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
900
901 let notification =
902 parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
903 assert_eq!(notification.account, "acct");
904 assert_eq!(notification.notification.context.slot, 456);
905 }
906
907 #[test]
908 fn dedup_accounts_preserves_first_seen_order() {
909 let accounts = dedup_accounts([
910 "b".to_string(),
911 "a".to_string(),
912 "b".to_string(),
913 "c".to_string(),
914 ]);
915 assert_eq!(accounts, vec!["b", "a", "c"]);
916 }
917}