volt_client_grpc/
volt_connection.rs1use crate::config::VoltClientConfig;
7use crate::error::{Result, VoltError};
8use crate::proto::{ConnectHello, ConnectRequest};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Mutex, RwLock};
12use tokio::time::{interval, Instant};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ConnectionState {
17 Disconnected,
19 Connecting,
21 Connected,
23 Reconnecting,
25}
26
27#[derive(Debug, Clone)]
29pub enum ConnectionEvent {
30 Connected(String),
32 Disconnected,
34 Ping(u64),
36 Error(String),
38 Event(serde_json::Value),
40 InvokeRequest(serde_json::Value),
42}
43
44pub type RequestSender = mpsc::Sender<ConnectRequest>;
46pub type EventReceiver = mpsc::Receiver<ConnectionEvent>;
48
49pub struct VoltConnection {
51 state: Arc<RwLock<ConnectionState>>,
53 connection_id: Arc<RwLock<String>>,
55 _config: VoltClientConfig,
57 auto_retry: bool,
59 ping_interval: Duration,
61 reconnect_interval: Duration,
63 timeout_interval: Duration,
65 dying: Arc<RwLock<bool>>,
67 event_tx: Option<mpsc::Sender<ConnectionEvent>>,
69 request_tx: Arc<Mutex<Option<RequestSender>>>,
71 last_ping: Arc<RwLock<Option<Instant>>>,
73}
74
75impl VoltConnection {
76 pub fn new(config: &VoltClientConfig) -> Self {
78 Self {
79 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
80 connection_id: Arc::new(RwLock::new(String::new())),
81 _config: config.clone(),
82 auto_retry: config.auto_reconnect,
83 ping_interval: Duration::from_millis(config.ping_interval),
84 reconnect_interval: Duration::from_millis(config.reconnect_interval),
85 timeout_interval: Duration::from_millis(config.timeout_interval),
86 dying: Arc::new(RwLock::new(false)),
87 event_tx: None,
88 request_tx: Arc::new(Mutex::new(None)),
89 last_ping: Arc::new(RwLock::new(None)),
90 }
91 }
92
93 pub async fn state(&self) -> ConnectionState {
95 *self.state.read().await
96 }
97
98 pub async fn connection_id(&self) -> String {
100 self.connection_id.read().await.clone()
101 }
102
103 pub async fn is_connected(&self) -> bool {
105 *self.state.read().await == ConnectionState::Connected
106 }
107
108 pub async fn connect(
112 &mut self,
113 hello_payload: Option<serde_json::Value>,
114 ) -> Result<EventReceiver> {
115 let (event_tx, event_rx) = mpsc::channel(100);
117 self.event_tx = Some(event_tx);
118
119 *self.state.write().await = ConnectionState::Connecting;
121
122 self.start_connection(hello_payload).await?;
124
125 Ok(event_rx)
126 }
127
128 async fn start_connection(&self, _hello_payload: Option<serde_json::Value>) -> Result<()> {
130 let state = self.state.clone();
131 let connection_id = self.connection_id.clone();
132 let event_tx = self.event_tx.clone();
133 let dying = self.dying.clone();
134 let auto_retry = self.auto_retry;
135 let reconnect_interval = self.reconnect_interval;
136 let ping_interval = self.ping_interval;
137 let timeout_interval = self.timeout_interval;
138 let _request_tx = self.request_tx.clone();
139 let last_ping = self.last_ping.clone();
140
141 let hello = ConnectHello {
143 ping_interval: ping_interval.as_millis() as u64,
144 timestamp: std::time::SystemTime::now()
145 .duration_since(std::time::UNIX_EPOCH)
146 .unwrap()
147 .as_millis() as u64,
148 };
149
150 tokio::spawn(async move {
152 loop {
153 if *dying.read().await {
155 tracing::debug!("Connection dying, exiting loop");
156 break;
157 }
158
159 tracing::debug!("Would connect with hello: {:?}", hello);
167
168 *state.write().await = ConnectionState::Connected;
170 *connection_id.write().await = uuid::Uuid::new_v4().to_string();
171
172 if let Some(ref tx) = event_tx {
173 let conn_id = connection_id.read().await.clone();
174 let _ = tx.send(ConnectionEvent::Connected(conn_id)).await;
175 }
176
177 let mut ping_timer = interval(ping_interval);
179 loop {
180 tokio::select! {
181 _ = ping_timer.tick() => {
182 let now = std::time::SystemTime::now()
184 .duration_since(std::time::UNIX_EPOCH)
185 .unwrap()
186 .as_millis() as u64;
187
188 *last_ping.write().await = Some(Instant::now());
189 tracing::trace!("Sending ping at {}", now);
190
191 }
193
194 _ = tokio::time::sleep(timeout_interval) => {
196 let last = last_ping.read().await;
197 if let Some(last_ping_time) = *last {
198 if last_ping_time.elapsed() > timeout_interval {
199 tracing::warn!("Connection timed out");
200 break;
201 }
202 }
203 }
204 }
205
206 if *dying.read().await {
208 break;
209 }
210 }
211
212 *state.write().await = ConnectionState::Disconnected;
214 *connection_id.write().await = String::new();
215
216 if let Some(ref tx) = event_tx {
217 let _ = tx.send(ConnectionEvent::Disconnected).await;
218 }
219
220 if !auto_retry || *dying.read().await {
222 break;
223 }
224
225 *state.write().await = ConnectionState::Reconnecting;
227 tokio::time::sleep(reconnect_interval).await;
228 }
229 });
230
231 Ok(())
232 }
233
234 pub async fn disconnect(&mut self) {
236 *self.dying.write().await = true;
237
238 *self.request_tx.lock().await = None;
240
241 *self.state.write().await = ConnectionState::Disconnected;
242 *self.connection_id.write().await = String::new();
243 }
244
245 pub async fn send(&self, request: ConnectRequest) -> Result<()> {
247 let tx = self.request_tx.lock().await;
248
249 if let Some(ref sender) = *tx {
250 sender
251 .send(request)
252 .await
253 .map_err(|_| VoltError::NotConnected)?;
254 Ok(())
255 } else {
256 Err(VoltError::NotConnected)
257 }
258 }
259}
260
261impl Drop for VoltConnection {
262 fn drop(&mut self) {
263 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::config::VoltConfig;
273
274 #[tokio::test]
275 async fn test_connection_state() {
276 let config = VoltClientConfig {
277 client_name: "test".to_string(),
278 volt: VoltConfig::default(),
279 ..Default::default()
280 };
281
282 let conn = VoltConnection::new(&config);
283 assert_eq!(conn.state().await, ConnectionState::Disconnected);
284 }
285}