scirs2_integrate/distributed/
node.rs1use crate::common::IntegrateFloat;
8use crate::distributed::types::{
9 ChunkResult, DistributedError, DistributedResult, NodeCapabilities, NodeId, NodeInfo,
10 NodeStatus, SimdCapability, WorkChunk,
11};
12use crate::error::IntegrateResult;
13use scirs2_core::ndarray::Array1;
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::{Arc, Mutex, RwLock};
18use std::thread;
19use std::time::{Duration, Instant};
20
21pub struct NodeManager {
23 nodes: RwLock<HashMap<NodeId, NodeInfo>>,
25 next_node_id: AtomicU64,
27 health_check_timeout: Duration,
29 shutdown: AtomicBool,
31 health_monitor: Mutex<Option<thread::JoinHandle<()>>>,
33 failure_callbacks: RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>>,
35}
36
37impl NodeManager {
38 pub fn new(health_check_timeout: Duration) -> Self {
40 Self {
41 nodes: RwLock::new(HashMap::new()),
42 next_node_id: AtomicU64::new(1),
43 health_check_timeout,
44 shutdown: AtomicBool::new(false),
45 health_monitor: Mutex::new(None),
46 failure_callbacks: RwLock::new(Vec::new()),
47 }
48 }
49
50 pub fn start_health_monitoring(&self) -> IntegrateResult<()> {
52 let nodes = unsafe { &*(&self.nodes as *const RwLock<HashMap<NodeId, NodeInfo>>) };
53 let timeout = self.health_check_timeout;
54 let shutdown = unsafe { &*(&self.shutdown as *const AtomicBool) };
55 let callbacks = unsafe {
56 &*(&self.failure_callbacks as *const RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>>)
57 };
58
59 let nodes_ptr = nodes as *const RwLock<HashMap<NodeId, NodeInfo>> as usize;
61 let shutdown_ptr = shutdown as *const AtomicBool as usize;
62 let callbacks_ptr =
63 callbacks as *const RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>> as usize;
64
65 let handle = thread::spawn(move || {
66 let nodes = unsafe { &*(nodes_ptr as *const RwLock<HashMap<NodeId, NodeInfo>>) };
67 let shutdown = unsafe { &*(shutdown_ptr as *const AtomicBool) };
68 let callbacks = unsafe {
69 &*(callbacks_ptr as *const RwLock<Vec<Arc<dyn Fn(NodeId) + Send + Sync>>>)
70 };
71
72 while !shutdown.load(Ordering::Relaxed) {
73 let failed_nodes = {
75 let mut nodes_write = match nodes.write() {
76 Ok(guard) => guard,
77 Err(_) => continue,
78 };
79
80 let mut failed = Vec::new();
81 for (id, info) in nodes_write.iter_mut() {
82 if !info.is_healthy(timeout) && info.status != NodeStatus::Failed {
83 info.status = NodeStatus::Failed;
84 failed.push(*id);
85 }
86 }
87 failed
88 };
89
90 if !failed_nodes.is_empty() {
92 if let Ok(cbs) = callbacks.read() {
93 for node_id in &failed_nodes {
94 for cb in cbs.iter() {
95 cb(*node_id);
96 }
97 }
98 }
99 }
100
101 thread::sleep(Duration::from_secs(1));
102 }
103 });
104
105 if let Ok(mut monitor) = self.health_monitor.lock() {
106 *monitor = Some(handle);
107 }
108
109 Ok(())
110 }
111
112 pub fn stop_health_monitoring(&self) {
114 self.shutdown.store(true, Ordering::Relaxed);
115 if let Ok(mut monitor) = self.health_monitor.lock() {
116 if let Some(handle) = monitor.take() {
117 let _ = handle.join();
118 }
119 }
120 }
121
122 pub fn register_node(
124 &self,
125 address: SocketAddr,
126 capabilities: NodeCapabilities,
127 ) -> DistributedResult<NodeId> {
128 let node_id = NodeId::new(self.next_node_id.fetch_add(1, Ordering::SeqCst));
129
130 let mut node_info = NodeInfo::new(node_id, address);
131 node_info.capabilities = capabilities;
132 node_info.status = NodeStatus::Available;
133
134 match self.nodes.write() {
135 Ok(mut nodes) => {
136 nodes.insert(node_id, node_info);
137 Ok(node_id)
138 }
139 Err(_) => Err(DistributedError::CommunicationError(
140 "Failed to acquire nodes lock".to_string(),
141 )),
142 }
143 }
144
145 pub fn deregister_node(&self, node_id: NodeId) -> DistributedResult<()> {
147 match self.nodes.write() {
148 Ok(mut nodes) => {
149 nodes.remove(&node_id);
150 Ok(())
151 }
152 Err(_) => Err(DistributedError::CommunicationError(
153 "Failed to acquire nodes lock".to_string(),
154 )),
155 }
156 }
157
158 pub fn update_heartbeat(&self, node_id: NodeId) -> DistributedResult<()> {
160 match self.nodes.write() {
161 Ok(mut nodes) => {
162 if let Some(node) = nodes.get_mut(&node_id) {
163 node.last_heartbeat = Instant::now();
164 if node.status == NodeStatus::Failed {
165 node.status = NodeStatus::Available;
166 }
167 Ok(())
168 } else {
169 Err(DistributedError::NodeFailure(
170 node_id,
171 "Node not found".to_string(),
172 ))
173 }
174 }
175 Err(_) => Err(DistributedError::CommunicationError(
176 "Failed to acquire nodes lock".to_string(),
177 )),
178 }
179 }
180
181 pub fn update_status(&self, node_id: NodeId, status: NodeStatus) -> DistributedResult<()> {
183 match self.nodes.write() {
184 Ok(mut nodes) => {
185 if let Some(node) = nodes.get_mut(&node_id) {
186 node.status = status;
187 Ok(())
188 } else {
189 Err(DistributedError::NodeFailure(
190 node_id,
191 "Node not found".to_string(),
192 ))
193 }
194 }
195 Err(_) => Err(DistributedError::CommunicationError(
196 "Failed to acquire nodes lock".to_string(),
197 )),
198 }
199 }
200
201 pub fn get_available_nodes(&self) -> Vec<NodeInfo> {
203 match self.nodes.read() {
204 Ok(nodes) => nodes
205 .values()
206 .filter(|n| n.status == NodeStatus::Available)
207 .cloned()
208 .collect(),
209 Err(_) => Vec::new(),
210 }
211 }
212
213 pub fn get_all_nodes(&self) -> Vec<NodeInfo> {
215 match self.nodes.read() {
216 Ok(nodes) => nodes.values().cloned().collect(),
217 Err(_) => Vec::new(),
218 }
219 }
220
221 pub fn get_node(&self, node_id: NodeId) -> Option<NodeInfo> {
223 match self.nodes.read() {
224 Ok(nodes) => nodes.get(&node_id).cloned(),
225 Err(_) => None,
226 }
227 }
228
229 pub fn available_node_count(&self) -> usize {
231 self.get_available_nodes().len()
232 }
233
234 pub fn on_node_failure<F>(&self, callback: F)
236 where
237 F: Fn(NodeId) + Send + Sync + 'static,
238 {
239 if let Ok(mut callbacks) = self.failure_callbacks.write() {
240 callbacks.push(Arc::new(callback));
241 }
242 }
243
244 pub fn record_job_completion(
246 &self,
247 node_id: NodeId,
248 duration: Duration,
249 ) -> DistributedResult<()> {
250 match self.nodes.write() {
251 Ok(mut nodes) => {
252 if let Some(node) = nodes.get_mut(&node_id) {
253 let total_time = node.average_job_duration * node.jobs_completed as u32;
254 node.jobs_completed += 1;
255 node.average_job_duration =
256 (total_time + duration) / node.jobs_completed as u32;
257 Ok(())
258 } else {
259 Err(DistributedError::NodeFailure(
260 node_id,
261 "Node not found".to_string(),
262 ))
263 }
264 }
265 Err(_) => Err(DistributedError::CommunicationError(
266 "Failed to acquire nodes lock".to_string(),
267 )),
268 }
269 }
270
271 pub fn select_best_node(&self, estimated_cost: f64) -> Option<NodeId> {
273 match self.nodes.read() {
274 Ok(nodes) => nodes
275 .values()
276 .filter(|n| n.status == NodeStatus::Available)
277 .max_by(|a, b| {
278 a.processing_score()
279 .partial_cmp(&b.processing_score())
280 .unwrap_or(std::cmp::Ordering::Equal)
281 })
282 .map(|n| n.id),
283 Err(_) => None,
284 }
285 }
286}
287
288impl Drop for NodeManager {
289 fn drop(&mut self) {
290 self.stop_health_monitoring();
291 }
292}
293
294pub struct ComputeNode<F: IntegrateFloat> {
296 info: NodeInfo,
298 work_queue: Mutex<Vec<WorkChunk<F>>>,
300 results: Mutex<Vec<ChunkResult<F>>>,
302 workers: Mutex<Vec<thread::JoinHandle<()>>>,
304 shutdown: Arc<AtomicBool>,
306 solver_fn: Arc<dyn Fn(&WorkChunk<F>) -> IntegrateResult<ChunkResult<F>> + Send + Sync>,
308}
309
310impl<F: IntegrateFloat> ComputeNode<F> {
311 pub fn new<S>(info: NodeInfo, solver_fn: S) -> Self
313 where
314 S: Fn(&WorkChunk<F>) -> IntegrateResult<ChunkResult<F>> + Send + Sync + 'static,
315 {
316 Self {
317 info,
318 work_queue: Mutex::new(Vec::new()),
319 results: Mutex::new(Vec::new()),
320 workers: Mutex::new(Vec::new()),
321 shutdown: Arc::new(AtomicBool::new(false)),
322 solver_fn: Arc::new(solver_fn),
323 }
324 }
325
326 pub fn id(&self) -> NodeId {
328 self.info.id
329 }
330
331 pub fn status(&self) -> NodeStatus {
333 self.info.status
334 }
335
336 pub fn submit_work(&self, chunk: WorkChunk<F>) -> DistributedResult<()> {
338 match self.work_queue.lock() {
339 Ok(mut queue) => {
340 queue.push(chunk);
341 Ok(())
342 }
343 Err(_) => Err(DistributedError::ResourceExhausted(
344 "Failed to acquire work queue lock".to_string(),
345 )),
346 }
347 }
348
349 pub fn process_all(&self) -> DistributedResult<Vec<ChunkResult<F>>> {
351 let chunks = {
352 match self.work_queue.lock() {
353 Ok(mut queue) => std::mem::take(&mut *queue),
354 Err(_) => {
355 return Err(DistributedError::ResourceExhausted(
356 "Failed to acquire work queue lock".to_string(),
357 ))
358 }
359 }
360 };
361
362 let mut results = Vec::with_capacity(chunks.len());
363 for chunk in chunks {
364 match (self.solver_fn)(&chunk) {
365 Ok(result) => results.push(result),
366 Err(e) => {
367 return Err(DistributedError::ChunkError(
368 chunk.id,
369 format!("Solver error: {}", e),
370 ))
371 }
372 }
373 }
374
375 Ok(results)
376 }
377
378 pub fn pending_work_count(&self) -> usize {
380 match self.work_queue.lock() {
381 Ok(queue) => queue.len(),
382 Err(_) => 0,
383 }
384 }
385
386 pub fn collect_results(&self) -> Vec<ChunkResult<F>> {
388 match self.results.lock() {
389 Ok(mut results) => std::mem::take(&mut *results),
390 Err(_) => Vec::new(),
391 }
392 }
393
394 pub fn shutdown(&self) {
396 self.shutdown.store(true, Ordering::Relaxed);
397 }
398}
399
400pub struct NodeBuilder {
402 address: SocketAddr,
403 capabilities: Option<NodeCapabilities>,
404}
405
406impl NodeBuilder {
407 pub fn new(address: SocketAddr) -> Self {
409 Self {
410 address,
411 capabilities: None,
412 }
413 }
414
415 pub fn with_capabilities(mut self, capabilities: NodeCapabilities) -> Self {
417 self.capabilities = Some(capabilities);
418 self
419 }
420
421 pub fn detect_capabilities(mut self) -> Self {
423 self.capabilities = Some(Self::detect_system_capabilities());
424 self
425 }
426
427 fn detect_system_capabilities() -> NodeCapabilities {
429 let cpu_cores = thread::available_parallelism()
430 .map(|n| n.get())
431 .unwrap_or(1);
432
433 #[cfg(target_pointer_width = "32")]
435 let memory_bytes = 512 * 1024 * 1024; #[cfg(target_pointer_width = "64")]
437 let memory_bytes = 8usize * 1024 * 1024 * 1024; let simd_capabilities = Self::detect_simd();
441
442 NodeCapabilities {
443 cpu_cores,
444 memory_bytes,
445 has_gpu: false, gpu_memory_bytes: None,
447 network_bandwidth: 1024 * 1024 * 1024, latency_us: 100,
449 supported_precisions: vec![
450 crate::distributed::types::FloatPrecision::F32,
451 crate::distributed::types::FloatPrecision::F64,
452 ],
453 simd_capabilities,
454 }
455 }
456
457 fn detect_simd() -> SimdCapability {
459 SimdCapability {
460 has_sse: cfg!(target_feature = "sse"),
461 has_sse2: cfg!(target_feature = "sse2"),
462 has_avx: cfg!(target_feature = "avx"),
463 has_avx2: cfg!(target_feature = "avx2"),
464 has_avx512: cfg!(target_feature = "avx512f"),
465 has_neon: cfg!(target_feature = "neon"),
466 }
467 }
468
469 pub fn build(self, node_id: NodeId) -> NodeInfo {
471 let capabilities = self
472 .capabilities
473 .unwrap_or_else(Self::detect_system_capabilities);
474 let mut info = NodeInfo::new(node_id, self.address);
475 info.capabilities = capabilities;
476 info.status = NodeStatus::Available;
477 info
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct ResourceMonitor {
484 pub cpu_usage: f64,
486 pub memory_usage: f64,
488 pub network_usage: usize,
490 pub gpu_usage: Option<f64>,
492 pub last_update: Instant,
494}
495
496impl Default for ResourceMonitor {
497 fn default() -> Self {
498 Self {
499 cpu_usage: 0.0,
500 memory_usage: 0.0,
501 network_usage: 0,
502 gpu_usage: None,
503 last_update: Instant::now(),
504 }
505 }
506}
507
508impl ResourceMonitor {
509 pub fn update(&mut self) {
511 self.last_update = Instant::now();
513 }
514
515 pub fn has_available_resources(&self, required_memory_fraction: f64) -> bool {
517 self.memory_usage + required_memory_fraction <= 1.0
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use std::net::{IpAddr, Ipv4Addr};
525
526 fn test_address() -> SocketAddr {
527 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080)
528 }
529
530 #[test]
531 fn test_node_manager_registration() {
532 let manager = NodeManager::new(Duration::from_secs(30));
533
534 let node_id = manager
535 .register_node(test_address(), NodeCapabilities::default())
536 .expect("Failed to register node");
537
538 assert_eq!(manager.available_node_count(), 1);
539
540 let node = manager.get_node(node_id);
541 assert!(node.is_some());
542 assert_eq!(node.map(|n| n.id), Some(node_id));
543 }
544
545 #[test]
546 fn test_node_manager_deregistration() {
547 let manager = NodeManager::new(Duration::from_secs(30));
548
549 let node_id = manager
550 .register_node(test_address(), NodeCapabilities::default())
551 .expect("Failed to register node");
552
553 assert_eq!(manager.available_node_count(), 1);
554
555 manager
556 .deregister_node(node_id)
557 .expect("Failed to deregister node");
558 assert_eq!(manager.available_node_count(), 0);
559 }
560
561 #[test]
562 fn test_node_manager_heartbeat() {
563 let manager = NodeManager::new(Duration::from_secs(30));
564
565 let node_id = manager
566 .register_node(test_address(), NodeCapabilities::default())
567 .expect("Failed to register node");
568
569 manager
570 .update_heartbeat(node_id)
571 .expect("Failed to update heartbeat");
572
573 let node = manager.get_node(node_id).expect("Node not found");
574 assert!(node.is_healthy(Duration::from_secs(60)));
575 }
576
577 #[test]
578 fn test_node_builder() {
579 let addr = test_address();
580 let node_info = NodeBuilder::new(addr)
581 .detect_capabilities()
582 .build(NodeId::new(1));
583
584 assert_eq!(node_info.id, NodeId::new(1));
585 assert_eq!(node_info.address, addr);
586 assert!(node_info.capabilities.cpu_cores > 0);
587 }
588
589 #[test]
590 fn test_resource_monitor() {
591 let mut monitor = ResourceMonitor::default();
592 assert!(monitor.has_available_resources(0.5));
593
594 monitor.memory_usage = 0.8;
595 assert!(!monitor.has_available_resources(0.3));
596 }
597}