1use std::{future::Future, time::Duration};
2
3use futures::{SinkExt, StreamExt};
4use serde::Deserialize;
5use solana_client::{
6 nonblocking::pubsub_client::PubsubClient,
7 rpc_response::{Response, RpcLogsResponse},
8};
9use solana_commitment_config::CommitmentConfig;
10use solana_rpc_client_api::config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
11use thiserror::Error;
12use tokio::{
13 sync::{oneshot, watch},
14 task::JoinHandle,
15};
16use tokio_tungstenite::tungstenite::Message;
17
18use crate::urls::{UrlError, http_to_ws_url};
19
20#[derive(Debug, Error)]
22pub enum SubscriptionError {
23 #[error(transparent)]
24 InvalidUrl(#[from] UrlError),
25
26 #[error("pubsub connect to {url} failed: {source}")]
27 Connect {
28 url: String,
29 #[source]
30 source: Box<dyn std::error::Error + Send + Sync>,
31 },
32
33 #[error("logs_subscribe failed: {source}")]
34 Subscribe {
35 #[source]
36 source: Box<dyn std::error::Error + Send + Sync>,
37 },
38
39 #[error("subscription task exited unexpectedly before signaling ready")]
40 TaskDropped,
41
42 #[error("session has no rpc_endpoint (was the session created?)")]
43 NoRpcEndpoint,
44}
45
46#[derive(Debug, Error)]
47pub enum SubscriptionRuntimeError {
48 #[error("{kind} subscription for {target} closed unexpectedly")]
49 Closed { kind: &'static str, target: String },
50
51 #[error("{kind} subscription callback worker for {target} failed: {source}")]
52 CallbackWorker {
53 kind: &'static str,
54 target: String,
55 #[source]
56 source: tokio::task::JoinError,
57 },
58}
59
60const SUBSCRIPTION_DRAIN_IDLE_TIMEOUT: Duration = Duration::from_millis(250);
61const SUBSCRIPTION_DRAIN_MAX_DURATION: Duration = Duration::from_secs(5);
62
63type SubscriptionTaskHandle = JoinHandle<Result<(), SubscriptionRuntimeError>>;
64type AccountDiffWs =
65 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
66
67pub struct SubscriptionHandle {
69 pub join_handle: SubscriptionTaskHandle,
70 pub stop: watch::Sender<bool>,
71}
72
73impl From<LogSubscriptionHandle> for SubscriptionHandle {
74 fn from(h: LogSubscriptionHandle) -> Self {
75 Self {
76 join_handle: h.join_handle,
77 stop: h.stop,
78 }
79 }
80}
81
82impl From<AccountDiffSubscriptionHandle> for SubscriptionHandle {
83 fn from(h: AccountDiffSubscriptionHandle) -> Self {
84 Self {
85 join_handle: h.join_handle,
86 stop: h.stop,
87 }
88 }
89}
90
91pub struct LogSubscriptionHandle {
93 pub join_handle: SubscriptionTaskHandle,
98
99 pub stop: watch::Sender<bool>,
102}
103
104pub async fn subscribe_program_logs<F, Fut>(
142 rpc_endpoint: &str,
143 program_id: &str,
144 commitment: CommitmentConfig,
145 on_notification: F,
146) -> Result<LogSubscriptionHandle, SubscriptionError>
147where
148 F: Fn(Response<RpcLogsResponse>) -> Fut + Send + Sync + 'static,
149 Fut: Future<Output = ()> + Send + 'static,
150{
151 let ws_url = http_to_ws_url(rpc_endpoint)?;
152 let program_id = program_id.to_string();
153
154 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
155 let (stop_tx, mut stop_rx) = watch::channel(false);
156
157 let join_handle = tokio::spawn(async move {
161 let client = match PubsubClient::new(&ws_url).await {
162 Ok(c) => c,
163 Err(e) => {
164 let _ = ready_tx.send(Err(SubscriptionError::Connect {
165 url: ws_url,
166 source: Box::new(e),
167 }));
168 return Ok(());
169 }
170 };
171
172 let (mut stream, _unsubscribe) = match client
173 .logs_subscribe(
174 RpcTransactionLogsFilter::Mentions(vec![program_id.clone()]),
175 RpcTransactionLogsConfig {
176 commitment: Some(commitment),
177 },
178 )
179 .await
180 {
181 Ok(s) => s,
182 Err(e) => {
183 let _ = ready_tx.send(Err(SubscriptionError::Subscribe {
184 source: Box::new(e),
185 }));
186 return Ok(());
187 }
188 };
189
190 let _ = ready_tx.send(Ok(()));
191
192 let mut tasks: Vec<JoinHandle<()>> = Vec::new();
193 let kind = "program logs";
194
195 loop {
196 if *stop_rx.borrow() {
197 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
198 while let Ok(Ok(Some(notification))) = tokio::time::timeout_at(
199 drain_deadline,
200 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, stream.next()),
201 )
202 .await
203 {
204 tasks.push(tokio::spawn(on_notification(notification)));
205 }
206 break;
207 }
208
209 let notification = tokio::select! {
210 n = stream.next() => n,
211 _ = stop_rx.changed() => continue,
212 };
213
214 match notification {
215 Some(n) => tasks.push(tokio::spawn(on_notification(n))),
216 None => return Err(subscription_runtime_closed(kind, &program_id)),
217 }
218 }
219
220 for task in tasks {
222 if let Err(source) = task.await {
223 return Err(callback_worker_failed(kind, &program_id, source));
224 }
225 }
226
227 Ok(())
228 });
229
230 match ready_rx.await {
231 Ok(Ok(())) => Ok(LogSubscriptionHandle {
232 join_handle,
233 stop: stop_tx,
234 }),
235 Ok(Err(e)) => {
236 join_handle.abort();
237 Err(e)
238 }
239 Err(_) => {
240 join_handle.abort();
241 Err(SubscriptionError::TaskDropped)
242 }
243 }
244}
245
246#[derive(Debug, Clone, Deserialize)]
250pub struct AccountDiffContext {
251 pub slot: u64,
252}
253
254#[derive(Debug, Clone, Deserialize)]
256pub struct AccountDiffNotification {
257 pub context: AccountDiffContext,
258 pub account: Option<String>,
260 pub signature: Option<String>,
262 pub pre: Option<serde_json::Value>,
264 pub post: Option<serde_json::Value>,
266}
267
268#[derive(Debug, Clone)]
270pub struct RoutedAccountDiffNotification {
271 pub account: String,
272 pub notification: AccountDiffNotification,
273}
274
275pub struct AccountDiffSubscriptionHandle {
279 pub join_handle: SubscriptionTaskHandle,
280 pub stop: watch::Sender<bool>,
281}
282
283fn subscription_runtime_closed(
284 kind: &'static str,
285 target: impl Into<String>,
286) -> SubscriptionRuntimeError {
287 SubscriptionRuntimeError::Closed {
288 kind,
289 target: target.into(),
290 }
291}
292
293fn callback_worker_failed(
294 kind: &'static str,
295 target: impl Into<String>,
296 source: tokio::task::JoinError,
297) -> SubscriptionRuntimeError {
298 SubscriptionRuntimeError::CallbackWorker {
299 kind,
300 target: target.into(),
301 source,
302 }
303}
304
305pub async fn subscribe_account_diffs<F, Fut>(
335 rpc_endpoint: &str,
336 account: &str,
337 on_notification: F,
338) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
339where
340 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
341 Fut: Future<Output = ()> + Send + 'static,
342{
343 subscribe_account_diffs_many(rpc_endpoint, [account.to_string()], move |notification| {
344 on_notification(notification.notification)
345 })
346 .await
347}
348
349pub async fn subscribe_account_diffs_many<F, Fut, I, S>(
355 rpc_endpoint: &str,
356 accounts: I,
357 on_notification: F,
358) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
359where
360 F: Fn(RoutedAccountDiffNotification) -> Fut + Send + Sync + 'static,
361 Fut: Future<Output = ()> + Send + 'static,
362 I: IntoIterator<Item = S>,
363 S: Into<String>,
364{
365 let ws_url = http_to_ws_url(rpc_endpoint)?;
366 let accounts = dedup_accounts(accounts);
367 if accounts.is_empty() {
368 let (stop_tx, stop_rx) = watch::channel(false);
369 return Ok(AccountDiffSubscriptionHandle {
370 join_handle: tokio::spawn(async move {
371 let _ = stop_rx;
372 Ok(())
373 }),
374 stop: stop_tx,
375 });
376 }
377
378 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
379 let (stop_tx, mut stop_rx) = watch::channel(false);
380 let target = format!("{} accounts", accounts.len());
381
382 let join_handle = tokio::spawn(async move {
383 let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
384 let callback_handle = tokio::spawn(async move {
385 while let Some(notification) = notification_rx.recv().await {
386 on_notification(notification).await;
387 }
388 });
389
390 let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
391 Ok(connection) => connection,
392 Err(e) => {
393 let _ = ready_tx.send(Err(SubscriptionError::Connect {
394 url: ws_url,
395 source: Box::new(e),
396 }));
397 return Ok(());
398 }
399 };
400
401 let subscriptions =
402 match send_account_diff_subscribe_many(&mut ws, &accounts, ¬ification_tx).await {
403 Ok(subscriptions) => subscriptions,
404 Err(error) => {
405 let _ = ready_tx.send(Err(error));
406 return Ok(());
407 }
408 };
409
410 let _ = ready_tx.send(Ok(()));
411
412 if let Err(error) =
413 drive_account_diff_stream_many(&mut ws, &subscriptions, ¬ification_tx, &mut stop_rx)
414 .await
415 {
416 drop(notification_tx);
417 if let Err(source) = callback_handle.await {
418 return Err(callback_worker_failed("account diff", target, source));
419 }
420 return Err(error);
421 }
422
423 drop(notification_tx);
424 if let Err(source) = callback_handle.await {
425 return Err(callback_worker_failed("account diff", target, source));
426 }
427
428 Ok(())
429 });
430
431 match ready_rx.await {
432 Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
433 join_handle,
434 stop: stop_tx,
435 }),
436 Ok(Err(e)) => {
437 join_handle.abort();
438 Err(e)
439 }
440 Err(_) => {
441 join_handle.abort();
442 Err(SubscriptionError::TaskDropped)
443 }
444 }
445}
446
447#[derive(Deserialize)]
448struct AccountDiffMessage {
449 method: String,
450 params: AccountDiffParams,
451}
452
453#[derive(Deserialize)]
454struct AccountDiffParams {
455 subscription: u64,
456 result: AccountDiffNotification,
457}
458
459async fn send_account_diff_subscribe_many(
460 ws: &mut AccountDiffWs,
461 accounts: &[String],
462 notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
463) -> Result<std::collections::HashMap<u64, String>, SubscriptionError> {
464 #[derive(Deserialize)]
465 struct SubscriptionConfirmation {
466 id: u64,
467 result: Option<u64>,
468 }
469
470 let mut pending: std::collections::HashMap<u64, String> = std::collections::HashMap::new();
471 let mut subscriptions = std::collections::HashMap::with_capacity(accounts.len());
472
473 for (index, account) in accounts.iter().enumerate() {
474 let request_id = (index + 1) as u64;
475 let req = serde_json::json!({
476 "jsonrpc": "2.0",
477 "id": request_id,
478 "method": "accountDiffSubscribe",
479 "params": [account]
480 });
481 ws.send(Message::Text(req.to_string()))
482 .await
483 .map_err(|source| SubscriptionError::Subscribe {
484 source: Box::new(source),
485 })?;
486 pending.insert(request_id, account.clone());
487 }
488
489 while !pending.is_empty() {
490 match ws.next().await {
491 Some(Ok(Message::Text(text))) => {
492 if let Ok(confirmation) = serde_json::from_str::<SubscriptionConfirmation>(&text) {
493 let Some(account) = pending.remove(&confirmation.id) else {
494 continue;
495 };
496 let Some(subscription_id) = confirmation.result else {
497 return Err(SubscriptionError::TaskDropped);
498 };
499 subscriptions.insert(subscription_id, account);
500 continue;
501 }
502
503 if let Some(notification) =
504 parse_routed_account_diff_notification(&text, &subscriptions)
505 {
506 let _ = notification_tx.send(notification);
507 }
508 }
509 Some(Ok(_)) => {}
510 _ => return Err(SubscriptionError::TaskDropped),
511 }
512 }
513
514 Ok(subscriptions)
515}
516
517async fn drive_account_diff_stream_many(
518 ws: &mut AccountDiffWs,
519 subscriptions: &std::collections::HashMap<u64, String>,
520 notification_tx: &tokio::sync::mpsc::UnboundedSender<RoutedAccountDiffNotification>,
521 stop_rx: &mut watch::Receiver<bool>,
522) -> Result<(), SubscriptionRuntimeError> {
523 loop {
524 if *stop_rx.borrow() {
525 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
526 loop {
527 match tokio::time::timeout_at(
528 drain_deadline,
529 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
530 )
531 .await
532 {
533 Ok(Ok(Some(Ok(Message::Text(text))))) => {
534 if let Some(notification) =
535 parse_routed_account_diff_notification(&text, subscriptions)
536 {
537 let _ = notification_tx.send(notification);
538 }
539 }
540 _ => return Ok(()),
541 }
542 }
543 }
544
545 let msg = tokio::select! {
546 m = ws.next() => m,
547 _ = stop_rx.changed() => continue,
548 };
549
550 match msg {
551 Some(Ok(Message::Text(text))) => {
552 if let Some(notification) =
553 parse_routed_account_diff_notification(&text, subscriptions)
554 {
555 let _ = notification_tx.send(notification);
556 }
557 }
558 Some(Ok(_)) => {}
559 _ => {
560 return Err(subscription_runtime_closed(
561 "account diff",
562 format!("{} accounts", subscriptions.len()),
563 ));
564 }
565 }
566 }
567}
568
569fn parse_account_diff_message(text: &str) -> Option<AccountDiffMessage> {
570 let msg: AccountDiffMessage = serde_json::from_str(text).ok()?;
571 (msg.method == "accountDiffNotification").then_some(msg)
572}
573
574fn parse_routed_account_diff_notification(
575 text: &str,
576 subscriptions: &std::collections::HashMap<u64, String>,
577) -> Option<RoutedAccountDiffNotification> {
578 let msg = parse_account_diff_message(text)?;
579 let account = subscriptions.get(&msg.params.subscription)?.clone();
580 Some(RoutedAccountDiffNotification {
581 account,
582 notification: msg.params.result,
583 })
584}
585
586fn dedup_accounts<I, S>(accounts: I) -> Vec<String>
587where
588 I: IntoIterator<Item = S>,
589 S: Into<String>,
590{
591 let mut unique = std::collections::BTreeSet::new();
592 accounts
593 .into_iter()
594 .map(Into::into)
595 .filter(|account| unique.insert(account.clone()))
596 .collect()
597}
598
599pub async fn subscribe_program_diffs<F, Fut>(
632 rpc_endpoint: &str,
633 program_id: &str,
634 on_notification: F,
635) -> Result<AccountDiffSubscriptionHandle, SubscriptionError>
636where
637 F: Fn(AccountDiffNotification) -> Fut + Send + Sync + 'static,
638 Fut: Future<Output = ()> + Send + 'static,
639{
640 let ws_url = http_to_ws_url(rpc_endpoint)?;
641 let program_id = program_id.to_string();
642
643 let (ready_tx, ready_rx) = oneshot::channel::<Result<(), SubscriptionError>>();
644 let (stop_tx, mut stop_rx) = watch::channel(false);
645
646 let join_handle = tokio::spawn(async move {
647 let (notification_tx, mut notification_rx) = tokio::sync::mpsc::unbounded_channel();
648 let callback_handle = tokio::spawn(async move {
649 while let Some(notification) = notification_rx.recv().await {
650 on_notification(notification).await;
651 }
652 });
653
654 let (mut ws, _) = match tokio_tungstenite::connect_async(&ws_url).await {
655 Ok(connection) => connection,
656 Err(e) => {
657 let _ = ready_tx.send(Err(SubscriptionError::Connect {
658 url: ws_url,
659 source: Box::new(e),
660 }));
661 return Ok(());
662 }
663 };
664
665 if let Err(error) = send_program_diff_subscribe(&mut ws, &program_id).await {
666 let _ = ready_tx.send(Err(error));
667 return Ok(());
668 }
669
670 let _ = ready_tx.send(Ok(()));
671
672 if let Err(error) =
673 drive_program_diff_stream(&mut ws, ¬ification_tx, &mut stop_rx, &program_id).await
674 {
675 drop(notification_tx);
676 if let Err(source) = callback_handle.await {
677 return Err(callback_worker_failed(
678 "program account diff",
679 &program_id,
680 source,
681 ));
682 }
683 return Err(error);
684 }
685
686 drop(notification_tx);
687 if let Err(source) = callback_handle.await {
688 return Err(callback_worker_failed(
689 "program account diff",
690 &program_id,
691 source,
692 ));
693 }
694
695 Ok(())
696 });
697
698 match ready_rx.await {
699 Ok(Ok(())) => Ok(AccountDiffSubscriptionHandle {
700 join_handle,
701 stop: stop_tx,
702 }),
703 Ok(Err(e)) => {
704 join_handle.abort();
705 Err(e)
706 }
707 Err(_) => {
708 join_handle.abort();
709 Err(SubscriptionError::TaskDropped)
710 }
711 }
712}
713
714async fn send_program_diff_subscribe(
715 ws: &mut AccountDiffWs,
716 program_id: &str,
717) -> Result<(), SubscriptionError> {
718 #[derive(Deserialize)]
719 struct SubscriptionConfirmation {
720 result: Option<u64>,
721 }
722
723 let req = serde_json::json!({
724 "jsonrpc": "2.0",
725 "id": 1,
726 "method": "accountDiffSubscribe",
727 "params": [program_id, {"address_type": "program"}]
728 });
729 ws.send(Message::Text(req.to_string()))
730 .await
731 .map_err(|source| SubscriptionError::Subscribe {
732 source: Box::new(source),
733 })?;
734
735 loop {
736 match ws.next().await {
737 Some(Ok(Message::Text(text))) => {
738 match serde_json::from_str::<SubscriptionConfirmation>(&text) {
739 Ok(SubscriptionConfirmation { result: Some(_) }) => return Ok(()),
740 Ok(_) => continue,
741 Err(source) => {
742 return Err(SubscriptionError::Subscribe {
743 source: Box::new(source),
744 });
745 }
746 }
747 }
748 Some(Ok(_)) => continue,
749 _ => return Err(SubscriptionError::TaskDropped),
750 }
751 }
752}
753
754async fn drive_program_diff_stream(
755 ws: &mut AccountDiffWs,
756 notification_tx: &tokio::sync::mpsc::UnboundedSender<AccountDiffNotification>,
757 stop_rx: &mut watch::Receiver<bool>,
758 program_id: &str,
759) -> Result<(), SubscriptionRuntimeError> {
760 loop {
761 if *stop_rx.borrow() {
762 let drain_deadline = tokio::time::Instant::now() + SUBSCRIPTION_DRAIN_MAX_DURATION;
763 loop {
764 match tokio::time::timeout_at(
765 drain_deadline,
766 tokio::time::timeout(SUBSCRIPTION_DRAIN_IDLE_TIMEOUT, ws.next()),
767 )
768 .await
769 {
770 Ok(Ok(Some(Ok(Message::Text(text))))) => {
771 if let Some(msg) = parse_account_diff_message(&text) {
772 let _ = notification_tx.send(msg.params.result);
773 }
774 }
775 _ => return Ok(()),
776 }
777 }
778 }
779
780 let msg = tokio::select! {
781 m = ws.next() => m,
782 _ = stop_rx.changed() => continue,
783 };
784
785 match msg {
786 Some(Ok(Message::Text(text))) => {
787 if let Some(msg) = parse_account_diff_message(&text) {
788 let _ = notification_tx.send(msg.params.result);
789 }
790 }
791 Some(Ok(_)) => {}
792 _ => {
793 return Err(subscription_runtime_closed(
794 "program account diff",
795 program_id,
796 ));
797 }
798 }
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use super::*;
805
806 #[test]
807 fn parse_account_diff_notification_ignores_other_messages() {
808 let confirmation = r#"{"jsonrpc":"2.0","result":1,"id":1}"#;
809 assert!(parse_account_diff_message(confirmation).is_none());
810 }
811
812 #[test]
813 fn parse_account_diff_notification_extracts_payload() {
814 let text = r#"{
815 "jsonrpc":"2.0",
816 "method":"accountDiffNotification",
817 "params":{
818 "subscription":7,
819 "result":{
820 "context":{"slot":123},
821 "signature":"sig",
822 "pre":{"a":1},
823 "post":{"a":2}
824 }
825 }
826 }"#;
827
828 let notification = parse_account_diff_message(text)
829 .expect("notification")
830 .params
831 .result;
832 assert_eq!(notification.context.slot, 123);
833 assert_eq!(notification.signature.as_deref(), Some("sig"));
834 assert_eq!(notification.pre, Some(serde_json::json!({"a": 1})));
835 assert_eq!(notification.post, Some(serde_json::json!({"a": 2})));
836 }
837
838 #[test]
839 fn parse_routed_account_diff_notification_extracts_subscription_account() {
840 let text = r#"{
841 "jsonrpc":"2.0",
842 "method":"accountDiffNotification",
843 "params":{
844 "subscription":42,
845 "result":{
846 "context":{"slot":456},
847 "signature":"sig",
848 "pre":null,
849 "post":{"a":2}
850 }
851 }
852 }"#;
853 let subscriptions = std::collections::HashMap::from([(42_u64, "acct".to_string())]);
854
855 let notification =
856 parse_routed_account_diff_notification(text, &subscriptions).expect("notification");
857 assert_eq!(notification.account, "acct");
858 assert_eq!(notification.notification.context.slot, 456);
859 }
860
861 #[test]
862 fn dedup_accounts_preserves_first_seen_order() {
863 let accounts = dedup_accounts([
864 "b".to_string(),
865 "a".to_string(),
866 "b".to_string(),
867 "c".to_string(),
868 ]);
869 assert_eq!(accounts, vec!["b", "a", "c"]);
870 }
871}