1use std::{future::Future, time::Duration};
2
3use futures::{SinkExt, StreamExt};
4use serde::Deserialize;
5use solana_client::{
6 nonblocking::pubsub_client::PubsubClient,
7 rpc_response::{Response, RpcLogsResponse},
8};
9use solana_commitment_config::CommitmentConfig;
10use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
11use thiserror::Error;
12use tokio::{
13 sync::{oneshot, watch},
14 task::JoinHandle,
15};
16use tokio_tungstenite::tungstenite::Message;
17
18use crate::urls::{UrlError, http_to_ws_url};
19
20#[derive(Debug, Error)]
22pub enum SubscriptionError {
23 #[error(transparent)]
24 InvalidUrl(#[from] UrlError),
25
26 #[error("pubsub connect to {url} failed: {source}")]
27 Connect {
28 url: String,
29 #[source]
30 source: Box<dyn std::error::Error + Send + Sync>,
31 },
32
33 #[error("logs_subscribe failed: {source}")]
34 Subscribe {
35 #[source]
36 source: Box<dyn std::error::Error + Send + Sync>,
37 },
38
39 #[error("subscription task exited unexpectedly before signaling ready")]
40 TaskDropped,
41
42 #[error("session has no rpc_endpoint (was the session created?)")]
43 NoRpcEndpoint,
44}
45
46#[derive(Debug, Error)]
47pub enum SubscriptionRuntimeError {
48 #[error("{kind} subscription for {target} closed unexpectedly")]
49 Closed { kind: &'static str, target: String },
50
51 #[error("{kind} subscription callback worker for {target} failed: {source}")]
52 CallbackWorker {
53 kind: &'static str,
54 target: String,
55 #[source]
56 source: tokio::task::JoinError,
57 },
58}
59
60const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
61const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
62
63type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
64type AccountDiffWs =
65 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
66
67pub struct SubscriptionHandle {
69 pub join_handle: SubscriptionTaskHandle,
70 pub stop: watch::Sender<bool>,
71}
72
73impl From<LogSubscriptionHandle> for SubscriptionHandle {
74 fn from(h: LogSubscriptionHandle) -> Self {
75 Self {
76 join_handle: h.join_handle,
77 stop: h.stop,
78 }
79 }
80}
81
82impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
83 fn from(h: AccountDiffSubscriptionHandle) -> Self {
84 Self {
85 join_handle: h.join_handle,
86 stop: h.stop,
87 }
88 }
89}
90
91pub struct LogSubscriptionHandle {
93 pub join_handle: SubscriptionTaskHandle,
98
99 pub stop: watch::Sender<bool>,
102}
103
104pub async fn subscribe_program_logs<F, Fut>(
142 rpc_endpoint: &str,
143 program_id: &str,
144 commitment: CommitmentConfig,
145 on_notification: F,
146) -> Result<LogSubscriptionHandle, SubscriptionError>
147where
148 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
149 Fut: Future<Output = ()> + Send + 'static,
150{
151 let ws_url = http_to_ws_url(rpc_endpoint)?;
152 let program_id = program_id.to_string();
153
154 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
155 let (stop_tx, mut stop_rx) = watch::channel(false);
156
157 let join_handle = tokio::spawn(async move {
161 let client = match PubsubClient::new(&ws_url).await {
162 Ok(c) => c,
163 Err(e) => {
164 let _ = ready_tx.send(Err(SubscriptionError::Connect {
165 url: ws_url,
166 source: Box::new(e),
167 }));
168 return Ok(());
169 }
170 };
171
172 let (mut stream, _unsubscribe) = match client
173 .logs_subscribe(
174 RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
175 RpcTransactionLogsConfig {
176 commitment: Some(commitment),
177 },
178 )
179 .await
180 {
181 Ok(s) => s,
182 Err(e) => {
183 let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
184 source: Box::new(e),
185 }));
186 return Ok(());
187 }
188 };
189
190 let _ = ready_tx.send(Ok(()));
191
192 let mut tasks: Vec<JoinHandle<()>> = Vec::new();
193 let kind = "program logs";
194
195 loop {
196 if *stop_rx.borrow() {
197 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
198 while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
199 drain_deadline,
200 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
201 )
202 .await
203 {
204 tasks.push(tokio::spawn(on_notification(notification)));
205 }
206 break;
207 }
208
209 let notification = tokio::select! {
210 n = stream.next() => n,
211 _ = stop_rx.changed() => continue,
212 };
213
214 match notification {
215 Some(n) => tasks.push(tokio::spawn(on_notification(n))),
216 None => return Err(subscription_runtime_closed(kind, &program_id)),
217 }
218 }
219
220 for task in tasks {
222 if let Err(source) = task.await {
223 return Err(callback_worker_failed(kind, &program_id, source));
224 }
225 }
226
227 Ok(())
228 });
229
230 match ready_rx.await {
231 Ok(Ok(())) => Ok(LogSubscriptionHandle {
232 join_handle,
233 stop: stop_tx,
234 }),
235 Ok(Err(e)) => {
236 join_handle.abort();
237 Err(e)
238 }
239 Err(_) => {
240 join_handle.abort();
241 Err(SubscriptionError::TaskDropped)
242 }
243 }
244}
245
246#[derive(Debug, Clone, Deserialize)]
250pub struct AccountDiffContext {
251 pub slot: u64,
252}
253
254#[derive(Debug, Clone, Deserialize)]
256pub struct AccountDiffNotification {
257 pub context: AccountDiffContext,
258 pub account: Option<String>,
260 pub signature: Option<String>,
262 #[serde(default)]
264 pub tx_index: Option<u32>,
265 #[serde(default)]
267 pub block_time: Option<i64>,
268 pub pre: Option<serde_json::Value>,
270 pub post: Option<serde_json::Value>,
272}
273
274#[derive(Debug, Clone)]
276pub struct RoutedAccountDiffNotification {
277 pub account: String,
278 pub notification: AccountDiffNotification,
279}
280
281pub struct AccountDiffSubscriptionHandle {
285 pub join_handle: SubscriptionTaskHandle,
286 pub stop: watch::Sender<bool>,
287}
288
289fn subscription_runtime_closed(
290 kind: &'static str,
291 target: impl Into<String>,
292) -> SubscriptionRuntimeError {
293 SubscriptionRuntimeError::Closed {
294 kind,
295 target: target.into(),
296 }
297}
298
299fn callback_worker_failed(
300 kind: &'static str,
301 target: impl Into<String>,
302 source: tokio::task::JoinError,
303) -> SubscriptionRuntimeError {
304 SubscriptionRuntimeError::CallbackWorker {
305 kind,
306 target: target.into(),
307 source,
308 }
309}
310
311pub async fn subscribe_account_diffs<F, Fut>(
341 rpc_endpoint: &str,
342 account: &str,
343 on_notification: F,
344) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
345where
346 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
347 Fut: Future<Output = ()> + Send + 'static,
348{
349 subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
350 on_notification(notification.notification)
351 })
352 .await
353}
354
355pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
361 rpc_endpoint: &str,
362 accounts: I,
363 on_notification: F,
364) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
365where
366 F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
367 Fut: Future<Output = ()> + Send + 'static,
368 I: IntoIterator<Item = S>,
369 S: Into<String>,
370{
371 let ws_url = http_to_ws_url(rpc_endpoint)?;
372 let accounts = dedup_accounts(accounts);
373 if accounts.is_empty() {
374 let (stop_tx, stop_rx) = watch::channel(false);
375 return Ok(AccountDiffSubscriptionHandle {
376 join_handle: tokio::spawn(async move {
377 let _ = stop_rx;
378 Ok(())
379 }),
380 stop: stop_tx,
381 });
382 }
383
384 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
385 let (stop_tx, mut stop_rx) = watch::channel(false);
386 let target = format!("{} accounts", accounts.len());
387
388 let join_handle = tokio::spawn(async move {
389 let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
390 let callback_handle = tokio::spawn(async move {
391 while let Some(notification) = notification_rx.recv().await {
392 on_notification(notification).await;
393 }
394 });
395
396 let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
397 Ok(connection) => connection,
398 Err(e) => {
399 let _ = ready_tx.send(Err(SubscriptionError::Connect {
400 url: ws_url,
401 source: Box::new(e),
402 }));
403 return Ok(());
404 }
405 };
406
407 let subscriptions =
408 match send_account_diff_subscribe_many(&mut ws, &accounts, ¬ification_tx).await {
409 Ok(subscriptions) => subscriptions,
410 Err(error) => {
411 let _ = ready_tx.send(Err(error));
412 return Ok(());
413 }
414 };
415
416 let _ = ready_tx.send(Ok(()));
417
418 if let Err(error) =
419 drive_account_diff_stream_many(&mut ws, &subscriptions, ¬ification_tx, &mut stop_rx)
420 .await
421 {
422 drop(notification_tx);
423 if let Err(source) = callback_handle.await {
424 return Err(callback_worker_failed("account diff", target, source));
425 }
426 return Err(error);
427 }
428
429 drop(notification_tx);
430 if let Err(source) = callback_handle.await {
431 return Err(callback_worker_failed("account diff", target, source));
432 }
433
434 Ok(())
435 });
436
437 match ready_rx.await {
438 Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
439 join_handle,
440 stop: stop_tx,
441 }),
442 Ok(Err(e)) => {
443 join_handle.abort();
444 Err(e)
445 }
446 Err(_) => {
447 join_handle.abort();
448 Err(SubscriptionError::TaskDropped)
449 }
450 }
451}
452
453#[derive(Deserialize)]
454struct AccountDiffMessage {
455 method: String,
456 params: AccountDiffParams,
457}
458
459#[derive(Deserialize)]
460struct AccountDiffParams {
461 subscription: u64,
462 result: AccountDiffNotification,
463}
464
465async fn send_account_diff_subscribe_many(
466 ws: &mut AccountDiffWs,
467 accounts: &[String],
468 notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
469) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
470 #[derive(Deserialize)]
471 struct SubscriptionConfirmation {
472 id: u64,
473 result: Option<u64>,
474 }
475
476 let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
477 let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
478
479 for (index, account) in accounts.iter().enumerate() {
480 let request_id = (index + 1) as u64;
481 let req = serde_json::json!({
482 "jsonrpc": "2.0",
483 "id": request_id,
484 "method": "accountDiffSubscribe",
485 "params": [account]
486 });
487 ws.send(Message::Text(req.to_string()))
488 .await
489 .map_err(|source| SubscriptionError::Subscribe {
490 source: Box::new(source),
491 })?;
492 pending.insert(request_id, account.clone());
493 }
494
495 while !pending.is_empty() {
496 match ws.next().await {
497 Some(Ok(Message::Text(text))) => {
498 if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
499 let Some(account) = pending.remove(&confirmation.id) else {
500 continue;
501 };
502 let Some(subscription_id) = confirmation.result else {
503 return Err(SubscriptionError::TaskDropped);
504 };
505 subscriptions.insert(subscription_id, account);
506 continue;
507 }
508
509 if let Some(notification) =
510 parse_routed_account_diff_notification(&text, &subscriptions)
511 {
512 let _ = notification_tx.send(notification);
513 }
514 }
515 Some(Ok(_)) => {}
516 _ => return Err(SubscriptionError::TaskDropped),
517 }
518 }
519
520 Ok(subscriptions)
521}
522
523async fn drive_account_diff_stream_many(
524 ws: &mut AccountDiffWs,
525 subscriptions: &std::collections::HashMap<u64, String>,
526 notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
527 stop_rx: &mut watch::Receiver<bool>,
528) -> Result<(), SubscriptionRuntimeError> {
529 loop {
530 if *stop_rx.borrow() {
531 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
532 loop {
533 match tokio::time::timeout_at(
534 drain_deadline,
535 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
536 )
537 .await
538 {
539 Ok(Ok(Some(Ok(Message::Text(text))))) => {
540 if let Some(notification) =
541 parse_routed_account_diff_notification(&text, subscriptions)
542 {
543 let _ = notification_tx.send(notification);
544 }
545 }
546 _ => return Ok(()),
547 }
548 }
549 }
550
551 let msg = tokio::select! {
552 m = ws.next() => m,
553 _ = stop_rx.changed() => continue,
554 };
555
556 match msg {
557 Some(Ok(Message::Text(text))) => {
558 if let Some(notification) =
559 parse_routed_account_diff_notification(&text, subscriptions)
560 {
561 let _ = notification_tx.send(notification);
562 }
563 }
564 Some(Ok(_)) => {}
565 _ => {
566 return Err(subscription_runtime_closed(
567 "account diff",
568 format!("{} accounts", subscriptions.len()),
569 ));
570 }
571 }
572 }
573}
574
575fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
576 let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
577 (msg.method == "accountDiffNotification").then_some(msg)
578}
579
580fn parse_routed_account_diff_notification(
581 text: &str,
582 subscriptions: &std::collections::HashMap<u64, String>,
583) -> Option<RoutedAccountDiffNotification> {
584 let msg = parse_account_diff_message(text)?;
585 let account = subscriptions.get(&msg.params.subscription)?.clone();
586 Some(RoutedAccountDiffNotification {
587 account,
588 notification: msg.params.result,
589 })
590}
591
592fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
593where
594 I: IntoIterator<Item = S>,
595 S: Into<String>,
596{
597 let mut unique = std::collections::BTreeSet::new();
598 accounts
599 .into_iter()
600 .map(Into::into)
601 .filter(|account| unique.insert(account.clone()))
602 .collect()
603}
604
605pub async fn subscribe_program_diffs<F, Fut>(
638 rpc_endpoint: &str,
639 program_id: &str,
640 on_notification: F,
641) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
642where
643 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
644 Fut: Future<Output = ()> + Send + 'static,
645{
646 let ws_url = http_to_ws_url(rpc_endpoint)?;
647 let program_id = program_id.to_string();
648
649 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
650 let (stop_tx, mut stop_rx) = watch::channel(false);
651
652 let join_handle = tokio::spawn(async move {
653 let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
654 let callback_handle = tokio::spawn(async move {
655 while let Some(notification) = notification_rx.recv().await {
656 on_notification(notification).await;
657 }
658 });
659
660 let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
661 Ok(connection) => connection,
662 Err(e) => {
663 let _ = ready_tx.send(Err(SubscriptionError::Connect {
664 url: ws_url,
665 source: Box::new(e),
666 }));
667 return Ok(());
668 }
669 };
670
671 if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
672 let _ = ready_tx.send(Err(error));
673 return Ok(());
674 }
675
676 let _ = ready_tx.send(Ok(()));
677
678 if let Err(error) =
679 drive_program_diff_stream(&mut ws, ¬ification_tx, &mut stop_rx, &program_id).await
680 {
681 drop(notification_tx);
682 if let Err(source) = callback_handle.await {
683 return Err(callback_worker_failed(
684 "program account diff",
685 &program_id,
686 source,
687 ));
688 }
689 return Err(error);
690 }
691
692 drop(notification_tx);
693 if let Err(source) = callback_handle.await {
694 return Err(callback_worker_failed(
695 "program account diff",
696 &program_id,
697 source,
698 ));
699 }
700
701 Ok(())
702 });
703
704 match ready_rx.await {
705 Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
706 join_handle,
707 stop: stop_tx,
708 }),
709 Ok(Err(e)) => {
710 join_handle.abort();
711 Err(e)
712 }
713 Err(_) => {
714 join_handle.abort();
715 Err(SubscriptionError::TaskDropped)
716 }
717 }
718}
719
720async fn send_program_diff_subscribe(
721 ws: &mut AccountDiffWs,
722 program_id: &str,
723) -> Result<(), SubscriptionError> {
724 #[derive(Deserialize)]
725 struct SubscriptionConfirmation {
726 result: Option<u64>,
727 }
728
729 let req = serde_json::json!({
730 "jsonrpc": "2.0",
731 "id": 1,
732 "method": "accountDiffSubscribe",
733 "params": [program_id, {"address_type": "program"}]
734 });
735 ws.send(Message::Text(req.to_string()))
736 .await
737 .map_err(|source| SubscriptionError::Subscribe {
738 source: Box::new(source),
739 })?;
740
741 loop {
742 match ws.next().await {
743 Some(Ok(Message::Text(text))) => {
744 match serde_json::from_str::<SubscriptionConfirmation>(&text) {
745 Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
746 Ok(_) => continue,
747 Err(source) => {
748 return Err(SubscriptionError::Subscribe {
749 source: Box::new(source),
750 });
751 }
752 }
753 }
754 Some(Ok(_)) => continue,
755 _ => return Err(SubscriptionError::TaskDropped),
756 }
757 }
758}
759
760async fn drive_program_diff_stream(
761 ws: &mut AccountDiffWs,
762 notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
763 stop_rx: &mut watch::Receiver<bool>,
764 program_id: &str,
765) -> Result<(), SubscriptionRuntimeError> {
766 loop {
767 if *stop_rx.borrow() {
768 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
769 loop {
770 match tokio::time::timeout_at(
771 drain_deadline,
772 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
773 )
774 .await
775 {
776 Ok(Ok(Some(Ok(Message::Text(text))))) => {
777 if let Some(msg) = parse_account_diff_message(&text) {
778 let _ = notification_tx.send(msg.params.result);
779 }
780 }
781 _ => return Ok(()),
782 }
783 }
784 }
785
786 let msg = tokio::select! {
787 m = ws.next() => m,
788 _ = stop_rx.changed() => continue,
789 };
790
791 match msg {
792 Some(Ok(Message::Text(text))) => {
793 if let Some(msg) = parse_account_diff_message(&text) {
794 let _ = notification_tx.send(msg.params.result);
795 }
796 }
797 Some(Ok(_)) => {}
798 _ => {
799 return Err(subscription_runtime_closed(
800 "program account diff",
801 program_id,
802 ));
803 }
804 }
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn parse_account_diff_notification_ignores_other_messages() {
814 let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
815 assert!(parse_account_diff_message(confirmation).is_none());
816 }
817
818 #[test]
819 fn parse_account_diff_notification_extracts_payload() {
820 let text = r#"{
821 "jsonrpc":"2.0",
822 "method":"accountDiffNotification",
823 "params":{
824 "subscription":7,
825 "result":{
826 "context":{"slot":123},
827 "signature":"sig",
828 "pre":{"a":1},
829 "post":{"a":2}
830 }
831 }
832 }"#;
833
834 let notification = parse_account_diff_message(text)
835 .expect("notification")
836 .params
837 .result;
838 assert_eq!(notification.context.slot, 123);
839 assert_eq!(notification.signature.as_deref(), Some("sig"));
840 assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
841 assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
842 }
843
844 #[test]
845 fn parse_routed_account_diff_notification_extracts_subscription_account() {
846 let text = r#"{
847 "jsonrpc":"2.0",
848 "method":"accountDiffNotification",
849 "params":{
850 "subscription":42,
851 "result":{
852 "context":{"slot":456},
853 "signature":"sig",
854 "pre":null,
855 "post":{"a":2}
856 }
857 }
858 }"#;
859 let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
860
861 let notification =
862 parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
863 assert_eq!(notification.account, "acct");
864 assert_eq!(notification.notification.context.slot, 456);
865 }
866
867 #[test]
868 fn dedup_accounts_preserves_first_seen_order() {
869 let accounts = dedup_accounts([
870 "b".to_string(),
871 "a".to_string(),
872 "b".to_string(),
873 "c".to_string(),
874 ]);
875 assert_eq!(accounts, vec!["b", "a", "c"]);
876 }
877}