sentinel_agent_protocol/v2/
reverse.rs

1//! Reverse connection support for Agent Protocol v2.
2//!
3//! This module allows agents to connect to the proxy instead of the proxy
4//! connecting to agents. This is useful for:
5//!
6//! - Agents behind NAT or firewalls
7//! - Dynamic agent scaling (agents register on startup)
8//! - Simpler agent deployment (no need to expose agent ports)
9//!
10//! # Protocol
11//!
12//! 1. Proxy starts a listener (UDS or TCP)
13//! 2. Agent connects and sends a `RegistrationRequest`
14//! 3. Proxy validates and responds with `RegistrationResponse`
15//! 4. On success, the connection is added to the AgentPool
16//! 5. The connection is used bidirectionally like a normal connection
17//!
18//! # Example
19//!
20//! ```ignore
21//! use sentinel_agent_protocol::v2::{AgentPool, ReverseConnectionListener};
22//!
23//! let pool = AgentPool::new();
24//! let listener = ReverseConnectionListener::bind_uds("/var/run/sentinel/agents.sock").await?;
25//!
26//! // Accept connections in background
27//! listener.accept_loop(pool).await;
28//! ```
29
30use std::collections::HashSet;
31use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
36use tokio::net::{UnixListener, UnixStream};
37use tokio::sync::{mpsc, oneshot, Mutex, RwLock};
38use tracing::{debug, error, info, warn};
39
40use crate::v2::client::FlowState;
41use crate::v2::uds::{read_message, write_message, MessageType, UdsCapabilities};
42use crate::v2::pool::CHANNEL_BUFFER_SIZE;
43use crate::v2::{AgentCapabilities, AgentPool, PROTOCOL_VERSION_2};
44use crate::{AgentProtocolError, AgentResponse};
45
46/// Configuration for the reverse connection listener.
47#[derive(Debug, Clone)]
48pub struct ReverseConnectionConfig {
49    /// Maximum number of pending connections in the accept queue
50    pub backlog: u32,
51    /// Timeout for the registration handshake
52    pub handshake_timeout: Duration,
53    /// Maximum number of connections per agent
54    pub max_connections_per_agent: usize,
55    /// Allowed agent IDs (empty = allow all)
56    pub allowed_agents: HashSet<String>,
57    /// Whether to require agent authentication
58    pub require_auth: bool,
59    /// Request timeout for accepted connections
60    pub request_timeout: Duration,
61}
62
63impl Default for ReverseConnectionConfig {
64    fn default() -> Self {
65        Self {
66            backlog: 128,
67            handshake_timeout: Duration::from_secs(10),
68            max_connections_per_agent: 4,
69            allowed_agents: HashSet::new(),
70            require_auth: false,
71            request_timeout: Duration::from_secs(30),
72        }
73    }
74}
75
76/// Registration request sent by agent when connecting to proxy.
77#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct RegistrationRequest {
79    /// Protocol version the agent supports
80    pub protocol_version: u32,
81    /// Agent's unique identifier
82    pub agent_id: String,
83    /// Agent capabilities
84    pub capabilities: UdsCapabilities,
85    /// Optional authentication token
86    pub auth_token: Option<String>,
87    /// Optional metadata
88    pub metadata: Option<serde_json::Value>,
89}
90
91/// Registration response sent by proxy to agent.
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct RegistrationResponse {
94    /// Whether registration was successful
95    pub success: bool,
96    /// Error message if registration failed
97    pub error: Option<String>,
98    /// Proxy identifier
99    pub proxy_id: String,
100    /// Proxy version
101    pub proxy_version: String,
102    /// Assigned connection ID (for debugging)
103    pub connection_id: String,
104}
105
106/// Listener for reverse agent connections over Unix Domain Socket.
107pub struct ReverseConnectionListener {
108    listener: UnixListener,
109    config: ReverseConnectionConfig,
110    socket_path: String,
111}
112
113impl ReverseConnectionListener {
114    /// Bind to a Unix Domain Socket path.
115    pub async fn bind_uds(
116        path: impl AsRef<Path>,
117        config: ReverseConnectionConfig,
118    ) -> Result<Self, AgentProtocolError> {
119        let path = path.as_ref();
120        let socket_path = path.to_string_lossy().to_string();
121
122        // Remove existing socket file if present
123        if path.exists() {
124            std::fs::remove_file(path).map_err(|e| {
125                AgentProtocolError::ConnectionFailed(format!(
126                    "Failed to remove existing socket {}: {}",
127                    socket_path, e
128                ))
129            })?;
130        }
131
132        let listener = UnixListener::bind(path).map_err(|e| {
133            AgentProtocolError::ConnectionFailed(format!(
134                "Failed to bind to {}: {}",
135                socket_path, e
136            ))
137        })?;
138
139        info!(path = %socket_path, "Reverse connection listener bound");
140
141        Ok(Self {
142            listener,
143            config,
144            socket_path,
145        })
146    }
147
148    /// Get the socket path.
149    pub fn socket_path(&self) -> &str {
150        &self.socket_path
151    }
152
153    /// Accept a single connection and register it with the pool.
154    ///
155    /// Returns the agent_id of the registered agent on success.
156    pub async fn accept_one(&self, pool: &AgentPool) -> Result<String, AgentProtocolError> {
157        let (stream, _addr) = self.listener.accept().await.map_err(|e| {
158            AgentProtocolError::ConnectionFailed(format!("Accept failed: {}", e))
159        })?;
160
161        debug!("Accepted reverse connection");
162
163        self.handle_connection(stream, pool).await
164    }
165
166    /// Run the accept loop, registering connections with the pool.
167    ///
168    /// This method runs forever, accepting connections and spawning tasks
169    /// to handle them.
170    pub async fn accept_loop(self: Arc<Self>, pool: Arc<AgentPool>) {
171        info!(path = %self.socket_path, "Starting reverse connection accept loop");
172
173        loop {
174            match self.listener.accept().await {
175                Ok((stream, _addr)) => {
176                    let listener = Arc::clone(&self);
177                    let pool = Arc::clone(&pool);
178
179                    tokio::spawn(async move {
180                        match listener.handle_connection(stream, &pool).await {
181                            Ok(agent_id) => {
182                                info!(agent_id = %agent_id, "Reverse connection registered");
183                            }
184                            Err(e) => {
185                                warn!(error = %e, "Failed to handle reverse connection");
186                            }
187                        }
188                    });
189                }
190                Err(e) => {
191                    error!(error = %e, "Accept failed");
192                    // Brief delay before retrying
193                    tokio::time::sleep(Duration::from_millis(100)).await;
194                }
195            }
196        }
197    }
198
199    /// Handle an accepted connection.
200    async fn handle_connection(
201        &self,
202        stream: UnixStream,
203        pool: &AgentPool,
204    ) -> Result<String, AgentProtocolError> {
205        let (read_half, write_half) = stream.into_split();
206        let mut reader = BufReader::new(read_half);
207        let mut writer = BufWriter::new(write_half);
208
209        // Read registration request with timeout
210        let registration = tokio::time::timeout(
211            self.config.handshake_timeout,
212            self.read_registration(&mut reader),
213        )
214        .await
215        .map_err(|_| {
216            AgentProtocolError::Timeout(self.config.handshake_timeout)
217        })??;
218
219        let agent_id = registration.agent_id.clone();
220
221        // Validate registration
222        if let Err(e) = self.validate_registration(&registration) {
223            let response = RegistrationResponse {
224                success: false,
225                error: Some(e.to_string()),
226                proxy_id: "sentinel-proxy".to_string(),
227                proxy_version: env!("CARGO_PKG_VERSION").to_string(),
228                connection_id: String::new(),
229            };
230            self.send_registration_response(&mut writer, &response).await?;
231            return Err(e);
232        }
233
234        // Generate connection ID
235        let connection_id = format!(
236            "{}-{:x}",
237            agent_id,
238            std::time::SystemTime::now()
239                .duration_since(std::time::UNIX_EPOCH)
240                .map(|d| d.as_millis())
241                .unwrap_or(0)
242        );
243
244        // Send success response
245        let response = RegistrationResponse {
246            success: true,
247            error: None,
248            proxy_id: "sentinel-proxy".to_string(),
249            proxy_version: env!("CARGO_PKG_VERSION").to_string(),
250            connection_id: connection_id.clone(),
251        };
252        self.send_registration_response(&mut writer, &response).await?;
253
254        info!(
255            agent_id = %agent_id,
256            connection_id = %connection_id,
257            "Agent registration successful"
258        );
259
260        // Convert capabilities
261        let capabilities: AgentCapabilities = registration.capabilities.into();
262
263        // Create the reverse connection client wrapper
264        let client = ReverseConnectionClient::new(
265            agent_id.clone(),
266            connection_id,
267            capabilities.clone(),
268            reader,
269            writer,
270            self.config.request_timeout,
271        )
272        .await;
273
274        // Add to pool
275        pool.add_reverse_connection(&agent_id, client, capabilities)
276            .await?;
277
278        Ok(agent_id)
279    }
280
281    /// Read registration request from stream.
282    async fn read_registration<R: AsyncReadExt + Unpin>(
283        &self,
284        reader: &mut R,
285    ) -> Result<RegistrationRequest, AgentProtocolError> {
286        let (msg_type, payload) = read_message(reader).await?;
287
288        if msg_type != MessageType::HandshakeRequest {
289            return Err(AgentProtocolError::InvalidMessage(format!(
290                "Expected registration request (HandshakeRequest), got {:?}",
291                msg_type
292            )));
293        }
294
295        serde_json::from_slice(&payload)
296            .map_err(|e| AgentProtocolError::InvalidMessage(e.to_string()))
297    }
298
299    /// Send registration response.
300    async fn send_registration_response<W: AsyncWriteExt + Unpin>(
301        &self,
302        writer: &mut W,
303        response: &RegistrationResponse,
304    ) -> Result<(), AgentProtocolError> {
305        let payload = serde_json::to_vec(response)
306            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
307
308        write_message(writer, MessageType::HandshakeResponse, &payload).await
309    }
310
311    /// Validate a registration request.
312    fn validate_registration(
313        &self,
314        registration: &RegistrationRequest,
315    ) -> Result<(), AgentProtocolError> {
316        // Check protocol version
317        if registration.protocol_version != PROTOCOL_VERSION_2 {
318            return Err(AgentProtocolError::VersionMismatch {
319                expected: PROTOCOL_VERSION_2,
320                actual: registration.protocol_version,
321            });
322        }
323
324        // Check agent ID is not empty
325        if registration.agent_id.is_empty() {
326            return Err(AgentProtocolError::InvalidMessage(
327                "Agent ID cannot be empty".to_string(),
328            ));
329        }
330
331        // Check if agent is in allowed list (if configured)
332        if !self.config.allowed_agents.is_empty()
333            && !self.config.allowed_agents.contains(&registration.agent_id)
334        {
335            return Err(AgentProtocolError::InvalidMessage(format!(
336                "Agent '{}' is not in the allowed list",
337                registration.agent_id
338            )));
339        }
340
341        // Check authentication if required
342        if self.config.require_auth && registration.auth_token.is_none() {
343            return Err(AgentProtocolError::InvalidMessage(
344                "Authentication required but no token provided".to_string(),
345            ));
346        }
347
348        Ok(())
349    }
350}
351
352impl Drop for ReverseConnectionListener {
353    fn drop(&mut self) {
354        // Clean up socket file
355        if let Err(e) = std::fs::remove_file(&self.socket_path) {
356            debug!(path = %self.socket_path, error = %e, "Failed to remove socket file on drop");
357        }
358    }
359}
360
361/// Client wrapper for a reverse connection.
362///
363/// This wraps an accepted connection and provides the same interface
364/// as AgentClientV2Uds but for inbound connections.
365pub struct ReverseConnectionClient {
366    agent_id: String,
367    connection_id: String,
368    capabilities: RwLock<Option<AgentCapabilities>>,
369    pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>>,
370    outbound_tx: Mutex<Option<mpsc::Sender<(MessageType, Vec<u8>)>>>,
371    connected: RwLock<bool>,
372    timeout: Duration,
373    in_flight: std::sync::atomic::AtomicU64,
374    /// Flow control state - tracks if agent has requested pause
375    flow_state: Arc<RwLock<FlowState>>,
376}
377
378impl ReverseConnectionClient {
379    /// Create a new reverse connection client from an accepted stream.
380    async fn new<R, W>(
381        agent_id: String,
382        connection_id: String,
383        capabilities: AgentCapabilities,
384        mut reader: BufReader<R>,
385        mut writer: BufWriter<W>,
386        timeout: Duration,
387    ) -> Self
388    where
389        R: AsyncReadExt + Unpin + Send + 'static,
390        W: AsyncWriteExt + Unpin + Send + 'static,
391    {
392        let pending: Arc<Mutex<std::collections::HashMap<String, oneshot::Sender<AgentResponse>>>> =
393            Arc::new(Mutex::new(std::collections::HashMap::new()));
394
395        // Create message channel
396        let (tx, mut rx) = mpsc::channel::<(MessageType, Vec<u8>)>(CHANNEL_BUFFER_SIZE);
397
398        // Spawn writer task
399        let agent_id_clone = agent_id.clone();
400        tokio::spawn(async move {
401            while let Some((msg_type, payload)) = rx.recv().await {
402                if let Err(e) = write_message(&mut writer, msg_type, &payload).await {
403                    error!(
404                        agent_id = %agent_id_clone,
405                        error = %e,
406                        "Failed to write to reverse connection"
407                    );
408                    break;
409                }
410            }
411            debug!(agent_id = %agent_id_clone, "Reverse connection writer ended");
412        });
413
414        // Spawn reader task
415        let pending_clone = Arc::clone(&pending);
416        let agent_id_clone = agent_id.clone();
417        tokio::spawn(async move {
418            loop {
419                match read_message(&mut reader).await {
420                    Ok((msg_type, payload)) => {
421                        if msg_type == MessageType::AgentResponse {
422                            if let Ok(response) = serde_json::from_slice::<AgentResponse>(&payload) {
423                                let correlation_id = response
424                                    .audit
425                                    .custom
426                                    .get("correlation_id")
427                                    .and_then(|v| v.as_str())
428                                    .unwrap_or("")
429                                    .to_string();
430
431                                if let Some(sender) =
432                                    pending_clone.lock().await.remove(&correlation_id)
433                                {
434                                    let _ = sender.send(response);
435                                }
436                            }
437                        }
438                    }
439                    Err(e) => {
440                        if !matches!(e, AgentProtocolError::ConnectionClosed) {
441                            error!(
442                                agent_id = %agent_id_clone,
443                                error = %e,
444                                "Error reading from reverse connection"
445                            );
446                        }
447                        break;
448                    }
449                }
450            }
451            debug!(agent_id = %agent_id_clone, "Reverse connection reader ended");
452        });
453
454        Self {
455            agent_id,
456            connection_id,
457            capabilities: RwLock::new(Some(capabilities)),
458            pending,
459            outbound_tx: Mutex::new(Some(tx)),
460            connected: RwLock::new(true),
461            timeout,
462            in_flight: std::sync::atomic::AtomicU64::new(0),
463            flow_state: Arc::new(RwLock::new(FlowState::Normal)),
464        }
465    }
466
467    /// Get the agent ID.
468    pub fn agent_id(&self) -> &str {
469        &self.agent_id
470    }
471
472    /// Get the connection ID.
473    pub fn connection_id(&self) -> &str {
474        &self.connection_id
475    }
476
477    /// Check if connected.
478    pub async fn is_connected(&self) -> bool {
479        *self.connected.read().await
480    }
481
482    /// Get capabilities.
483    pub async fn capabilities(&self) -> Option<AgentCapabilities> {
484        self.capabilities.read().await.clone()
485    }
486
487    /// Check if the agent has requested flow control pause.
488    ///
489    /// Returns true if the agent sent a `FlowAction::Pause` signal,
490    /// indicating it cannot accept more requests.
491    pub async fn is_paused(&self) -> bool {
492        matches!(*self.flow_state.read().await, FlowState::Paused)
493    }
494
495    /// Check if the transport can accept new requests.
496    ///
497    /// Returns false if the agent has requested a flow control pause.
498    pub async fn can_accept_requests(&self) -> bool {
499        !self.is_paused().await
500    }
501
502    /// Send a request headers event.
503    pub async fn send_request_headers(
504        &self,
505        correlation_id: &str,
506        event: &crate::RequestHeadersEvent,
507    ) -> Result<AgentResponse, AgentProtocolError> {
508        self.send_event(MessageType::RequestHeaders, correlation_id, event)
509            .await
510    }
511
512    /// Send a request body chunk event.
513    pub async fn send_request_body_chunk(
514        &self,
515        correlation_id: &str,
516        event: &crate::RequestBodyChunkEvent,
517    ) -> Result<AgentResponse, AgentProtocolError> {
518        self.send_event(MessageType::RequestBodyChunk, correlation_id, event)
519            .await
520    }
521
522    /// Send a response headers event.
523    pub async fn send_response_headers(
524        &self,
525        correlation_id: &str,
526        event: &crate::ResponseHeadersEvent,
527    ) -> Result<AgentResponse, AgentProtocolError> {
528        self.send_event(MessageType::ResponseHeaders, correlation_id, event)
529            .await
530    }
531
532    /// Send a response body chunk event.
533    pub async fn send_response_body_chunk(
534        &self,
535        correlation_id: &str,
536        event: &crate::ResponseBodyChunkEvent,
537    ) -> Result<AgentResponse, AgentProtocolError> {
538        self.send_event(MessageType::ResponseBodyChunk, correlation_id, event)
539            .await
540    }
541
542    /// Send an event and wait for response.
543    async fn send_event<T: serde::Serialize>(
544        &self,
545        msg_type: MessageType,
546        correlation_id: &str,
547        event: &T,
548    ) -> Result<AgentResponse, AgentProtocolError> {
549        let (tx, rx) = oneshot::channel();
550        self.pending
551            .lock()
552            .await
553            .insert(correlation_id.to_string(), tx);
554
555        // Serialize event with correlation ID
556        let mut payload = serde_json::to_value(event)
557            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
558
559        if let Some(obj) = payload.as_object_mut() {
560            obj.insert(
561                "correlation_id".to_string(),
562                serde_json::Value::String(correlation_id.to_string()),
563            );
564        }
565
566        let payload_bytes = serde_json::to_vec(&payload)
567            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
568
569        // Send message
570        {
571            let outbound = self.outbound_tx.lock().await;
572            if let Some(tx) = outbound.as_ref() {
573                tx.send((msg_type, payload_bytes))
574                    .await
575                    .map_err(|_| AgentProtocolError::ConnectionClosed)?;
576            } else {
577                return Err(AgentProtocolError::ConnectionClosed);
578            }
579        }
580
581        self.in_flight
582            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
583
584        // Wait for response with timeout
585        let response = tokio::time::timeout(self.timeout, rx)
586            .await
587            .map_err(|_| {
588                self.pending
589                    .try_lock()
590                    .ok()
591                    .map(|mut p| p.remove(correlation_id));
592                AgentProtocolError::Timeout(self.timeout)
593            })?
594            .map_err(|_| AgentProtocolError::ConnectionClosed)?;
595
596        self.in_flight
597            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
598
599        Ok(response)
600    }
601
602    /// Cancel a specific request.
603    pub async fn cancel_request(
604        &self,
605        correlation_id: &str,
606        reason: super::client::CancelReason,
607    ) -> Result<(), AgentProtocolError> {
608        let cancel = serde_json::json!({
609            "correlation_id": correlation_id,
610            "reason": reason as i32,
611            "timestamp_ms": now_ms(),
612        });
613
614        let payload = serde_json::to_vec(&cancel)
615            .map_err(|e| AgentProtocolError::Serialization(e.to_string()))?;
616
617        let outbound = self.outbound_tx.lock().await;
618        if let Some(tx) = outbound.as_ref() {
619            tx.send((MessageType::Cancel, payload))
620                .await
621                .map_err(|_| AgentProtocolError::ConnectionClosed)?;
622        }
623
624        self.pending.lock().await.remove(correlation_id);
625        Ok(())
626    }
627
628    /// Cancel all in-flight requests.
629    pub async fn cancel_all(
630        &self,
631        reason: super::client::CancelReason,
632    ) -> Result<usize, AgentProtocolError> {
633        let pending_ids: Vec<String> = self.pending.lock().await.keys().cloned().collect();
634        let count = pending_ids.len();
635
636        for correlation_id in pending_ids {
637            let _ = self.cancel_request(&correlation_id, reason).await;
638        }
639
640        Ok(count)
641    }
642
643    /// Close the connection.
644    pub async fn close(&self) -> Result<(), AgentProtocolError> {
645        *self.connected.write().await = false;
646        *self.outbound_tx.lock().await = None;
647        Ok(())
648    }
649
650    /// Get in-flight request count.
651    pub fn in_flight(&self) -> u64 {
652        self.in_flight.load(std::sync::atomic::Ordering::Relaxed)
653    }
654}
655
656fn now_ms() -> u64 {
657    std::time::SystemTime::now()
658        .duration_since(std::time::UNIX_EPOCH)
659        .map(|d| d.as_millis() as u64)
660        .unwrap_or(0)
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666
667    #[test]
668    fn test_config_default() {
669        let config = ReverseConnectionConfig::default();
670        assert_eq!(config.backlog, 128);
671        assert_eq!(config.max_connections_per_agent, 4);
672        assert!(!config.require_auth);
673    }
674
675    #[test]
676    fn test_registration_request_serialization() {
677        let request = RegistrationRequest {
678            protocol_version: 2,
679            agent_id: "test-agent".to_string(),
680            capabilities: UdsCapabilities {
681                agent_id: "test-agent".to_string(),
682                name: "Test Agent".to_string(),
683                version: "1.0.0".to_string(),
684                supported_events: vec![1, 2],
685                features: Default::default(),
686                limits: Default::default(),
687            },
688            auth_token: None,
689            metadata: None,
690        };
691
692        let json = serde_json::to_string(&request).unwrap();
693        let parsed: RegistrationRequest = serde_json::from_str(&json).unwrap();
694
695        assert_eq!(parsed.agent_id, "test-agent");
696        assert_eq!(parsed.protocol_version, 2);
697    }
698
699    #[test]
700    fn test_registration_response_serialization() {
701        let response = RegistrationResponse {
702            success: true,
703            error: None,
704            proxy_id: "sentinel".to_string(),
705            proxy_version: "1.0.0".to_string(),
706            connection_id: "conn-123".to_string(),
707        };
708
709        let json = serde_json::to_string(&response).unwrap();
710        let parsed: RegistrationResponse = serde_json::from_str(&json).unwrap();
711
712        assert!(parsed.success);
713        assert_eq!(parsed.connection_id, "conn-123");
714    }
715}