1use crate::{ZoeyError, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tracing::{debug, info, warn};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NodeInfo {
17 pub id: uuid::Uuid,
19
20 pub name: String,
22
23 pub address: String,
25
26 pub status: NodeStatus,
28
29 pub agents: Vec<uuid::Uuid>,
31
32 pub cpu_usage: f32,
34
35 pub memory_usage: f32,
37
38 pub last_heartbeat: i64,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum NodeStatus {
45 Healthy,
47
48 Degraded,
50
51 Unhealthy,
53
54 Offline,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct DistributedMessage {
61 pub id: uuid::Uuid,
63
64 pub from_node: uuid::Uuid,
66
67 pub to_node: uuid::Uuid,
69
70 pub from_agent: uuid::Uuid,
72
73 pub to_agent: uuid::Uuid,
75
76 pub payload: serde_json::Value,
78
79 pub message_type: String,
81
82 pub timestamp: i64,
84}
85
86pub struct DistributedRuntime {
88 node_id: uuid::Uuid,
90
91 nodes: Arc<RwLock<HashMap<uuid::Uuid, NodeInfo>>>,
93
94 message_tx: mpsc::UnboundedSender<DistributedMessage>,
96
97 message_rx: Arc<RwLock<mpsc::UnboundedReceiver<DistributedMessage>>>,
99
100 agent_locations: Arc<RwLock<HashMap<uuid::Uuid, uuid::Uuid>>>,
102
103 pending_count: Arc<AtomicUsize>,
105
106 messages_sent: Arc<AtomicUsize>,
108
109 messages_received: Arc<AtomicUsize>,
111}
112
113impl DistributedRuntime {
114 pub fn new(node_id: uuid::Uuid) -> Self {
116 let (tx, rx) = mpsc::unbounded_channel();
117
118 Self {
119 node_id,
120 nodes: Arc::new(RwLock::new(HashMap::new())),
121 message_tx: tx,
122 message_rx: Arc::new(RwLock::new(rx)),
123 agent_locations: Arc::new(RwLock::new(HashMap::new())),
124 pending_count: Arc::new(AtomicUsize::new(0)),
125 messages_sent: Arc::new(AtomicUsize::new(0)),
126 messages_received: Arc::new(AtomicUsize::new(0)),
127 }
128 }
129
130 pub fn register_node(&self, node: NodeInfo) -> Result<()> {
132 info!("Registering node {} at {}", node.name, node.address);
133 debug!("Node {} has {} agents", node.id, node.agents.len());
134
135 for agent_id in &node.agents {
137 debug!("Mapping agent {} to node {}", agent_id, node.id);
138 self.agent_locations
139 .write()
140 .unwrap()
141 .insert(*agent_id, node.id);
142 }
143
144 self.nodes.write().unwrap().insert(node.id, node);
145 debug!(
146 "Total nodes in cluster: {}",
147 self.nodes.read().unwrap().len()
148 );
149
150 Ok(())
151 }
152
153 pub fn unregister_node(&self, node_id: uuid::Uuid) -> Result<()> {
155 info!("Unregistering node {}", node_id);
156
157 if let Some(node) = self.nodes.write().unwrap().remove(&node_id) {
158 for agent_id in &node.agents {
160 self.agent_locations.write().unwrap().remove(agent_id);
161 }
162 }
163
164 Ok(())
165 }
166
167 pub async fn send_to_agent(
169 &self,
170 from_agent: uuid::Uuid,
171 to_agent: uuid::Uuid,
172 payload: serde_json::Value,
173 message_type: String,
174 ) -> Result<()> {
175 debug!(
176 "Sending {} message from agent {} to agent {}",
177 message_type, from_agent, to_agent
178 );
179
180 let to_node = self
182 .agent_locations
183 .read()
184 .unwrap()
185 .get(&to_agent)
186 .copied()
187 .ok_or_else(|| {
188 ZoeyError::not_found(format!("Agent {} not found in cluster", to_agent))
189 })?;
190
191 debug!("Target agent {} is on node {}", to_agent, to_node);
192
193 let message = DistributedMessage {
194 id: uuid::Uuid::new_v4(),
195 from_node: self.node_id,
196 to_node,
197 from_agent,
198 to_agent,
199 payload,
200 message_type: message_type.clone(),
201 timestamp: chrono::Utc::now().timestamp(),
202 };
203
204 self.message_tx
206 .send(message)
207 .map_err(|e| ZoeyError::other(format!("Failed to send message: {}", e)))?;
208
209 self.pending_count.fetch_add(1, Ordering::SeqCst);
211 self.messages_sent.fetch_add(1, Ordering::SeqCst);
212
213 debug!(
214 "Message queued successfully (pending: {})",
215 self.pending_count.load(Ordering::SeqCst)
216 );
217 Ok(())
218 }
219
220 pub fn get_agent_node(&self, agent_id: uuid::Uuid) -> Option<uuid::Uuid> {
222 self.agent_locations.read().unwrap().get(&agent_id).copied()
223 }
224
225 pub fn get_nodes(&self) -> Vec<NodeInfo> {
227 self.nodes.read().unwrap().values().cloned().collect()
228 }
229
230 pub fn get_healthy_nodes(&self) -> Vec<NodeInfo> {
232 self.nodes
233 .read()
234 .unwrap()
235 .values()
236 .filter(|n| n.status == NodeStatus::Healthy)
237 .cloned()
238 .collect()
239 }
240
241 pub fn find_best_node(&self) -> Option<uuid::Uuid> {
243 let nodes = self.get_healthy_nodes();
244
245 if nodes.is_empty() {
246 warn!("No healthy nodes available for load balancing");
247 return None;
248 }
249
250 debug!("Finding best node among {} healthy nodes", nodes.len());
251
252 let best = nodes
254 .iter()
255 .min_by(|a, b| {
256 let load_a = a.cpu_usage + a.memory_usage;
257 let load_b = b.cpu_usage + b.memory_usage;
258 load_a
259 .partial_cmp(&load_b)
260 .unwrap_or(std::cmp::Ordering::Equal)
261 })
262 .map(|n| {
263 let load = n.cpu_usage + n.memory_usage;
264 debug!("Selected node {} with load {:.2}", n.name, load);
265 n.id
266 });
267
268 best
269 }
270
271 pub fn heartbeat(&self, node_id: uuid::Uuid, cpu_usage: f32, memory_usage: f32) -> Result<()> {
273 if let Some(node) = self.nodes.write().unwrap().get_mut(&node_id) {
274 let old_status = node.status;
275 node.cpu_usage = cpu_usage;
276 node.memory_usage = memory_usage;
277 node.last_heartbeat = chrono::Utc::now().timestamp();
278
279 node.status = if cpu_usage > 0.9 || memory_usage > 0.9 {
281 NodeStatus::Degraded
282 } else if cpu_usage > 0.95 || memory_usage > 0.95 {
283 NodeStatus::Unhealthy
284 } else {
285 NodeStatus::Healthy
286 };
287
288 if old_status != node.status {
290 info!(
291 "Node {} status changed: {:?} -> {:?}",
292 node.name, old_status, node.status
293 );
294 }
295 debug!(
296 "Node {} heartbeat: CPU {:.1}%, Memory {:.1}%",
297 node.name,
298 cpu_usage * 100.0,
299 memory_usage * 100.0
300 );
301 } else {
302 warn!("Received heartbeat from unknown node {}", node_id);
303 }
304
305 Ok(())
306 }
307
308 pub fn check_node_health(&self, timeout_seconds: i64) -> Vec<uuid::Uuid> {
310 let now = chrono::Utc::now().timestamp();
311 let mut dead_nodes = Vec::new();
312
313 for (node_id, node) in self.nodes.read().unwrap().iter() {
314 if now - node.last_heartbeat > timeout_seconds {
315 warn!(
316 "Node {} hasn't sent heartbeat for {} seconds",
317 node.name,
318 now - node.last_heartbeat
319 );
320 dead_nodes.push(*node_id);
321 }
322 }
323
324 dead_nodes
325 }
326
327 pub fn try_recv_message(&self) -> Option<DistributedMessage> {
329 let mut rx = self.message_rx.write().unwrap();
330 match rx.try_recv() {
331 Ok(msg) => {
332 self.pending_count.fetch_sub(1, Ordering::SeqCst);
334 self.messages_received.fetch_add(1, Ordering::SeqCst);
335 Some(msg)
336 }
337 Err(_) => None,
338 }
339 }
340
341 pub async fn receive_messages<F>(&self, mut handler: F) -> Result<()>
343 where
344 F: FnMut(
345 DistributedMessage,
346 )
347 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
348 + Send,
349 {
350 loop {
351 let message = {
352 let mut rx = self.message_rx.write().unwrap();
353 match rx.try_recv() {
354 Ok(msg) => Some(msg),
355 Err(mpsc::error::TryRecvError::Empty) => None,
356 Err(mpsc::error::TryRecvError::Disconnected) => {
357 warn!("Message channel disconnected");
358 return Err(ZoeyError::other("Message channel disconnected"));
359 }
360 }
361 };
362
363 if let Some(msg) = message {
364 debug!("Received message {} from node {}", msg.id, msg.from_node);
365 handler(msg).await?;
366 } else {
367 tokio::time::sleep(Duration::from_millis(10)).await;
369 }
370 }
371 }
372
373 pub async fn process_pending_messages<F>(&self, handler: F) -> Result<usize>
375 where
376 F: Fn(&DistributedMessage) -> Result<()>,
377 {
378 let mut processed = 0;
379
380 loop {
381 let message = self.try_recv_message();
382
383 match message {
384 Some(msg) => {
385 debug!("Processing message {} type={}", msg.id, msg.message_type);
386
387 if msg.to_node != self.node_id {
389 warn!(
390 "Received message for wrong node: expected {}, got {}",
391 self.node_id, msg.to_node
392 );
393 continue;
394 }
395
396 match handler(&msg) {
397 Ok(_) => {
398 processed += 1;
399 debug!("Successfully processed message {}", msg.id);
400 }
401 Err(e) => {
402 warn!("Failed to process message {}: {}", msg.id, e);
403 }
404 }
405 }
406 None => {
407 break;
409 }
410 }
411 }
412
413 if processed > 0 {
414 info!("Processed {} distributed message(s)", processed);
415 }
416
417 Ok(processed)
418 }
419
420 pub fn pending_message_count(&self) -> usize {
422 self.pending_count.load(Ordering::SeqCst)
423 }
424
425 pub fn total_messages_sent(&self) -> usize {
427 self.messages_sent.load(Ordering::SeqCst)
428 }
429
430 pub fn total_messages_received(&self) -> usize {
432 self.messages_received.load(Ordering::SeqCst)
433 }
434
435 pub fn get_message_stats(&self) -> MessageStats {
437 MessageStats {
438 pending: self.pending_count.load(Ordering::SeqCst),
439 sent: self.messages_sent.load(Ordering::SeqCst),
440 received: self.messages_received.load(Ordering::SeqCst),
441 }
442 }
443
444 pub fn reset_message_stats(&self) {
446 self.messages_sent.store(0, Ordering::SeqCst);
447 self.messages_received.store(0, Ordering::SeqCst);
448 info!("Message statistics reset for node {}", self.node_id);
449 }
450}
451
452#[derive(Debug, Clone, Copy)]
454pub struct MessageStats {
455 pub pending: usize,
457 pub sent: usize,
459 pub received: usize,
461}
462
463#[derive(Debug, Clone)]
465pub struct ClusterConfig {
466 pub heartbeat_interval: Duration,
468
469 pub node_timeout: Duration,
471
472 pub auto_rebalance: bool,
474
475 pub replication_factor: usize,
477}
478
479impl Default for ClusterConfig {
480 fn default() -> Self {
481 Self {
482 heartbeat_interval: Duration::from_secs(5),
483 node_timeout: Duration::from_secs(30),
484 auto_rebalance: true,
485 replication_factor: 1,
486 }
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn test_distributed_runtime() {
496 let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
497 assert_eq!(runtime.get_nodes().len(), 0);
498 }
499
500 #[test]
501 fn test_node_registration() {
502 let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
503
504 let node = NodeInfo {
505 id: uuid::Uuid::new_v4(),
506 name: "node1".to_string(),
507 address: "127.0.0.1:8080".to_string(),
508 status: NodeStatus::Healthy,
509 agents: vec![],
510 cpu_usage: 0.5,
511 memory_usage: 0.6,
512 last_heartbeat: chrono::Utc::now().timestamp(),
513 };
514
515 runtime.register_node(node.clone()).unwrap();
516
517 assert_eq!(runtime.get_nodes().len(), 1);
518 assert_eq!(runtime.get_healthy_nodes().len(), 1);
519 }
520
521 #[test]
522 fn test_load_balancing() {
523 let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
524
525 let node1 = NodeInfo {
526 id: uuid::Uuid::new_v4(),
527 name: "node1".to_string(),
528 address: "127.0.0.1:8080".to_string(),
529 status: NodeStatus::Healthy,
530 agents: vec![],
531 cpu_usage: 0.8, memory_usage: 0.7,
533 last_heartbeat: chrono::Utc::now().timestamp(),
534 };
535
536 let node2 = NodeInfo {
537 id: uuid::Uuid::new_v4(),
538 name: "node2".to_string(),
539 address: "127.0.0.1:8081".to_string(),
540 status: NodeStatus::Healthy,
541 agents: vec![],
542 cpu_usage: 0.3, memory_usage: 0.4,
544 last_heartbeat: chrono::Utc::now().timestamp(),
545 };
546
547 runtime.register_node(node1).unwrap();
548 runtime.register_node(node2.clone()).unwrap();
549
550 let best = runtime.find_best_node().unwrap();
552 assert_eq!(best, node2.id);
553 }
554
555 #[tokio::test]
556 async fn test_cross_node_messaging() {
557 let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
558
559 let agent1 = uuid::Uuid::new_v4();
560 let agent2 = uuid::Uuid::new_v4();
561
562 let node1 = NodeInfo {
563 id: uuid::Uuid::new_v4(),
564 name: "node1".to_string(),
565 address: "127.0.0.1:8080".to_string(),
566 status: NodeStatus::Healthy,
567 agents: vec![agent1],
568 cpu_usage: 0.5,
569 memory_usage: 0.5,
570 last_heartbeat: chrono::Utc::now().timestamp(),
571 };
572
573 let node2 = NodeInfo {
574 id: uuid::Uuid::new_v4(),
575 name: "node2".to_string(),
576 address: "127.0.0.1:8081".to_string(),
577 status: NodeStatus::Healthy,
578 agents: vec![agent2],
579 cpu_usage: 0.5,
580 memory_usage: 0.5,
581 last_heartbeat: chrono::Utc::now().timestamp(),
582 };
583
584 runtime.register_node(node1).unwrap();
585 runtime.register_node(node2).unwrap();
586
587 let result = runtime
589 .send_to_agent(
590 agent1,
591 agent2,
592 serde_json::json!({"message": "Hello from another node!"}),
593 "greeting".to_string(),
594 )
595 .await;
596
597 assert!(result.is_ok());
598 }
599}