runtara_protocol/
client.rs

1// Copyright (C) 2025 SyncMyOrders Sp. z o.o.
2// SPDX-License-Identifier: AGPL-3.0-or-later
3//! QUIC client helpers for connecting to runtara-core.
4
5use std::net::SocketAddr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use quinn::{ClientConfig, Connection, Endpoint, TransportConfig};
10use thiserror::Error;
11use tokio::sync::Mutex;
12use tracing::{debug, info, instrument};
13
14use crate::frame::{Frame, FrameError, FramedStream};
15
16/// Errors that can occur in the QUIC client
17#[derive(Debug, Error)]
18pub enum ClientError {
19    #[error("connection error: {0}")]
20    Connection(#[from] quinn::ConnectionError),
21
22    #[error("connect error: {0}")]
23    Connect(#[from] quinn::ConnectError),
24
25    #[error("write error: {0}")]
26    Write(#[from] quinn::WriteError),
27
28    #[error("read error: {0}")]
29    Read(#[from] quinn::ReadExactError),
30
31    #[error("frame error: {0}")]
32    Frame(#[from] FrameError),
33
34    #[error("IO error: {0}")]
35    Io(#[from] std::io::Error),
36
37    #[error("stream closed: {0}")]
38    ClosedStream(#[from] quinn::ClosedStream),
39
40    #[error("no connection established")]
41    NotConnected,
42
43    #[error("invalid server name: {0}")]
44    InvalidServerName(String),
45
46    #[error("connection timed out after {0}ms")]
47    Timeout(u64),
48}
49
50/// Configuration for the QUIC client
51#[derive(Debug, Clone)]
52pub struct RuntaraClientConfig {
53    /// Server address to connect to
54    pub server_addr: SocketAddr,
55    /// Server name for TLS verification (use "localhost" for local dev)
56    pub server_name: String,
57    /// Enable 0-RTT for lower latency (requires server support)
58    pub enable_0rtt: bool,
59    /// Skip certificate verification (for development only!)
60    pub dangerous_skip_cert_verification: bool,
61    /// Keep-alive interval in milliseconds (0 to disable)
62    pub keep_alive_interval_ms: u64,
63    /// Idle timeout in milliseconds
64    pub idle_timeout_ms: u64,
65    /// Connection timeout in milliseconds
66    pub connect_timeout_ms: u64,
67}
68
69impl Default for RuntaraClientConfig {
70    fn default() -> Self {
71        Self {
72            server_addr: "127.0.0.1:8001".parse().unwrap(),
73            server_name: "localhost".to_string(),
74            enable_0rtt: true,
75            dangerous_skip_cert_verification: false,
76            keep_alive_interval_ms: 10_000,
77            idle_timeout_ms: 30_000,
78            connect_timeout_ms: 10_000,
79        }
80    }
81}
82
83/// QUIC client for communicating with runtara-core
84pub struct RuntaraClient {
85    endpoint: Endpoint,
86    connection: Mutex<Option<Connection>>,
87    config: RuntaraClientConfig,
88}
89
90impl RuntaraClient {
91    /// Create a new client with the given configuration
92    pub fn new(config: RuntaraClientConfig) -> Result<Self, ClientError> {
93        let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
94
95        let client_config = Self::build_client_config(&config)?;
96        endpoint.set_default_client_config(client_config);
97
98        Ok(Self {
99            endpoint,
100            connection: Mutex::new(None),
101            config,
102        })
103    }
104
105    /// Create a client with default configuration for local development
106    pub fn localhost() -> Result<Self, ClientError> {
107        Self::new(RuntaraClientConfig {
108            dangerous_skip_cert_verification: true,
109            ..Default::default()
110        })
111    }
112
113    fn build_client_config(config: &RuntaraClientConfig) -> Result<ClientConfig, ClientError> {
114        let crypto = if config.dangerous_skip_cert_verification {
115            rustls::ClientConfig::builder()
116                .dangerous()
117                .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
118                .with_no_client_auth()
119        } else {
120            let mut roots = rustls::RootCertStore::empty();
121            roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
122            rustls::ClientConfig::builder()
123                .with_root_certificates(roots)
124                .with_no_client_auth()
125        };
126
127        let mut transport = TransportConfig::default();
128        if config.keep_alive_interval_ms > 0 {
129            transport.keep_alive_interval(Some(std::time::Duration::from_millis(
130                config.keep_alive_interval_ms,
131            )));
132        }
133        transport.max_idle_timeout(Some(
134            std::time::Duration::from_millis(config.idle_timeout_ms)
135                .try_into()
136                .unwrap(),
137        ));
138
139        let mut client_config = ClientConfig::new(Arc::new(
140            quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
141        ));
142        client_config.transport_config(Arc::new(transport));
143
144        Ok(client_config)
145    }
146
147    /// Connect to the server
148    #[instrument(skip(self))]
149    pub async fn connect(&self) -> Result<(), ClientError> {
150        let mut conn_guard = self.connection.lock().await;
151
152        // Check if we already have a valid connection
153        if let Some(ref conn) = *conn_guard
154            && conn.close_reason().is_none()
155        {
156            debug!("reusing existing connection");
157            return Ok(());
158        }
159
160        info!(addr = %self.config.server_addr, "connecting to runtara-core");
161
162        let timeout = Duration::from_millis(self.config.connect_timeout_ms);
163        let connecting = self
164            .endpoint
165            .connect(self.config.server_addr, &self.config.server_name)?;
166
167        let connection = tokio::time::timeout(timeout, connecting)
168            .await
169            .map_err(|_| ClientError::Timeout(self.config.connect_timeout_ms))??;
170
171        info!("connected to runtara-core");
172        *conn_guard = Some(connection);
173        Ok(())
174    }
175
176    /// Get the current connection, connecting if necessary
177    async fn get_connection(&self) -> Result<Connection, ClientError> {
178        self.connect().await?;
179        let conn_guard = self.connection.lock().await;
180        conn_guard.clone().ok_or(ClientError::NotConnected)
181    }
182
183    /// Open a new bidirectional stream for a request/response
184    pub async fn open_stream(
185        &self,
186    ) -> Result<FramedStream<(quinn::SendStream, quinn::RecvStream)>, ClientError> {
187        let conn = self.get_connection().await?;
188        let (send, recv) = conn.open_bi().await?;
189        Ok(FramedStream::new((send, recv)))
190    }
191
192    /// Open a unidirectional stream for sending (e.g., events)
193    pub async fn open_uni_send(&self) -> Result<FramedStream<quinn::SendStream>, ClientError> {
194        let conn = self.get_connection().await?;
195        let send = conn.open_uni().await?;
196        Ok(FramedStream::new(send))
197    }
198
199    /// Send a request and receive a response using a new stream
200    #[instrument(skip(self, request))]
201    pub async fn request<Req: prost::Message, Resp: prost::Message + Default>(
202        &self,
203        request: &Req,
204    ) -> Result<Resp, ClientError> {
205        let conn = self.get_connection().await?;
206        let (mut send, mut recv) = conn.open_bi().await?;
207
208        // Send request
209        let frame = Frame::request(request)?;
210        crate::frame::write_frame(&mut send, &frame).await?;
211        send.finish()?;
212
213        // Read response
214        let response_frame = crate::frame::read_frame(&mut recv).await?;
215        Ok(response_frame.decode()?)
216    }
217
218    /// Send a fire-and-forget request (no response expected).
219    ///
220    /// Use this for events that don't require acknowledgement.
221    #[instrument(skip(self, request))]
222    pub async fn send_fire_and_forget<Req: prost::Message>(
223        &self,
224        request: &Req,
225    ) -> Result<(), ClientError> {
226        let conn = self.get_connection().await?;
227        let (mut send, _recv) = conn.open_bi().await?;
228
229        // Send request
230        let frame = Frame::request(request)?;
231        crate::frame::write_frame(&mut send, &frame).await?;
232        send.finish()?;
233
234        // No response expected - just return
235        Ok(())
236    }
237
238    /// Open a raw bidirectional stream for streaming operations.
239    ///
240    /// This returns the raw QUIC streams for advanced use cases like
241    /// streaming large data that doesn't fit in a single frame.
242    pub async fn open_raw_stream(
243        &self,
244    ) -> Result<(quinn::SendStream, quinn::RecvStream), ClientError> {
245        let conn = self.get_connection().await?;
246        Ok(conn.open_bi().await?)
247    }
248
249    /// Close the connection gracefully
250    pub async fn close(&self) {
251        let mut conn_guard = self.connection.lock().await;
252        if let Some(conn) = conn_guard.take() {
253            conn.close(0u32.into(), b"client closing");
254        }
255    }
256
257    /// Check if the client is currently connected
258    pub async fn is_connected(&self) -> bool {
259        let conn_guard = self.connection.lock().await;
260        if let Some(ref conn) = *conn_guard {
261            conn.close_reason().is_none()
262        } else {
263            false
264        }
265    }
266}
267
268impl Drop for RuntaraClient {
269    fn drop(&mut self) {
270        // Close connection on drop (non-async, best effort)
271        if let Ok(mut guard) = self.connection.try_lock()
272            && let Some(conn) = guard.take()
273        {
274            conn.close(0u32.into(), b"client dropped");
275        }
276    }
277}
278
279/// Certificate verifier that skips all verification (for development only!)
280#[derive(Debug)]
281struct SkipServerVerification;
282
283impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
284    fn verify_server_cert(
285        &self,
286        _end_entity: &rustls::pki_types::CertificateDer<'_>,
287        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
288        _server_name: &rustls::pki_types::ServerName<'_>,
289        _ocsp_response: &[u8],
290        _now: rustls::pki_types::UnixTime,
291    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
292        Ok(rustls::client::danger::ServerCertVerified::assertion())
293    }
294
295    fn verify_tls12_signature(
296        &self,
297        _message: &[u8],
298        _cert: &rustls::pki_types::CertificateDer<'_>,
299        _dss: &rustls::DigitallySignedStruct,
300    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
301        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
302    }
303
304    fn verify_tls13_signature(
305        &self,
306        _message: &[u8],
307        _cert: &rustls::pki_types::CertificateDer<'_>,
308        _dss: &rustls::DigitallySignedStruct,
309    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
311    }
312
313    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
314        vec![
315            rustls::SignatureScheme::RSA_PKCS1_SHA256,
316            rustls::SignatureScheme::RSA_PKCS1_SHA384,
317            rustls::SignatureScheme::RSA_PKCS1_SHA512,
318            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
319            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
320            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
321            rustls::SignatureScheme::RSA_PSS_SHA256,
322            rustls::SignatureScheme::RSA_PSS_SHA384,
323            rustls::SignatureScheme::RSA_PSS_SHA512,
324            rustls::SignatureScheme::ED25519,
325        ]
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_default_config() {
335        let config = RuntaraClientConfig::default();
336        assert_eq!(config.server_addr, "127.0.0.1:8001".parse().unwrap());
337        assert_eq!(config.server_name, "localhost");
338    }
339
340    #[test]
341    fn test_default_config_all_fields() {
342        let config = RuntaraClientConfig::default();
343        assert_eq!(config.server_addr, "127.0.0.1:8001".parse().unwrap());
344        assert_eq!(config.server_name, "localhost");
345        assert!(config.enable_0rtt);
346        assert!(!config.dangerous_skip_cert_verification);
347        assert_eq!(config.keep_alive_interval_ms, 10_000);
348        assert_eq!(config.idle_timeout_ms, 30_000);
349        assert_eq!(config.connect_timeout_ms, 10_000);
350    }
351
352    #[test]
353    fn test_config_clone() {
354        let config = RuntaraClientConfig {
355            server_addr: "192.168.1.1:9000".parse().unwrap(),
356            server_name: "custom".to_string(),
357            enable_0rtt: false,
358            dangerous_skip_cert_verification: true,
359            keep_alive_interval_ms: 5000,
360            idle_timeout_ms: 60000,
361            connect_timeout_ms: 3000,
362        };
363        let cloned = config.clone();
364        assert_eq!(config.server_addr, cloned.server_addr);
365        assert_eq!(config.server_name, cloned.server_name);
366        assert_eq!(config.enable_0rtt, cloned.enable_0rtt);
367        assert_eq!(
368            config.dangerous_skip_cert_verification,
369            cloned.dangerous_skip_cert_verification
370        );
371        assert_eq!(config.keep_alive_interval_ms, cloned.keep_alive_interval_ms);
372        assert_eq!(config.idle_timeout_ms, cloned.idle_timeout_ms);
373        assert_eq!(config.connect_timeout_ms, cloned.connect_timeout_ms);
374    }
375
376    #[test]
377    fn test_config_debug() {
378        let config = RuntaraClientConfig::default();
379        let debug_str = format!("{:?}", config);
380        assert!(debug_str.contains("RuntaraClientConfig"));
381        assert!(debug_str.contains("server_addr"));
382        assert!(debug_str.contains("server_name"));
383    }
384
385    #[tokio::test]
386    async fn test_client_creation() {
387        let mut config = RuntaraClientConfig::default();
388        config.dangerous_skip_cert_verification = true;
389        let client = RuntaraClient::new(config);
390        assert!(
391            client.is_ok(),
392            "Failed to create client: {:?}",
393            client.err()
394        );
395    }
396
397    #[tokio::test]
398    async fn test_client_localhost() {
399        let client = RuntaraClient::localhost();
400        assert!(
401            client.is_ok(),
402            "Failed to create localhost client: {:?}",
403            client.err()
404        );
405    }
406
407    #[tokio::test]
408    async fn test_client_with_custom_config() {
409        let config = RuntaraClientConfig {
410            server_addr: "10.0.0.1:8888".parse().unwrap(),
411            server_name: "my-server".to_string(),
412            enable_0rtt: false,
413            dangerous_skip_cert_verification: true,
414            keep_alive_interval_ms: 0, // Disable keep-alive
415            idle_timeout_ms: 120000,
416            connect_timeout_ms: 5000,
417        };
418        let client = RuntaraClient::new(config);
419        assert!(client.is_ok());
420    }
421
422    #[tokio::test]
423    async fn test_client_initial_not_connected() {
424        let config = RuntaraClientConfig {
425            dangerous_skip_cert_verification: true,
426            ..Default::default()
427        };
428        let client = RuntaraClient::new(config).unwrap();
429        assert!(!client.is_connected().await);
430    }
431
432    #[tokio::test]
433    async fn test_client_connect_timeout() {
434        let config = RuntaraClientConfig {
435            server_addr: "127.0.0.1:59998".parse().unwrap(), // Unlikely to have a server
436            dangerous_skip_cert_verification: true,
437            connect_timeout_ms: 100, // Very short timeout
438            ..Default::default()
439        };
440        let client = RuntaraClient::new(config).unwrap();
441        let result = client.connect().await;
442        // Should timeout since no server is running
443        assert!(result.is_err());
444    }
445
446    #[tokio::test]
447    async fn test_client_close_without_connection() {
448        let config = RuntaraClientConfig {
449            dangerous_skip_cert_verification: true,
450            ..Default::default()
451        };
452        let client = RuntaraClient::new(config).unwrap();
453        // Closing without a connection should be safe
454        client.close().await;
455        assert!(!client.is_connected().await);
456    }
457
458    #[tokio::test]
459    async fn test_open_stream_without_connection() {
460        let config = RuntaraClientConfig {
461            server_addr: "127.0.0.1:59997".parse().unwrap(),
462            dangerous_skip_cert_verification: true,
463            connect_timeout_ms: 100,
464            ..Default::default()
465        };
466        let client = RuntaraClient::new(config).unwrap();
467        // open_stream will try to connect first, then fail
468        let result = client.open_stream().await;
469        assert!(result.is_err());
470    }
471
472    #[tokio::test]
473    async fn test_open_uni_send_without_connection() {
474        let config = RuntaraClientConfig {
475            server_addr: "127.0.0.1:59996".parse().unwrap(),
476            dangerous_skip_cert_verification: true,
477            connect_timeout_ms: 100,
478            ..Default::default()
479        };
480        let client = RuntaraClient::new(config).unwrap();
481        let result = client.open_uni_send().await;
482        assert!(result.is_err());
483    }
484
485    #[tokio::test]
486    async fn test_open_raw_stream_without_connection() {
487        let config = RuntaraClientConfig {
488            server_addr: "127.0.0.1:59995".parse().unwrap(),
489            dangerous_skip_cert_verification: true,
490            connect_timeout_ms: 100,
491            ..Default::default()
492        };
493        let client = RuntaraClient::new(config).unwrap();
494        let result = client.open_raw_stream().await;
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_client_error_display() {
500        let err = ClientError::NotConnected;
501        assert_eq!(format!("{}", err), "no connection established");
502
503        let err = ClientError::Timeout(5000);
504        assert_eq!(format!("{}", err), "connection timed out after 5000ms");
505
506        let err = ClientError::InvalidServerName("bad-name".to_string());
507        assert_eq!(format!("{}", err), "invalid server name: bad-name");
508    }
509
510    #[test]
511    fn test_skip_server_verification_schemes() {
512        use rustls::client::danger::ServerCertVerifier;
513        let verifier = SkipServerVerification;
514        let schemes = verifier.supported_verify_schemes();
515        assert!(!schemes.is_empty());
516        assert!(schemes.contains(&rustls::SignatureScheme::RSA_PKCS1_SHA256));
517        assert!(schemes.contains(&rustls::SignatureScheme::ECDSA_NISTP256_SHA256));
518        assert!(schemes.contains(&rustls::SignatureScheme::ED25519));
519    }
520
521    #[test]
522    fn test_skip_server_verification_debug() {
523        let verifier = SkipServerVerification;
524        let debug_str = format!("{:?}", verifier);
525        assert!(debug_str.contains("SkipServerVerification"));
526    }
527
528    #[test]
529    fn test_build_client_config_with_verification() {
530        let config = RuntaraClientConfig {
531            dangerous_skip_cert_verification: false,
532            ..Default::default()
533        };
534        // This should work (uses webpki_roots)
535        let result = RuntaraClient::build_client_config(&config);
536        assert!(result.is_ok());
537    }
538
539    #[test]
540    fn test_build_client_config_skip_verification() {
541        let config = RuntaraClientConfig {
542            dangerous_skip_cert_verification: true,
543            ..Default::default()
544        };
545        let result = RuntaraClient::build_client_config(&config);
546        assert!(result.is_ok());
547    }
548
549    #[test]
550    fn test_build_client_config_no_keepalive() {
551        let config = RuntaraClientConfig {
552            keep_alive_interval_ms: 0,
553            dangerous_skip_cert_verification: true,
554            ..Default::default()
555        };
556        let result = RuntaraClient::build_client_config(&config);
557        assert!(result.is_ok());
558    }
559}