1use super::cluster::{ComputeCapacity, DistributedTask, NodeInfo, ResourceRequirements, TaskId};
8use crate::error::{CoreError, CoreResult, ErrorContext};
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap, VecDeque};
11use std::sync::{Arc, Mutex, RwLock};
12use std::time::{Duration, Instant};
13
14static GLOBAL_SCHEDULER: std::sync::OnceLock<Arc<DistributedScheduler>> =
16 std::sync::OnceLock::new();
17
18#[derive(Debug)]
20pub struct DistributedScheduler {
21 task_queue: Arc<Mutex<TaskQueue>>,
22 execution_tracker: Arc<RwLock<ExecutionTracker>>,
23 schedulingpolicies: Arc<RwLock<SchedulingPolicies>>,
24 load_balancer: Arc<RwLock<LoadBalancer>>,
25}
26
27impl DistributedScheduler {
28 pub fn new() -> CoreResult<Self> {
30 Ok(Self {
31 task_queue: Arc::new(Mutex::new(TaskQueue::new())),
32 execution_tracker: Arc::new(RwLock::new(ExecutionTracker::new())),
33 schedulingpolicies: Arc::new(RwLock::new(SchedulingPolicies::default())),
34 load_balancer: Arc::new(RwLock::new(LoadBalancer::new())),
35 })
36 }
37
38 pub fn global() -> CoreResult<Arc<Self>> {
40 Ok(GLOBAL_SCHEDULER
41 .get_or_init(|| Arc::new(Self::new().unwrap()))
42 .clone())
43 }
44
45 pub fn submit_task(&self, task: DistributedTask) -> CoreResult<TaskId> {
47 let mut queue = self.task_queue.lock().map_err(|_| {
48 CoreError::InvalidState(ErrorContext::new("Failed to acquire task queue lock"))
49 })?;
50
51 let taskid = task.taskid.clone();
52 queue.enqueue(task)?;
53
54 Ok(taskid)
55 }
56
57 pub fn get_pending_task_count(&self) -> CoreResult<usize> {
59 let queue = self.task_queue.lock().map_err(|_| {
60 CoreError::InvalidState(ErrorContext::new("Failed to acquire task queue lock"))
61 })?;
62
63 Ok(queue.size())
64 }
65
66 pub fn schedule_next(&self, availablenodes: &[NodeInfo]) -> CoreResult<Vec<TaskAssignment>> {
68 let mut queue = self.task_queue.lock().map_err(|_| {
69 CoreError::InvalidState(ErrorContext::new("Failed to acquire task queue lock"))
70 })?;
71
72 let policies = self.schedulingpolicies.read().map_err(|_| {
73 CoreError::InvalidState(ErrorContext::new("Failed to acquire policies lock"))
74 })?;
75
76 let mut load_balancer = self.load_balancer.write().map_err(|_| {
77 CoreError::InvalidState(ErrorContext::new("Failed to acquire load balancer lock"))
78 })?;
79
80 let assignments = match policies.scheduling_algorithm {
82 SchedulingAlgorithm::FirstComeFirstServe => {
83 self.schedule_fcfs(&mut queue, availablenodes, &mut load_balancer)?
84 }
85 SchedulingAlgorithm::PriorityBased => {
86 self.schedule_priority(&mut queue, availablenodes, &mut load_balancer)?
87 }
88 SchedulingAlgorithm::LoadBalanced => {
89 self.schedule_load_balanced(&mut queue, availablenodes, &mut load_balancer)?
90 }
91 SchedulingAlgorithm::ResourceAware => {
92 self.schedule_resource_aware(&mut queue, availablenodes, &mut load_balancer)?
93 }
94 };
95
96 let mut tracker = self.execution_tracker.write().map_err(|_| {
98 CoreError::InvalidState(ErrorContext::new(
99 "Failed to acquire execution tracker lock",
100 ))
101 })?;
102
103 for assignment in &assignments {
104 tracker.track_assignment(assignment.clone())?;
105 }
106
107 Ok(assignments)
108 }
109
110 fn schedule_fcfs(
111 &self,
112 queue: &mut TaskQueue,
113 availablenodes: &[NodeInfo],
114 load_balancer: &mut LoadBalancer,
115 ) -> CoreResult<Vec<TaskAssignment>> {
116 let mut assignments = Vec::new();
117
118 while let Some(task) = queue.dequeue_next() {
119 if let Some(node) = load_balancer.select_node_for_task(&task, availablenodes)? {
120 assignments.push(TaskAssignment {
121 taskid: task.taskid.clone(),
122 nodeid: node.id.clone(),
123 assigned_at: Instant::now(),
124 estimated_duration: self.estimate_task_duration(&task, &node)?,
125 });
126
127 if assignments.len() >= 10 {
128 break;
130 }
131 } else {
132 queue.enqueue(task)?;
134 break;
135 }
136 }
137
138 Ok(assignments)
139 }
140
141 fn schedule_priority(
142 &self,
143 queue: &mut TaskQueue,
144 availablenodes: &[NodeInfo],
145 load_balancer: &mut LoadBalancer,
146 ) -> CoreResult<Vec<TaskAssignment>> {
147 let mut assignments = Vec::new();
148 let mut scheduled_tasks = Vec::new();
149
150 while let Some(task) = queue.dequeue_highest_priority() {
152 if let Some(node) = load_balancer.select_node_for_task(&task, availablenodes)? {
153 assignments.push(TaskAssignment {
154 taskid: task.taskid.clone(),
155 nodeid: node.id.clone(),
156 assigned_at: Instant::now(),
157 estimated_duration: self.estimate_task_duration(&task, &node)?,
158 });
159
160 if assignments.len() >= 10 {
161 break;
163 }
164 } else {
165 scheduled_tasks.push(task);
167 }
168 }
169
170 for task in scheduled_tasks {
172 queue.enqueue(task)?;
173 }
174
175 Ok(assignments)
176 }
177
178 fn schedule_load_balanced(
179 &self,
180 queue: &mut TaskQueue,
181 availablenodes: &[NodeInfo],
182 load_balancer: &mut LoadBalancer,
183 ) -> CoreResult<Vec<TaskAssignment>> {
184 let mut assignments = Vec::new();
185
186 load_balancer.update_nodeloads(availablenodes)?;
188
189 while let Some(task) = queue.dequeue_next() {
190 if let Some(node) = load_balancer.select_least_loaded_node(&task, availablenodes)? {
191 assignments.push(TaskAssignment {
192 taskid: task.taskid.clone(),
193 nodeid: node.id.clone(),
194 assigned_at: Instant::now(),
195 estimated_duration: self.estimate_task_duration(&task, &node)?,
196 });
197
198 load_balancer.record_assignment(&node.id, &task)?;
200
201 if assignments.len() >= 10 {
202 break;
204 }
205 } else {
206 queue.enqueue(task)?;
207 break;
208 }
209 }
210
211 Ok(assignments)
212 }
213
214 fn schedule_resource_aware(
215 &self,
216 queue: &mut TaskQueue,
217 availablenodes: &[NodeInfo],
218 load_balancer: &mut LoadBalancer,
219 ) -> CoreResult<Vec<TaskAssignment>> {
220 let mut assignments = Vec::new();
221
222 while let Some(task) = queue.dequeue_next() {
223 if let Some(node) = load_balancer.select_best_fit_node(&task, availablenodes)? {
224 assignments.push(TaskAssignment {
225 taskid: task.taskid.clone(),
226 nodeid: node.id.clone(),
227 assigned_at: Instant::now(),
228 estimated_duration: self.estimate_task_duration(&task, &node)?,
229 });
230
231 if assignments.len() >= 10 {
232 break;
234 }
235 } else {
236 queue.enqueue(task)?;
237 break;
238 }
239 }
240
241 Ok(assignments)
242 }
243
244 fn estimate_task_duration(
245 &self,
246 task: &DistributedTask,
247 node: &NodeInfo,
248 ) -> CoreResult<Duration> {
249 let cpu_factor =
251 task.resource_requirements.cpu_cores as f64 / node.capabilities.cpu_cores as f64;
252 let memory_factor =
253 task.resource_requirements.memory_gb as f64 / node.capabilities.memory_gb as f64;
254
255 let complexity_factor = cpu_factor.max(memory_factor);
256 let base_duration = Duration::from_secs(60); Ok(Duration::from_secs(
259 (base_duration.as_secs() as f64 * complexity_factor) as u64,
260 ))
261 }
262}
263
264#[derive(Debug)]
266pub struct TaskQueue {
267 priority_queue: BinaryHeap<PriorityTask>,
268 fifo_queue: VecDeque<DistributedTask>,
269 task_count: usize,
270}
271
272impl Default for TaskQueue {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278impl TaskQueue {
279 pub fn new() -> Self {
280 Self {
281 priority_queue: BinaryHeap::new(),
282 fifo_queue: VecDeque::new(),
283 task_count: 0,
284 }
285 }
286
287 pub fn enqueue(&mut self, task: DistributedTask) -> CoreResult<()> {
288 match task.priority {
289 super::cluster::TaskPriority::Low | super::cluster::TaskPriority::Normal => {
290 self.fifo_queue.push_back(task);
291 }
292 super::cluster::TaskPriority::High | super::cluster::TaskPriority::Critical => {
293 self.priority_queue.push(PriorityTask {
294 task,
295 submitted_at: Instant::now(),
296 });
297 }
298 }
299
300 self.task_count += 1;
301 Ok(())
302 }
303
304 pub fn dequeue_next(&mut self) -> Option<DistributedTask> {
305 if let Some(priority_task) = self.priority_queue.pop() {
307 self.task_count -= 1;
308 return Some(priority_task.task);
309 }
310
311 if let Some(task) = self.fifo_queue.pop_front() {
313 self.task_count -= 1;
314 return Some(task);
315 }
316
317 None
318 }
319
320 pub fn dequeue_highest_priority(&mut self) -> Option<DistributedTask> {
321 if let Some(priority_task) = self.priority_queue.pop() {
322 self.task_count -= 1;
323 Some(priority_task.task)
324 } else {
325 None
326 }
327 }
328
329 pub fn size(&self) -> usize {
330 self.task_count
331 }
332}
333
334#[derive(Debug)]
336struct PriorityTask {
337 task: DistributedTask,
338 submitted_at: Instant,
339}
340
341impl PartialEq for PriorityTask {
342 fn eq(&self, other: &Self) -> bool {
343 self.task.priority == other.task.priority
344 }
345}
346
347impl Eq for PriorityTask {}
348
349impl PartialOrd for PriorityTask {
350 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
351 Some(self.cmp(other))
352 }
353}
354
355impl Ord for PriorityTask {
356 fn cmp(&self, other: &Self) -> Ordering {
357 match self.task.priority.cmp(&other.task.priority) {
359 Ordering::Equal => other.submitted_at.cmp(&self.submitted_at), other => other, }
362 }
363}
364
365#[derive(Debug)]
367pub struct ExecutionTracker {
368 active_assignments: HashMap<TaskId, TaskAssignment>,
369 completed_tasks: VecDeque<CompletedTask>,
370 failed_tasks: VecDeque<FailedTask>,
371}
372
373impl Default for ExecutionTracker {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379impl ExecutionTracker {
380 pub fn new() -> Self {
381 Self {
382 active_assignments: HashMap::new(),
383 completed_tasks: VecDeque::with_capacity(1000),
384 failed_tasks: VecDeque::with_capacity(1000),
385 }
386 }
387
388 pub fn track_assignment(&mut self, assignment: TaskAssignment) -> CoreResult<()> {
389 self.active_assignments
390 .insert(assignment.taskid.clone(), assignment);
391 Ok(())
392 }
393
394 pub fn mark_task_complete(
395 &mut self,
396 taskid: &TaskId,
397 execution_time: Duration,
398 ) -> CoreResult<()> {
399 if let Some(assignment) = self.active_assignments.remove(taskid) {
400 let completed_task = CompletedTask {
401 taskid: taskid.clone(),
402 nodeid: assignment.nodeid,
403 execution_time,
404 completed_at: Instant::now(),
405 };
406
407 self.completed_tasks.push_back(completed_task);
408
409 while self.completed_tasks.len() > 1000 {
411 self.completed_tasks.pop_front();
412 }
413 }
414
415 Ok(())
416 }
417
418 pub fn mark_task_failed(&mut self, taskid: &TaskId, error: String) -> CoreResult<()> {
419 if let Some(assignment) = self.active_assignments.remove(taskid) {
420 let failed_task = FailedTask {
421 taskid: taskid.clone(),
422 nodeid: assignment.nodeid,
423 error,
424 failed_at: Instant::now(),
425 };
426
427 self.failed_tasks.push_back(failed_task);
428
429 while self.failed_tasks.len() > 1000 {
431 self.failed_tasks.pop_front();
432 }
433 }
434
435 Ok(())
436 }
437
438 pub fn get_active_count(&self) -> usize {
439 self.active_assignments.len()
440 }
441}
442
443#[derive(Debug)]
445pub struct LoadBalancer {
446 nodeloads: HashMap<String, NodeLoad>,
447 balancing_strategy: LoadBalancingStrategy,
448}
449
450impl Default for LoadBalancer {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456impl LoadBalancer {
457 pub fn new() -> Self {
458 Self {
459 nodeloads: HashMap::new(),
460 balancing_strategy: LoadBalancingStrategy::LeastLoaded,
461 }
462 }
463
464 pub fn select_node_for_task(
465 &self,
466 task: &DistributedTask,
467 nodes: &[NodeInfo],
468 ) -> CoreResult<Option<NodeInfo>> {
469 match self.balancing_strategy {
470 LoadBalancingStrategy::RoundRobin => self.select_round_robin(nodes),
471 LoadBalancingStrategy::LeastLoaded => self.select_least_loaded(task, nodes),
472 LoadBalancingStrategy::ResourceBased => self.select_resourcebased(task, nodes),
473 }
474 }
475
476 pub fn select_least_loaded_node(
477 &self,
478 task: &DistributedTask,
479 nodes: &[NodeInfo],
480 ) -> CoreResult<Option<NodeInfo>> {
481 self.select_least_loaded(task, nodes)
482 }
483
484 pub fn select_best_fit_node(
485 &self,
486 task: &DistributedTask,
487 nodes: &[NodeInfo],
488 ) -> CoreResult<Option<NodeInfo>> {
489 self.select_resourcebased(task, nodes)
490 }
491
492 fn select_round_robin(&self, nodes: &[NodeInfo]) -> CoreResult<Option<NodeInfo>> {
493 if nodes.is_empty() {
494 return Ok(None);
495 }
496
497 let index = self.nodeloads.len() % nodes.len();
499 Ok(Some(nodes[index].clone()))
500 }
501
502 fn select_least_loaded(
503 &self,
504 _task: &DistributedTask,
505 nodes: &[NodeInfo],
506 ) -> CoreResult<Option<NodeInfo>> {
507 if nodes.is_empty() {
508 return Ok(None);
509 }
510
511 let least_loaded = nodes.iter().min_by_key(|node| {
513 self.nodeloads
514 .get(&node.id)
515 .map(|load| load.current_tasks)
516 .unwrap_or(0)
517 });
518
519 Ok(least_loaded.cloned())
520 }
521
522 fn select_resourcebased(
523 &self,
524 task: &DistributedTask,
525 nodes: &[NodeInfo],
526 ) -> CoreResult<Option<NodeInfo>> {
527 if nodes.is_empty() {
528 return Ok(None);
529 }
530
531 let best_fit = nodes
533 .iter()
534 .filter(|node| self.can_satisfy_requirements(node, &task.resource_requirements))
535 .min_by_key(|node| self.calculate_resource_waste(node, &task.resource_requirements));
536
537 Ok(best_fit.cloned())
538 }
539
540 fn can_satisfy_requirements(
541 &self,
542 node: &NodeInfo,
543 requirements: &ResourceRequirements,
544 ) -> bool {
545 let available = self.available_capacity(&node.id, &node.capabilities);
546
547 available.cpu_cores >= requirements.cpu_cores
548 && available.memory_gb >= requirements.memory_gb
549 && available.gpu_count >= requirements.gpu_count
550 && available.disk_space_gb >= requirements.disk_space_gb
551 }
552
553 fn calculate_resource_waste(
554 &self,
555 node: &NodeInfo,
556 requirements: &ResourceRequirements,
557 ) -> usize {
558 let available = self.available_capacity(&node.id, &node.capabilities);
559
560 let cpu_waste = available.cpu_cores.saturating_sub(requirements.cpu_cores);
561 let memory_waste = available.memory_gb.saturating_sub(requirements.memory_gb);
562 let gpu_waste = available.gpu_count.saturating_sub(requirements.gpu_count);
563 let disk_waste = available
564 .disk_space_gb
565 .saturating_sub(requirements.disk_space_gb);
566
567 cpu_waste + memory_waste + gpu_waste + disk_waste / 10 }
569
570 fn available_capacity(
571 &self,
572 nodeid: &str,
573 total_capacity: &super::cluster::NodeCapabilities,
574 ) -> ComputeCapacity {
575 let used = self
576 .nodeloads
577 .get(nodeid)
578 .map(|load| &load.used_capacity)
579 .cloned()
580 .unwrap_or_default();
581
582 ComputeCapacity {
583 cpu_cores: total_capacity.cpu_cores.saturating_sub(used.cpu_cores),
584 memory_gb: total_capacity.memory_gb.saturating_sub(used.memory_gb),
585 gpu_count: total_capacity.gpu_count.saturating_sub(used.gpu_count),
586 disk_space_gb: total_capacity
587 .disk_space_gb
588 .saturating_sub(used.disk_space_gb),
589 }
590 }
591
592 pub fn update_nodeloads(&mut self, nodes: &[NodeInfo]) -> CoreResult<()> {
593 for node in nodes {
595 self.nodeloads
596 .entry(node.id.clone())
597 .or_insert_with(|| NodeLoad {
598 nodeid: node.id.clone(),
599 current_tasks: 0,
600 used_capacity: ComputeCapacity::default(),
601 last_updated: Instant::now(),
602 });
603 }
604
605 Ok(())
606 }
607
608 pub fn record_assignment(&mut self, nodeid: &str, task: &DistributedTask) -> CoreResult<()> {
609 if let Some(load) = self.nodeloads.get_mut(nodeid) {
610 load.current_tasks += 1;
611 load.used_capacity.cpu_cores += task.resource_requirements.cpu_cores;
612 load.used_capacity.memory_gb += task.resource_requirements.memory_gb;
613 load.used_capacity.gpu_count += task.resource_requirements.gpu_count;
614 load.used_capacity.disk_space_gb += task.resource_requirements.disk_space_gb;
615 load.last_updated = Instant::now();
616 }
617
618 Ok(())
619 }
620}
621
622#[derive(Debug, Clone)]
624pub struct NodeLoad {
625 pub nodeid: String,
626 pub current_tasks: usize,
627 pub used_capacity: ComputeCapacity,
628 pub last_updated: Instant,
629}
630
631#[derive(Debug, Clone)]
633pub struct SchedulingPolicies {
634 pub scheduling_algorithm: SchedulingAlgorithm,
635 pub load_balancing_strategy: LoadBalancingStrategy,
636 pub batch_size: usize,
637 pub scheduling_interval: Duration,
638 pub priority_boost_threshold: Duration,
639}
640
641impl Default for SchedulingPolicies {
642 fn default() -> Self {
643 Self {
644 scheduling_algorithm: SchedulingAlgorithm::PriorityBased,
645 load_balancing_strategy: LoadBalancingStrategy::LeastLoaded,
646 batch_size: 10,
647 scheduling_interval: Duration::from_secs(5),
648 priority_boost_threshold: Duration::from_secs(300), }
650 }
651}
652
653#[derive(Debug, Clone, Copy, PartialEq, Eq)]
654pub enum SchedulingAlgorithm {
655 FirstComeFirstServe,
656 PriorityBased,
657 LoadBalanced,
658 ResourceAware,
659}
660
661#[derive(Debug, Clone, Copy, PartialEq, Eq)]
662pub enum LoadBalancingStrategy {
663 RoundRobin,
664 LeastLoaded,
665 ResourceBased,
666}
667
668#[derive(Debug, Clone)]
670pub struct TaskAssignment {
671 pub taskid: TaskId,
672 pub nodeid: String,
673 pub assigned_at: Instant,
674 pub estimated_duration: Duration,
675}
676
677#[derive(Debug, Clone)]
679pub struct CompletedTask {
680 pub taskid: TaskId,
681 pub nodeid: String,
682 pub execution_time: Duration,
683 pub completed_at: Instant,
684}
685
686#[derive(Debug, Clone)]
688pub struct FailedTask {
689 pub taskid: TaskId,
690 pub nodeid: String,
691 pub error: String,
692 pub failed_at: Instant,
693}
694
695#[allow(dead_code)]
697pub fn initialize_distributed_scheduler() -> CoreResult<()> {
698 let _scheduler = DistributedScheduler::global()?;
699 Ok(())
700}
701
702#[cfg(test)]
703mod tests {
704 use super::*;
705 use crate::distributed::{
706 BackoffStrategy, ClusterTaskPriority, NodeCapabilities, NodeMetadata, NodeStatus, NodeType,
707 RetryPolicy, TaskParameters, TaskType,
708 };
709 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
710
711 #[test]
712 fn test_scheduler_creation() {
713 let _scheduler = DistributedScheduler::new().unwrap();
714 }
716
717 #[test]
718 fn test_task_queue() {
719 let mut queue = TaskQueue::new();
720 assert_eq!(queue.size(), 0);
721
722 let task = create_test_task(ClusterTaskPriority::Normal);
723 queue.enqueue(task).unwrap();
724 assert_eq!(queue.size(), 1);
725
726 let dequeued = queue.dequeue_next();
727 assert!(dequeued.is_some());
728 assert_eq!(queue.size(), 0);
729 }
730
731 #[test]
732 fn test_priority_scheduling() {
733 let mut queue = TaskQueue::new();
734
735 let low_task = create_test_task(ClusterTaskPriority::Low);
737 let high_task = create_test_task(ClusterTaskPriority::High);
738
739 queue.enqueue(low_task).unwrap();
740 queue.enqueue(high_task).unwrap();
741
742 let first = queue.dequeue_next().unwrap();
744 assert_eq!(first.priority, ClusterTaskPriority::High);
745
746 let second = queue.dequeue_next().unwrap();
747 assert_eq!(second.priority, ClusterTaskPriority::Low);
748 }
749
750 #[test]
751 fn test_load_balancer() {
752 let balancer = LoadBalancer::new();
753 let nodes = vec![create_test_node("node1"), create_test_node("node2")];
754 let task = create_test_task(ClusterTaskPriority::Normal);
755
756 let selected = balancer.select_node_for_task(&task, &nodes).unwrap();
757 assert!(selected.is_some());
758 }
759
760 fn create_test_task(priority: ClusterTaskPriority) -> DistributedTask {
761 DistributedTask {
762 taskid: TaskId::generate(),
763 task_type: TaskType::Computation,
764 resource_requirements: ResourceRequirements {
765 cpu_cores: 2,
766 memory_gb: 4,
767 gpu_count: 0,
768 disk_space_gb: 10,
769 specialized_requirements: Vec::new(),
770 },
771 data_dependencies: Vec::new(),
772 execution_parameters: TaskParameters {
773 environment_variables: HashMap::new(),
774 command_arguments: Vec::new(),
775 timeout: None,
776 retrypolicy: RetryPolicy {
777 max_attempts: 3,
778 backoff_strategy: BackoffStrategy::Fixed(Duration::from_secs(1)),
779 },
780 },
781 priority,
782 }
783 }
784
785 fn create_test_node(id: &str) -> NodeInfo {
786 NodeInfo {
787 id: id.to_string(),
788 address: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080),
789 node_type: NodeType::Worker,
790 capabilities: NodeCapabilities {
791 cpu_cores: 8,
792 memory_gb: 16,
793 gpu_count: 1,
794 disk_space_gb: 100,
795 networkbandwidth_gbps: 1.0,
796 specialized_units: Vec::new(),
797 },
798 status: NodeStatus::Healthy,
799 last_seen: Instant::now(),
800 metadata: NodeMetadata::default(),
801 }
802 }
803}