1pub mod policy_gate;
6
7use async_trait::async_trait;
8use parking_lot::RwLock;
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, SystemTime};
12use tokio::sync::{mpsc, oneshot, Notify};
13use tokio::time::{interval, timeout};
14
15use crate::crypto::Aes256GcmCrypto;
16use crate::types::*;
17use ed25519_dalek::{SigningKey, VerifyingKey};
18use rand::rngs::OsRng;
19use rand::RngCore;
20
21#[async_trait]
23pub trait CommunicationBus {
24 async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError>;
26
27 async fn receive_messages(
29 &self,
30 agent_id: AgentId,
31 ) -> Result<Vec<SecureMessage>, CommunicationError>;
32
33 async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError>;
35
36 async fn unsubscribe(&self, agent_id: AgentId, topic: String)
38 -> Result<(), CommunicationError>;
39
40 async fn publish(
42 &self,
43 topic: String,
44 message: SecureMessage,
45 ) -> Result<(), CommunicationError>;
46
47 async fn get_delivery_status(
49 &self,
50 message_id: MessageId,
51 ) -> Result<DeliveryStatus, CommunicationError>;
52
53 async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
55
56 async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
58
59 async fn request(
61 &self,
62 target_agent: AgentId,
63 request_payload: bytes::Bytes,
64 timeout_duration: Duration,
65 ) -> Result<bytes::Bytes, CommunicationError>;
66
67 async fn shutdown(&self) -> Result<(), CommunicationError>;
69
70 async fn check_health(&self) -> Result<ComponentHealth, CommunicationError>;
72
73 fn create_internal_message(
75 &self,
76 sender: AgentId,
77 recipient: AgentId,
78 payload_data: bytes::Bytes,
79 message_type: MessageType,
80 ttl: std::time::Duration,
81 ) -> SecureMessage;
82}
83
84#[derive(Debug, Clone)]
86pub struct CommunicationConfig {
87 pub max_message_size: usize,
88 pub message_ttl: Duration,
89 pub max_queue_size: usize,
90 pub delivery_timeout: Duration,
91 pub retry_attempts: u32,
92 pub enable_encryption: bool,
93 pub enable_compression: bool,
94 pub dead_letter_queue_size: usize,
95}
96
97impl Default for CommunicationConfig {
98 fn default() -> Self {
99 Self {
100 max_message_size: 1024 * 1024, message_ttl: Duration::from_secs(3600), max_queue_size: 10000,
103 delivery_timeout: Duration::from_secs(30),
104 retry_attempts: 3,
105 enable_encryption: true,
106 enable_compression: true,
107 dead_letter_queue_size: 1000,
108 }
109 }
110}
111
112pub struct DefaultCommunicationBus {
114 config: CommunicationConfig,
115 message_queues: Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
116 subscriptions: Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
117 message_tracker: Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
118 dead_letter_queue: Arc<RwLock<DeadLetterQueue>>,
119 pending_requests: Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
120 event_sender: mpsc::UnboundedSender<CommunicationEvent>,
121 shutdown_notify: Arc<Notify>,
122 is_running: Arc<RwLock<bool>>,
123 signing_key: SigningKey,
124 verifying_key: VerifyingKey,
125 system_agent_id: AgentId,
126 #[allow(dead_code)]
127 crypto: Aes256GcmCrypto,
128}
129
130impl DefaultCommunicationBus {
131 pub async fn new(config: CommunicationConfig) -> Result<Self, CommunicationError> {
133 let message_queues = Arc::new(RwLock::new(HashMap::new()));
134 let subscriptions = Arc::new(RwLock::new(HashMap::new()));
135 let message_tracker = Arc::new(RwLock::new(HashMap::new()));
136 let dead_letter_queue = Arc::new(RwLock::new(DeadLetterQueue::new(
137 config.dead_letter_queue_size,
138 )));
139 let pending_requests = Arc::new(RwLock::new(HashMap::new()));
140 let (event_sender, event_receiver) = mpsc::unbounded_channel();
141 let shutdown_notify = Arc::new(Notify::new());
142 let is_running = Arc::new(RwLock::new(true));
143
144 let mut secret_bytes = [0u8; 32];
146 OsRng.fill_bytes(&mut secret_bytes);
147 let signing_key = SigningKey::from_bytes(&secret_bytes);
148 let verifying_key = signing_key.verifying_key();
149
150 let system_agent_id = AgentId::new();
152
153 let crypto = Aes256GcmCrypto::new();
154
155 let bus = Self {
156 config,
157 message_queues,
158 subscriptions,
159 message_tracker,
160 dead_letter_queue,
161 pending_requests,
162 event_sender,
163 shutdown_notify,
164 is_running,
165 signing_key,
166 verifying_key,
167 system_agent_id,
168 crypto,
169 };
170
171 bus.start_event_loop(event_receiver).await;
173 bus.start_cleanup_loop().await;
174
175 Ok(bus)
176 }
177
178 async fn start_event_loop(
180 &self,
181 mut event_receiver: mpsc::UnboundedReceiver<CommunicationEvent>,
182 ) {
183 let message_queues = self.message_queues.clone();
184 let subscriptions = self.subscriptions.clone();
185 let message_tracker = self.message_tracker.clone();
186 let dead_letter_queue = self.dead_letter_queue.clone();
187 let pending_requests = self.pending_requests.clone();
188 let shutdown_notify = self.shutdown_notify.clone();
189 let config = self.config.clone();
190
191 tokio::spawn(async move {
192 loop {
193 tokio::select! {
194 event = event_receiver.recv() => {
195 if let Some(event) = event {
196 Self::process_communication_event(
197 event,
198 &message_queues,
199 &subscriptions,
200 &message_tracker,
201 &dead_letter_queue,
202 &pending_requests,
203 &config,
204 ).await;
205 } else {
206 break;
207 }
208 }
209 _ = shutdown_notify.notified() => {
210 break;
211 }
212 }
213 }
214 });
215 }
216
217 async fn start_cleanup_loop(&self) {
219 let message_queues = self.message_queues.clone();
220 let message_tracker = self.message_tracker.clone();
221 let dead_letter_queue = self.dead_letter_queue.clone();
222 let shutdown_notify = self.shutdown_notify.clone();
223 let is_running = self.is_running.clone();
224 let message_ttl = self.config.message_ttl;
225
226 tokio::spawn(async move {
227 let mut interval = interval(Duration::from_secs(60)); loop {
230 tokio::select! {
231 _ = interval.tick() => {
232 if !*is_running.read() {
233 break;
234 }
235
236 Self::cleanup_expired_messages(&message_queues, &message_tracker, &dead_letter_queue, message_ttl).await;
237 }
238 _ = shutdown_notify.notified() => {
239 break;
240 }
241 }
242 }
243 });
244 }
245
246 async fn process_communication_event(
248 event: CommunicationEvent,
249 message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
250 subscriptions: &Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
251 message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
252 dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
253 pending_requests: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
254 config: &CommunicationConfig,
255 ) {
256 match event {
257 CommunicationEvent::MessageSent { message } => {
258 let recipient = message.recipient;
259 let message_id = message.id;
260
261 if let MessageType::Response(request_id) = &message.message_type {
263 if let Some(sender) = pending_requests.write().remove(request_id) {
264 let _ = sender.send(message.payload.data.clone());
266 tracing::debug!(
267 "Response {} sent for request {:?}",
268 message_id,
269 request_id
270 );
271 return;
272 }
273 }
274
275 let mut tracker_map = message_tracker.write();
278 let mut queues = message_queues.write();
279
280 tracker_map.insert(message_id, MessageTracker::new(message.clone()));
281
282 if let Some(recipient_id) = recipient {
284 if let Some(queue) = queues.get_mut(&recipient_id) {
285 if queue.can_accept_message(config) {
286 queue.add_message(message);
287
288 if let Some(tracker) = tracker_map.get_mut(&message_id) {
290 tracker.status = DeliveryStatus::Delivered;
291 tracker.delivered_at = Some(SystemTime::now());
292 }
293
294 tracing::debug!(
295 "Message {} delivered to agent {}",
296 message_id,
297 recipient_id
298 );
299 } else {
300 dead_letter_queue
302 .write()
303 .add_message(message, DeadLetterReason::QueueFull);
304
305 if let Some(tracker) = tracker_map.get_mut(&message_id) {
306 tracker.status = DeliveryStatus::Failed;
307 tracker.failure_reason = Some("Queue full".to_string());
308 }
309
310 tracing::warn!(
311 "Message {} failed to deliver: queue full for agent {}",
312 message_id,
313 recipient_id
314 );
315 }
316 } else {
317 dead_letter_queue
319 .write()
320 .add_message(message, DeadLetterReason::AgentNotFound);
321
322 if let Some(tracker) = tracker_map.get_mut(&message_id) {
323 tracker.status = DeliveryStatus::Failed;
324 tracker.failure_reason = Some("Agent not registered".to_string());
325 }
326
327 tracing::warn!(
328 "Message {} failed to deliver: agent {:?} not registered",
329 message_id,
330 recipient
331 );
332 }
333 } else {
334 dead_letter_queue
336 .write()
337 .add_message(message, DeadLetterReason::AgentNotFound);
338
339 if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
340 tracker.status = DeliveryStatus::Failed;
341 tracker.failure_reason = Some("Agent not registered".to_string());
342 }
343
344 tracing::warn!(
345 "Message {} failed to deliver: agent {:?} not registered",
346 message_id,
347 recipient
348 );
349 }
350 }
351 CommunicationEvent::TopicPublished { topic, message } => {
352 let subscribers = subscriptions
353 .read()
354 .get(&topic)
355 .cloned()
356 .unwrap_or_default();
357 let subscriber_count = subscribers.len();
358
359 for subscriber in &subscribers {
360 let mut subscriber_message = message.clone();
361 subscriber_message.recipient = Some(*subscriber);
362 subscriber_message.id = MessageId::new();
363
364 Box::pin(Self::process_communication_event(
366 CommunicationEvent::MessageSent {
367 message: subscriber_message,
368 },
369 message_queues,
370 subscriptions,
371 message_tracker,
372 dead_letter_queue,
373 pending_requests,
374 config,
375 ))
376 .await;
377 }
378
379 tracing::debug!(
380 "Published message to topic {} for {} subscribers",
381 topic,
382 subscriber_count
383 );
384 }
385 CommunicationEvent::AgentRegistered { agent_id } => {
386 message_queues.write().insert(agent_id, MessageQueue::new());
387 tracing::info!("Registered agent {} for communication", agent_id);
388 }
389 CommunicationEvent::AgentUnregistered { agent_id } => {
390 message_queues.write().remove(&agent_id);
391
392 let mut subs = subscriptions.write();
394 for subscribers in subs.values_mut() {
395 subscribers.retain(|&id| id != agent_id);
396 }
397
398 tracing::info!("Unregistered agent {} from communication", agent_id);
399 }
400 }
401 }
402
403 async fn cleanup_expired_messages(
405 message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
406 message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
407 dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
408 message_ttl: Duration,
409 ) {
410 let now = SystemTime::now();
411 let mut expired_messages = Vec::new();
412
413 {
415 let mut queues = message_queues.write();
416 let mut stale_queues = 0;
417 for queue in queues.values_mut() {
418 let expired = queue.remove_expired_messages(now, message_ttl);
419 expired_messages.extend(expired);
420
421 if queue.is_stale(message_ttl * 3) {
423 stale_queues += 1;
424 }
425 }
426
427 if stale_queues > 0 {
428 tracing::debug!("Found {} stale message queues", stale_queues);
429 }
430 }
431
432 {
434 let mut dlq = dead_letter_queue.write();
435 for message in expired_messages {
436 dlq.add_message(message.clone(), DeadLetterReason::Expired);
437
438 if let Some(tracker) = message_tracker.write().get_mut(&message.id) {
440 tracker.status = DeliveryStatus::Failed;
441 tracker.failure_reason = Some("Message expired".to_string());
442 }
443 }
444 }
445
446 {
448 let mut tracker = message_tracker.write();
449 let mut retry_candidates = Vec::new();
450
451 tracker.retain(|message_id, t| {
452 let age = t.get_age();
453 if age < message_ttl * 2 {
454 if t.should_retry(message_ttl) {
456 retry_candidates.push(*message_id);
457
458 let msg = t.get_message();
460 tracing::debug!(
461 "Message {} eligible for retry: size={} bytes, age={:?}s, sender={}",
462 message_id,
463 t.get_message_size(),
464 t.get_age().as_secs(),
465 msg.sender
466 );
467 }
468 true
469 } else {
470 false
471 }
472 });
473
474 if !retry_candidates.is_empty() {
476 tracing::debug!(
477 "Found {} messages eligible for retry",
478 retry_candidates.len()
479 );
480 }
481 }
482 }
483
484 fn send_event(&self, event: CommunicationEvent) -> Result<(), CommunicationError> {
486 self.event_sender
487 .send(event)
488 .map_err(|_| CommunicationError::EventProcessingFailed {
489 reason: "Failed to send communication event".into(),
490 })
491 }
492
493 fn generate_nonce() -> Vec<u8> {
495 use aes_gcm::{aead::AeadCore, Aes256Gcm};
496 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
497 nonce.to_vec()
498 }
499
500 fn sign_message_data(&self, data: &[u8]) -> MessageSignature {
502 use ed25519_dalek::Signer;
503
504 let signature = self.signing_key.sign(data);
505 MessageSignature {
506 signature: signature.to_bytes().to_vec(),
507 algorithm: SignatureAlgorithm::Ed25519,
508 public_key: self.verifying_key.to_bytes().to_vec(),
509 }
510 }
511
512 fn create_secure_request_message(
514 &self,
515 target_agent: AgentId,
516 request_id: RequestId,
517 request_payload: bytes::Bytes,
518 timeout_duration: Duration,
519 ) -> Result<SecureMessage, CommunicationError> {
520 Ok(self.create_internal_message(
521 self.system_agent_id,
522 target_agent,
523 request_payload,
524 MessageType::Request(request_id),
525 timeout_duration,
526 ))
527 }
528}
529
530#[async_trait]
531impl CommunicationBus for DefaultCommunicationBus {
532 async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError> {
533 if !*self.is_running.read() {
534 return Err(CommunicationError::ShuttingDown);
535 }
536
537 if message.payload.data.len() > self.config.max_message_size {
539 return Err(CommunicationError::MessageTooLarge {
540 size: message.payload.data.len(),
541 max_size: self.config.max_message_size,
542 });
543 }
544
545 let message_id = message.id;
546
547 self.send_event(CommunicationEvent::MessageSent { message })?;
548
549 Ok(message_id)
550 }
551
552 async fn receive_messages(
553 &self,
554 agent_id: AgentId,
555 ) -> Result<Vec<SecureMessage>, CommunicationError> {
556 let mut queues = self.message_queues.write();
557 if let Some(queue) = queues.get_mut(&agent_id) {
558 Ok(queue.drain_messages())
559 } else {
560 Err(CommunicationError::AgentNotRegistered { agent_id })
561 }
562 }
563
564 async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError> {
565 let mut subscriptions = self.subscriptions.write();
566 subscriptions
567 .entry(topic.clone())
568 .or_default()
569 .push(agent_id);
570
571 tracing::info!("Agent {} subscribed to topic {}", agent_id, topic);
572 Ok(())
573 }
574
575 async fn unsubscribe(
576 &self,
577 agent_id: AgentId,
578 topic: String,
579 ) -> Result<(), CommunicationError> {
580 let mut subscriptions = self.subscriptions.write();
581 if let Some(subscribers) = subscriptions.get_mut(&topic) {
582 subscribers.retain(|&id| id != agent_id);
583 if subscribers.is_empty() {
584 subscriptions.remove(&topic);
585 }
586 }
587
588 tracing::info!("Agent {} unsubscribed from topic {}", agent_id, topic);
589 Ok(())
590 }
591
592 async fn publish(
593 &self,
594 topic: String,
595 message: SecureMessage,
596 ) -> Result<(), CommunicationError> {
597 if !*self.is_running.read() {
598 return Err(CommunicationError::ShuttingDown);
599 }
600
601 self.send_event(CommunicationEvent::TopicPublished { topic, message })?;
602 Ok(())
603 }
604
605 async fn get_delivery_status(
606 &self,
607 message_id: MessageId,
608 ) -> Result<DeliveryStatus, CommunicationError> {
609 self.message_tracker
610 .read()
611 .get(&message_id)
612 .map(|tracker| tracker.status.clone())
613 .ok_or(CommunicationError::MessageNotFound { message_id })
614 }
615
616 async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
617 self.send_event(CommunicationEvent::AgentRegistered { agent_id })?;
618 Ok(())
619 }
620
621 async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
622 self.send_event(CommunicationEvent::AgentUnregistered { agent_id })?;
623 Ok(())
624 }
625
626 async fn request(
627 &self,
628 target_agent: AgentId,
629 request_payload: bytes::Bytes,
630 timeout_duration: Duration,
631 ) -> Result<bytes::Bytes, CommunicationError> {
632 if !*self.is_running.read() {
633 return Err(CommunicationError::ShuttingDown);
634 }
635
636 let request_id = RequestId::new();
638 let (response_sender, response_receiver) = oneshot::channel();
639
640 self.pending_requests
642 .write()
643 .insert(request_id, response_sender);
644
645 let request_message = self.create_secure_request_message(
647 target_agent,
648 request_id,
649 request_payload,
650 timeout_duration,
651 )?;
652
653 self.send_message(request_message).await?;
655
656 match timeout(timeout_duration, response_receiver).await {
658 Ok(Ok(response_payload)) => Ok(response_payload),
659 Ok(Err(_)) => {
660 self.pending_requests.write().remove(&request_id);
662 Err(CommunicationError::RequestCancelled { request_id })
663 }
664 Err(_) => {
665 self.pending_requests.write().remove(&request_id);
667 Err(CommunicationError::RequestTimeout {
668 request_id,
669 timeout: timeout_duration,
670 })
671 }
672 }
673 }
674
675 async fn shutdown(&self) -> Result<(), CommunicationError> {
676 tracing::info!("Shutting down communication bus");
677
678 *self.is_running.write() = false;
679 self.shutdown_notify.notify_waiters();
680
681 let agent_ids: Vec<AgentId> = self.message_queues.read().keys().copied().collect();
683
684 for agent_id in agent_ids {
685 if let Err(e) = self.unregister_agent(agent_id).await {
686 tracing::error!(
687 "Failed to unregister agent {} during shutdown: {}",
688 agent_id,
689 e
690 );
691 }
692 }
693
694 Ok(())
695 }
696
697 async fn check_health(&self) -> Result<ComponentHealth, CommunicationError> {
698 let is_running = *self.is_running.read();
699 if !is_running {
700 return Ok(ComponentHealth::unhealthy(
701 "Communication bus is shut down".to_string(),
702 ));
703 }
704
705 let queue_count = self.message_queues.read().len();
706 let topic_count = self.subscriptions.read().len();
707 let tracker_count = self.message_tracker.read().len();
708 let pending_requests = self.pending_requests.read().len();
709
710 let mut total_queued_messages = 0;
712 let mut full_queues = 0;
713
714 {
715 let queues = self.message_queues.read();
716 for queue in queues.values() {
717 total_queued_messages += queue.messages.len();
718 if queue.messages.len() >= self.config.max_queue_size * 9 / 10 {
719 full_queues += 1;
721 }
722 }
723 }
724
725 let dead_letter_count = self.dead_letter_queue.read().messages.len();
726
727 let status = if dead_letter_count > 100 {
728 ComponentHealth::degraded(format!(
729 "High dead letter queue: {} messages",
730 dead_letter_count
731 ))
732 } else if full_queues > 0 {
733 ComponentHealth::degraded(format!("{} message queues near capacity", full_queues))
734 } else if pending_requests > 50 {
735 ComponentHealth::degraded(format!("Many pending requests: {}", pending_requests))
736 } else {
737 ComponentHealth::healthy(Some(format!(
738 "{} agents registered, {} active topics",
739 queue_count, topic_count
740 )))
741 };
742
743 Ok(status
744 .with_metric("registered_agents".to_string(), queue_count.to_string())
745 .with_metric("active_topics".to_string(), topic_count.to_string())
746 .with_metric(
747 "queued_messages".to_string(),
748 total_queued_messages.to_string(),
749 )
750 .with_metric("pending_requests".to_string(), pending_requests.to_string())
751 .with_metric("dead_letters".to_string(), dead_letter_count.to_string())
752 .with_metric("message_trackers".to_string(), tracker_count.to_string()))
753 }
754
755 fn create_internal_message(
756 &self,
757 sender: AgentId,
758 recipient: AgentId,
759 payload_data: bytes::Bytes,
760 message_type: MessageType,
761 ttl: Duration,
762 ) -> SecureMessage {
763 let nonce = Self::generate_nonce();
764
765 let payload = EncryptedPayload {
766 data: payload_data,
767 nonce,
768 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
769 };
770
771 let message_data_to_sign = [payload.data.as_ref(), &payload.nonce].concat();
773 let signature = self.sign_message_data(&message_data_to_sign);
774
775 SecureMessage {
776 id: MessageId::new(),
777 sender,
778 recipient: Some(recipient),
779 topic: None,
780 message_type,
781 payload,
782 signature,
783 ttl,
784 timestamp: SystemTime::now(),
785 }
786 }
787}
788
789#[derive(Debug, Clone)]
791struct MessageQueue {
792 messages: Vec<SecureMessage>,
793 created_at: SystemTime,
794}
795
796impl MessageQueue {
797 fn new() -> Self {
798 Self {
799 messages: Vec::new(),
800 created_at: SystemTime::now(),
801 }
802 }
803
804 fn can_accept_message(&self, config: &CommunicationConfig) -> bool {
805 self.messages.len() < config.max_queue_size
806 }
807
808 fn add_message(&mut self, message: SecureMessage) {
809 self.messages.push(message);
810 }
811
812 fn drain_messages(&mut self) -> Vec<SecureMessage> {
813 std::mem::take(&mut self.messages)
814 }
815
816 fn remove_expired_messages(&mut self, now: SystemTime, ttl: Duration) -> Vec<SecureMessage> {
817 let mut expired = Vec::new();
818
819 self.messages.retain(|message| {
820 let age = now.duration_since(message.timestamp).unwrap_or_default();
821 if age > ttl {
822 expired.push(message.clone());
823 false
824 } else {
825 true
826 }
827 });
828
829 expired
830 }
831
832 fn get_queue_age(&self) -> Duration {
833 SystemTime::now()
834 .duration_since(self.created_at)
835 .unwrap_or_default()
836 }
837
838 fn is_stale(&self, max_age: Duration) -> bool {
839 self.get_queue_age() > max_age
840 }
841}
842
843#[derive(Debug, Clone)]
845struct MessageTracker {
846 message: SecureMessage,
847 status: DeliveryStatus,
848 created_at: SystemTime,
849 delivered_at: Option<SystemTime>,
850 failure_reason: Option<String>,
851}
852
853impl MessageTracker {
854 fn new(message: SecureMessage) -> Self {
855 Self {
856 message,
857 status: DeliveryStatus::Pending,
858 created_at: SystemTime::now(),
859 delivered_at: None,
860 failure_reason: None,
861 }
862 }
863
864 fn get_message(&self) -> &SecureMessage {
866 &self.message
867 }
868
869 fn get_message_size(&self) -> usize {
871 self.message.payload.data.len()
872 }
873
874 fn get_age(&self) -> Duration {
876 SystemTime::now()
877 .duration_since(self.created_at)
878 .unwrap_or_default()
879 }
880
881 fn should_retry(&self, max_age: Duration) -> bool {
883 matches!(self.status, DeliveryStatus::Failed) && self.get_age() < max_age
884 }
885}
886
887#[derive(Debug, Clone, PartialEq, Eq)]
889pub enum DeliveryStatus {
890 Pending,
891 Delivered,
892 Failed,
893 Expired,
894}
895
896#[derive(Debug, Clone)]
898enum CommunicationEvent {
899 MessageSent {
900 message: SecureMessage,
901 },
902 TopicPublished {
903 topic: String,
904 message: SecureMessage,
905 },
906 AgentRegistered {
907 agent_id: AgentId,
908 },
909 AgentUnregistered {
910 agent_id: AgentId,
911 },
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917 use crate::types::{EncryptedPayload, MessageType};
918
919 fn create_test_message(sender: AgentId, recipient: AgentId) -> SecureMessage {
920 use crate::types::RequestId;
921 use aes_gcm::{aead::AeadCore, Aes256Gcm};
922 use ed25519_dalek::Signer;
923
924 let mut secret_bytes = [0u8; 32];
925 OsRng.fill_bytes(&mut secret_bytes);
926 let signing_key = SigningKey::from_bytes(&secret_bytes);
927 let verifying_key = signing_key.verifying_key();
928
929 let nonce = Aes256Gcm::generate_nonce(&mut OsRng).to_vec();
930 let data: bytes::Bytes = b"test message".to_vec().into();
931
932 let message_data_to_sign = [data.as_ref(), &nonce].concat();
933 let signature = signing_key.sign(&message_data_to_sign);
934
935 SecureMessage {
936 id: MessageId::new(),
937 sender,
938 recipient: Some(recipient),
939 message_type: MessageType::Request(RequestId::new()),
940 topic: Some("test".to_string()),
941 payload: EncryptedPayload {
942 data,
943 nonce,
944 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
945 },
946 signature: MessageSignature {
947 signature: signature.to_bytes().to_vec(),
948 algorithm: SignatureAlgorithm::Ed25519,
949 public_key: verifying_key.to_bytes().to_vec(),
950 },
951 ttl: Duration::from_secs(3600),
952 timestamp: SystemTime::now(),
953 }
954 }
955
956 #[tokio::test]
957 async fn test_agent_registration() {
958 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
959 .await
960 .unwrap();
961 let agent_id = AgentId::new();
962
963 let result = bus.register_agent(agent_id).await;
964 assert!(result.is_ok());
965
966 tokio::time::sleep(Duration::from_millis(50)).await;
968
969 let messages = bus.receive_messages(agent_id).await;
971 assert!(messages.is_ok());
972 }
973
974 #[tokio::test]
975 async fn test_message_sending() {
976 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
977 .await
978 .unwrap();
979 let sender = AgentId::new();
980 let recipient = AgentId::new();
981
982 bus.register_agent(sender).await.unwrap();
984 bus.register_agent(recipient).await.unwrap();
985
986 tokio::time::sleep(Duration::from_millis(50)).await;
987
988 let message = create_test_message(sender, recipient);
990 let message_id = bus.send_message(message).await.unwrap();
991
992 tokio::time::sleep(Duration::from_millis(50)).await;
993
994 let status = bus.get_delivery_status(message_id).await.unwrap();
996 assert_eq!(status, DeliveryStatus::Delivered);
997
998 let messages = bus.receive_messages(recipient).await.unwrap();
1000 assert_eq!(messages.len(), 1);
1001 assert_eq!(messages[0].sender, sender);
1002 }
1003
1004 #[tokio::test]
1005 async fn test_topic_subscription() {
1006 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1007 .await
1008 .unwrap();
1009 let publisher = AgentId::new();
1010 let subscriber1 = AgentId::new();
1011 let subscriber2 = AgentId::new();
1012
1013 bus.register_agent(publisher).await.unwrap();
1015 bus.register_agent(subscriber1).await.unwrap();
1016 bus.register_agent(subscriber2).await.unwrap();
1017
1018 let topic = "test_topic".to_string();
1020 bus.subscribe(subscriber1, topic.clone()).await.unwrap();
1021 bus.subscribe(subscriber2, topic.clone()).await.unwrap();
1022
1023 tokio::time::sleep(Duration::from_millis(50)).await;
1024
1025 let message = create_test_message(publisher, AgentId::new()); bus.publish(topic, message).await.unwrap();
1028
1029 tokio::time::sleep(Duration::from_millis(50)).await;
1030
1031 let messages1 = bus.receive_messages(subscriber1).await.unwrap();
1033 let messages2 = bus.receive_messages(subscriber2).await.unwrap();
1034
1035 assert_eq!(messages1.len(), 1);
1036 assert_eq!(messages2.len(), 1);
1037 assert_eq!(messages1[0].sender, publisher);
1038 assert_eq!(messages2[0].sender, publisher);
1039 }
1040
1041 #[tokio::test]
1042 async fn test_message_size_limit() {
1043 let config = CommunicationConfig {
1044 max_message_size: 100, ..Default::default()
1046 };
1047
1048 let bus = DefaultCommunicationBus::new(config).await.unwrap();
1049 let sender = AgentId::new();
1050 let recipient = AgentId::new();
1051
1052 bus.register_agent(sender).await.unwrap();
1053 bus.register_agent(recipient).await.unwrap();
1054
1055 let mut message = create_test_message(sender, recipient);
1057 message.payload.data = vec![0u8; 200].into(); let result = bus.send_message(message).await;
1060 assert!(result.is_err());
1061
1062 if let Err(CommunicationError::MessageTooLarge { size, max_size }) = result {
1063 assert_eq!(size, 200);
1064 assert_eq!(max_size, 100);
1065 } else {
1066 panic!("Expected MessageTooLarge error");
1067 }
1068 }
1069
1070 #[tokio::test]
1071 async fn test_agent_unregistration() {
1072 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1073 .await
1074 .unwrap();
1075 let agent_id = AgentId::new();
1076
1077 bus.register_agent(agent_id).await.unwrap();
1079 tokio::time::sleep(Duration::from_millis(50)).await;
1080
1081 bus.unregister_agent(agent_id).await.unwrap();
1082 tokio::time::sleep(Duration::from_millis(50)).await;
1083
1084 let result = bus.receive_messages(agent_id).await;
1086 assert!(result.is_err());
1087 }
1088
1089 #[tokio::test]
1090 async fn test_request_response_timeout() {
1091 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1092 .await
1093 .unwrap();
1094 let target_agent = AgentId::new();
1095
1096 bus.register_agent(target_agent).await.unwrap();
1098 tokio::time::sleep(Duration::from_millis(50)).await;
1099
1100 let request_payload = bytes::Bytes::from("test request");
1102 let timeout = Duration::from_millis(100);
1103
1104 let result = bus.request(target_agent, request_payload, timeout).await;
1105 assert!(result.is_err());
1106
1107 if let Err(CommunicationError::RequestTimeout {
1108 request_id: _,
1109 timeout: actual_timeout,
1110 }) = result
1111 {
1112 assert_eq!(actual_timeout, timeout);
1113 } else {
1114 panic!("Expected RequestTimeout error");
1115 }
1116 }
1117
1118 #[tokio::test]
1119 async fn test_request_response_success() {
1120 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1121 .await
1122 .unwrap();
1123 let requester = AgentId::new();
1124 let responder = AgentId::new();
1125
1126 bus.register_agent(requester).await.unwrap();
1128 bus.register_agent(responder).await.unwrap();
1129 tokio::time::sleep(Duration::from_millis(50)).await;
1130
1131 let request_payload = bytes::Bytes::from("test request");
1132 let response_payload = bytes::Bytes::from("test response");
1133
1134 let bus_clone = Arc::new(bus);
1136 let request_bus = bus_clone.clone();
1137 let request_handle = tokio::spawn(async move {
1138 request_bus
1139 .request(responder, request_payload, Duration::from_secs(5))
1140 .await
1141 });
1142
1143 tokio::time::sleep(Duration::from_millis(100)).await;
1145
1146 let messages = bus_clone.receive_messages(responder).await.unwrap();
1148 assert_eq!(messages.len(), 1);
1149 assert!(matches!(messages[0].message_type, MessageType::Request(_)));
1150
1151 if let MessageType::Request(request_id) = &messages[0].message_type {
1153 let response_message = bus_clone.create_internal_message(
1154 responder,
1155 requester,
1156 response_payload.clone(),
1157 MessageType::Response(*request_id),
1158 Duration::from_secs(3600),
1159 );
1160
1161 bus_clone.send_message(response_message).await.unwrap();
1162 }
1163
1164 let result = request_handle.await.unwrap();
1166 assert!(result.is_ok());
1167 assert_eq!(result.unwrap(), response_payload);
1168 }
1169}