Skip to main content

tpcp_std/
node.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use tokio::sync::RwLock;
5use futures_util::{SinkExt, StreamExt};
6use tokio_tungstenite::{connect_async, accept_async, tungstenite::Message};
7use serde_json;
8use tpcp_core::{AgentIdentity, Intent, TPCPEnvelope, MessageHeader, PROTOCOL_VERSION};
9use tpcp_core::identity::{canonical_json, sign, verify};
10use ed25519_dalek::SigningKey;
11use uuid::Uuid;
12use crate::DLQ;
13
14type HandlerFn = Arc<dyn Fn(TPCPEnvelope) + Send + Sync + 'static>;
15
16/// Async TPCP node backed by tokio and tokio-tungstenite.
17pub struct TPCPNode {
18    pub identity: AgentIdentity,
19    pub memory: Arc<RwLock<tpcp_core::LWWMap>>,
20    pub dlq: Arc<DLQ>,
21    handlers: Arc<RwLock<HashMap<Intent, HandlerFn>>>,
22    /// Optional Ed25519 signing key. When present, `send_message` attaches a signature.
23    signing_key: Option<Arc<SigningKey>>,
24    /// Known peer public keys: agent_id → base64-encoded public key.
25    /// Inbound messages from registered peers are signature-verified before dispatch.
26    peer_keys: Arc<RwLock<HashMap<String, String>>>,
27    /// Replay-protection window: message_id → unix timestamp (seconds) of first receipt.
28    seen_messages: Arc<RwLock<HashMap<String, i64>>>,
29    /// Maximum number of concurrent inbound WebSocket connections.
30    max_connections: usize,
31    /// Current count of active inbound connections.
32    active_connections: Arc<AtomicUsize>,
33}
34
35impl TPCPNode {
36    /// Creates a new TPCPNode with the given identity and no signing key.
37    pub fn new(identity: AgentIdentity) -> Self {
38        Self {
39            identity,
40            memory: Arc::new(RwLock::new(tpcp_core::LWWMap::new())),
41            dlq: Arc::new(DLQ::new()),
42            handlers: Arc::new(RwLock::new(HashMap::new())),
43            signing_key: None,
44            peer_keys: Arc::new(RwLock::new(HashMap::new())),
45            seen_messages: Arc::new(RwLock::new(HashMap::new())),
46            max_connections: 100,
47            active_connections: Arc::new(AtomicUsize::new(0)),
48        }
49    }
50
51    /// Creates a TPCPNode that signs all outbound messages.
52    pub fn with_signing_key(identity: AgentIdentity, key: SigningKey) -> Self {
53        Self {
54            identity,
55            memory: Arc::new(RwLock::new(tpcp_core::LWWMap::new())),
56            dlq: Arc::new(DLQ::new()),
57            handlers: Arc::new(RwLock::new(HashMap::new())),
58            signing_key: Some(Arc::new(key)),
59            peer_keys: Arc::new(RwLock::new(HashMap::new())),
60            seen_messages: Arc::new(RwLock::new(HashMap::new())),
61            max_connections: 100,
62            active_connections: Arc::new(AtomicUsize::new(0)),
63        }
64    }
65
66    /// Registers a peer's public key for inbound signature verification.
67    pub async fn register_peer_key(&self, agent_id: &str, public_key_b64: &str) {
68        self.peer_keys.write().await.insert(agent_id.to_string(), public_key_b64.to_string());
69    }
70
71    /// Registers a handler for a specific intent.
72    pub async fn register_handler<F>(&self, intent: Intent, handler: F)
73    where
74        F: Fn(TPCPEnvelope) + Send + Sync + 'static,
75    {
76        self.handlers.write().await.insert(intent, Arc::new(handler));
77    }
78
79    /// Connects as a WebSocket client to the given URL.
80    pub async fn connect(&self, url: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
81        let (ws_stream, _) = connect_async(url).await?;
82        let (_, mut read) = ws_stream.split();
83        let handlers = Arc::clone(&self.handlers);
84        let dlq = Arc::clone(&self.dlq);
85        let peer_keys = Arc::clone(&self.peer_keys);
86        let seen_messages = Arc::clone(&self.seen_messages);
87        tokio::spawn(async move {
88            while let Some(Ok(msg)) = read.next().await {
89                if let Message::Text(text) = msg {
90                    if let Ok(env) = serde_json::from_str::<TPCPEnvelope>(&text) {
91                        Self::dispatch_env(env, &handlers, &dlq, &peer_keys, &seen_messages).await;
92                    }
93                }
94            }
95        });
96        Ok(())
97    }
98
99    /// Starts listening on the given address (e.g. "127.0.0.1:9001").
100    pub async fn listen(&self, addr: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
101        let listener = tokio::net::TcpListener::bind(addr).await?;
102        let handlers = Arc::clone(&self.handlers);
103        let dlq = Arc::clone(&self.dlq);
104        let peer_keys = Arc::clone(&self.peer_keys);
105        let seen_messages = Arc::clone(&self.seen_messages);
106        let max_connections = self.max_connections;
107        let active_connections = Arc::clone(&self.active_connections);
108        tokio::spawn(async move {
109            while let Ok((stream, _)) = listener.accept().await {
110                // Enforce connection limit before upgrading to WebSocket.
111                if active_connections.load(Ordering::Acquire) >= max_connections {
112                    eprintln!("[TPCPNode] max_connections ({}) reached — rejecting new connection", max_connections);
113                    // Drop the raw TcpStream; this closes the TCP connection.
114                    drop(stream);
115                    continue;
116                }
117
118                let ws = match accept_async(stream).await {
119                    Ok(ws) => ws,
120                    Err(_) => continue,
121                };
122                let (_, mut read) = ws.split();
123                let handlers = Arc::clone(&handlers);
124                let dlq = Arc::clone(&dlq);
125                let peer_keys = Arc::clone(&peer_keys);
126                let seen_messages = Arc::clone(&seen_messages);
127                let active_connections_inner = Arc::clone(&active_connections);
128                active_connections_inner.fetch_add(1, Ordering::AcqRel);
129                tokio::spawn(async move {
130                    while let Some(Ok(msg)) = read.next().await {
131                        if let Message::Text(text) = msg {
132                            if let Ok(env) = serde_json::from_str::<TPCPEnvelope>(&text) {
133                                Self::dispatch_env(env, &handlers, &dlq, &peer_keys, &seen_messages).await;
134                            }
135                        }
136                    }
137                    active_connections_inner.fetch_sub(1, Ordering::AcqRel);
138                });
139            }
140        });
141        Ok(())
142    }
143
144    /// Sends a message to the given WebSocket URL.
145    /// `receiver_id` is the target agent's agent_id (not the URL).
146    pub async fn send_message(
147        &self,
148        url: &str,
149        receiver_id: &str,
150        intent: Intent,
151        payload: serde_json::Value,
152    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
153        let (mut ws_stream, _) = connect_async(url).await?;
154        let signature = self.signing_key.as_ref().map(|key| {
155            let canonical = canonical_json(&payload);
156            sign(key, &canonical)
157        });
158        let envelope = TPCPEnvelope {
159            header: MessageHeader {
160                message_id: uuid_v4(),
161                sender_id: self.identity.agent_id.clone(),
162                receiver_id: receiver_id.to_string(),
163                intent,
164                timestamp: chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true),
165                ttl: 30,
166                protocol_version: PROTOCOL_VERSION.to_string(),
167            },
168            payload,
169            signature,
170            ack_info: None,
171            chunk_info: None,
172        };
173        let text = serde_json::to_string(&envelope)?;
174        ws_stream.send(Message::Text(text)).await?;
175        ws_stream.close(None).await?;
176        Ok(())
177    }
178
179    async fn dispatch_env(
180        env: TPCPEnvelope,
181        handlers: &RwLock<HashMap<Intent, HandlerFn>>,
182        dlq: &DLQ,
183        peer_keys: &RwLock<HashMap<String, String>>,
184        seen_messages: &RwLock<HashMap<String, i64>>,
185    ) {
186        // --- Signature enforcement ---
187        // If the envelope carries a non-empty signature, the sender's public key
188        // MUST be registered. Fail closed: unknown-peer signed messages go to DLQ.
189        let has_signature = env.signature.as_deref().map(|s| !s.is_empty()).unwrap_or(false);
190        let sender_key = peer_keys.read().await.get(&env.header.sender_id).cloned();
191
192        if has_signature && sender_key.is_none() {
193            eprintln!(
194                "[TPCPNode] signed envelope from unknown peer '{}' (message_id: {}) — routing to DLQ",
195                env.header.sender_id, env.header.message_id
196            );
197            if !dlq.enqueue(env) {
198                eprintln!("[TPCPNode] DLQ full — dropping signed-unknown-peer envelope");
199            }
200            return;
201        }
202
203        if let Some(pub_key_b64) = sender_key {
204            let canonical = canonical_json(&env.payload);
205            let sig_ok = env.signature.as_deref()
206                .map(|sig| verify(&pub_key_b64, &canonical, sig))
207                .unwrap_or(false);
208            if !sig_ok {
209                // Invalid or missing signature from a known peer — route to DLQ.
210                if !dlq.enqueue(env) {
211                    eprintln!("[TPCPNode] DLQ full — dropping invalid-signature envelope");
212                }
213                return;
214            }
215        }
216
217        // --- Replay protection ---
218        let now = chrono::Utc::now().timestamp();
219        let message_id = env.header.message_id.clone();
220        {
221            let mut seen = seen_messages.write().await;
222            // Evict entries older than 300 seconds.
223            seen.retain(|_, &mut ts| now - ts < 300);
224            // Check for duplicate.
225            if seen.contains_key(&message_id) {
226                // Silently drop replayed messages.
227                return;
228            }
229            // Record first-seen timestamp.
230            seen.insert(message_id, now);
231        }
232
233        let handler = handlers.read().await.get(&env.header.intent).cloned();
234        match handler {
235            Some(h) => h(env),
236            None => {
237                if !dlq.enqueue(env) {
238                    eprintln!("[TPCPNode] DLQ full — dropping unhandled envelope");
239                }
240            }
241        }
242    }
243}
244
245fn uuid_v4() -> String {
246    Uuid::new_v4().to_string()
247}
248