1use std::collections::HashMap;
7use std::sync::Arc;
8
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use tokio::sync::{broadcast, RwLock};
12
13use crate::error::{ElectrumError, Result};
14use crate::scripthash::address_to_scripthash;
15use crate::transport::Transport;
16use crate::types::ClientConfig;
17
18#[derive(Debug, Clone)]
20pub enum SubscriptionEvent {
21 AddressStatus(AddressStatusEvent),
23 BlockHeader(BlockHeaderEvent),
25 ConnectionStatus(ConnectionStatus),
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct AddressStatusEvent {
32 pub address: String,
34 pub scripthash: String,
36 pub status: Option<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BlockHeaderEvent {
43 pub height: u64,
45 pub hex: String,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum ConnectionStatus {
52 Connected,
54 Disconnected,
56 Reconnecting,
58}
59
60pub struct SubscriptionManager {
62 transport: Arc<Transport>,
63 #[allow(dead_code)]
64 config: ClientConfig,
65 address_subs: RwLock<HashMap<String, String>>,
67 header_sub_active: RwLock<bool>,
69 event_tx: broadcast::Sender<SubscriptionEvent>,
71 request_id: std::sync::atomic::AtomicU64,
73 running: RwLock<bool>,
75}
76
77impl SubscriptionManager {
78 pub async fn new(config: ClientConfig) -> Result<Self> {
80 let transport = Arc::new(Transport::connect(config.clone()).await?);
81 let (event_tx, _) = broadcast::channel(1000);
82
83 Ok(Self {
84 transport,
85 config,
86 address_subs: RwLock::new(HashMap::new()),
87 header_sub_active: RwLock::new(false),
88 event_tx,
89 request_id: std::sync::atomic::AtomicU64::new(1),
90 running: RwLock::new(true),
91 })
92 }
93
94 fn next_id(&self) -> u64 {
96 self.request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
97 }
98
99 pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
101 self.event_tx.subscribe()
102 }
103
104 pub async fn subscribe_address(&self, address: &str) -> Result<Option<String>> {
108 let scripthash = address_to_scripthash(address)?;
109
110 let id = self.next_id();
111 let result = self.transport
112 .request(id, "blockchain.scripthash.subscribe", vec![json!(scripthash)])
113 .await?;
114
115 let mut subs = self.address_subs.write().await;
117 subs.insert(scripthash.clone(), address.to_string());
118
119 let status = result.as_str().map(|s| s.to_string());
121
122 Ok(status)
123 }
124
125 pub async fn unsubscribe_address(&self, address: &str) -> Result<bool> {
127 let scripthash = address_to_scripthash(address)?;
128
129 let id = self.next_id();
130 let result = self.transport
131 .request(id, "blockchain.scripthash.unsubscribe", vec![json!(scripthash)])
132 .await?;
133
134 let mut subs = self.address_subs.write().await;
136 subs.remove(&scripthash);
137
138 Ok(result.as_bool().unwrap_or(false))
139 }
140
141 pub async fn subscribe_headers(&self) -> Result<BlockHeaderEvent> {
145 let id = self.next_id();
146 let result = self.transport
147 .request(id, "blockchain.headers.subscribe", vec![])
148 .await?;
149
150 *self.header_sub_active.write().await = true;
151
152 let height = result.get("height")
153 .and_then(|h| h.as_u64())
154 .ok_or_else(|| ElectrumError::InvalidResponse("Missing height".into()))?;
155
156 let hex = result.get("hex")
157 .and_then(|h| h.as_str())
158 .unwrap_or("")
159 .to_string();
160
161 Ok(BlockHeaderEvent { height, hex })
162 }
163
164 pub async fn subscribed_addresses(&self) -> Vec<String> {
166 let subs = self.address_subs.read().await;
167 subs.values().cloned().collect()
168 }
169
170 pub async fn is_headers_subscribed(&self) -> bool {
172 *self.header_sub_active.read().await
173 }
174
175 pub async fn subscription_count(&self) -> usize {
177 let subs = self.address_subs.read().await;
178 let header_active = *self.header_sub_active.read().await;
179 subs.len() + if header_active { 1 } else { 0 }
180 }
181
182 fn broadcast(&self, event: SubscriptionEvent) {
184 let _ = self.event_tx.send(event);
185 }
186
187 pub async fn process_notification(&self, method: &str, params: &[serde_json::Value]) -> Result<()> {
189 match method {
190 "blockchain.scripthash.subscribe" => {
191 if params.len() >= 2 {
192 let scripthash = params[0].as_str().unwrap_or("").to_string();
193 let status = params[1].as_str().map(|s| s.to_string());
194
195 let subs = self.address_subs.read().await;
196 if let Some(address) = subs.get(&scripthash) {
197 self.broadcast(SubscriptionEvent::AddressStatus(AddressStatusEvent {
198 address: address.clone(),
199 scripthash,
200 status,
201 }));
202 }
203 }
204 }
205 "blockchain.headers.subscribe" => {
206 if let Some(header) = params.first() {
207 let height = header.get("height")
208 .and_then(|h| h.as_u64())
209 .unwrap_or(0);
210 let hex = header.get("hex")
211 .and_then(|h| h.as_str())
212 .unwrap_or("")
213 .to_string();
214
215 self.broadcast(SubscriptionEvent::BlockHeader(BlockHeaderEvent {
216 height,
217 hex,
218 }));
219 }
220 }
221 _ => {}
222 }
223
224 Ok(())
225 }
226
227 pub async fn stop(&self) {
229 *self.running.write().await = false;
230 }
231
232 pub async fn is_running(&self) -> bool {
234 *self.running.read().await
235 }
236}
237
238pub struct SubscriptionClientBuilder {
240 config: ClientConfig,
241 addresses: Vec<String>,
242 subscribe_headers: bool,
243}
244
245impl SubscriptionClientBuilder {
246 pub fn new(config: ClientConfig) -> Self {
248 Self {
249 config,
250 addresses: Vec::new(),
251 subscribe_headers: false,
252 }
253 }
254
255 pub fn subscribe_address(mut self, address: impl Into<String>) -> Self {
257 self.addresses.push(address.into());
258 self
259 }
260
261 pub fn subscribe_addresses(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
263 self.addresses.extend(addresses.into_iter().map(|a| a.into()));
264 self
265 }
266
267 pub fn subscribe_headers(mut self) -> Self {
269 self.subscribe_headers = true;
270 self
271 }
272
273 pub async fn build(self) -> Result<SubscriptionClient> {
275 let manager = SubscriptionManager::new(self.config).await?;
276
277 for address in &self.addresses {
279 manager.subscribe_address(address).await?;
280 }
281
282 if self.subscribe_headers {
284 manager.subscribe_headers().await?;
285 }
286
287 Ok(SubscriptionClient { manager })
288 }
289}
290
291pub struct SubscriptionClient {
293 manager: SubscriptionManager,
294}
295
296impl SubscriptionClient {
297 pub async fn new(config: ClientConfig) -> Result<Self> {
299 let manager = SubscriptionManager::new(config).await?;
300 Ok(Self { manager })
301 }
302
303 pub fn builder(config: ClientConfig) -> SubscriptionClientBuilder {
305 SubscriptionClientBuilder::new(config)
306 }
307
308 pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
310 self.manager.subscribe()
311 }
312
313 pub async fn subscribe_address(&self, address: &str) -> Result<Option<String>> {
315 self.manager.subscribe_address(address).await
316 }
317
318 pub async fn unsubscribe_address(&self, address: &str) -> Result<bool> {
320 self.manager.unsubscribe_address(address).await
321 }
322
323 pub async fn subscribe_headers(&self) -> Result<BlockHeaderEvent> {
325 self.manager.subscribe_headers().await
326 }
327
328 pub async fn subscribed_addresses(&self) -> Vec<String> {
330 self.manager.subscribed_addresses().await
331 }
332
333 pub async fn subscription_count(&self) -> usize {
335 self.manager.subscription_count().await
336 }
337
338 pub async fn stop(&self) {
340 self.manager.stop().await;
341 }
342}
343
344pub struct AddressWatcher {
346 client: SubscriptionClient,
347 addresses: Vec<String>,
348}
349
350impl AddressWatcher {
351 pub async fn new(config: ClientConfig, addresses: Vec<String>) -> Result<Self> {
353 let client = SubscriptionClient::new(config).await?;
354
355 for address in &addresses {
356 client.subscribe_address(address).await?;
357 }
358
359 Ok(Self { client, addresses })
360 }
361
362 pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
364 self.client.subscribe()
365 }
366
367 pub fn addresses(&self) -> &[String] {
369 &self.addresses
370 }
371
372 pub async fn watch(&mut self, address: impl Into<String>) -> Result<()> {
374 let addr = address.into();
375 self.client.subscribe_address(&addr).await?;
376 self.addresses.push(addr);
377 Ok(())
378 }
379
380 pub async fn unwatch(&mut self, address: &str) -> Result<()> {
382 self.client.unsubscribe_address(address).await?;
383 self.addresses.retain(|a| a != address);
384 Ok(())
385 }
386
387 pub async fn stop(&self) {
389 self.client.stop().await;
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_address_status_event() {
399 let event = AddressStatusEvent {
400 address: "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa".to_string(),
401 scripthash: "abc123".to_string(),
402 status: Some("def456".to_string()),
403 };
404
405 assert_eq!(event.address, "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa");
406 assert!(event.status.is_some());
407 }
408
409 #[test]
410 fn test_block_header_event() {
411 let event = BlockHeaderEvent {
412 height: 800000,
413 hex: "0100000000000000".to_string(),
414 };
415
416 assert_eq!(event.height, 800000);
417 }
418
419 #[test]
420 fn test_connection_status() {
421 assert_eq!(ConnectionStatus::Connected, ConnectionStatus::Connected);
422 assert_ne!(ConnectionStatus::Connected, ConnectionStatus::Disconnected);
423 }
424}