Skip to main content

ternary_priority_queue/
lib.rs

1//! # ternary-priority-queue
2//!
3//! Priority queue for GPU kernel scheduling with ternary scoring.
4
5use std::collections::BinaryHeap;
6use std::cmp::Ordering;
7
8#[derive(Debug, Clone)]
9pub struct KernelJob {
10    pub id: u64,
11    pub name: String,
12    pub ternary_score: i8,
13    pub exact_priority: i32,
14    pub submitted_us: u64,
15}
16
17impl PartialEq for KernelJob { fn eq(&self, other: &Self) -> bool { self.id == other.id } }
18impl Eq for KernelJob {}
19impl PartialOrd for KernelJob {
20    fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
21}
22impl Ord for KernelJob {
23    fn cmp(&self, other: &Self) -> Ordering {
24        // Higher ternary score first, then higher exact priority, then earlier submission
25        self.ternary_score.cmp(&other.ternary_score)
26            .then_with(|| self.exact_priority.cmp(&other.exact_priority))
27            .then_with(|| other.submitted_us.cmp(&self.submitted_us))
28    }
29}
30
31pub struct TernaryPriorityQueue {
32    heap: BinaryHeap<KernelJob>,
33    high_count: u64,
34    normal_count: u64,
35    low_count: u64,
36    time_us: u64,
37}
38
39impl TernaryPriorityQueue {
40    pub fn new() -> Self { Self { heap: BinaryHeap::new(), high_count: 0, normal_count: 0, low_count: 0, time_us: 0 } }
41
42    pub fn push(&mut self, name: &str, exact_priority: i32) -> u64 {
43        self.time_us += 1;
44        let score = if exact_priority > 50 { 1i8 } else if exact_priority < -50 { -1i8 } else { 0i8 };
45        match score { 1 => self.high_count += 1, 0 => self.normal_count += 1, _ => self.low_count += 1 }
46        let id = self.time_us;
47        self.heap.push(KernelJob { id, name: name.into(), ternary_score: score, exact_priority, submitted_us: self.time_us });
48        id
49    }
50
51    pub fn push_with_score(&mut self, name: &str, exact: i32, score: i8) -> u64 {
52        self.time_us += 1;
53        match score { 1 => self.high_count += 1, 0 => self.normal_count += 1, _ => self.low_count += 1 }
54        let id = self.time_us;
55        self.heap.push(KernelJob { id, name: name.into(), ternary_score: score, exact_priority: exact, submitted_us: self.time_us });
56        id
57    }
58
59    pub fn pop(&mut self) -> Option<KernelJob> {
60        let job = self.heap.pop()?;
61        match job.ternary_score { 1 => self.high_count -= 1, 0 => self.normal_count -= 1, _ => self.low_count -= 1 }
62        Some(job)
63    }
64
65    pub fn peek(&self) -> Option<&KernelJob> { self.heap.peek() }
66
67    pub fn drain_high_priority(&mut self) -> Vec<KernelJob> {
68        let mut high = Vec::new();
69        let mut rest = Vec::new();
70        while let Some(job) = self.heap.pop() {
71            if job.ternary_score == 1 { high.push(job); } else { rest.push(job); }
72        }
73        for job in rest { self.heap.push(job); }
74        self.high_count = 0;
75        high
76    }
77
78    pub fn len(&self) -> usize { self.heap.len() }
79    pub fn is_empty(&self) -> bool { self.heap.is_empty() }
80    pub fn high_count(&self) -> u64 { self.high_count }
81    pub fn normal_count(&self) -> u64 { self.normal_count }
82    pub fn low_count(&self) -> u64 { self.low_count }
83}
84
85impl Default for TernaryPriorityQueue { fn default() -> Self { Self::new() } }
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_push_pop() {
93        let mut q = TernaryPriorityQueue::new();
94        q.push("kernel_a", 50);
95        q.push("kernel_b", 80);
96        let job = q.pop().unwrap();
97        assert_eq!(job.name, "kernel_b"); // higher priority
98    }
99
100    #[test]
101    fn test_ternary_classify() {
102        let mut q = TernaryPriorityQueue::new();
103        q.push("low", -100);
104        q.push("normal", 0);
105        q.push("high", 100);
106        assert_eq!(q.high_count(), 1);
107        assert_eq!(q.normal_count(), 1);
108        assert_eq!(q.low_count(), 1);
109    }
110
111    #[test]
112    fn test_ordering() {
113        let mut q = TernaryPriorityQueue::new();
114        q.push("normal", 10);
115        q.push_with_score("high", 5, 1); // manually boosted
116        assert_eq!(q.pop().unwrap().name, "high"); // ternary overrides
117    }
118
119    #[test]
120    fn test_drain_high() {
121        let mut q = TernaryPriorityQueue::new();
122        q.push("a", 100);
123        q.push("b", -10);
124        q.push("c", 80);
125        let high = q.drain_high_priority();
126        assert_eq!(high.len(), 2);
127        assert_eq!(q.len(), 1);
128    }
129
130    #[test]
131    fn test_peek() {
132        let mut q = TernaryPriorityQueue::new();
133        q.push("top", 90);
134        assert_eq!(q.peek().unwrap().name, "top");
135        assert_eq!(q.len(), 1); // peek doesn't remove
136    }
137
138    #[test]
139    fn test_empty() {
140        let mut q = TernaryPriorityQueue::new();
141        assert!(q.is_empty());
142        assert!(q.pop().is_none());
143    }
144
145    #[test]
146    fn test_manual_score() {
147        let mut q = TernaryPriorityQueue::new();
148        q.push_with_score("forced_high", 10, 1);
149        assert_eq!(q.peek().unwrap().ternary_score, 1);
150    }
151}