1use std::sync::{
22 atomic::{AtomicU64, Ordering},
23 Arc,
24};
25use std::time::Duration;
26
27use axum::{
28 extract::{
29 ws::{Message, WebSocket},
30 State, WebSocketUpgrade,
31 },
32 response::Response,
33 routing::get,
34 Router,
35};
36use futures_util::{SinkExt, StreamExt};
37use serde_json::{json, Value};
38use tokio::net::TcpListener;
39use tokio::sync::{mpsc, Mutex};
40use tokio::task::JoinHandle;
41use tracing::{info, warn};
42
43const POLL_INTERVAL: Duration = Duration::from_millis(500);
47
48const MAX_POLLS: u32 = 240; static NEXT_SUB_ID: AtomicU64 = AtomicU64::new(1);
57
58#[derive(Clone)]
59pub struct WsState {
60 pub upstream_url: String,
61 pub rpc_timeout: Duration,
62}
63
64pub async fn run_ws(
69 port: u16,
70 upstream_url: String,
71 rpc_timeout: Duration,
72) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
73 let state = WsState {
74 upstream_url,
75 rpc_timeout,
76 };
77 let app = Router::new().route("/", get(ws_upgrade)).with_state(state);
78 let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
79 let listener = TcpListener::bind(&addr).await?;
80 info!("tidepool WS listening on ws://{addr}");
81 axum::serve(listener, app).await?;
82 Ok(())
83}
84
85async fn ws_upgrade(ws: WebSocketUpgrade, State(state): State<WsState>) -> Response {
86 ws.on_upgrade(move |socket| handle_connection(socket, state))
87}
88
89#[allow(clippy::too_many_lines)]
92async fn handle_connection(socket: WebSocket, state: WsState) {
93 let (mut sink, mut stream) = socket.split();
94 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
95
96 let subs: Arc<Mutex<std::collections::HashMap<u64, JoinHandle<()>>>> =
99 Arc::new(Mutex::new(std::collections::HashMap::new()));
100
101 let write_task = tokio::spawn(async move {
106 while let Some(msg) = rx.recv().await {
107 if sink.send(msg).await.is_err() {
108 break;
109 }
110 }
111 });
112
113 while let Some(Ok(msg)) = stream.next().await {
115 let Message::Text(text) = msg else {
116 if matches!(msg, Message::Close(_)) {
119 break;
120 }
121 continue;
122 };
123
124 let Ok(req) = serde_json::from_str::<Value>(&text) else {
125 continue;
126 };
127 let method = req.get("method").and_then(Value::as_str).unwrap_or("");
128 let id = req.get("id").cloned().unwrap_or(Value::Null);
129
130 match method {
131 "signatureSubscribe" => {
132 let sub_id = NEXT_SUB_ID.fetch_add(1, Ordering::Relaxed);
133 let Some(signature) = req
134 .get("params")
135 .and_then(Value::as_array)
136 .and_then(|a| a.first())
137 .and_then(Value::as_str)
138 .map(String::from)
139 else {
140 send(&tx, &error_msg(&id, -32602, "missing signature param"));
141 continue;
142 };
143 let commitment = req
144 .get("params")
145 .and_then(Value::as_array)
146 .and_then(|a| a.get(1))
147 .and_then(|v| v.get("commitment"))
148 .and_then(Value::as_str)
149 .unwrap_or("finalized")
150 .to_string();
151
152 send(
154 &tx,
155 &json!({ "jsonrpc": "2.0", "id": id, "result": sub_id }),
156 );
157
158 let poll_tx = tx.clone();
160 let state_clone = state.clone();
161 let subs_clone = Arc::clone(&subs);
162 let handle = tokio::spawn(async move {
163 poll_signature(sub_id, signature, commitment, state_clone, poll_tx).await;
164 subs_clone.lock().await.remove(&sub_id);
168 });
169 subs.lock().await.insert(sub_id, handle);
170 }
171
172 "accountSubscribe" => {
173 let sub_id = NEXT_SUB_ID.fetch_add(1, Ordering::Relaxed);
174 let Some(pubkey) = req
175 .get("params")
176 .and_then(Value::as_array)
177 .and_then(|a| a.first())
178 .and_then(Value::as_str)
179 .map(String::from)
180 else {
181 send(&tx, &error_msg(&id, -32602, "missing account pubkey param"));
182 continue;
183 };
184 let opts = req
185 .get("params")
186 .and_then(Value::as_array)
187 .and_then(|a| a.get(1))
188 .cloned()
189 .unwrap_or(Value::Null);
190 let commitment = opts
191 .get("commitment")
192 .and_then(Value::as_str)
193 .unwrap_or("finalized")
194 .to_string();
195 let encoding = opts
198 .get("encoding")
199 .and_then(Value::as_str)
200 .unwrap_or("base64")
201 .to_string();
202
203 send(
204 &tx,
205 &json!({ "jsonrpc": "2.0", "id": id, "result": sub_id }),
206 );
207
208 let poll_tx = tx.clone();
209 let state_clone = state.clone();
210 let subs_clone = Arc::clone(&subs);
211 let handle = tokio::spawn(async move {
212 poll_account(sub_id, pubkey, commitment, encoding, state_clone, poll_tx).await;
213 subs_clone.lock().await.remove(&sub_id);
214 });
215 subs.lock().await.insert(sub_id, handle);
216 }
217
218 "logsSubscribe" => {
219 let sub_id = NEXT_SUB_ID.fetch_add(1, Ordering::Relaxed);
220 let params = req.get("params").and_then(Value::as_array);
221 let filter = params
222 .and_then(|a| a.first())
223 .cloned()
224 .unwrap_or(Value::Null);
225 let mention = match &filter {
229 Value::Object(map) => map
230 .get("mentions")
231 .and_then(Value::as_array)
232 .and_then(|a| a.first())
233 .and_then(Value::as_str)
234 .map(String::from),
235 Value::String(s) if s == "all" || s == "allWithVotes" => {
236 send(
237 &tx,
238 &error_msg(
239 &id,
240 -32601,
241 "logsSubscribe with filter 'all' / 'allWithVotes' is not \
242 polyfilled by the tidepool WS shim; use { mentions: [pubkey] }",
243 ),
244 );
245 continue;
246 }
247 _ => None,
248 };
249 let Some(mention) = mention else {
250 send(
251 &tx,
252 &error_msg(
253 &id,
254 -32602,
255 "logsSubscribe requires `{ mentions: [pubkey] }` filter",
256 ),
257 );
258 continue;
259 };
260 let commitment = params
261 .and_then(|a| a.get(1))
262 .and_then(|v| v.get("commitment"))
263 .and_then(Value::as_str)
264 .unwrap_or("finalized")
265 .to_string();
266
267 send(
268 &tx,
269 &json!({ "jsonrpc": "2.0", "id": id, "result": sub_id }),
270 );
271
272 let poll_tx = tx.clone();
273 let state_clone = state.clone();
274 let subs_clone = Arc::clone(&subs);
275 let handle = tokio::spawn(async move {
276 poll_logs(sub_id, mention, commitment, state_clone, poll_tx).await;
277 subs_clone.lock().await.remove(&sub_id);
278 });
279 subs.lock().await.insert(sub_id, handle);
280 }
281
282 "signatureUnsubscribe" | "accountUnsubscribe" | "logsUnsubscribe" => {
283 let Some(sub_id) = req
284 .get("params")
285 .and_then(Value::as_array)
286 .and_then(|a| a.first())
287 .and_then(Value::as_u64)
288 else {
289 send(&tx, &error_msg(&id, -32602, "missing subscription id"));
290 continue;
291 };
292 let removed = subs.lock().await.remove(&sub_id);
293 let was_present = removed.is_some();
294 if let Some(handle) = removed {
295 handle.abort();
296 }
297 send(
298 &tx,
299 &json!({
300 "jsonrpc": "2.0",
301 "id": id,
302 "result": was_present
303 }),
304 );
305 }
306
307 _ => {
310 send(
311 &tx,
312 &error_msg(
313 &id,
314 -32601,
315 &format!("method '{method}' is not supported by the tidepool WS polyfill"),
316 ),
317 );
318 }
319 }
320 }
321
322 let mut subs = subs.lock().await;
325 for (_, handle) in subs.drain() {
326 handle.abort();
327 }
328 drop(tx);
329 let _ = write_task.await;
330}
331
332async fn poll_signature(
335 sub_id: u64,
336 signature: String,
337 commitment: String,
338 state: WsState,
339 tx: mpsc::UnboundedSender<Message>,
340) {
341 let client = match reqwest::Client::builder()
342 .timeout(state.rpc_timeout)
343 .build()
344 {
345 Ok(c) => c,
346 Err(e) => {
347 warn!(err = %e, "failed to build reqwest client for ws polling");
348 return;
349 }
350 };
351 for _ in 0..MAX_POLLS {
352 tokio::time::sleep(POLL_INTERVAL).await;
353 let body = json!({
354 "jsonrpc": "2.0",
355 "id": 1,
356 "method": "getSignatureStatuses",
357 "params": [[signature], { "searchTransactionHistory": true }]
358 });
359 let Ok(resp) = client.post(&state.upstream_url).json(&body).send().await else {
360 continue;
361 };
362 let Ok(json): Result<Value, _> = resp.json().await else {
363 continue;
364 };
365 let Some(statuses) = json
366 .get("result")
367 .and_then(|r| r.get("value"))
368 .and_then(Value::as_array)
369 else {
370 continue;
371 };
372 let Some(status) = statuses.first() else {
373 continue;
374 };
375 if status.is_null() {
376 continue; }
378 let status_conf = status
379 .get("confirmationStatus")
380 .and_then(Value::as_str)
381 .unwrap_or("");
382 if commitment_matches(&commitment, status_conf) {
383 let notif = json!({
385 "jsonrpc": "2.0",
386 "method": "signatureNotification",
387 "params": {
388 "result": {
389 "context": json.get("result").and_then(|r| r.get("context")).cloned().unwrap_or(Value::Null),
390 "value": { "err": status.get("err").cloned().unwrap_or(Value::Null) }
391 },
392 "subscription": sub_id
393 }
394 });
395 send(&tx, ¬if);
396 return;
397 }
398 }
399 warn!(sub_id, signature, "signatureSubscribe poll timed out");
400}
401
402async fn poll_account(
415 sub_id: u64,
416 pubkey: String,
417 commitment: String,
418 encoding: String,
419 state: WsState,
420 tx: mpsc::UnboundedSender<Message>,
421) {
422 let client = match reqwest::Client::builder()
423 .timeout(state.rpc_timeout)
424 .build()
425 {
426 Ok(c) => c,
427 Err(e) => {
428 warn!(err = %e, "failed to build reqwest client for account polling");
429 return;
430 }
431 };
432 let mut last: Option<Value> = None;
433 loop {
434 tokio::time::sleep(POLL_INTERVAL).await;
435 let body = json!({
436 "jsonrpc": "2.0",
437 "id": 1,
438 "method": "getAccountInfo",
439 "params": [pubkey, { "commitment": commitment, "encoding": encoding }]
440 });
441 let Ok(resp) = client.post(&state.upstream_url).json(&body).send().await else {
442 continue;
443 };
444 let Ok(json): Result<Value, _> = resp.json().await else {
445 continue;
446 };
447 let Some(result) = json.get("result") else {
448 continue;
449 };
450 let value = result.get("value").cloned().unwrap_or(Value::Null);
454 if last.as_ref() == Some(&value) {
455 continue;
456 }
457 last = Some(value.clone());
458 let notif = json!({
459 "jsonrpc": "2.0",
460 "method": "accountNotification",
461 "params": {
462 "result": {
463 "context": result.get("context").cloned().unwrap_or(Value::Null),
464 "value": value
465 },
466 "subscription": sub_id
467 }
468 });
469 send(&tx, ¬if);
470 }
471}
472
473async fn poll_logs(
485 sub_id: u64,
486 mention: String,
487 commitment: String,
488 state: WsState,
489 tx: mpsc::UnboundedSender<Message>,
490) {
491 let client = match reqwest::Client::builder()
492 .timeout(state.rpc_timeout)
493 .build()
494 {
495 Ok(c) => c,
496 Err(e) => {
497 warn!(err = %e, "failed to build reqwest client for logs polling");
498 return;
499 }
500 };
501 let mut last_seen: Option<String> = None;
502 loop {
503 tokio::time::sleep(POLL_INTERVAL).await;
504 let sigs_body = json!({
505 "jsonrpc": "2.0",
506 "id": 1,
507 "method": "getSignaturesForAddress",
508 "params": [mention, { "commitment": commitment, "limit": 25 }]
509 });
510 let Ok(resp) = client
511 .post(&state.upstream_url)
512 .json(&sigs_body)
513 .send()
514 .await
515 else {
516 continue;
517 };
518 let Ok(json): Result<Value, _> = resp.json().await else {
519 continue;
520 };
521 let Some(entries) = json.get("result").and_then(Value::as_array) else {
522 continue;
523 };
524
525 let mut new_sigs: Vec<String> = Vec::new();
529 for entry in entries.iter().rev() {
530 let Some(sig) = entry.get("signature").and_then(Value::as_str) else {
531 continue;
532 };
533 if last_seen.as_deref() == Some(sig) {
534 new_sigs.clear();
535 continue;
536 }
537 new_sigs.push(sig.to_string());
538 }
539 if last_seen.is_none() {
543 if let Some(sig) = entries
544 .first()
545 .and_then(|e| e.get("signature"))
546 .and_then(Value::as_str)
547 {
548 last_seen = Some(sig.to_string());
549 }
550 continue;
551 }
552
553 for sig in &new_sigs {
554 if let Some(notif) =
555 fetch_logs_notification(&client, &state, &commitment, sub_id, sig).await
556 {
557 send(&tx, ¬if);
558 }
559 }
560 if let Some(last) = new_sigs.last() {
561 last_seen = Some(last.clone());
562 }
563 }
564}
565
566async fn fetch_logs_notification(
569 client: &reqwest::Client,
570 state: &WsState,
571 commitment: &str,
572 sub_id: u64,
573 signature: &str,
574) -> Option<Value> {
575 let body = json!({
576 "jsonrpc": "2.0",
577 "id": 1,
578 "method": "getTransaction",
579 "params": [
580 signature,
581 { "commitment": commitment, "encoding": "json", "maxSupportedTransactionVersion": 0 }
582 ]
583 });
584 let resp = client
585 .post(&state.upstream_url)
586 .json(&body)
587 .send()
588 .await
589 .ok()?;
590 let json: Value = resp.json().await.ok()?;
591 let result = json.get("result")?;
592 let slot = result.get("slot").and_then(Value::as_u64).unwrap_or(0);
593 let meta = result.get("meta").cloned().unwrap_or(Value::Null);
594 let err = meta.get("err").cloned().unwrap_or(Value::Null);
595 let logs = meta
596 .get("logMessages")
597 .cloned()
598 .unwrap_or(Value::Array(Vec::new()));
599 Some(json!({
600 "jsonrpc": "2.0",
601 "method": "logsNotification",
602 "params": {
603 "result": {
604 "context": { "slot": slot },
605 "value": {
606 "signature": signature,
607 "err": err,
608 "logs": logs
609 }
610 },
611 "subscription": sub_id
612 }
613 }))
614}
615
616fn commitment_matches(requested: &str, actual: &str) -> bool {
617 let rank = |s: &str| match s {
621 "processed" => 1,
622 "confirmed" => 2,
623 "finalized" => 3,
624 _ => 0,
625 };
626 rank(actual) >= rank(requested)
627}
628
629fn send(tx: &mpsc::UnboundedSender<Message>, value: &Value) {
632 let _ = tx.send(Message::Text(value.to_string().into()));
633}
634
635fn error_msg(id: &Value, code: i32, message: &str) -> Value {
636 json!({
637 "jsonrpc": "2.0",
638 "id": id,
639 "error": { "code": code, "message": message }
640 })
641}