Skip to main content

volt_client_grpc/
volt_connection.rs

1//! Volt connection management.
2//!
3//! Handles persistent connections to the Volt with automatic reconnection
4//! and ping/pong keep-alive.
5
6use crate::config::VoltClientConfig;
7use crate::error::{Result, VoltError};
8use crate::proto::{ConnectHello, ConnectRequest};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{mpsc, Mutex, RwLock};
12use tokio::time::{interval, Instant};
13
14/// Connection state
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum ConnectionState {
17    /// Not connected
18    Disconnected,
19    /// Connecting
20    Connecting,
21    /// Connected and active
22    Connected,
23    /// Reconnecting after disconnect
24    Reconnecting,
25}
26
27/// Events emitted by the connection
28#[derive(Debug, Clone)]
29pub enum ConnectionEvent {
30    /// Connection established with connection ID
31    Connected(String),
32    /// Connection lost
33    Disconnected,
34    /// Ping received with timestamp
35    Ping(u64),
36    /// Error occurred
37    Error(String),
38    /// Event received from server
39    Event(serde_json::Value),
40    /// Invoke request received
41    InvokeRequest(serde_json::Value),
42}
43
44/// Channel for sending requests on the connection
45pub type RequestSender = mpsc::Sender<ConnectRequest>;
46/// Channel for receiving connection events
47pub type EventReceiver = mpsc::Receiver<ConnectionEvent>;
48
49/// Manages a persistent connection to a Volt server
50pub struct VoltConnection {
51    /// Current connection state
52    state: Arc<RwLock<ConnectionState>>,
53    /// Connection ID assigned by the server
54    connection_id: Arc<RwLock<String>>,
55    /// Configuration
56    _config: VoltClientConfig,
57    /// Whether auto-reconnect is enabled
58    auto_retry: bool,
59    /// Ping interval in milliseconds
60    ping_interval: Duration,
61    /// Reconnect interval in milliseconds
62    reconnect_interval: Duration,
63    /// Timeout interval in milliseconds  
64    timeout_interval: Duration,
65    /// Flag to signal shutdown
66    dying: Arc<RwLock<bool>>,
67    /// Event sender
68    event_tx: Option<mpsc::Sender<ConnectionEvent>>,
69    /// Request sender for the active connection
70    request_tx: Arc<Mutex<Option<RequestSender>>>,
71    /// Last ping timestamp
72    last_ping: Arc<RwLock<Option<Instant>>>,
73}
74
75impl VoltConnection {
76    /// Create a new connection manager
77    pub fn new(config: &VoltClientConfig) -> Self {
78        Self {
79            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
80            connection_id: Arc::new(RwLock::new(String::new())),
81            _config: config.clone(),
82            auto_retry: config.auto_reconnect,
83            ping_interval: Duration::from_millis(config.ping_interval),
84            reconnect_interval: Duration::from_millis(config.reconnect_interval),
85            timeout_interval: Duration::from_millis(config.timeout_interval),
86            dying: Arc::new(RwLock::new(false)),
87            event_tx: None,
88            request_tx: Arc::new(Mutex::new(None)),
89            last_ping: Arc::new(RwLock::new(None)),
90        }
91    }
92
93    /// Get the current connection state
94    pub async fn state(&self) -> ConnectionState {
95        *self.state.read().await
96    }
97
98    /// Get the connection ID
99    pub async fn connection_id(&self) -> String {
100        self.connection_id.read().await.clone()
101    }
102
103    /// Check if connected
104    pub async fn is_connected(&self) -> bool {
105        *self.state.read().await == ConnectionState::Connected
106    }
107
108    /// Connect to the Volt server
109    ///
110    /// Returns a receiver for connection events
111    pub async fn connect(
112        &mut self,
113        hello_payload: Option<serde_json::Value>,
114    ) -> Result<EventReceiver> {
115        // Create event channel
116        let (event_tx, event_rx) = mpsc::channel(100);
117        self.event_tx = Some(event_tx);
118
119        // Set state to connecting
120        *self.state.write().await = ConnectionState::Connecting;
121
122        // Start connection task
123        self.start_connection(hello_payload).await?;
124
125        Ok(event_rx)
126    }
127
128    /// Start the connection process
129    async fn start_connection(&self, _hello_payload: Option<serde_json::Value>) -> Result<()> {
130        let state = self.state.clone();
131        let connection_id = self.connection_id.clone();
132        let event_tx = self.event_tx.clone();
133        let dying = self.dying.clone();
134        let auto_retry = self.auto_retry;
135        let reconnect_interval = self.reconnect_interval;
136        let ping_interval = self.ping_interval;
137        let timeout_interval = self.timeout_interval;
138        let _request_tx = self.request_tx.clone();
139        let last_ping = self.last_ping.clone();
140
141        // Build hello message
142        let hello = ConnectHello {
143            ping_interval: ping_interval.as_millis() as u64,
144            timestamp: std::time::SystemTime::now()
145                .duration_since(std::time::UNIX_EPOCH)
146                .unwrap()
147                .as_millis() as u64,
148        };
149
150        // Spawn connection task
151        tokio::spawn(async move {
152            loop {
153                // Check if we're shutting down
154                if *dying.read().await {
155                    tracing::debug!("Connection dying, exiting loop");
156                    break;
157                }
158
159                // TODO: Establish actual gRPC connection here
160                // This is a placeholder for the connection logic
161                // In a real implementation, this would:
162                // 1. Create a gRPC channel
163                // 2. Call the Connect RPC
164                // 3. Handle the bidirectional stream
165
166                tracing::debug!("Would connect with hello: {:?}", hello);
167
168                // Simulate connection established
169                *state.write().await = ConnectionState::Connected;
170                *connection_id.write().await = uuid::Uuid::new_v4().to_string();
171
172                if let Some(ref tx) = event_tx {
173                    let conn_id = connection_id.read().await.clone();
174                    let _ = tx.send(ConnectionEvent::Connected(conn_id)).await;
175                }
176
177                // Start ping loop
178                let mut ping_timer = interval(ping_interval);
179                loop {
180                    tokio::select! {
181                        _ = ping_timer.tick() => {
182                            // Send ping
183                            let now = std::time::SystemTime::now()
184                                .duration_since(std::time::UNIX_EPOCH)
185                                .unwrap()
186                                .as_millis() as u64;
187
188                            *last_ping.write().await = Some(Instant::now());
189                            tracing::trace!("Sending ping at {}", now);
190
191                            // TODO: Actually send ping on the stream
192                        }
193
194                        // Check for timeout
195                        _ = tokio::time::sleep(timeout_interval) => {
196                            let last = last_ping.read().await;
197                            if let Some(last_ping_time) = *last {
198                                if last_ping_time.elapsed() > timeout_interval {
199                                    tracing::warn!("Connection timed out");
200                                    break;
201                                }
202                            }
203                        }
204                    }
205
206                    // Check if dying
207                    if *dying.read().await {
208                        break;
209                    }
210                }
211
212                // Connection lost
213                *state.write().await = ConnectionState::Disconnected;
214                *connection_id.write().await = String::new();
215
216                if let Some(ref tx) = event_tx {
217                    let _ = tx.send(ConnectionEvent::Disconnected).await;
218                }
219
220                // Should we reconnect?
221                if !auto_retry || *dying.read().await {
222                    break;
223                }
224
225                // Wait before reconnecting
226                *state.write().await = ConnectionState::Reconnecting;
227                tokio::time::sleep(reconnect_interval).await;
228            }
229        });
230
231        Ok(())
232    }
233
234    /// Disconnect from the server
235    pub async fn disconnect(&mut self) {
236        *self.dying.write().await = true;
237
238        // Close the request channel
239        *self.request_tx.lock().await = None;
240
241        *self.state.write().await = ConnectionState::Disconnected;
242        *self.connection_id.write().await = String::new();
243    }
244
245    /// Send a message on the connection
246    pub async fn send(&self, request: ConnectRequest) -> Result<()> {
247        let tx = self.request_tx.lock().await;
248
249        if let Some(ref sender) = *tx {
250            sender
251                .send(request)
252                .await
253                .map_err(|_| VoltError::NotConnected)?;
254            Ok(())
255        } else {
256            Err(VoltError::NotConnected)
257        }
258    }
259}
260
261impl Drop for VoltConnection {
262    fn drop(&mut self) {
263        // Signal shutdown
264        // Note: This is sync context, so we can't use async
265        // The actual cleanup happens in the spawned task
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::config::VoltConfig;
273
274    #[tokio::test]
275    async fn test_connection_state() {
276        let config = VoltClientConfig {
277            client_name: "test".to_string(),
278            volt: VoltConfig::default(),
279            ..Default::default()
280        };
281
282        let conn = VoltConnection::new(&config);
283        assert_eq!(conn.state().await, ConnectionState::Disconnected);
284    }
285}