1use crate::error::{NetworkError, NetworkResult};
60use crate::identity::{AgentId, MachineId};
61use crate::trust::TrustDecision;
62use std::collections::HashMap;
63use std::sync::Arc;
64use std::time::{SystemTime, UNIX_EPOCH};
65use tokio::sync::{broadcast, mpsc, RwLock};
66
67pub const DIRECT_MESSAGE_STREAM_TYPE: u8 = 0x10;
69
70pub const MAX_DIRECT_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
72
73#[derive(Debug, Clone, PartialEq, Eq)]
85pub struct DirectMessage {
86 pub sender: AgentId,
92 pub machine_id: MachineId,
97 pub payload: Vec<u8>,
99 pub received_at: u64,
101 pub verified: bool,
109 pub trust_decision: Option<TrustDecision>,
116}
117
118impl DirectMessage {
119 #[must_use]
124 pub fn new(sender: AgentId, machine_id: MachineId, payload: Vec<u8>) -> Self {
125 Self::new_verified(sender, machine_id, payload, false, None)
126 }
127
128 #[must_use]
130 pub fn new_verified(
131 sender: AgentId,
132 machine_id: MachineId,
133 payload: Vec<u8>,
134 verified: bool,
135 trust_decision: Option<TrustDecision>,
136 ) -> Self {
137 let received_at = SystemTime::now()
138 .duration_since(UNIX_EPOCH)
139 .map(|d| d.as_millis() as u64)
140 .unwrap_or(0);
141
142 Self {
143 sender,
144 machine_id,
145 payload,
146 received_at,
147 verified,
148 trust_decision,
149 }
150 }
151
152 #[must_use]
154 pub fn payload_str(&self) -> Option<&str> {
155 std::str::from_utf8(&self.payload).ok()
156 }
157}
158
159#[derive(Debug)]
164pub struct DirectMessageReceiver {
165 rx: broadcast::Receiver<DirectMessage>,
166}
167
168impl DirectMessageReceiver {
169 pub(crate) fn new(rx: broadcast::Receiver<DirectMessage>) -> Self {
171 Self { rx }
172 }
173
174 pub async fn recv(&mut self) -> Option<DirectMessage> {
178 loop {
179 match self.rx.recv().await {
180 Ok(msg) => return Some(msg),
181 Err(broadcast::error::RecvError::Lagged(n)) => {
182 tracing::warn!("Direct message receiver lagged, skipped {} messages", n);
183 continue;
184 }
185 Err(broadcast::error::RecvError::Closed) => return None,
186 }
187 }
188 }
189
190 pub fn try_recv(&mut self) -> Option<DirectMessage> {
194 self.rx.try_recv().ok()
195 }
196}
197
198impl Clone for DirectMessageReceiver {
199 fn clone(&self) -> Self {
200 Self {
201 rx: self.rx.resubscribe(),
202 }
203 }
204}
205
206#[derive(Debug)]
211pub struct DirectMessaging {
212 machine_to_agent: Arc<RwLock<HashMap<MachineId, AgentId>>>,
215
216 connected_agents: Arc<RwLock<HashMap<AgentId, MachineId>>>,
218
219 message_tx: broadcast::Sender<DirectMessage>,
221
222 internal_tx: mpsc::Sender<DirectMessage>,
224
225 internal_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<DirectMessage>>>,
227}
228
229impl DirectMessaging {
230 #[must_use]
232 pub fn new() -> Self {
233 let (message_tx, _) = broadcast::channel(256);
234 let (internal_tx, internal_rx) = mpsc::channel(256);
235
236 Self {
237 machine_to_agent: Arc::new(RwLock::new(HashMap::new())),
238 connected_agents: Arc::new(RwLock::new(HashMap::new())),
239 message_tx,
240 internal_tx,
241 internal_rx: Arc::new(tokio::sync::Mutex::new(internal_rx)),
242 }
243 }
244
245 pub async fn register_agent(&self, agent_id: AgentId, machine_id: MachineId) {
249 let mut map = self.machine_to_agent.write().await;
250 map.insert(machine_id, agent_id);
251 tracing::debug!(
252 "Registered agent mapping: {:?} -> {:?}",
253 machine_id,
254 agent_id
255 );
256 }
257
258 pub async fn lookup_agent(&self, machine_id: &MachineId) -> Option<AgentId> {
260 let map = self.machine_to_agent.read().await;
261 map.get(machine_id).copied()
262 }
263
264 pub async fn mark_connected(&self, agent_id: AgentId, machine_id: MachineId) {
266 self.register_agent(agent_id, machine_id).await;
268
269 let mut connected = self.connected_agents.write().await;
270 connected.insert(agent_id, machine_id);
271 tracing::info!("Agent connected: {:?}", agent_id);
272 }
273
274 pub async fn mark_disconnected(&self, agent_id: &AgentId) {
276 let mut connected = self.connected_agents.write().await;
277 connected.remove(agent_id);
278 tracing::info!("Agent disconnected: {:?}", agent_id);
279 }
280
281 pub async fn is_connected(&self, agent_id: &AgentId) -> bool {
283 let connected = self.connected_agents.read().await;
284 connected.contains_key(agent_id)
285 }
286
287 pub async fn get_machine_id(&self, agent_id: &AgentId) -> Option<MachineId> {
289 let connected = self.connected_agents.read().await;
290 connected.get(agent_id).copied()
291 }
292
293 pub async fn connected_agents(&self) -> Vec<AgentId> {
295 let connected = self.connected_agents.read().await;
296 connected.keys().copied().collect()
297 }
298
299 pub fn subscribe(&self) -> DirectMessageReceiver {
301 DirectMessageReceiver::new(self.message_tx.subscribe())
302 }
303
304 pub fn subscriber_count(&self) -> usize {
309 self.message_tx.receiver_count()
310 }
311
312 pub async fn handle_incoming(
318 &self,
319 machine_id: MachineId,
320 sender_agent_id: AgentId,
321 payload: Vec<u8>,
322 verified: bool,
323 trust_decision: Option<TrustDecision>,
324 ) {
325 let msg = DirectMessage::new_verified(
326 sender_agent_id,
327 machine_id,
328 payload,
329 verified,
330 trust_decision,
331 );
332
333 if self.message_tx.receiver_count() > 0 {
335 let _ = self.message_tx.send(msg.clone());
336 }
337
338 if self.internal_tx.try_send(msg).is_err() {
350 tracing::trace!("direct internal_tx full or closed, skipping pull-API copy");
351 }
352 }
353
354 pub async fn recv(&self) -> Option<DirectMessage> {
356 let mut rx = self.internal_rx.lock().await;
357 rx.recv().await
358 }
359
360 pub fn encode_message(sender_agent_id: &AgentId, payload: &[u8]) -> NetworkResult<Vec<u8>> {
364 if payload.len() > MAX_DIRECT_PAYLOAD_SIZE {
365 return Err(NetworkError::PayloadTooLarge {
366 size: payload.len(),
367 max: MAX_DIRECT_PAYLOAD_SIZE,
368 });
369 }
370
371 let mut buf = Vec::with_capacity(1 + 32 + payload.len());
372 buf.push(DIRECT_MESSAGE_STREAM_TYPE);
373 buf.extend_from_slice(&sender_agent_id.0);
374 buf.extend_from_slice(payload);
375 Ok(buf)
376 }
377
378 pub fn decode_message(data: &[u8]) -> NetworkResult<(AgentId, Vec<u8>)> {
382 if data.len() < 33 {
384 return Err(NetworkError::InvalidMessage(
385 "Direct message too short".to_string(),
386 ));
387 }
388
389 if data[0] != DIRECT_MESSAGE_STREAM_TYPE {
390 return Err(NetworkError::InvalidMessage(format!(
391 "Invalid stream type byte: expected {}, got {}",
392 DIRECT_MESSAGE_STREAM_TYPE, data[0]
393 )));
394 }
395
396 let mut agent_id_bytes = [0u8; 32];
397 agent_id_bytes.copy_from_slice(&data[1..33]);
398 let sender = AgentId(agent_id_bytes);
399
400 let payload = data[33..].to_vec();
401
402 Ok((sender, payload))
403 }
404}
405
406impl Default for DirectMessaging {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_encode_decode_roundtrip() {
418 let agent_id = AgentId([42u8; 32]);
419 let payload = b"hello world".to_vec();
420
421 let encoded = DirectMessaging::encode_message(&agent_id, &payload).unwrap();
422
423 assert_eq!(encoded[0], DIRECT_MESSAGE_STREAM_TYPE);
424 assert_eq!(encoded.len(), 1 + 32 + payload.len());
425
426 let (decoded_agent, decoded_payload) = DirectMessaging::decode_message(&encoded).unwrap();
427
428 assert_eq!(decoded_agent, agent_id);
429 assert_eq!(decoded_payload, payload);
430 }
431
432 #[test]
433 fn test_decode_too_short() {
434 let short_data = vec![DIRECT_MESSAGE_STREAM_TYPE; 10];
435 let result = DirectMessaging::decode_message(&short_data);
436 assert!(result.is_err());
437 }
438
439 #[test]
440 fn test_decode_wrong_type() {
441 let mut data = vec![0x00; 50]; data[0] = 0x01;
443 let result = DirectMessaging::decode_message(&data);
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn test_encode_payload_too_large() {
449 let agent_id = AgentId([1u8; 32]);
450 let payload = vec![0u8; MAX_DIRECT_PAYLOAD_SIZE + 1];
451 let result = DirectMessaging::encode_message(&agent_id, &payload);
452 assert!(result.is_err());
453 }
454
455 #[tokio::test]
456 async fn test_register_and_lookup() {
457 let dm = DirectMessaging::new();
458 let agent_id = AgentId([1u8; 32]);
459 let machine_id = MachineId([2u8; 32]);
460
461 dm.register_agent(agent_id, machine_id).await;
462
463 let lookup = dm.lookup_agent(&machine_id).await;
464 assert_eq!(lookup, Some(agent_id));
465 }
466
467 #[tokio::test]
468 async fn test_connection_tracking() {
469 let dm = DirectMessaging::new();
470 let agent_id = AgentId([1u8; 32]);
471 let machine_id = MachineId([2u8; 32]);
472
473 assert!(!dm.is_connected(&agent_id).await);
474
475 dm.mark_connected(agent_id, machine_id).await;
476 assert!(dm.is_connected(&agent_id).await);
477 assert_eq!(dm.get_machine_id(&agent_id).await, Some(machine_id));
478
479 let connected = dm.connected_agents().await;
480 assert_eq!(connected, vec![agent_id]);
481
482 dm.mark_disconnected(&agent_id).await;
483 assert!(!dm.is_connected(&agent_id).await);
484 }
485
486 #[tokio::test]
487 async fn test_message_subscription() {
488 let dm = DirectMessaging::new();
489 let mut rx = dm.subscribe();
490
491 let sender = AgentId([1u8; 32]);
492 let machine_id = MachineId([2u8; 32]);
493 let payload = b"test message".to_vec();
494
495 dm.handle_incoming(machine_id, sender, payload.clone(), true, None)
496 .await;
497
498 let msg = rx.recv().await.unwrap();
499 assert_eq!(msg.sender, sender);
500 assert_eq!(msg.machine_id, machine_id);
501 assert_eq!(msg.payload, payload);
502 assert!(msg.verified);
503 assert!(msg.trust_decision.is_none());
504 }
505
506 #[test]
507 fn test_direct_message_payload_str() {
508 let msg = DirectMessage::new(AgentId([1u8; 32]), MachineId([2u8; 32]), b"hello".to_vec());
509 assert_eq!(msg.payload_str(), Some("hello"));
510
511 let binary_msg =
512 DirectMessage::new(AgentId([1u8; 32]), MachineId([2u8; 32]), vec![0xff, 0xfe]);
513 assert!(binary_msg.payload_str().is_none());
514 }
515}