Skip to main content

sentinel_agent_protocol/v2/
reverse.rs

1//! Reverse connection support for Agent Protocol v2.
2//!
3//! This module allows agents to connect to the proxy instead of the proxy
4//! connecting to agents. This is useful for:
5//!
6//! - Agents behind NAT or firewalls
7//! - Dynamic agent scaling (agents register on startup)
8//! - Simpler agent deployment (no need to expose agent ports)
9//!
10//! # Protocol
11//!
12//! 1. Proxy starts a listener (UDS or TCP)
13//! 2. Agent connects and sends a `RegistrationRequest`
14//! 3. Proxy validates and responds with `RegistrationResponse`
15//! 4. On success, the connection is added to the AgentPool
16//! 5. The connection is used bidirectionally like a normal connection
17//!
18//! # Example
19//!
20//! ```ignore
21//! use sentinel_agent_protocol::v2::{AgentPool, ReverseConnectionListener};
22//!
23//! let pool = AgentPool::new();
24//! let listener = ReverseConnectionListener::bind_uds("/var/run/sentinel/agents.sock").await?;
25//!
26//! // Accept connections in background
27//! listener.accept_loop(pool).await;
28//! ```
29
30use std::collections::HashSet;
31use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
36use tokio::net::{UnixListener, UnixStream};
37use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
38use tracing::{debug, error, info, warn};
39
40use crate::v2::client::FlowState;
41use crate::v2::pool::CHANNEL_BUFFER_SIZE;
42use crate::v2::uds::{read_message, write_message, MessageType, UdsCapabilities};
43use crate::v2::{AgentCapabilities, AgentPool, PROTOCOL_VERSION_2};
44use crate::{AgentProtocolError, AgentResponse};
45
46/// Configuration for the reverse connection listener.
47#[derive(Debug, Clone)]
48pub struct ReverseConnectionConfig {
49    /// Maximum number of pending connections in the accept queue
50    pub backlog: u32,
51    /// Timeout for the registration handshake
52    pub handshake_timeout: Duration,
53    /// Maximum number of connections per agent
54    pub max_connections_per_agent: usize,
55    /// Allowed agent IDs (empty = allow all)
56    pub allowed_agents: HashSet<String>,
57    /// Whether to require agent authentication
58    pub require_auth: bool,
59    /// Request timeout for accepted connections
60    pub request_timeout: Duration,
61}
62
63impl Default for ReverseConnectionConfig {
64    fn default() -> Self {
65        Self {
66            backlog: 128,
67            handshake_timeout: Duration::from_secs(10),
68            max_connections_per_agent: 4,
69            allowed_agents: HashSet::new(),
70            require_auth: false,
71            request_timeout: Duration::from_secs(30),
72        }
73    }
74}
75
76/// Registration request sent by agent when connecting to proxy.
77#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct RegistrationRequest {
79    /// Protocol version the agent supports
80    pub protocol_version: u32,
81    /// Agent's unique identifier
82    pub agent_id: String,
83    /// Agent capabilities
84    pub capabilities: UdsCapabilities,
85    /// Optional authentication token
86    pub auth_token: Option<String>,
87    /// Optional metadata
88    pub metadata: Option<serde_json::Value>,
89}
90
91/// Registration response sent by proxy to agent.
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct RegistrationResponse {
94    /// Whether registration was successful
95    pub success: bool,
96    /// Error message if registration failed
97    pub error: Option<String>,
98    /// Proxy identifier
99    pub proxy_id: String,
100    /// Proxy version
101    pub proxy_version: String,
102    /// Assigned connection ID (for debugging)
103    pub connection_id: String,
104}
105
106/// Listener for reverse agent connections over Unix Domain Socket.
107pub struct ReverseConnectionListener {
108    listener: UnixListener,
109    config: ReverseConnectionConfig,
110    socket_path: String,
111}
112
113impl ReverseConnectionListener {
114    /// Bind to a Unix Domain Socket path.
115    pub async fn bind_uds(
116        path: impl AsRef<Path>,
117        config: ReverseConnectionConfig,
118    ) -> Result<Self, AgentProtocolError> {
119        let path = path.as_ref();
120        let socket_path = path.to_string_lossy().to_string();
121
122        // Remove existing socket file if present
123        if path.exists() {
124            std::fs::remove_file(path).map_err(|e| {
125                AgentProtocolError::ConnectionFailed(format!(
126                    "Failed to remove existing socket {}: {}",
127                    socket_path, e
128                ))
129            })?;
130        }
131
132        let listener = UnixListener::bind(path).map_err(|e| {
133            AgentProtocolError::ConnectionFailed(format!(
134                "Failed to bind to {}: {}",
135                socket_path, e
136            ))
137        })?;
138
139        info!(path = %socket_path, "Reverse connection listener bound");
140
141        Ok(Self {
142            listener,
143            config,
144            socket_path,
145        })
146    }
147
148    /// Get the socket path.
149    pub fn socket_path(&self) -> &str {
150        &self.socket_path
151    }
152
153    /// Accept a single connection and register it with the pool.
154    ///
155    /// Returns the agent_id of the registered agent on success.
156    pub async fn accept_one(&self, pool: &AgentPool) -> Result<String, AgentProtocolError> {
157        let (stream, _addr) =
158            self.listener.accept().await.map_err(|e| {
159                AgentProtocolError::ConnectionFailed(format!("Accept failed: {}", e))
160            })?;
161
162        debug!("Accepted reverse connection");
163
164        self.handle_connection(stream, pool).await
165    }
166
167    /// Run the accept loop, registering connections with the pool.
168    ///
169    /// This method runs forever, accepting connections and spawning tasks
170    /// to handle them.
171    pub async fn accept_loop(self: Arc<Self>, pool: Arc<AgentPool>) {
172        info!(path = %self.socket_path, "Starting reverse connection accept loop");
173
174        loop {
175            match self.listener.accept().await {
176                Ok((stream, _addr)) => {
177                    let listener = Arc::clone(&self);
178                    let pool = Arc::clone(&pool);
179
180                    tokio::spawn(async move {
181                        match listener.handle_connection(stream, &pool).await {
182                            Ok(agent_id) => {
183                                info!(agent_id = %agent_id, "Reverse connection registered");
184                            }
185                            Err(e) => {
186                                warn!(error = %e, "Failed to handle reverse connection");
187                            }
188                        }
189                    });
190                }
191                Err(e) => {
192                    error!(error = %e, "Accept failed");
193                    // Brief delay before retrying
194                    tokio::time::sleep(Duration::from_millis(100)).await;
195                }
196            }
197        }
198    }
199
200    /// Handle an accepted connection.
201    async fn handle_connection(
202        &self,
203        stream: UnixStream,
204        pool: &AgentPool,
205    ) -> Result<String, AgentProtocolError> {
206        let (read_half, write_half) = stream.into_split();
207        let mut reader = BufReader::new(read_half);
208        let mut writer = BufWriter::new(write_half);
209
210        // Read registration request with timeout
211        let registration = tokio::time::timeout(
212            self.config.handshake_timeout,
213            self.read_registration(&mut reader),
214        )
215        .await
216        .map_err(|_| AgentProtocolError::Timeout(self.config.handshake_timeout))??;
217
218        let agent_id = registration.agent_id.clone();
219
220        // Validate registration
221        if let Err(e) = self.validate_registration(&registration) {
222            let response = RegistrationResponse {
223                success: false,
224                error: Some(e.to_string()),
225                proxy_id: "sentinel-proxy".to_string(),
226                proxy_version: env!("CARGO_PKG_VERSION").to_string(),
227                connection_id: String::new(),
228            };
229            self.send_registration_response(&mut writer, &response)
230                .await?;
231            return Err(e);
232        }
233
234        // Generate connection ID
235        let connection_id = format!(
236            "{}-{:x}",
237            agent_id,
238            std::time::SystemTime::now()
239                .duration_since(std::time::UNIX_EPOCH)
240                .map(|d| d.as_millis())
241                .unwrap_or(0)
242        );
243
244        // Send success response
245        let response = RegistrationResponse {
246            success: true,
247            error: None,
248            proxy_id: "sentinel-proxy".to_string(),
249            proxy_version: env!("CARGO_PKG_VERSION").to_string(),
250            connection_id: connection_id.clone(),
251        };
252        self.send_registration_response(&mut writer, &response)
253            .await?;
254
255        info!(
256            agent_id = %agent_id,
257            connection_id = %connection_id,
258            "Agent registration successful"
259        );
260
261        // Convert capabilities
262        let capabilities: AgentCapabilities = registration.capabilities.into();
263
264        // Create the reverse connection client wrapper
265        let client = ReverseConnectionClient::new(
266            agent_id.clone(),
267            connection_id,
268            capabilities.clone(),
269            reader,
270            writer,
271            self.config.request_timeout,
272        )
273        .await;
274
275        // Add to pool
276        pool.add_reverse_connection(&agent_id, client, capabilities)
277            .await?;
278
279        Ok(agent_id)
280    }
281
282    /// Read registration request from stream.
283    async fn read_registration<R: AsyncReadExt + Unpin>(
284        &self,
285        reader: &mut R,
286    ) -> Result<RegistrationRequest, AgentProtocolError> {
287        let (msg_type, payload) = read_message(reader).await?;
288
289        if msg_type != MessageType::HandshakeRequest {
290            return Err(AgentProtocolError::InvalidMessage(format!(
291                "Expected registration request (HandshakeRequest), got {:?}",
292                msg_type
293            )));
294        }
295
296        serde_json::from_slice(&payload)
297            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
298    }
299
300    /// Send registration response.
301    async fn send_registration_response<W: AsyncWriteExt + Unpin>(
302        &self,
303        writer: &mut W,
304        response: &RegistrationResponse,
305    ) -> Result<(), AgentProtocolError> {
306        let payload = serde_json::to_vec(response)
307            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
308
309        write_message(writer, MessageType::HandshakeResponse, &payload).await
310    }
311
312    /// Validate a registration request.
313    fn validate_registration(
314        &self,
315        registration: &RegistrationRequest,
316    ) -> Result<(), AgentProtocolError> {
317        // Check protocol version
318        if registration.protocol_version != PROTOCOL_VERSION_2 {
319            return Err(AgentProtocolError::VersionMismatch {
320                expected: PROTOCOL_VERSION_2,
321                actual: registration.protocol_version,
322            });
323        }
324
325        // Check agent ID is not empty
326        if registration.agent_id.is_empty() {
327            return Err(AgentProtocolError::InvalidMessage(
328                "Agent ID cannot be empty".to_string(),
329            ));
330        }
331
332        // Check if agent is in allowed list (if configured)
333        if !self.config.allowed_agents.is_empty()
334            && !self.config.allowed_agents.contains(&registration.agent_id)
335        {
336            return Err(AgentProtocolError::InvalidMessage(format!(
337                "Agent '{}' is not in the allowed list",
338                registration.agent_id
339            )));
340        }
341
342        // Check authentication if required
343        if self.config.require_auth && registration.auth_token.is_none() {
344            return Err(AgentProtocolError::InvalidMessage(
345                "Authentication required but no token provided".to_string(),
346            ));
347        }
348
349        Ok(())
350    }
351}
352
353impl Drop for ReverseConnectionListener {
354    fn drop(&mut self) {
355        // Clean up socket file
356        if let Err(e) = std::fs::remove_file(&self.socket_path) {
357            debug!(path = %self.socket_path, error = %e, "Failed to remove socket file on drop");
358        }
359    }
360}
361
362/// Client wrapper for a reverse connection.
363///
364/// This wraps an accepted connection and provides the same interface
365/// as AgentClientV2Uds but for inbound connections.
366pub struct ReverseConnectionClient {
367    agent_id: String,
368    connection_id: String,
369    capabilities: RwLock<Option<AgentCapabilities>>,
370    pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>>,
371    #[allow(clippy::type_complexity)]
372    outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
373    connected: RwLock<bool>,
374    timeout: Duration,
375    in_flight: std::sync::atomic::AtomicU64,
376    /// Flow control state - tracks if agent has requested pause
377    flow_state: Arc<RwLock<FlowState>>,
378}
379
380impl ReverseConnectionClient {
381    /// Create a new reverse connection client from an accepted stream.
382    async fn new<R, W>(
383        agent_id: String,
384        connection_id: String,
385        capabilities: AgentCapabilities,
386        mut reader: BufReader<R>,
387        mut writer: BufWriter<W>,
388        timeout: Duration,
389    ) -> Self
390    where
391        R: AsyncReadExt + Unpin + Send + 'static,
392        W: AsyncWriteExt + Unpin + Send + 'static,
393    {
394        let pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>> =
395            Arc::new(Mutex::new(std::collections::HashMap::new()));
396
397        // Create message channel
398        let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
399
400        // Spawn writer task
401        let agent_id_clone = agent_id.clone();
402        tokio::spawn(async move {
403            while let Some((msg_type, payload)) = rx.recv().await {
404                if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
405                    error!(
406                        agent_id = %agent_id_clone,
407                        error = %e,
408                        "Failed to write to reverse connection"
409                    );
410                    break;
411                }
412            }
413            debug!(agent_id = %agent_id_clone, "Reverse connection writer ended");
414        });
415
416        // Spawn reader task
417        let pending_clone = Arc::clone(&pending);
418        let agent_id_clone = agent_id.clone();
419        tokio::spawn(async move {
420            loop {
421                match read_message(&mut reader).await {
422                    Ok((msg_type, payload)) => {
423                        if msg_type == MessageType::AgentResponse {
424                            if let Ok(response) = serde_json::from_slice::<AgentResponse>(&payload)
425                            {
426                                let correlation_id = response
427                                    .audit
428                                    .custom
429                                    .get("correlation_id")
430                                    .and_then(|v| v.as_str())
431                                    .unwrap_or("")
432                                    .to_string();
433
434                                if let Some(sender) =
435                                    pending_clone.lock().await.remove(&correlation_id)
436                                {
437                                    let _ = sender.send(response);
438                                }
439                            }
440                        }
441                    }
442                    Err(e) => {
443                        if !matches!(e, AgentProtocolError::ConnectionClosed) {
444                            error!(
445                                agent_id = %agent_id_clone,
446                                error = %e,
447                                "Error reading from reverse connection"
448                            );
449                        }
450                        break;
451                    }
452                }
453            }
454            debug!(agent_id = %agent_id_clone, "Reverse connection reader ended");
455        });
456
457        Self {
458            agent_id,
459            connection_id,
460            capabilities: RwLock::new(Some(capabilities)),
461            pending,
462            outbound_tx: Mutex::new(Some(tx)),
463            connected: RwLock::new(true),
464            timeout,
465            in_flight: std::sync::atomic::AtomicU64::new(0),
466            flow_state: Arc::new(RwLock::new(FlowState::Normal)),
467        }
468    }
469
470    /// Get the agent ID.
471    pub fn agent_id(&self) -> &str {
472        &self.agent_id
473    }
474
475    /// Get the connection ID.
476    pub fn connection_id(&self) -> &str {
477        &self.connection_id
478    }
479
480    /// Check if connected.
481    pub async fn is_connected(&self) -> bool {
482        *self.connected.read().await
483    }
484
485    /// Get capabilities.
486    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
487        self.capabilities.read().await.clone()
488    }
489
490    /// Check if the agent has requested flow control pause.
491    ///
492    /// Returns true if the agent sent a `FlowAction::Pause` signal,
493    /// indicating it cannot accept more requests.
494    pub async fn is_paused(&self) -> bool {
495        matches!(*self.flow_state.read().await, FlowState::Paused)
496    }
497
498    /// Check if the transport can accept new requests.
499    ///
500    /// Returns false if the agent has requested a flow control pause.
501    pub async fn can_accept_requests(&self) -> bool {
502        !self.is_paused().await
503    }
504
505    /// Send a request headers event.
506    pub async fn send_request_headers(
507        &self,
508        correlation_id: &str,
509        event: &crate::RequestHeadersEvent,
510    ) -> Result<AgentResponse, AgentProtocolError> {
511        self.send_event(MessageType::RequestHeaders, correlation_id, event)
512            .await
513    }
514
515    /// Send a request body chunk event.
516    pub async fn send_request_body_chunk(
517        &self,
518        correlation_id: &str,
519        event: &crate::RequestBodyChunkEvent,
520    ) -> Result<AgentResponse, AgentProtocolError> {
521        self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
522            .await
523    }
524
525    /// Send a response headers event.
526    pub async fn send_response_headers(
527        &self,
528        correlation_id: &str,
529        event: &crate::ResponseHeadersEvent,
530    ) -> Result<AgentResponse, AgentProtocolError> {
531        self.send_event(MessageType::ResponseHeaders, correlation_id, event)
532            .await
533    }
534
535    /// Send a response body chunk event.
536    pub async fn send_response_body_chunk(
537        &self,
538        correlation_id: &str,
539        event: &crate::ResponseBodyChunkEvent,
540    ) -> Result<AgentResponse, AgentProtocolError> {
541        self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
542            .await
543    }
544
545    /// Send an event and wait for response.
546    async fn send_event<T: serde::Serialize>(
547        &self,
548        msg_type: MessageType,
549        correlation_id: &str,
550        event: &T,
551    ) -> Result<AgentResponse, AgentProtocolError> {
552        let (tx, rx) = oneshot::channel();
553        self.pending
554            .lock()
555            .await
556            .insert(correlation_id.to_string(), tx);
557
558        // Serialize event with correlation ID
559        let mut payload = serde_json::to_value(event)
560            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
561
562        if let Some(obj) = payload.as_object_mut() {
563            obj.insert(
564                "correlation_id".to_string(),
565                serde_json::Value::String(correlation_id.to_string()),
566            );
567        }
568
569        let payload_bytes = serde_json::to_vec(&payload)
570            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
571
572        // Send message
573        {
574            let outbound = self.outbound_tx.lock().await;
575            if let Some(tx) = outbound.as_ref() {
576                tx.send((msg_type, payload_bytes))
577                    .await
578                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
579            } else {
580                return Err(AgentProtocolError::ConnectionClosed);
581            }
582        }
583
584        self.in_flight
585            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
586
587        // Wait for response with timeout
588        let response = tokio::time::timeout(self.timeout, rx)
589            .await
590            .map_err(|_| {
591                self.pending
592                    .try_lock()
593                    .ok()
594                    .map(|mut p| p.remove(correlation_id));
595                AgentProtocolError::Timeout(self.timeout)
596            })?
597            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
598
599        self.in_flight
600            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
601
602        Ok(response)
603    }
604
605    /// Cancel a specific request.
606    pub async fn cancel_request(
607        &self,
608        correlation_id: &str,
609        reason: super::client::CancelReason,
610    ) -> Result<(), AgentProtocolError> {
611        let cancel = serde_json::json!({
612            "correlation_id": correlation_id,
613            "reason": reason as i32,
614            "timestamp_ms": now_ms(),
615        });
616
617        let payload = serde_json::to_vec(&cancel)
618            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
619
620        let outbound = self.outbound_tx.lock().await;
621        if let Some(tx) = outbound.as_ref() {
622            tx.send((MessageType::Cancel, payload))
623                .await
624                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
625        }
626
627        self.pending.lock().await.remove(correlation_id);
628        Ok(())
629    }
630
631    /// Cancel all in-flight requests.
632    pub async fn cancel_all(
633        &self,
634        reason: super::client::CancelReason,
635    ) -> Result<usize, AgentProtocolError> {
636        let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
637        let count = pending_ids.len();
638
639        for correlation_id in pending_ids {
640            let _ = self.cancel_request(&correlation_id, reason).await;
641        }
642
643        Ok(count)
644    }
645
646    /// Close the connection.
647    pub async fn close(&self) -> Result<(), AgentProtocolError> {
648        *self.connected.write().await = false;
649        *self.outbound_tx.lock().await = None;
650        Ok(())
651    }
652
653    /// Get in-flight request count.
654    pub fn in_flight(&self) -> u64 {
655        self.in_flight.load(std::sync::atomic::Ordering::Relaxed)
656    }
657}
658
659fn now_ms() -> u64 {
660    std::time::SystemTime::now()
661        .duration_since(std::time::UNIX_EPOCH)
662        .map(|d| d.as_millis() as u64)
663        .unwrap_or(0)
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669
670    #[test]
671    fn test_config_default() {
672        let config = ReverseConnectionConfig::default();
673        assert_eq!(config.backlog, 128);
674        assert_eq!(config.max_connections_per_agent, 4);
675        assert!(!config.require_auth);
676    }
677
678    #[test]
679    fn test_registration_request_serialization() {
680        let request = RegistrationRequest {
681            protocol_version: 2,
682            agent_id: "test-agent".to_string(),
683            capabilities: UdsCapabilities {
684                agent_id: "test-agent".to_string(),
685                name: "Test Agent".to_string(),
686                version: "1.0.0".to_string(),
687                supported_events: vec![1, 2],
688                features: Default::default(),
689                limits: Default::default(),
690            },
691            auth_token: None,
692            metadata: None,
693        };
694
695        let json = serde_json::to_string(&request).unwrap();
696        let parsed: RegistrationRequest = serde_json::from_str(&json).unwrap();
697
698        assert_eq!(parsed.agent_id, "test-agent");
699        assert_eq!(parsed.protocol_version, 2);
700    }
701
702    #[test]
703    fn test_registration_response_serialization() {
704        let response = RegistrationResponse {
705            success: true,
706            error: None,
707            proxy_id: "sentinel".to_string(),
708            proxy_version: "1.0.0".to_string(),
709            connection_id: "conn-123".to_string(),
710        };
711
712        let json = serde_json::to_string(&response).unwrap();
713        let parsed: RegistrationResponse = serde_json::from_str(&json).unwrap();
714
715        assert!(parsed.success);
716        assert_eq!(parsed.connection_id, "conn-123");
717    }
718}