zoey_core/
distributed.rs

1//! Distributed Runtime Support
2//!
3//! Enables agents to run across multiple nodes/processes
4
5use crate::{ZoeyError, Result};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::{Arc, RwLock};
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tracing::{debug, info, warn};
13
14/// Node information
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NodeInfo {
17    /// Node ID
18    pub id: uuid::Uuid,
19
20    /// Node name/hostname
21    pub name: String,
22
23    /// Node address (IP:port)
24    pub address: String,
25
26    /// Node status
27    pub status: NodeStatus,
28
29    /// Agents running on this node
30    pub agents: Vec<uuid::Uuid>,
31
32    /// CPU usage (0.0 - 1.0)
33    pub cpu_usage: f32,
34
35    /// Memory usage (0.0 - 1.0)
36    pub memory_usage: f32,
37
38    /// Last heartbeat timestamp
39    pub last_heartbeat: i64,
40}
41
42/// Node status
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum NodeStatus {
45    /// Node is healthy and operational
46    Healthy,
47
48    /// Node is degraded but functional
49    Degraded,
50
51    /// Node is unhealthy
52    Unhealthy,
53
54    /// Node is offline
55    Offline,
56}
57
58/// Distributed message for cross-node communication
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct DistributedMessage {
61    /// Message ID
62    pub id: uuid::Uuid,
63
64    /// Source node ID
65    pub from_node: uuid::Uuid,
66
67    /// Target node ID
68    pub to_node: uuid::Uuid,
69
70    /// Source agent ID
71    pub from_agent: uuid::Uuid,
72
73    /// Target agent ID
74    pub to_agent: uuid::Uuid,
75
76    /// Message payload
77    pub payload: serde_json::Value,
78
79    /// Message type
80    pub message_type: String,
81
82    /// Timestamp
83    pub timestamp: i64,
84}
85
86/// Distributed runtime coordinator
87pub struct DistributedRuntime {
88    /// This node's ID
89    node_id: uuid::Uuid,
90
91    /// Registered nodes
92    nodes: Arc<RwLock<HashMap<uuid::Uuid, NodeInfo>>>,
93
94    /// Message sender
95    message_tx: mpsc::UnboundedSender<DistributedMessage>,
96
97    /// Message receiver for processing incoming messages
98    message_rx: Arc<RwLock<mpsc::UnboundedReceiver<DistributedMessage>>>,
99
100    /// Agent-to-node mapping
101    agent_locations: Arc<RwLock<HashMap<uuid::Uuid, uuid::Uuid>>>,
102
103    /// Count of pending messages in the queue
104    pending_count: Arc<AtomicUsize>,
105
106    /// Total messages sent
107    messages_sent: Arc<AtomicUsize>,
108
109    /// Total messages received
110    messages_received: Arc<AtomicUsize>,
111}
112
113impl DistributedRuntime {
114    /// Create a new distributed runtime
115    pub fn new(node_id: uuid::Uuid) -> Self {
116        let (tx, rx) = mpsc::unbounded_channel();
117
118        Self {
119            node_id,
120            nodes: Arc::new(RwLock::new(HashMap::new())),
121            message_tx: tx,
122            message_rx: Arc::new(RwLock::new(rx)),
123            agent_locations: Arc::new(RwLock::new(HashMap::new())),
124            pending_count: Arc::new(AtomicUsize::new(0)),
125            messages_sent: Arc::new(AtomicUsize::new(0)),
126            messages_received: Arc::new(AtomicUsize::new(0)),
127        }
128    }
129
130    /// Register a node in the cluster
131    pub fn register_node(&self, node: NodeInfo) -> Result<()> {
132        info!("Registering node {} at {}", node.name, node.address);
133        debug!("Node {} has {} agents", node.id, node.agents.len());
134
135        // Update agent locations
136        for agent_id in &node.agents {
137            debug!("Mapping agent {} to node {}", agent_id, node.id);
138            self.agent_locations
139                .write()
140                .unwrap()
141                .insert(*agent_id, node.id);
142        }
143
144        self.nodes.write().unwrap().insert(node.id, node);
145        debug!(
146            "Total nodes in cluster: {}",
147            self.nodes.read().unwrap().len()
148        );
149
150        Ok(())
151    }
152
153    /// Unregister a node
154    pub fn unregister_node(&self, node_id: uuid::Uuid) -> Result<()> {
155        info!("Unregistering node {}", node_id);
156
157        if let Some(node) = self.nodes.write().unwrap().remove(&node_id) {
158            // Remove agent locations
159            for agent_id in &node.agents {
160                self.agent_locations.write().unwrap().remove(agent_id);
161            }
162        }
163
164        Ok(())
165    }
166
167    /// Send message to agent on any node
168    pub async fn send_to_agent(
169        &self,
170        from_agent: uuid::Uuid,
171        to_agent: uuid::Uuid,
172        payload: serde_json::Value,
173        message_type: String,
174    ) -> Result<()> {
175        debug!(
176            "Sending {} message from agent {} to agent {}",
177            message_type, from_agent, to_agent
178        );
179
180        // Find target node
181        let to_node = self
182            .agent_locations
183            .read()
184            .unwrap()
185            .get(&to_agent)
186            .copied()
187            .ok_or_else(|| {
188                ZoeyError::not_found(format!("Agent {} not found in cluster", to_agent))
189            })?;
190
191        debug!("Target agent {} is on node {}", to_agent, to_node);
192
193        let message = DistributedMessage {
194            id: uuid::Uuid::new_v4(),
195            from_node: self.node_id,
196            to_node,
197            from_agent,
198            to_agent,
199            payload,
200            message_type: message_type.clone(),
201            timestamp: chrono::Utc::now().timestamp(),
202        };
203
204        // Send via message queue
205        self.message_tx
206            .send(message)
207            .map_err(|e| ZoeyError::other(format!("Failed to send message: {}", e)))?;
208
209        // Update counters
210        self.pending_count.fetch_add(1, Ordering::SeqCst);
211        self.messages_sent.fetch_add(1, Ordering::SeqCst);
212
213        debug!(
214            "Message queued successfully (pending: {})",
215            self.pending_count.load(Ordering::SeqCst)
216        );
217        Ok(())
218    }
219
220    /// Get node for agent
221    pub fn get_agent_node(&self, agent_id: uuid::Uuid) -> Option<uuid::Uuid> {
222        self.agent_locations.read().unwrap().get(&agent_id).copied()
223    }
224
225    /// Get all nodes
226    pub fn get_nodes(&self) -> Vec<NodeInfo> {
227        self.nodes.read().unwrap().values().cloned().collect()
228    }
229
230    /// Get healthy nodes
231    pub fn get_healthy_nodes(&self) -> Vec<NodeInfo> {
232        self.nodes
233            .read()
234            .unwrap()
235            .values()
236            .filter(|n| n.status == NodeStatus::Healthy)
237            .cloned()
238            .collect()
239    }
240
241    /// Find best node for new agent (load balancing)
242    pub fn find_best_node(&self) -> Option<uuid::Uuid> {
243        let nodes = self.get_healthy_nodes();
244
245        if nodes.is_empty() {
246            warn!("No healthy nodes available for load balancing");
247            return None;
248        }
249
250        debug!("Finding best node among {} healthy nodes", nodes.len());
251
252        // Find node with lowest combined load
253        let best = nodes
254            .iter()
255            .min_by(|a, b| {
256                let load_a = a.cpu_usage + a.memory_usage;
257                let load_b = b.cpu_usage + b.memory_usage;
258                load_a
259                    .partial_cmp(&load_b)
260                    .unwrap_or(std::cmp::Ordering::Equal)
261            })
262            .map(|n| {
263                let load = n.cpu_usage + n.memory_usage;
264                debug!("Selected node {} with load {:.2}", n.name, load);
265                n.id
266            });
267
268        best
269    }
270
271    /// Heartbeat to update node status
272    pub fn heartbeat(&self, node_id: uuid::Uuid, cpu_usage: f32, memory_usage: f32) -> Result<()> {
273        if let Some(node) = self.nodes.write().unwrap().get_mut(&node_id) {
274            let old_status = node.status;
275            node.cpu_usage = cpu_usage;
276            node.memory_usage = memory_usage;
277            node.last_heartbeat = chrono::Utc::now().timestamp();
278
279            // Update status based on health
280            node.status = if cpu_usage > 0.9 || memory_usage > 0.9 {
281                NodeStatus::Degraded
282            } else if cpu_usage > 0.95 || memory_usage > 0.95 {
283                NodeStatus::Unhealthy
284            } else {
285                NodeStatus::Healthy
286            };
287
288            // Log status changes
289            if old_status != node.status {
290                info!(
291                    "Node {} status changed: {:?} -> {:?}",
292                    node.name, old_status, node.status
293                );
294            }
295            debug!(
296                "Node {} heartbeat: CPU {:.1}%, Memory {:.1}%",
297                node.name,
298                cpu_usage * 100.0,
299                memory_usage * 100.0
300            );
301        } else {
302            warn!("Received heartbeat from unknown node {}", node_id);
303        }
304
305        Ok(())
306    }
307
308    /// Check for dead nodes (no heartbeat)
309    pub fn check_node_health(&self, timeout_seconds: i64) -> Vec<uuid::Uuid> {
310        let now = chrono::Utc::now().timestamp();
311        let mut dead_nodes = Vec::new();
312
313        for (node_id, node) in self.nodes.read().unwrap().iter() {
314            if now - node.last_heartbeat > timeout_seconds {
315                warn!(
316                    "Node {} hasn't sent heartbeat for {} seconds",
317                    node.name,
318                    now - node.last_heartbeat
319                );
320                dead_nodes.push(*node_id);
321            }
322        }
323
324        dead_nodes
325    }
326
327    /// Try to receive a message from the queue (non-blocking)
328    pub fn try_recv_message(&self) -> Option<DistributedMessage> {
329        let mut rx = self.message_rx.write().unwrap();
330        match rx.try_recv() {
331            Ok(msg) => {
332                // Decrement pending count and increment received count
333                self.pending_count.fetch_sub(1, Ordering::SeqCst);
334                self.messages_received.fetch_add(1, Ordering::SeqCst);
335                Some(msg)
336            }
337            Err(_) => None,
338        }
339    }
340
341    /// Receive messages with a handler (blocking until handler returns)
342    pub async fn receive_messages<F>(&self, mut handler: F) -> Result<()>
343    where
344        F: FnMut(
345                DistributedMessage,
346            )
347                -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
348            + Send,
349    {
350        loop {
351            let message = {
352                let mut rx = self.message_rx.write().unwrap();
353                match rx.try_recv() {
354                    Ok(msg) => Some(msg),
355                    Err(mpsc::error::TryRecvError::Empty) => None,
356                    Err(mpsc::error::TryRecvError::Disconnected) => {
357                        warn!("Message channel disconnected");
358                        return Err(ZoeyError::other("Message channel disconnected"));
359                    }
360                }
361            };
362
363            if let Some(msg) = message {
364                debug!("Received message {} from node {}", msg.id, msg.from_node);
365                handler(msg).await?;
366            } else {
367                // No messages available, yield
368                tokio::time::sleep(Duration::from_millis(10)).await;
369            }
370        }
371    }
372
373    /// Process pending messages with a batch handler
374    pub async fn process_pending_messages<F>(&self, handler: F) -> Result<usize>
375    where
376        F: Fn(&DistributedMessage) -> Result<()>,
377    {
378        let mut processed = 0;
379
380        loop {
381            let message = self.try_recv_message();
382
383            match message {
384                Some(msg) => {
385                    debug!("Processing message {} type={}", msg.id, msg.message_type);
386
387                    // Validate message is for this node
388                    if msg.to_node != self.node_id {
389                        warn!(
390                            "Received message for wrong node: expected {}, got {}",
391                            self.node_id, msg.to_node
392                        );
393                        continue;
394                    }
395
396                    match handler(&msg) {
397                        Ok(_) => {
398                            processed += 1;
399                            debug!("Successfully processed message {}", msg.id);
400                        }
401                        Err(e) => {
402                            warn!("Failed to process message {}: {}", msg.id, e);
403                        }
404                    }
405                }
406                None => {
407                    // No more messages in queue
408                    break;
409                }
410            }
411        }
412
413        if processed > 0 {
414            info!("Processed {} distributed message(s)", processed);
415        }
416
417        Ok(processed)
418    }
419
420    /// Get the number of pending messages in the queue
421    pub fn pending_message_count(&self) -> usize {
422        self.pending_count.load(Ordering::SeqCst)
423    }
424
425    /// Get the total number of messages sent through this node
426    pub fn total_messages_sent(&self) -> usize {
427        self.messages_sent.load(Ordering::SeqCst)
428    }
429
430    /// Get the total number of messages received by this node
431    pub fn total_messages_received(&self) -> usize {
432        self.messages_received.load(Ordering::SeqCst)
433    }
434
435    /// Get message processing statistics
436    pub fn get_message_stats(&self) -> MessageStats {
437        MessageStats {
438            pending: self.pending_count.load(Ordering::SeqCst),
439            sent: self.messages_sent.load(Ordering::SeqCst),
440            received: self.messages_received.load(Ordering::SeqCst),
441        }
442    }
443
444    /// Reset message statistics
445    pub fn reset_message_stats(&self) {
446        self.messages_sent.store(0, Ordering::SeqCst);
447        self.messages_received.store(0, Ordering::SeqCst);
448        info!("Message statistics reset for node {}", self.node_id);
449    }
450}
451
452/// Message processing statistics
453#[derive(Debug, Clone, Copy)]
454pub struct MessageStats {
455    /// Number of pending messages
456    pub pending: usize,
457    /// Total messages sent
458    pub sent: usize,
459    /// Total messages received
460    pub received: usize,
461}
462
463/// Cluster configuration
464#[derive(Debug, Clone)]
465pub struct ClusterConfig {
466    /// Heartbeat interval
467    pub heartbeat_interval: Duration,
468
469    /// Node timeout before considering dead
470    pub node_timeout: Duration,
471
472    /// Enable automatic rebalancing
473    pub auto_rebalance: bool,
474
475    /// Replication factor
476    pub replication_factor: usize,
477}
478
479impl Default for ClusterConfig {
480    fn default() -> Self {
481        Self {
482            heartbeat_interval: Duration::from_secs(5),
483            node_timeout: Duration::from_secs(30),
484            auto_rebalance: true,
485            replication_factor: 1,
486        }
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn test_distributed_runtime() {
496        let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
497        assert_eq!(runtime.get_nodes().len(), 0);
498    }
499
500    #[test]
501    fn test_node_registration() {
502        let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
503
504        let node = NodeInfo {
505            id: uuid::Uuid::new_v4(),
506            name: "node1".to_string(),
507            address: "127.0.0.1:8080".to_string(),
508            status: NodeStatus::Healthy,
509            agents: vec![],
510            cpu_usage: 0.5,
511            memory_usage: 0.6,
512            last_heartbeat: chrono::Utc::now().timestamp(),
513        };
514
515        runtime.register_node(node.clone()).unwrap();
516
517        assert_eq!(runtime.get_nodes().len(), 1);
518        assert_eq!(runtime.get_healthy_nodes().len(), 1);
519    }
520
521    #[test]
522    fn test_load_balancing() {
523        let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
524
525        let node1 = NodeInfo {
526            id: uuid::Uuid::new_v4(),
527            name: "node1".to_string(),
528            address: "127.0.0.1:8080".to_string(),
529            status: NodeStatus::Healthy,
530            agents: vec![],
531            cpu_usage: 0.8, // High load
532            memory_usage: 0.7,
533            last_heartbeat: chrono::Utc::now().timestamp(),
534        };
535
536        let node2 = NodeInfo {
537            id: uuid::Uuid::new_v4(),
538            name: "node2".to_string(),
539            address: "127.0.0.1:8081".to_string(),
540            status: NodeStatus::Healthy,
541            agents: vec![],
542            cpu_usage: 0.3, // Low load
543            memory_usage: 0.4,
544            last_heartbeat: chrono::Utc::now().timestamp(),
545        };
546
547        runtime.register_node(node1).unwrap();
548        runtime.register_node(node2.clone()).unwrap();
549
550        // Should select node2 (lower load)
551        let best = runtime.find_best_node().unwrap();
552        assert_eq!(best, node2.id);
553    }
554
555    #[tokio::test]
556    async fn test_cross_node_messaging() {
557        let runtime = DistributedRuntime::new(uuid::Uuid::new_v4());
558
559        let agent1 = uuid::Uuid::new_v4();
560        let agent2 = uuid::Uuid::new_v4();
561
562        let node1 = NodeInfo {
563            id: uuid::Uuid::new_v4(),
564            name: "node1".to_string(),
565            address: "127.0.0.1:8080".to_string(),
566            status: NodeStatus::Healthy,
567            agents: vec![agent1],
568            cpu_usage: 0.5,
569            memory_usage: 0.5,
570            last_heartbeat: chrono::Utc::now().timestamp(),
571        };
572
573        let node2 = NodeInfo {
574            id: uuid::Uuid::new_v4(),
575            name: "node2".to_string(),
576            address: "127.0.0.1:8081".to_string(),
577            status: NodeStatus::Healthy,
578            agents: vec![agent2],
579            cpu_usage: 0.5,
580            memory_usage: 0.5,
581            last_heartbeat: chrono::Utc::now().timestamp(),
582        };
583
584        runtime.register_node(node1).unwrap();
585        runtime.register_node(node2).unwrap();
586
587        // Send message from agent1 to agent2 (cross-node)
588        let result = runtime
589            .send_to_agent(
590                agent1,
591                agent2,
592                serde_json::json!({"message": "Hello from another node!"}),
593                "greeting".to_string(),
594            )
595            .await;
596
597        assert!(result.is_ok());
598    }
599}