1use std::net::SocketAddr;
6use std::sync::Arc;
7use std::time::Duration;
8
9use quinn::{ClientConfig, Connection, Endpoint, TransportConfig};
10use thiserror::Error;
11use tokio::sync::Mutex;
12use tracing::{debug, info, instrument};
13
14use crate::frame::{Frame, FrameError, FramedStream};
15
16#[derive(Debug, Error)]
18pub enum ClientError {
19 #[error("connection error: {0}")]
20 Connection(#[from] quinn::ConnectionError),
21
22 #[error("connect error: {0}")]
23 Connect(#[from] quinn::ConnectError),
24
25 #[error("write error: {0}")]
26 Write(#[from] quinn::WriteError),
27
28 #[error("read error: {0}")]
29 Read(#[from] quinn::ReadExactError),
30
31 #[error("frame error: {0}")]
32 Frame(#[from] FrameError),
33
34 #[error("IO error: {0}")]
35 Io(#[from] std::io::Error),
36
37 #[error("stream closed: {0}")]
38 ClosedStream(#[from] quinn::ClosedStream),
39
40 #[error("no connection established")]
41 NotConnected,
42
43 #[error("invalid server name: {0}")]
44 InvalidServerName(String),
45
46 #[error("connection timed out after {0}ms")]
47 Timeout(u64),
48}
49
50#[derive(Debug, Clone)]
52pub struct RuntaraClientConfig {
53 pub server_addr: SocketAddr,
55 pub server_name: String,
57 pub enable_0rtt: bool,
59 pub dangerous_skip_cert_verification: bool,
61 pub keep_alive_interval_ms: u64,
63 pub idle_timeout_ms: u64,
65 pub connect_timeout_ms: u64,
67}
68
69impl Default for RuntaraClientConfig {
70 fn default() -> Self {
71 Self {
72 server_addr: "127.0.0.1:8001".parse().unwrap(),
73 server_name: "localhost".to_string(),
74 enable_0rtt: true,
75 dangerous_skip_cert_verification: false,
76 keep_alive_interval_ms: 10_000,
77 idle_timeout_ms: 30_000,
78 connect_timeout_ms: 10_000,
79 }
80 }
81}
82
83pub struct RuntaraClient {
85 endpoint: Endpoint,
86 connection: Mutex<Option<Connection>>,
87 config: RuntaraClientConfig,
88}
89
90impl RuntaraClient {
91 pub fn new(config: RuntaraClientConfig) -> Result<Self, ClientError> {
93 let mut endpoint = Endpoint::client("0.0.0.0:0".parse().unwrap())?;
94
95 let client_config = Self::build_client_config(&config)?;
96 endpoint.set_default_client_config(client_config);
97
98 Ok(Self {
99 endpoint,
100 connection: Mutex::new(None),
101 config,
102 })
103 }
104
105 pub fn localhost() -> Result<Self, ClientError> {
107 Self::new(RuntaraClientConfig {
108 dangerous_skip_cert_verification: true,
109 ..Default::default()
110 })
111 }
112
113 fn build_client_config(config: &RuntaraClientConfig) -> Result<ClientConfig, ClientError> {
114 let crypto = if config.dangerous_skip_cert_verification {
115 rustls::ClientConfig::builder()
116 .dangerous()
117 .with_custom_certificate_verifier(Arc::new(SkipServerVerification))
118 .with_no_client_auth()
119 } else {
120 let mut roots = rustls::RootCertStore::empty();
121 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
122 rustls::ClientConfig::builder()
123 .with_root_certificates(roots)
124 .with_no_client_auth()
125 };
126
127 let mut transport = TransportConfig::default();
128 if config.keep_alive_interval_ms > 0 {
129 transport.keep_alive_interval(Some(std::time::Duration::from_millis(
130 config.keep_alive_interval_ms,
131 )));
132 }
133 transport.max_idle_timeout(Some(
134 std::time::Duration::from_millis(config.idle_timeout_ms)
135 .try_into()
136 .unwrap(),
137 ));
138
139 let mut client_config = ClientConfig::new(Arc::new(
140 quinn::crypto::rustls::QuicClientConfig::try_from(crypto).unwrap(),
141 ));
142 client_config.transport_config(Arc::new(transport));
143
144 Ok(client_config)
145 }
146
147 #[instrument(skip(self))]
149 pub async fn connect(&self) -> Result<(), ClientError> {
150 let mut conn_guard = self.connection.lock().await;
151
152 if let Some(ref conn) = *conn_guard
154 && conn.close_reason().is_none()
155 {
156 debug!("reusing existing connection");
157 return Ok(());
158 }
159
160 info!(addr = %self.config.server_addr, "connecting to runtara-core");
161
162 let timeout = Duration::from_millis(self.config.connect_timeout_ms);
163 let connecting = self
164 .endpoint
165 .connect(self.config.server_addr, &self.config.server_name)?;
166
167 let connection = tokio::time::timeout(timeout, connecting)
168 .await
169 .map_err(|_| ClientError::Timeout(self.config.connect_timeout_ms))??;
170
171 info!("connected to runtara-core");
172 *conn_guard = Some(connection);
173 Ok(())
174 }
175
176 async fn get_connection(&self) -> Result<Connection, ClientError> {
178 self.connect().await?;
179 let conn_guard = self.connection.lock().await;
180 conn_guard.clone().ok_or(ClientError::NotConnected)
181 }
182
183 pub async fn open_stream(
185 &self,
186 ) -> Result<FramedStream<(quinn::SendStream, quinn::RecvStream)>, ClientError> {
187 let conn = self.get_connection().await?;
188 let (send, recv) = conn.open_bi().await?;
189 Ok(FramedStream::new((send, recv)))
190 }
191
192 pub async fn open_uni_send(&self) -> Result<FramedStream<quinn::SendStream>, ClientError> {
194 let conn = self.get_connection().await?;
195 let send = conn.open_uni().await?;
196 Ok(FramedStream::new(send))
197 }
198
199 #[instrument(skip(self, request))]
201 pub async fn request<Req: prost::Message, Resp: prost::Message + Default>(
202 &self,
203 request: &Req,
204 ) -> Result<Resp, ClientError> {
205 let conn = self.get_connection().await?;
206 let (mut send, mut recv) = conn.open_bi().await?;
207
208 let frame = Frame::request(request)?;
210 crate::frame::write_frame(&mut send, &frame).await?;
211 send.finish()?;
212
213 let response_frame = crate::frame::read_frame(&mut recv).await?;
215 Ok(response_frame.decode()?)
216 }
217
218 #[instrument(skip(self, request))]
222 pub async fn send_fire_and_forget<Req: prost::Message>(
223 &self,
224 request: &Req,
225 ) -> Result<(), ClientError> {
226 let conn = self.get_connection().await?;
227 let (mut send, _recv) = conn.open_bi().await?;
228
229 let frame = Frame::request(request)?;
231 crate::frame::write_frame(&mut send, &frame).await?;
232 send.finish()?;
233
234 Ok(())
236 }
237
238 pub async fn open_raw_stream(
243 &self,
244 ) -> Result<(quinn::SendStream, quinn::RecvStream), ClientError> {
245 let conn = self.get_connection().await?;
246 Ok(conn.open_bi().await?)
247 }
248
249 pub async fn close(&self) {
251 let mut conn_guard = self.connection.lock().await;
252 if let Some(conn) = conn_guard.take() {
253 conn.close(0u32.into(), b"client closing");
254 }
255 }
256
257 pub async fn is_connected(&self) -> bool {
259 let conn_guard = self.connection.lock().await;
260 if let Some(ref conn) = *conn_guard {
261 conn.close_reason().is_none()
262 } else {
263 false
264 }
265 }
266}
267
268impl Drop for RuntaraClient {
269 fn drop(&mut self) {
270 if let Ok(mut guard) = self.connection.try_lock()
272 && let Some(conn) = guard.take()
273 {
274 conn.close(0u32.into(), b"client dropped");
275 }
276 }
277}
278
279#[derive(Debug)]
281struct SkipServerVerification;
282
283impl rustls::client::danger::ServerCertVerifier for SkipServerVerification {
284 fn verify_server_cert(
285 &self,
286 _end_entity: &rustls::pki_types::CertificateDer<'_>,
287 _intermediates: &[rustls::pki_types::CertificateDer<'_>],
288 _server_name: &rustls::pki_types::ServerName<'_>,
289 _ocsp_response: &[u8],
290 _now: rustls::pki_types::UnixTime,
291 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
292 Ok(rustls::client::danger::ServerCertVerified::assertion())
293 }
294
295 fn verify_tls12_signature(
296 &self,
297 _message: &[u8],
298 _cert: &rustls::pki_types::CertificateDer<'_>,
299 _dss: &rustls::DigitallySignedStruct,
300 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
301 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
302 }
303
304 fn verify_tls13_signature(
305 &self,
306 _message: &[u8],
307 _cert: &rustls::pki_types::CertificateDer<'_>,
308 _dss: &rustls::DigitallySignedStruct,
309 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
310 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
311 }
312
313 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
314 vec![
315 rustls::SignatureScheme::RSA_PKCS1_SHA256,
316 rustls::SignatureScheme::RSA_PKCS1_SHA384,
317 rustls::SignatureScheme::RSA_PKCS1_SHA512,
318 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
319 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
320 rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
321 rustls::SignatureScheme::RSA_PSS_SHA256,
322 rustls::SignatureScheme::RSA_PSS_SHA384,
323 rustls::SignatureScheme::RSA_PSS_SHA512,
324 rustls::SignatureScheme::ED25519,
325 ]
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_default_config() {
335 let config = RuntaraClientConfig::default();
336 assert_eq!(config.server_addr, "127.0.0.1:8001".parse().unwrap());
337 assert_eq!(config.server_name, "localhost");
338 }
339
340 #[test]
341 fn test_default_config_all_fields() {
342 let config = RuntaraClientConfig::default();
343 assert_eq!(config.server_addr, "127.0.0.1:8001".parse().unwrap());
344 assert_eq!(config.server_name, "localhost");
345 assert!(config.enable_0rtt);
346 assert!(!config.dangerous_skip_cert_verification);
347 assert_eq!(config.keep_alive_interval_ms, 10_000);
348 assert_eq!(config.idle_timeout_ms, 30_000);
349 assert_eq!(config.connect_timeout_ms, 10_000);
350 }
351
352 #[test]
353 fn test_config_clone() {
354 let config = RuntaraClientConfig {
355 server_addr: "192.168.1.1:9000".parse().unwrap(),
356 server_name: "custom".to_string(),
357 enable_0rtt: false,
358 dangerous_skip_cert_verification: true,
359 keep_alive_interval_ms: 5000,
360 idle_timeout_ms: 60000,
361 connect_timeout_ms: 3000,
362 };
363 let cloned = config.clone();
364 assert_eq!(config.server_addr, cloned.server_addr);
365 assert_eq!(config.server_name, cloned.server_name);
366 assert_eq!(config.enable_0rtt, cloned.enable_0rtt);
367 assert_eq!(
368 config.dangerous_skip_cert_verification,
369 cloned.dangerous_skip_cert_verification
370 );
371 assert_eq!(config.keep_alive_interval_ms, cloned.keep_alive_interval_ms);
372 assert_eq!(config.idle_timeout_ms, cloned.idle_timeout_ms);
373 assert_eq!(config.connect_timeout_ms, cloned.connect_timeout_ms);
374 }
375
376 #[test]
377 fn test_config_debug() {
378 let config = RuntaraClientConfig::default();
379 let debug_str = format!("{:?}", config);
380 assert!(debug_str.contains("RuntaraClientConfig"));
381 assert!(debug_str.contains("server_addr"));
382 assert!(debug_str.contains("server_name"));
383 }
384
385 #[tokio::test]
386 async fn test_client_creation() {
387 let mut config = RuntaraClientConfig::default();
388 config.dangerous_skip_cert_verification = true;
389 let client = RuntaraClient::new(config);
390 assert!(
391 client.is_ok(),
392 "Failed to create client: {:?}",
393 client.err()
394 );
395 }
396
397 #[tokio::test]
398 async fn test_client_localhost() {
399 let client = RuntaraClient::localhost();
400 assert!(
401 client.is_ok(),
402 "Failed to create localhost client: {:?}",
403 client.err()
404 );
405 }
406
407 #[tokio::test]
408 async fn test_client_with_custom_config() {
409 let config = RuntaraClientConfig {
410 server_addr: "10.0.0.1:8888".parse().unwrap(),
411 server_name: "my-server".to_string(),
412 enable_0rtt: false,
413 dangerous_skip_cert_verification: true,
414 keep_alive_interval_ms: 0, idle_timeout_ms: 120000,
416 connect_timeout_ms: 5000,
417 };
418 let client = RuntaraClient::new(config);
419 assert!(client.is_ok());
420 }
421
422 #[tokio::test]
423 async fn test_client_initial_not_connected() {
424 let config = RuntaraClientConfig {
425 dangerous_skip_cert_verification: true,
426 ..Default::default()
427 };
428 let client = RuntaraClient::new(config).unwrap();
429 assert!(!client.is_connected().await);
430 }
431
432 #[tokio::test]
433 async fn test_client_connect_timeout() {
434 let config = RuntaraClientConfig {
435 server_addr: "127.0.0.1:59998".parse().unwrap(), dangerous_skip_cert_verification: true,
437 connect_timeout_ms: 100, ..Default::default()
439 };
440 let client = RuntaraClient::new(config).unwrap();
441 let result = client.connect().await;
442 assert!(result.is_err());
444 }
445
446 #[tokio::test]
447 async fn test_client_close_without_connection() {
448 let config = RuntaraClientConfig {
449 dangerous_skip_cert_verification: true,
450 ..Default::default()
451 };
452 let client = RuntaraClient::new(config).unwrap();
453 client.close().await;
455 assert!(!client.is_connected().await);
456 }
457
458 #[tokio::test]
459 async fn test_open_stream_without_connection() {
460 let config = RuntaraClientConfig {
461 server_addr: "127.0.0.1:59997".parse().unwrap(),
462 dangerous_skip_cert_verification: true,
463 connect_timeout_ms: 100,
464 ..Default::default()
465 };
466 let client = RuntaraClient::new(config).unwrap();
467 let result = client.open_stream().await;
469 assert!(result.is_err());
470 }
471
472 #[tokio::test]
473 async fn test_open_uni_send_without_connection() {
474 let config = RuntaraClientConfig {
475 server_addr: "127.0.0.1:59996".parse().unwrap(),
476 dangerous_skip_cert_verification: true,
477 connect_timeout_ms: 100,
478 ..Default::default()
479 };
480 let client = RuntaraClient::new(config).unwrap();
481 let result = client.open_uni_send().await;
482 assert!(result.is_err());
483 }
484
485 #[tokio::test]
486 async fn test_open_raw_stream_without_connection() {
487 let config = RuntaraClientConfig {
488 server_addr: "127.0.0.1:59995".parse().unwrap(),
489 dangerous_skip_cert_verification: true,
490 connect_timeout_ms: 100,
491 ..Default::default()
492 };
493 let client = RuntaraClient::new(config).unwrap();
494 let result = client.open_raw_stream().await;
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_client_error_display() {
500 let err = ClientError::NotConnected;
501 assert_eq!(format!("{}", err), "no connection established");
502
503 let err = ClientError::Timeout(5000);
504 assert_eq!(format!("{}", err), "connection timed out after 5000ms");
505
506 let err = ClientError::InvalidServerName("bad-name".to_string());
507 assert_eq!(format!("{}", err), "invalid server name: bad-name");
508 }
509
510 #[test]
511 fn test_skip_server_verification_schemes() {
512 use rustls::client::danger::ServerCertVerifier;
513 let verifier = SkipServerVerification;
514 let schemes = verifier.supported_verify_schemes();
515 assert!(!schemes.is_empty());
516 assert!(schemes.contains(&rustls::SignatureScheme::RSA_PKCS1_SHA256));
517 assert!(schemes.contains(&rustls::SignatureScheme::ECDSA_NISTP256_SHA256));
518 assert!(schemes.contains(&rustls::SignatureScheme::ED25519));
519 }
520
521 #[test]
522 fn test_skip_server_verification_debug() {
523 let verifier = SkipServerVerification;
524 let debug_str = format!("{:?}", verifier);
525 assert!(debug_str.contains("SkipServerVerification"));
526 }
527
528 #[test]
529 fn test_build_client_config_with_verification() {
530 let config = RuntaraClientConfig {
531 dangerous_skip_cert_verification: false,
532 ..Default::default()
533 };
534 let result = RuntaraClient::build_client_config(&config);
536 assert!(result.is_ok());
537 }
538
539 #[test]
540 fn test_build_client_config_skip_verification() {
541 let config = RuntaraClientConfig {
542 dangerous_skip_cert_verification: true,
543 ..Default::default()
544 };
545 let result = RuntaraClient::build_client_config(&config);
546 assert!(result.is_ok());
547 }
548
549 #[test]
550 fn test_build_client_config_no_keepalive() {
551 let config = RuntaraClientConfig {
552 keep_alive_interval_ms: 0,
553 dangerous_skip_cert_verification: true,
554 ..Default::default()
555 };
556 let result = RuntaraClient::build_client_config(&config);
557 assert!(result.is_ok());
558 }
559}