1use std::collections::HashSet;
31use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
36use tokio::net::{UnixListener, UnixStream};
37use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
38use tracing::{debug, error, info, warn};
39
40use crate::v2::client::FlowState;
41use crate::v2::pool::CHANNEL_BUFFER_SIZE;
42use crate::v2::uds::{read_message, write_message, MessageType, UdsCapabilities};
43use crate::v2::{AgentCapabilities, AgentPool, PROTOCOL_VERSION_2};
44use crate::{AgentProtocolError, AgentResponse};
45
46#[derive(Debug, Clone)]
48pub struct ReverseConnectionConfig {
49 pub backlog: u32,
51 pub handshake_timeout: Duration,
53 pub max_connections_per_agent: usize,
55 pub allowed_agents: HashSet<String>,
57 pub require_auth: bool,
59 pub request_timeout: Duration,
61}
62
63impl Default for ReverseConnectionConfig {
64 fn default() -> Self {
65 Self {
66 backlog: 128,
67 handshake_timeout: Duration::from_secs(10),
68 max_connections_per_agent: 4,
69 allowed_agents: HashSet::new(),
70 require_auth: false,
71 request_timeout: Duration::from_secs(30),
72 }
73 }
74}
75
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct RegistrationRequest {
79 pub protocol_version: u32,
81 pub agent_id: String,
83 pub capabilities: UdsCapabilities,
85 pub auth_token: Option<String>,
87 pub metadata: Option<serde_json::Value>,
89}
90
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct RegistrationResponse {
94 pub success: bool,
96 pub error: Option<String>,
98 pub proxy_id: String,
100 pub proxy_version: String,
102 pub connection_id: String,
104}
105
106pub struct ReverseConnectionListener {
108 listener: UnixListener,
109 config: ReverseConnectionConfig,
110 socket_path: String,
111}
112
113impl ReverseConnectionListener {
114 pub async fn bind_uds(
116 path: impl AsRef<Path>,
117 config: ReverseConnectionConfig,
118 ) -> Result<Self, AgentProtocolError> {
119 let path = path.as_ref();
120 let socket_path = path.to_string_lossy().to_string();
121
122 if path.exists() {
124 std::fs::remove_file(path).map_err(|e| {
125 AgentProtocolError::ConnectionFailed(format!(
126 "Failed to remove existing socket {}: {}",
127 socket_path, e
128 ))
129 })?;
130 }
131
132 let listener = UnixListener::bind(path).map_err(|e| {
133 AgentProtocolError::ConnectionFailed(format!(
134 "Failed to bind to {}: {}",
135 socket_path, e
136 ))
137 })?;
138
139 info!(path = %socket_path, "Reverse connection listener bound");
140
141 Ok(Self {
142 listener,
143 config,
144 socket_path,
145 })
146 }
147
148 pub fn socket_path(&self) -> &str {
150 &self.socket_path
151 }
152
153 pub async fn accept_one(&self, pool: &AgentPool) -> Result<String, AgentProtocolError> {
157 let (stream, _addr) =
158 self.listener.accept().await.map_err(|e| {
159 AgentProtocolError::ConnectionFailed(format!("Accept failed: {}", e))
160 })?;
161
162 debug!("Accepted reverse connection");
163
164 self.handle_connection(stream, pool).await
165 }
166
167 pub async fn accept_loop(self: Arc<Self>, pool: Arc<AgentPool>) {
172 info!(path = %self.socket_path, "Starting reverse connection accept loop");
173
174 loop {
175 match self.listener.accept().await {
176 Ok((stream, _addr)) => {
177 let listener = Arc::clone(&self);
178 let pool = Arc::clone(&pool);
179
180 tokio::spawn(async move {
181 match listener.handle_connection(stream, &pool).await {
182 Ok(agent_id) => {
183 info!(agent_id = %agent_id, "Reverse connection registered");
184 }
185 Err(e) => {
186 warn!(error = %e, "Failed to handle reverse connection");
187 }
188 }
189 });
190 }
191 Err(e) => {
192 error!(error = %e, "Accept failed");
193 tokio::time::sleep(Duration::from_millis(100)).await;
195 }
196 }
197 }
198 }
199
200 async fn handle_connection(
202 &self,
203 stream: UnixStream,
204 pool: &AgentPool,
205 ) -> Result<String, AgentProtocolError> {
206 let (read_half, write_half) = stream.into_split();
207 let mut reader = BufReader::new(read_half);
208 let mut writer = BufWriter::new(write_half);
209
210 let registration = tokio::time::timeout(
212 self.config.handshake_timeout,
213 self.read_registration(&mut reader),
214 )
215 .await
216 .map_err(|_| AgentProtocolError::Timeout(self.config.handshake_timeout))??;
217
218 let agent_id = registration.agent_id.clone();
219
220 if let Err(e) = self.validate_registration(®istration) {
222 let response = RegistrationResponse {
223 success: false,
224 error: Some(e.to_string()),
225 proxy_id: "sentinel-proxy".to_string(),
226 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
227 connection_id: String::new(),
228 };
229 self.send_registration_response(&mut writer, &response)
230 .await?;
231 return Err(e);
232 }
233
234 let connection_id = format!(
236 "{}-{:x}",
237 agent_id,
238 std::time::SystemTime::now()
239 .duration_since(std::time::UNIX_EPOCH)
240 .map(|d| d.as_millis())
241 .unwrap_or(0)
242 );
243
244 let response = RegistrationResponse {
246 success: true,
247 error: None,
248 proxy_id: "sentinel-proxy".to_string(),
249 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
250 connection_id: connection_id.clone(),
251 };
252 self.send_registration_response(&mut writer, &response)
253 .await?;
254
255 info!(
256 agent_id = %agent_id,
257 connection_id = %connection_id,
258 "Agent registration successful"
259 );
260
261 let capabilities: AgentCapabilities = registration.capabilities.into();
263
264 let client = ReverseConnectionClient::new(
266 agent_id.clone(),
267 connection_id,
268 capabilities.clone(),
269 reader,
270 writer,
271 self.config.request_timeout,
272 )
273 .await;
274
275 pool.add_reverse_connection(&agent_id, client, capabilities)
277 .await?;
278
279 Ok(agent_id)
280 }
281
282 async fn read_registration<R: AsyncReadExt + Unpin>(
284 &self,
285 reader: &mut R,
286 ) -> Result<RegistrationRequest, AgentProtocolError> {
287 let (msg_type, payload) = read_message(reader).await?;
288
289 if msg_type != MessageType::HandshakeRequest {
290 return Err(AgentProtocolError::InvalidMessage(format!(
291 "Expected registration request (HandshakeRequest), got {:?}",
292 msg_type
293 )));
294 }
295
296 serde_json::from_slice(&payload)
297 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
298 }
299
300 async fn send_registration_response<W: AsyncWriteExt + Unpin>(
302 &self,
303 writer: &mut W,
304 response: &RegistrationResponse,
305 ) -> Result<(), AgentProtocolError> {
306 let payload = serde_json::to_vec(response)
307 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
308
309 write_message(writer, MessageType::HandshakeResponse, &payload).await
310 }
311
312 fn validate_registration(
314 &self,
315 registration: &RegistrationRequest,
316 ) -> Result<(), AgentProtocolError> {
317 if registration.protocol_version != PROTOCOL_VERSION_2 {
319 return Err(AgentProtocolError::VersionMismatch {
320 expected: PROTOCOL_VERSION_2,
321 actual: registration.protocol_version,
322 });
323 }
324
325 if registration.agent_id.is_empty() {
327 return Err(AgentProtocolError::InvalidMessage(
328 "Agent ID cannot be empty".to_string(),
329 ));
330 }
331
332 if !self.config.allowed_agents.is_empty()
334 && !self.config.allowed_agents.contains(®istration.agent_id)
335 {
336 return Err(AgentProtocolError::InvalidMessage(format!(
337 "Agent '{}' is not in the allowed list",
338 registration.agent_id
339 )));
340 }
341
342 if self.config.require_auth && registration.auth_token.is_none() {
344 return Err(AgentProtocolError::InvalidMessage(
345 "Authentication required but no token provided".to_string(),
346 ));
347 }
348
349 Ok(())
350 }
351}
352
353impl Drop for ReverseConnectionListener {
354 fn drop(&mut self) {
355 if let Err(e) = std::fs::remove_file(&self.socket_path) {
357 debug!(path = %self.socket_path, error = %e, "Failed to remove socket file on drop");
358 }
359 }
360}
361
362pub struct ReverseConnectionClient {
367 agent_id: String,
368 connection_id: String,
369 capabilities: RwLock<Option<AgentCapabilities>>,
370 pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>>,
371 #[allow(clippy::type_complexity)]
372 outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
373 connected: RwLock<bool>,
374 timeout: Duration,
375 in_flight: std::sync::atomic::AtomicU64,
376 flow_state: Arc<RwLock<FlowState>>,
378}
379
380impl ReverseConnectionClient {
381 async fn new<R, W>(
383 agent_id: String,
384 connection_id: String,
385 capabilities: AgentCapabilities,
386 mut reader: BufReader<R>,
387 mut writer: BufWriter<W>,
388 timeout: Duration,
389 ) -> Self
390 where
391 R: AsyncReadExt + Unpin + Send + 'static,
392 W: AsyncWriteExt + Unpin + Send + 'static,
393 {
394 let pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>> =
395 Arc::new(Mutex::new(std::collections::HashMap::new()));
396
397 let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
399
400 let agent_id_clone = agent_id.clone();
402 tokio::spawn(async move {
403 while let Some((msg_type, payload)) = rx.recv().await {
404 if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
405 error!(
406 agent_id = %agent_id_clone,
407 error = %e,
408 "Failed to write to reverse connection"
409 );
410 break;
411 }
412 }
413 debug!(agent_id = %agent_id_clone, "Reverse connection writer ended");
414 });
415
416 let pending_clone = Arc::clone(&pending);
418 let agent_id_clone = agent_id.clone();
419 tokio::spawn(async move {
420 loop {
421 match read_message(&mut reader).await {
422 Ok((msg_type, payload)) => {
423 if msg_type == MessageType::AgentResponse {
424 if let Ok(response) = serde_json::from_slice::<AgentResponse>(&payload)
425 {
426 let correlation_id = response
427 .audit
428 .custom
429 .get("correlation_id")
430 .and_then(|v| v.as_str())
431 .unwrap_or("")
432 .to_string();
433
434 if let Some(sender) =
435 pending_clone.lock().await.remove(&correlation_id)
436 {
437 let _ = sender.send(response);
438 }
439 }
440 }
441 }
442 Err(e) => {
443 if !matches!(e, AgentProtocolError::ConnectionClosed) {
444 error!(
445 agent_id = %agent_id_clone,
446 error = %e,
447 "Error reading from reverse connection"
448 );
449 }
450 break;
451 }
452 }
453 }
454 debug!(agent_id = %agent_id_clone, "Reverse connection reader ended");
455 });
456
457 Self {
458 agent_id,
459 connection_id,
460 capabilities: RwLock::new(Some(capabilities)),
461 pending,
462 outbound_tx: Mutex::new(Some(tx)),
463 connected: RwLock::new(true),
464 timeout,
465 in_flight: std::sync::atomic::AtomicU64::new(0),
466 flow_state: Arc::new(RwLock::new(FlowState::Normal)),
467 }
468 }
469
470 pub fn agent_id(&self) -> &str {
472 &self.agent_id
473 }
474
475 pub fn connection_id(&self) -> &str {
477 &self.connection_id
478 }
479
480 pub async fn is_connected(&self) -> bool {
482 *self.connected.read().await
483 }
484
485 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
487 self.capabilities.read().await.clone()
488 }
489
490 pub async fn is_paused(&self) -> bool {
495 matches!(*self.flow_state.read().await, FlowState::Paused)
496 }
497
498 pub async fn can_accept_requests(&self) -> bool {
502 !self.is_paused().await
503 }
504
505 pub async fn send_request_headers(
507 &self,
508 correlation_id: &str,
509 event: &crate::RequestHeadersEvent,
510 ) -> Result<AgentResponse, AgentProtocolError> {
511 self.send_event(MessageType::RequestHeaders, correlation_id, event)
512 .await
513 }
514
515 pub async fn send_request_body_chunk(
517 &self,
518 correlation_id: &str,
519 event: &crate::RequestBodyChunkEvent,
520 ) -> Result<AgentResponse, AgentProtocolError> {
521 self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
522 .await
523 }
524
525 pub async fn send_response_headers(
527 &self,
528 correlation_id: &str,
529 event: &crate::ResponseHeadersEvent,
530 ) -> Result<AgentResponse, AgentProtocolError> {
531 self.send_event(MessageType::ResponseHeaders, correlation_id, event)
532 .await
533 }
534
535 pub async fn send_response_body_chunk(
537 &self,
538 correlation_id: &str,
539 event: &crate::ResponseBodyChunkEvent,
540 ) -> Result<AgentResponse, AgentProtocolError> {
541 self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
542 .await
543 }
544
545 async fn send_event<T: serde::Serialize>(
547 &self,
548 msg_type: MessageType,
549 correlation_id: &str,
550 event: &T,
551 ) -> Result<AgentResponse, AgentProtocolError> {
552 let (tx, rx) = oneshot::channel();
553 self.pending
554 .lock()
555 .await
556 .insert(correlation_id.to_string(), tx);
557
558 let mut payload = serde_json::to_value(event)
560 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
561
562 if let Some(obj) = payload.as_object_mut() {
563 obj.insert(
564 "correlation_id".to_string(),
565 serde_json::Value::String(correlation_id.to_string()),
566 );
567 }
568
569 let payload_bytes = serde_json::to_vec(&payload)
570 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
571
572 {
574 let outbound = self.outbound_tx.lock().await;
575 if let Some(tx) = outbound.as_ref() {
576 tx.send((msg_type, payload_bytes))
577 .await
578 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
579 } else {
580 return Err(AgentProtocolError::ConnectionClosed);
581 }
582 }
583
584 self.in_flight
585 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
586
587 let response = tokio::time::timeout(self.timeout, rx)
589 .await
590 .map_err(|_| {
591 self.pending
592 .try_lock()
593 .ok()
594 .map(|mut p| p.remove(correlation_id));
595 AgentProtocolError::Timeout(self.timeout)
596 })?
597 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
598
599 self.in_flight
600 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
601
602 Ok(response)
603 }
604
605 pub async fn cancel_request(
607 &self,
608 correlation_id: &str,
609 reason: super::client::CancelReason,
610 ) -> Result<(), AgentProtocolError> {
611 let cancel = serde_json::json!({
612 "correlation_id": correlation_id,
613 "reason": reason as i32,
614 "timestamp_ms": now_ms(),
615 });
616
617 let payload = serde_json::to_vec(&cancel)
618 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
619
620 let outbound = self.outbound_tx.lock().await;
621 if let Some(tx) = outbound.as_ref() {
622 tx.send((MessageType::Cancel, payload))
623 .await
624 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
625 }
626
627 self.pending.lock().await.remove(correlation_id);
628 Ok(())
629 }
630
631 pub async fn cancel_all(
633 &self,
634 reason: super::client::CancelReason,
635 ) -> Result<usize, AgentProtocolError> {
636 let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
637 let count = pending_ids.len();
638
639 for correlation_id in pending_ids {
640 let _ = self.cancel_request(&correlation_id, reason).await;
641 }
642
643 Ok(count)
644 }
645
646 pub async fn close(&self) -> Result<(), AgentProtocolError> {
648 *self.connected.write().await = false;
649 *self.outbound_tx.lock().await = None;
650 Ok(())
651 }
652
653 pub fn in_flight(&self) -> u64 {
655 self.in_flight.load(std::sync::atomic::Ordering::Relaxed)
656 }
657}
658
659fn now_ms() -> u64 {
660 std::time::SystemTime::now()
661 .duration_since(std::time::UNIX_EPOCH)
662 .map(|d| d.as_millis() as u64)
663 .unwrap_or(0)
664}
665
666#[cfg(test)]
667mod tests {
668 use super::*;
669
670 #[test]
671 fn test_config_default() {
672 let config = ReverseConnectionConfig::default();
673 assert_eq!(config.backlog, 128);
674 assert_eq!(config.max_connections_per_agent, 4);
675 assert!(!config.require_auth);
676 }
677
678 #[test]
679 fn test_registration_request_serialization() {
680 let request = RegistrationRequest {
681 protocol_version: 2,
682 agent_id: "test-agent".to_string(),
683 capabilities: UdsCapabilities {
684 agent_id: "test-agent".to_string(),
685 name: "Test Agent".to_string(),
686 version: "1.0.0".to_string(),
687 supported_events: vec![1, 2],
688 features: Default::default(),
689 limits: Default::default(),
690 },
691 auth_token: None,
692 metadata: None,
693 };
694
695 let json = serde_json::to_string(&request).unwrap();
696 let parsed: RegistrationRequest = serde_json::from_str(&json).unwrap();
697
698 assert_eq!(parsed.agent_id, "test-agent");
699 assert_eq!(parsed.protocol_version, 2);
700 }
701
702 #[test]
703 fn test_registration_response_serialization() {
704 let response = RegistrationResponse {
705 success: true,
706 error: None,
707 proxy_id: "sentinel".to_string(),
708 proxy_version: "1.0.0".to_string(),
709 connection_id: "conn-123".to_string(),
710 };
711
712 let json = serde_json::to_string(&response).unwrap();
713 let parsed: RegistrationResponse = serde_json::from_str(&json).unwrap();
714
715 assert!(parsed.success);
716 assert_eq!(parsed.connection_id, "conn-123");
717 }
718}