1use std::collections::HashSet;
31use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
36use tokio::net::{UnixListener, UnixStream};
37use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
38use tracing::{debug, error, info, warn};
39
40use crate::v2::client::FlowState;
41use crate::v2::uds::{read_message, write_message, MessageType, UdsCapabilities};
42use crate::v2::pool::CHANNEL_BUFFER_SIZE;
43use crate::v2::{AgentCapabilities, AgentPool, PROTOCOL_VERSION_2};
44use crate::{AgentProtocolError, AgentResponse};
45
46#[derive(Debug, Clone)]
48pub struct ReverseConnectionConfig {
49 pub backlog: u32,
51 pub handshake_timeout: Duration,
53 pub max_connections_per_agent: usize,
55 pub allowed_agents: HashSet<String>,
57 pub require_auth: bool,
59 pub request_timeout: Duration,
61}
62
63impl Default for ReverseConnectionConfig {
64 fn default() -> Self {
65 Self {
66 backlog: 128,
67 handshake_timeout: Duration::from_secs(10),
68 max_connections_per_agent: 4,
69 allowed_agents: HashSet::new(),
70 require_auth: false,
71 request_timeout: Duration::from_secs(30),
72 }
73 }
74}
75
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct RegistrationRequest {
79 pub protocol_version: u32,
81 pub agent_id: String,
83 pub capabilities: UdsCapabilities,
85 pub auth_token: Option<String>,
87 pub metadata: Option<serde_json::Value>,
89}
90
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct RegistrationResponse {
94 pub success: bool,
96 pub error: Option<String>,
98 pub proxy_id: String,
100 pub proxy_version: String,
102 pub connection_id: String,
104}
105
106pub struct ReverseConnectionListener {
108 listener: UnixListener,
109 config: ReverseConnectionConfig,
110 socket_path: String,
111}
112
113impl ReverseConnectionListener {
114 pub async fn bind_uds(
116 path: impl AsRef<Path>,
117 config: ReverseConnectionConfig,
118 ) -> Result<Self, AgentProtocolError> {
119 let path = path.as_ref();
120 let socket_path = path.to_string_lossy().to_string();
121
122 if path.exists() {
124 std::fs::remove_file(path).map_err(|e| {
125 AgentProtocolError::ConnectionFailed(format!(
126 "Failed to remove existing socket {}: {}",
127 socket_path, e
128 ))
129 })?;
130 }
131
132 let listener = UnixListener::bind(path).map_err(|e| {
133 AgentProtocolError::ConnectionFailed(format!(
134 "Failed to bind to {}: {}",
135 socket_path, e
136 ))
137 })?;
138
139 info!(path = %socket_path, "Reverse connection listener bound");
140
141 Ok(Self {
142 listener,
143 config,
144 socket_path,
145 })
146 }
147
148 pub fn socket_path(&self) -> &str {
150 &self.socket_path
151 }
152
153 pub async fn accept_one(&self, pool: &AgentPool) -> Result<String, AgentProtocolError> {
157 let (stream, _addr) = self.listener.accept().await.map_err(|e| {
158 AgentProtocolError::ConnectionFailed(format!("Accept failed: {}", e))
159 })?;
160
161 debug!("Accepted reverse connection");
162
163 self.handle_connection(stream, pool).await
164 }
165
166 pub async fn accept_loop(self: Arc<Self>, pool: Arc<AgentPool>) {
171 info!(path = %self.socket_path, "Starting reverse connection accept loop");
172
173 loop {
174 match self.listener.accept().await {
175 Ok((stream, _addr)) => {
176 let listener = Arc::clone(&self);
177 let pool = Arc::clone(&pool);
178
179 tokio::spawn(async move {
180 match listener.handle_connection(stream, &pool).await {
181 Ok(agent_id) => {
182 info!(agent_id = %agent_id, "Reverse connection registered");
183 }
184 Err(e) => {
185 warn!(error = %e, "Failed to handle reverse connection");
186 }
187 }
188 });
189 }
190 Err(e) => {
191 error!(error = %e, "Accept failed");
192 tokio::time::sleep(Duration::from_millis(100)).await;
194 }
195 }
196 }
197 }
198
199 async fn handle_connection(
201 &self,
202 stream: UnixStream,
203 pool: &AgentPool,
204 ) -> Result<String, AgentProtocolError> {
205 let (read_half, write_half) = stream.into_split();
206 let mut reader = BufReader::new(read_half);
207 let mut writer = BufWriter::new(write_half);
208
209 let registration = tokio::time::timeout(
211 self.config.handshake_timeout,
212 self.read_registration(&mut reader),
213 )
214 .await
215 .map_err(|_| {
216 AgentProtocolError::Timeout(self.config.handshake_timeout)
217 })??;
218
219 let agent_id = registration.agent_id.clone();
220
221 if let Err(e) = self.validate_registration(®istration) {
223 let response = RegistrationResponse {
224 success: false,
225 error: Some(e.to_string()),
226 proxy_id: "sentinel-proxy".to_string(),
227 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
228 connection_id: String::new(),
229 };
230 self.send_registration_response(&mut writer, &response).await?;
231 return Err(e);
232 }
233
234 let connection_id = format!(
236 "{}-{:x}",
237 agent_id,
238 std::time::SystemTime::now()
239 .duration_since(std::time::UNIX_EPOCH)
240 .map(|d| d.as_millis())
241 .unwrap_or(0)
242 );
243
244 let response = RegistrationResponse {
246 success: true,
247 error: None,
248 proxy_id: "sentinel-proxy".to_string(),
249 proxy_version: env!("CARGO_PKG_VERSION").to_string(),
250 connection_id: connection_id.clone(),
251 };
252 self.send_registration_response(&mut writer, &response).await?;
253
254 info!(
255 agent_id = %agent_id,
256 connection_id = %connection_id,
257 "Agent registration successful"
258 );
259
260 let capabilities: AgentCapabilities = registration.capabilities.into();
262
263 let client = ReverseConnectionClient::new(
265 agent_id.clone(),
266 connection_id,
267 capabilities.clone(),
268 reader,
269 writer,
270 self.config.request_timeout,
271 )
272 .await;
273
274 pool.add_reverse_connection(&agent_id, client, capabilities)
276 .await?;
277
278 Ok(agent_id)
279 }
280
281 async fn read_registration<R: AsyncReadExt + Unpin>(
283 &self,
284 reader: &mut R,
285 ) -> Result<RegistrationRequest, AgentProtocolError> {
286 let (msg_type, payload) = read_message(reader).await?;
287
288 if msg_type != MessageType::HandshakeRequest {
289 return Err(AgentProtocolError::InvalidMessage(format!(
290 "Expected registration request (HandshakeRequest), got {:?}",
291 msg_type
292 )));
293 }
294
295 serde_json::from_slice(&payload)
296 .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
297 }
298
299 async fn send_registration_response<W: AsyncWriteExt + Unpin>(
301 &self,
302 writer: &mut W,
303 response: &RegistrationResponse,
304 ) -> Result<(), AgentProtocolError> {
305 let payload = serde_json::to_vec(response)
306 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
307
308 write_message(writer, MessageType::HandshakeResponse, &payload).await
309 }
310
311 fn validate_registration(
313 &self,
314 registration: &RegistrationRequest,
315 ) -> Result<(), AgentProtocolError> {
316 if registration.protocol_version != PROTOCOL_VERSION_2 {
318 return Err(AgentProtocolError::VersionMismatch {
319 expected: PROTOCOL_VERSION_2,
320 actual: registration.protocol_version,
321 });
322 }
323
324 if registration.agent_id.is_empty() {
326 return Err(AgentProtocolError::InvalidMessage(
327 "Agent ID cannot be empty".to_string(),
328 ));
329 }
330
331 if !self.config.allowed_agents.is_empty()
333 && !self.config.allowed_agents.contains(®istration.agent_id)
334 {
335 return Err(AgentProtocolError::InvalidMessage(format!(
336 "Agent '{}' is not in the allowed list",
337 registration.agent_id
338 )));
339 }
340
341 if self.config.require_auth && registration.auth_token.is_none() {
343 return Err(AgentProtocolError::InvalidMessage(
344 "Authentication required but no token provided".to_string(),
345 ));
346 }
347
348 Ok(())
349 }
350}
351
352impl Drop for ReverseConnectionListener {
353 fn drop(&mut self) {
354 if let Err(e) = std::fs::remove_file(&self.socket_path) {
356 debug!(path = %self.socket_path, error = %e, "Failed to remove socket file on drop");
357 }
358 }
359}
360
361pub struct ReverseConnectionClient {
366 agent_id: String,
367 connection_id: String,
368 capabilities: RwLock<Option<AgentCapabilities>>,
369 pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>>,
370 outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
371 connected: RwLock<bool>,
372 timeout: Duration,
373 in_flight: std::sync::atomic::AtomicU64,
374 flow_state: Arc<RwLock<FlowState>>,
376}
377
378impl ReverseConnectionClient {
379 async fn new<R, W>(
381 agent_id: String,
382 connection_id: String,
383 capabilities: AgentCapabilities,
384 mut reader: BufReader<R>,
385 mut writer: BufWriter<W>,
386 timeout: Duration,
387 ) -> Self
388 where
389 R: AsyncReadExt + Unpin + Send + 'static,
390 W: AsyncWriteExt + Unpin + Send + 'static,
391 {
392 let pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>> =
393 Arc::new(Mutex::new(std::collections::HashMap::new()));
394
395 let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
397
398 let agent_id_clone = agent_id.clone();
400 tokio::spawn(async move {
401 while let Some((msg_type, payload)) = rx.recv().await {
402 if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
403 error!(
404 agent_id = %agent_id_clone,
405 error = %e,
406 "Failed to write to reverse connection"
407 );
408 break;
409 }
410 }
411 debug!(agent_id = %agent_id_clone, "Reverse connection writer ended");
412 });
413
414 let pending_clone = Arc::clone(&pending);
416 let agent_id_clone = agent_id.clone();
417 tokio::spawn(async move {
418 loop {
419 match read_message(&mut reader).await {
420 Ok((msg_type, payload)) => {
421 if msg_type == MessageType::AgentResponse {
422 if let Ok(response) = serde_json::from_slice::<AgentResponse>(&payload) {
423 let correlation_id = response
424 .audit
425 .custom
426 .get("correlation_id")
427 .and_then(|v| v.as_str())
428 .unwrap_or("")
429 .to_string();
430
431 if let Some(sender) =
432 pending_clone.lock().await.remove(&correlation_id)
433 {
434 let _ = sender.send(response);
435 }
436 }
437 }
438 }
439 Err(e) => {
440 if !matches!(e, AgentProtocolError::ConnectionClosed) {
441 error!(
442 agent_id = %agent_id_clone,
443 error = %e,
444 "Error reading from reverse connection"
445 );
446 }
447 break;
448 }
449 }
450 }
451 debug!(agent_id = %agent_id_clone, "Reverse connection reader ended");
452 });
453
454 Self {
455 agent_id,
456 connection_id,
457 capabilities: RwLock::new(Some(capabilities)),
458 pending,
459 outbound_tx: Mutex::new(Some(tx)),
460 connected: RwLock::new(true),
461 timeout,
462 in_flight: std::sync::atomic::AtomicU64::new(0),
463 flow_state: Arc::new(RwLock::new(FlowState::Normal)),
464 }
465 }
466
467 pub fn agent_id(&self) -> &str {
469 &self.agent_id
470 }
471
472 pub fn connection_id(&self) -> &str {
474 &self.connection_id
475 }
476
477 pub async fn is_connected(&self) -> bool {
479 *self.connected.read().await
480 }
481
482 pub async fn capabilities(&self) -> Option<AgentCapabilities> {
484 self.capabilities.read().await.clone()
485 }
486
487 pub async fn is_paused(&self) -> bool {
492 matches!(*self.flow_state.read().await, FlowState::Paused)
493 }
494
495 pub async fn can_accept_requests(&self) -> bool {
499 !self.is_paused().await
500 }
501
502 pub async fn send_request_headers(
504 &self,
505 correlation_id: &str,
506 event: &crate::RequestHeadersEvent,
507 ) -> Result<AgentResponse, AgentProtocolError> {
508 self.send_event(MessageType::RequestHeaders, correlation_id, event)
509 .await
510 }
511
512 pub async fn send_request_body_chunk(
514 &self,
515 correlation_id: &str,
516 event: &crate::RequestBodyChunkEvent,
517 ) -> Result<AgentResponse, AgentProtocolError> {
518 self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
519 .await
520 }
521
522 pub async fn send_response_headers(
524 &self,
525 correlation_id: &str,
526 event: &crate::ResponseHeadersEvent,
527 ) -> Result<AgentResponse, AgentProtocolError> {
528 self.send_event(MessageType::ResponseHeaders, correlation_id, event)
529 .await
530 }
531
532 pub async fn send_response_body_chunk(
534 &self,
535 correlation_id: &str,
536 event: &crate::ResponseBodyChunkEvent,
537 ) -> Result<AgentResponse, AgentProtocolError> {
538 self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
539 .await
540 }
541
542 async fn send_event<T: serde::Serialize>(
544 &self,
545 msg_type: MessageType,
546 correlation_id: &str,
547 event: &T,
548 ) -> Result<AgentResponse, AgentProtocolError> {
549 let (tx, rx) = oneshot::channel();
550 self.pending
551 .lock()
552 .await
553 .insert(correlation_id.to_string(), tx);
554
555 let mut payload = serde_json::to_value(event)
557 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
558
559 if let Some(obj) = payload.as_object_mut() {
560 obj.insert(
561 "correlation_id".to_string(),
562 serde_json::Value::String(correlation_id.to_string()),
563 );
564 }
565
566 let payload_bytes = serde_json::to_vec(&payload)
567 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
568
569 {
571 let outbound = self.outbound_tx.lock().await;
572 if let Some(tx) = outbound.as_ref() {
573 tx.send((msg_type, payload_bytes))
574 .await
575 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
576 } else {
577 return Err(AgentProtocolError::ConnectionClosed);
578 }
579 }
580
581 self.in_flight
582 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
583
584 let response = tokio::time::timeout(self.timeout, rx)
586 .await
587 .map_err(|_| {
588 self.pending
589 .try_lock()
590 .ok()
591 .map(|mut p| p.remove(correlation_id));
592 AgentProtocolError::Timeout(self.timeout)
593 })?
594 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
595
596 self.in_flight
597 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
598
599 Ok(response)
600 }
601
602 pub async fn cancel_request(
604 &self,
605 correlation_id: &str,
606 reason: super::client::CancelReason,
607 ) -> Result<(), AgentProtocolError> {
608 let cancel = serde_json::json!({
609 "correlation_id": correlation_id,
610 "reason": reason as i32,
611 "timestamp_ms": now_ms(),
612 });
613
614 let payload = serde_json::to_vec(&cancel)
615 .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
616
617 let outbound = self.outbound_tx.lock().await;
618 if let Some(tx) = outbound.as_ref() {
619 tx.send((MessageType::Cancel, payload))
620 .await
621 .map_err(|_| AgentProtocolError::ConnectionClosed)?;
622 }
623
624 self.pending.lock().await.remove(correlation_id);
625 Ok(())
626 }
627
628 pub async fn cancel_all(
630 &self,
631 reason: super::client::CancelReason,
632 ) -> Result<usize, AgentProtocolError> {
633 let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
634 let count = pending_ids.len();
635
636 for correlation_id in pending_ids {
637 let _ = self.cancel_request(&correlation_id, reason).await;
638 }
639
640 Ok(count)
641 }
642
643 pub async fn close(&self) -> Result<(), AgentProtocolError> {
645 *self.connected.write().await = false;
646 *self.outbound_tx.lock().await = None;
647 Ok(())
648 }
649
650 pub fn in_flight(&self) -> u64 {
652 self.in_flight.load(std::sync::atomic::Ordering::Relaxed)
653 }
654}
655
656fn now_ms() -> u64 {
657 std::time::SystemTime::now()
658 .duration_since(std::time::UNIX_EPOCH)
659 .map(|d| d.as_millis() as u64)
660 .unwrap_or(0)
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn test_config_default() {
669 let config = ReverseConnectionConfig::default();
670 assert_eq!(config.backlog, 128);
671 assert_eq!(config.max_connections_per_agent, 4);
672 assert!(!config.require_auth);
673 }
674
675 #[test]
676 fn test_registration_request_serialization() {
677 let request = RegistrationRequest {
678 protocol_version: 2,
679 agent_id: "test-agent".to_string(),
680 capabilities: UdsCapabilities {
681 agent_id: "test-agent".to_string(),
682 name: "Test Agent".to_string(),
683 version: "1.0.0".to_string(),
684 supported_events: vec![1, 2],
685 features: Default::default(),
686 limits: Default::default(),
687 },
688 auth_token: None,
689 metadata: None,
690 };
691
692 let json = serde_json::to_string(&request).unwrap();
693 let parsed: RegistrationRequest = serde_json::from_str(&json).unwrap();
694
695 assert_eq!(parsed.agent_id, "test-agent");
696 assert_eq!(parsed.protocol_version, 2);
697 }
698
699 #[test]
700 fn test_registration_response_serialization() {
701 let response = RegistrationResponse {
702 success: true,
703 error: None,
704 proxy_id: "sentinel".to_string(),
705 proxy_version: "1.0.0".to_string(),
706 connection_id: "conn-123".to_string(),
707 };
708
709 let json = serde_json::to_string(&response).unwrap();
710 let parsed: RegistrationResponse = serde_json::from_str(&json).unwrap();
711
712 assert!(parsed.success);
713 assert_eq!(parsed.connection_id, "conn-123");
714 }
715}