Skip to main content

scirs2_integrate/distributed/
node.rs

1//! Node management for distributed computing
2//!
3//! This module provides abstractions for compute nodes in the distributed
4//! integration system, including node lifecycle management, resource tracking,
5//! and health monitoring.
6
7use crate::common::IntegrateFloat;
8use crate::distributed::types::{
9    ChunkResult, DistributedError, DistributedResult, NodeCapabilities, NodeId, NodeInfo,
10    NodeStatus, SimdCapability, WorkChunk,
11};
12use crate::error::IntegrateResult;
13use scirs2_core::ndarray::Array1;
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::{Arc, Mutex, RwLock};
18use std::thread;
19use std::time::{Duration, Instant};
20
21/// Manager for compute nodes in the distributed system
22pub struct NodeManager {
23    /// Registered nodes
24    nodes: RwLock<HashMap<NodeId, NodeInfo>>,
25    /// Next node ID to assign
26    next_node_id: AtomicU64,
27    /// Node health check timeout
28    health_check_timeout: Duration,
29    /// Shutdown flag
30    shutdown: AtomicBool,
31    /// Health monitor handle
32    health_monitor: Mutex<Option<thread::JoinHandle<()>>>,
33    /// Node failure callbacks
34    failure_callbacks: RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>>,
35}
36
37impl NodeManager {
38    /// Create a new node manager
39    pub fn new(health_check_timeout: Duration) -> Self {
40        Self {
41            nodes: RwLock::new(HashMap::new()),
42            next_node_id: AtomicU64::new(1),
43            health_check_timeout,
44            shutdown: AtomicBool::new(false),
45            health_monitor: Mutex::new(None),
46            failure_callbacks: RwLock::new(Vec::new()),
47        }
48    }
49
50    /// Start the health monitoring background thread
51    pub fn start_health_monitoring(&self) -> IntegrateResult<()> {
52        let nodes = unsafe { &*(&self.nodes as *const RwLock<HashMap<NodeId, NodeInfo>>) };
53        let timeout = self.health_check_timeout;
54        let shutdown = unsafe { &*(&self.shutdown as *const AtomicBool) };
55        let callbacks = unsafe {
56            &*(&self.failure_callbacks as *const RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>>)
57        };
58
59        // Create references for the thread
60        let nodes_ptr = nodes as *const RwLock<HashMap<NodeId, NodeInfo>> as usize;
61        let shutdown_ptr = shutdown as *const AtomicBool as usize;
62        let callbacks_ptr =
63            callbacks as *const RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>> as usize;
64
65        let handle = thread::spawn(move || {
66            let nodes = unsafe { &*(nodes_ptr as *const RwLock<HashMap<NodeId, NodeInfo>>) };
67            let shutdown = unsafe { &*(shutdown_ptr as *const AtomicBool) };
68            let callbacks = unsafe {
69                &*(callbacks_ptr as *const RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>>)
70            };
71
72            while !shutdown.load(Ordering::Relaxed) {
73                // Check node health
74                let failed_nodes = {
75                    let mut nodes_write = match nodes.write() {
76                        Ok(guard) => guard,
77                        Err(_) => continue,
78                    };
79
80                    let mut failed = Vec::new();
81                    for (id, info) in nodes_write.iter_mut() {
82                        if !info.is_healthy(timeout) && info.status != NodeStatus::Failed {
83                            info.status = NodeStatus::Failed;
84                            failed.push(*id);
85                        }
86                    }
87                    failed
88                };
89
90                // Invoke failure callbacks
91                if !failed_nodes.is_empty() {
92                    if let Ok(cbs) = callbacks.read() {
93                        for node_id in &failed_nodes {
94                            for cb in cbs.iter() {
95                                cb(*node_id);
96                            }
97                        }
98                    }
99                }
100
101                thread::sleep(Duration::from_secs(1));
102            }
103        });
104
105        if let Ok(mut monitor) = self.health_monitor.lock() {
106            *monitor = Some(handle);
107        }
108
109        Ok(())
110    }
111
112    /// Stop health monitoring
113    pub fn stop_health_monitoring(&self) {
114        self.shutdown.store(true, Ordering::Relaxed);
115        if let Ok(mut monitor) = self.health_monitor.lock() {
116            if let Some(handle) = monitor.take() {
117                let _ = handle.join();
118            }
119        }
120    }
121
122    /// Register a new node
123    pub fn register_node(
124        &self,
125        address: SocketAddr,
126        capabilities: NodeCapabilities,
127    ) -> DistributedResult<NodeId> {
128        let node_id = NodeId::new(self.next_node_id.fetch_add(1, Ordering::SeqCst));
129
130        let mut node_info = NodeInfo::new(node_id, address);
131        node_info.capabilities = capabilities;
132        node_info.status = NodeStatus::Available;
133
134        match self.nodes.write() {
135            Ok(mut nodes) => {
136                nodes.insert(node_id, node_info);
137                Ok(node_id)
138            }
139            Err(_) => Err(DistributedError::CommunicationError(
140                "Failed to acquire nodes lock".to_string(),
141            )),
142        }
143    }
144
145    /// Deregister a node
146    pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
147        match self.nodes.write() {
148            Ok(mut nodes) => {
149                nodes.remove(&node_id);
150                Ok(())
151            }
152            Err(_) => Err(DistributedError::CommunicationError(
153                "Failed to acquire nodes lock".to_string(),
154            )),
155        }
156    }
157
158    /// Update node heartbeat
159    pub fn update_heartbeat(&self, node_id: NodeId) -> DistributedResult<()> {
160        match self.nodes.write() {
161            Ok(mut nodes) => {
162                if let Some(node) = nodes.get_mut(&node_id) {
163                    node.last_heartbeat = Instant::now();
164                    if node.status == NodeStatus::Failed {
165                        node.status = NodeStatus::Available;
166                    }
167                    Ok(())
168                } else {
169                    Err(DistributedError::NodeFailure(
170                        node_id,
171                        "Node not found".to_string(),
172                    ))
173                }
174            }
175            Err(_) => Err(DistributedError::CommunicationError(
176                "Failed to acquire nodes lock".to_string(),
177            )),
178        }
179    }
180
181    /// Update node status
182    pub fn update_status(&self, node_id: NodeId, status: NodeStatus) -> DistributedResult<()> {
183        match self.nodes.write() {
184            Ok(mut nodes) => {
185                if let Some(node) = nodes.get_mut(&node_id) {
186                    node.status = status;
187                    Ok(())
188                } else {
189                    Err(DistributedError::NodeFailure(
190                        node_id,
191                        "Node not found".to_string(),
192                    ))
193                }
194            }
195            Err(_) => Err(DistributedError::CommunicationError(
196                "Failed to acquire nodes lock".to_string(),
197            )),
198        }
199    }
200
201    /// Get list of available nodes
202    pub fn get_available_nodes(&self) -> Vec<NodeInfo> {
203        match self.nodes.read() {
204            Ok(nodes) => nodes
205                .values()
206                .filter(|n| n.status == NodeStatus::Available)
207                .cloned()
208                .collect(),
209            Err(_) => Vec::new(),
210        }
211    }
212
213    /// Get all registered nodes
214    pub fn get_all_nodes(&self) -> Vec<NodeInfo> {
215        match self.nodes.read() {
216            Ok(nodes) => nodes.values().cloned().collect(),
217            Err(_) => Vec::new(),
218        }
219    }
220
221    /// Get node by ID
222    pub fn get_node(&self, node_id: NodeId) -> Option<NodeInfo> {
223        match self.nodes.read() {
224            Ok(nodes) => nodes.get(&node_id).cloned(),
225            Err(_) => None,
226        }
227    }
228
229    /// Get number of available nodes
230    pub fn available_node_count(&self) -> usize {
231        self.get_available_nodes().len()
232    }
233
234    /// Register a failure callback
235    pub fn on_node_failure<F>(&self, callback: F)
236    where
237        F: Fn(NodeId) + Send + Sync + 'static,
238    {
239        if let Ok(mut callbacks) = self.failure_callbacks.write() {
240            callbacks.push(Arc::new(callback));
241        }
242    }
243
244    /// Record job completion for a node
245    pub fn record_job_completion(
246        &self,
247        node_id: NodeId,
248        duration: Duration,
249    ) -> DistributedResult<()> {
250        match self.nodes.write() {
251            Ok(mut nodes) => {
252                if let Some(node) = nodes.get_mut(&node_id) {
253                    let total_time = node.average_job_duration * node.jobs_completed as u32;
254                    node.jobs_completed += 1;
255                    node.average_job_duration =
256                        (total_time + duration) / node.jobs_completed as u32;
257                    Ok(())
258                } else {
259                    Err(DistributedError::NodeFailure(
260                        node_id,
261                        "Node not found".to_string(),
262                    ))
263                }
264            }
265            Err(_) => Err(DistributedError::CommunicationError(
266                "Failed to acquire nodes lock".to_string(),
267            )),
268        }
269    }
270
271    /// Select best node for a given workload
272    pub fn select_best_node(&self, estimated_cost: f64) -> Option<NodeId> {
273        match self.nodes.read() {
274            Ok(nodes) => nodes
275                .values()
276                .filter(|n| n.status == NodeStatus::Available)
277                .max_by(|a, b| {
278                    a.processing_score()
279                        .partial_cmp(&b.processing_score())
280                        .unwrap_or(std::cmp::Ordering::Equal)
281                })
282                .map(|n| n.id),
283            Err(_) => None,
284        }
285    }
286}
287
288impl Drop for NodeManager {
289    fn drop(&mut self) {
290        self.stop_health_monitoring();
291    }
292}
293
294/// A compute node that can process work chunks
295pub struct ComputeNode<F: IntegrateFloat> {
296    /// Node information
297    info: NodeInfo,
298    /// Current work queue
299    work_queue: Mutex<Vec<WorkChunk<F>>>,
300    /// Results buffer
301    results: Mutex<Vec<ChunkResult<F>>>,
302    /// Processing thread handles
303    workers: Mutex<Vec<thread::JoinHandle<()>>>,
304    /// Shutdown flag
305    shutdown: Arc<AtomicBool>,
306    /// ODE solver function
307    solver_fn: Arc<dyn Fn(&WorkChunk<F>) -> IntegrateResult<ChunkResult<F>> + Send + Sync>,
308}
309
310impl<F: IntegrateFloat> ComputeNode<F> {
311    /// Create a new compute node
312    pub fn new<S>(info: NodeInfo, solver_fn: S) -> Self
313    where
314        S: Fn(&WorkChunk<F>) -> IntegrateResult<ChunkResult<F>> + Send + Sync + 'static,
315    {
316        Self {
317            info,
318            work_queue: Mutex::new(Vec::new()),
319            results: Mutex::new(Vec::new()),
320            workers: Mutex::new(Vec::new()),
321            shutdown: Arc::new(AtomicBool::new(false)),
322            solver_fn: Arc::new(solver_fn),
323        }
324    }
325
326    /// Get node ID
327    pub fn id(&self) -> NodeId {
328        self.info.id
329    }
330
331    /// Get node status
332    pub fn status(&self) -> NodeStatus {
333        self.info.status
334    }
335
336    /// Submit a work chunk
337    pub fn submit_work(&self, chunk: WorkChunk<F>) -> DistributedResult<()> {
338        match self.work_queue.lock() {
339            Ok(mut queue) => {
340                queue.push(chunk);
341                Ok(())
342            }
343            Err(_) => Err(DistributedError::ResourceExhausted(
344                "Failed to acquire work queue lock".to_string(),
345            )),
346        }
347    }
348
349    /// Process all queued work
350    pub fn process_all(&self) -> DistributedResult<Vec<ChunkResult<F>>> {
351        let chunks = {
352            match self.work_queue.lock() {
353                Ok(mut queue) => std::mem::take(&mut *queue),
354                Err(_) => {
355                    return Err(DistributedError::ResourceExhausted(
356                        "Failed to acquire work queue lock".to_string(),
357                    ))
358                }
359            }
360        };
361
362        let mut results = Vec::with_capacity(chunks.len());
363        for chunk in chunks {
364            match (self.solver_fn)(&chunk) {
365                Ok(result) => results.push(result),
366                Err(e) => {
367                    return Err(DistributedError::ChunkError(
368                        chunk.id,
369                        format!("Solver error: {}", e),
370                    ))
371                }
372            }
373        }
374
375        Ok(results)
376    }
377
378    /// Get pending work count
379    pub fn pending_work_count(&self) -> usize {
380        match self.work_queue.lock() {
381            Ok(queue) => queue.len(),
382            Err(_) => 0,
383        }
384    }
385
386    /// Collect completed results
387    pub fn collect_results(&self) -> Vec<ChunkResult<F>> {
388        match self.results.lock() {
389            Ok(mut results) => std::mem::take(&mut *results),
390            Err(_) => Vec::new(),
391        }
392    }
393
394    /// Shutdown the node
395    pub fn shutdown(&self) {
396        self.shutdown.store(true, Ordering::Relaxed);
397    }
398}
399
400/// Builder for creating compute nodes with detected capabilities
401pub struct NodeBuilder {
402    address: SocketAddr,
403    capabilities: Option<NodeCapabilities>,
404}
405
406impl NodeBuilder {
407    /// Create a new node builder
408    pub fn new(address: SocketAddr) -> Self {
409        Self {
410            address,
411            capabilities: None,
412        }
413    }
414
415    /// Set custom capabilities
416    pub fn with_capabilities(mut self, capabilities: NodeCapabilities) -> Self {
417        self.capabilities = Some(capabilities);
418        self
419    }
420
421    /// Auto-detect capabilities
422    pub fn detect_capabilities(mut self) -> Self {
423        self.capabilities = Some(Self::detect_system_capabilities());
424        self
425    }
426
427    /// Detect system capabilities
428    fn detect_system_capabilities() -> NodeCapabilities {
429        let cpu_cores = thread::available_parallelism()
430            .map(|n| n.get())
431            .unwrap_or(1);
432
433        // Estimate available memory (simplified)
434        #[cfg(target_pointer_width = "32")]
435        let memory_bytes = 512 * 1024 * 1024; // 512MB default for 32-bit
436        #[cfg(target_pointer_width = "64")]
437        let memory_bytes = 8usize * 1024 * 1024 * 1024; // 8GB default for 64-bit
438
439        // Detect SIMD capabilities
440        let simd_capabilities = Self::detect_simd();
441
442        NodeCapabilities {
443            cpu_cores,
444            memory_bytes,
445            has_gpu: false, // Would need GPU detection library
446            gpu_memory_bytes: None,
447            network_bandwidth: 1024 * 1024 * 1024, // 1 Gbps
448            latency_us: 100,
449            supported_precisions: vec![
450                crate::distributed::types::FloatPrecision::F32,
451                crate::distributed::types::FloatPrecision::F64,
452            ],
453            simd_capabilities,
454        }
455    }
456
457    /// Detect SIMD capabilities
458    fn detect_simd() -> SimdCapability {
459        SimdCapability {
460            has_sse: cfg!(target_feature = "sse"),
461            has_sse2: cfg!(target_feature = "sse2"),
462            has_avx: cfg!(target_feature = "avx"),
463            has_avx2: cfg!(target_feature = "avx2"),
464            has_avx512: cfg!(target_feature = "avx512f"),
465            has_neon: cfg!(target_feature = "neon"),
466        }
467    }
468
469    /// Build the node info
470    pub fn build(self, node_id: NodeId) -> NodeInfo {
471        let capabilities = self
472            .capabilities
473            .unwrap_or_else(Self::detect_system_capabilities);
474        let mut info = NodeInfo::new(node_id, self.address);
475        info.capabilities = capabilities;
476        info.status = NodeStatus::Available;
477        info
478    }
479}
480
481/// Resource monitor for tracking node resource usage
482#[derive(Debug, Clone)]
483pub struct ResourceMonitor {
484    /// CPU usage (0.0 to 1.0)
485    pub cpu_usage: f64,
486    /// Memory usage (0.0 to 1.0)
487    pub memory_usage: f64,
488    /// Network usage (bytes/sec)
489    pub network_usage: usize,
490    /// GPU usage (0.0 to 1.0), if available
491    pub gpu_usage: Option<f64>,
492    /// Last update time
493    pub last_update: Instant,
494}
495
496impl Default for ResourceMonitor {
497    fn default() -> Self {
498        Self {
499            cpu_usage: 0.0,
500            memory_usage: 0.0,
501            network_usage: 0,
502            gpu_usage: None,
503            last_update: Instant::now(),
504        }
505    }
506}
507
508impl ResourceMonitor {
509    /// Update resource usage (simplified implementation)
510    pub fn update(&mut self) {
511        // In a real implementation, this would query system resources
512        self.last_update = Instant::now();
513    }
514
515    /// Check if resources are available
516    pub fn has_available_resources(&self, required_memory_fraction: f64) -> bool {
517        self.memory_usage + required_memory_fraction <= 1.0
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524    use std::net::{IpAddr, Ipv4Addr};
525
526    fn test_address() -> SocketAddr {
527        SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)
528    }
529
530    #[test]
531    fn test_node_manager_registration() {
532        let manager = NodeManager::new(Duration::from_secs(30));
533
534        let node_id = manager
535            .register_node(test_address(), NodeCapabilities::default())
536            .expect("Failed to register node");
537
538        assert_eq!(manager.available_node_count(), 1);
539
540        let node = manager.get_node(node_id);
541        assert!(node.is_some());
542        assert_eq!(node.map(|n| n.id), Some(node_id));
543    }
544
545    #[test]
546    fn test_node_manager_deregistration() {
547        let manager = NodeManager::new(Duration::from_secs(30));
548
549        let node_id = manager
550            .register_node(test_address(), NodeCapabilities::default())
551            .expect("Failed to register node");
552
553        assert_eq!(manager.available_node_count(), 1);
554
555        manager
556            .deregister_node(node_id)
557            .expect("Failed to deregister node");
558        assert_eq!(manager.available_node_count(), 0);
559    }
560
561    #[test]
562    fn test_node_manager_heartbeat() {
563        let manager = NodeManager::new(Duration::from_secs(30));
564
565        let node_id = manager
566            .register_node(test_address(), NodeCapabilities::default())
567            .expect("Failed to register node");
568
569        manager
570            .update_heartbeat(node_id)
571            .expect("Failed to update heartbeat");
572
573        let node = manager.get_node(node_id).expect("Node not found");
574        assert!(node.is_healthy(Duration::from_secs(60)));
575    }
576
577    #[test]
578    fn test_node_builder() {
579        let addr = test_address();
580        let node_info = NodeBuilder::new(addr)
581            .detect_capabilities()
582            .build(NodeId::new(1));
583
584        assert_eq!(node_info.id, NodeId::new(1));
585        assert_eq!(node_info.address, addr);
586        assert!(node_info.capabilities.cpu_cores > 0);
587    }
588
589    #[test]
590    fn test_resource_monitor() {
591        let mut monitor = ResourceMonitor::default();
592        assert!(monitor.has_available_resources(0.5));
593
594        monitor.memory_usage = 0.8;
595        assert!(!monitor.has_available_resources(0.3));
596    }
597}