1use crate::chains::dex::DexTokenData;
7use crate::chains::{ChainClientFactory, DexDataSource};
8use crate::web::AppState;
9use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
10use axum::extract::{Query, State};
11use axum::response::IntoResponse;
12use serde::Deserialize;
13use std::sync::Arc;
14use std::time::Duration;
15
16#[derive(Debug, Deserialize)]
18pub struct MonitorQuery {
19 pub token: String,
21 #[serde(default = "default_chain")]
23 pub chain: String,
24 #[serde(default = "default_refresh")]
26 pub refresh: u64,
27}
28
29fn default_chain() -> String {
30 "ethereum".to_string()
31}
32
33fn default_refresh() -> u64 {
34 5
35}
36
37pub async fn ws_handler(
39 ws: WebSocketUpgrade,
40 State(state): State<Arc<AppState>>,
41 Query(params): Query<MonitorQuery>,
42) -> impl IntoResponse {
43 ws.on_upgrade(move |socket| handle_socket(socket, state, params))
44}
45
46async fn handle_socket(mut socket: WebSocket, state: Arc<AppState>, params: MonitorQuery) {
51 let dex_client: Box<dyn DexDataSource> = state.factory.create_dex_client();
52 let refresh = Duration::from_secs(params.refresh.max(1));
53
54 let token_input = params.token.clone();
56 let chain = params.chain.clone();
57
58 let init_msg = serde_json::json!({
60 "type": "connected",
61 "token": token_input,
62 "chain": chain,
63 "refresh_secs": params.refresh,
64 });
65 if socket
66 .send(Message::Text(init_msg.to_string()))
67 .await
68 .is_err()
69 {
70 return;
71 }
72
73 loop {
74 let data: crate::error::Result<DexTokenData> =
76 dex_client.get_token_data(&chain, &token_input).await;
77
78 let msg = match data {
79 Ok(token_data) => {
80 serde_json::json!({
81 "type": "update",
82 "timestamp": chrono::Utc::now().to_rfc3339(),
83 "token": {
84 "symbol": token_data.symbol,
85 "name": token_data.name,
86 "address": token_data.address,
87 },
88 "price_usd": token_data.price_usd,
89 "price_change_24h": token_data.price_change_24h,
90 "price_change_6h": token_data.price_change_6h,
91 "price_change_1h": token_data.price_change_1h,
92 "volume_24h": token_data.volume_24h,
93 "volume_6h": token_data.volume_6h,
94 "volume_1h": token_data.volume_1h,
95 "liquidity_usd": token_data.liquidity_usd,
96 "market_cap": token_data.market_cap,
97 "buys_24h": token_data.total_buys_24h,
98 "sells_24h": token_data.total_sells_24h,
99 "buys_1h": token_data.total_buys_1h,
100 "sells_1h": token_data.total_sells_1h,
101 "pairs": token_data.pairs.iter().take(5).map(|p| {
102 serde_json::json!({
103 "dex": p.dex_name,
104 "base": p.base_token,
105 "quote": p.quote_token,
106 "price_usd": p.price_usd,
107 "volume_24h": p.volume_24h,
108 "liquidity_usd": p.liquidity_usd,
109 })
110 }).collect::<Vec<_>>(),
111 })
112 }
113 Err(e) => {
114 serde_json::json!({
115 "type": "error",
116 "message": e.to_string(),
117 })
118 }
119 };
120
121 if socket.send(Message::Text(msg.to_string())).await.is_err() {
122 break;
124 }
125
126 tokio::select! {
128 _ = tokio::time::sleep(refresh) => {},
129 msg = socket.recv() => {
130 match msg {
131 Some(Ok(Message::Close(_))) | None => break,
132 Some(Ok(Message::Text(text))) => {
133 if let Ok(cmd) = serde_json::from_str::<serde_json::Value>(&text)
135 && cmd.get("type").and_then(|t| t.as_str()) == Some("ping")
136 {
137 let pong = serde_json::json!({"type": "pong"});
138 let _ = socket.send(Message::Text(pong.to_string())).await;
139 }
140 }
141 _ => {}
142 }
143 }
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_default_chain() {
154 assert_eq!(default_chain(), "ethereum");
155 }
156
157 #[test]
158 fn test_default_refresh() {
159 assert_eq!(default_refresh(), 5);
160 }
161
162 #[test]
163 fn test_deserialize_monitor_query_full() {
164 let json = serde_json::json!({
165 "token": "USDC",
166 "chain": "solana",
167 "refresh": 10
168 });
169 let query: MonitorQuery = serde_json::from_value(json).unwrap();
170 assert_eq!(query.token, "USDC");
171 assert_eq!(query.chain, "solana");
172 assert_eq!(query.refresh, 10);
173 }
174
175 #[test]
176 fn test_deserialize_monitor_query_minimal() {
177 let json = serde_json::json!({
178 "token": "ETH"
179 });
180 let query: MonitorQuery = serde_json::from_value(json).unwrap();
181 assert_eq!(query.token, "ETH");
182 assert_eq!(query.chain, "ethereum");
183 assert_eq!(query.refresh, 5);
184 }
185
186 #[test]
187 fn test_deserialize_monitor_query_custom_refresh() {
188 let json = serde_json::json!({
189 "token": "BTC",
190 "refresh": 30
191 });
192 let query: MonitorQuery = serde_json::from_value(json).unwrap();
193 assert_eq!(query.token, "BTC");
194 assert_eq!(query.chain, "ethereum");
195 assert_eq!(query.refresh, 30);
196 }
197}