1use std::fmt;
9use std::panic::{AssertUnwindSafe, catch_unwind};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::mpsc;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct TaskId(u64);
16
17static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
19
20impl TaskId {
21 pub fn new() -> Self {
23 TaskId(NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed))
24 }
25
26 pub fn value(&self) -> u64 {
28 self.0
29 }
30}
31
32impl Default for TaskId {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl fmt::Display for TaskId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 write!(f, "Task({})", self.0)
41 }
42}
43
44#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
49pub enum Priority {
50 Low,
51 #[default]
52 Normal,
53 High,
54}
55
56#[derive(Debug)]
58pub enum TaskError {
59 Panicked(String),
61 ExecutionFailed(Box<dyn std::error::Error + Send + Sync + 'static>),
63 TimedOut,
65}
66
67impl fmt::Display for TaskError {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 match self {
70 TaskError::Panicked(msg) => write!(f, "Task panicked: {msg}"),
71 TaskError::ExecutionFailed(err) => write!(f, "Task execution failed: {err}"),
72 TaskError::TimedOut => write!(f, "Operation timed out"),
73 }
74 }
75}
76
77impl std::error::Error for TaskError {
78 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
79 match self {
80 TaskError::ExecutionFailed(err) => Some(err.as_ref()),
81 _ => None,
82 }
83 }
84}
85
86fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
88 if let Some(s) = payload.downcast_ref::<&str>() {
89 s.to_string()
90 } else if let Some(s) = payload.downcast_ref::<String>() {
91 s.clone()
92 } else {
93 "Unknown panic payload type".to_string()
94 }
95}
96
97#[derive(Debug)]
102pub struct TaskHandle<Output> {
103 pub task_id: TaskId,
104 result_receiver: mpsc::Receiver<Result<Output, TaskError>>,
105}
106
107impl<Output> TaskHandle<Output> {
108 pub(crate) fn new(
110 task_id: TaskId,
111 result_receiver: mpsc::Receiver<Result<Output, TaskError>>,
112 ) -> Self {
113 Self {
114 task_id,
115 result_receiver,
116 }
117 }
118
119 pub fn get_task_id(&self) -> TaskId {
121 self.task_id
122 }
123
124 pub fn recv_result(&self) -> Result<Result<Output, TaskError>, mpsc::RecvError> {
126 self.result_receiver.recv()
127 }
128
129 pub fn try_recv_result(&self) -> Result<Result<Output, TaskError>, mpsc::TryRecvError> {
131 self.result_receiver.try_recv()
132 }
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
140pub enum TaskExecutionOutcome {
141 Success,
143 LogicError,
145 ResultSendFailed,
147 Panicked,
149}
150
151pub struct Task {
156 pub id: TaskId,
157 pub priority: Priority,
158 pub estimated_cost: u32,
159 runnable: Box<dyn FnOnce() -> TaskExecutionOutcome + Send + 'static>,
160}
161
162impl Task {
163 pub fn new_for_cpu<F, Output>(
168 priority: Priority,
169 estimated_cost: u32,
170 work_fn: F,
171 result_tx: mpsc::Sender<Result<Output, TaskError>>,
172 ) -> Self
173 where
174 F: FnOnce() -> Result<Output, TaskError> + Send + 'static,
175 Output: Send + 'static,
176 {
177 let task_id = TaskId::new();
178
179 let runnable = Box::new(move || {
180 let task_result = match catch_unwind(AssertUnwindSafe(work_fn)) {
182 Ok(result) => result,
183 Err(panic_payload) => {
184 Err(TaskError::Panicked(panic_payload_to_string(panic_payload)))
185 }
186 };
187
188 let outcome_before_send = match &task_result {
190 Ok(_) => TaskExecutionOutcome::Success,
191 Err(TaskError::Panicked(_)) => TaskExecutionOutcome::Panicked,
192 Err(TaskError::ExecutionFailed(_)) => TaskExecutionOutcome::LogicError,
193 Err(TaskError::TimedOut) => TaskExecutionOutcome::LogicError,
194 };
195
196 match result_tx.send(task_result) {
198 Ok(_) => outcome_before_send,
199 Err(_) => TaskExecutionOutcome::ResultSendFailed,
200 }
201 });
202
203 Task {
204 id: task_id,
205 priority,
206 estimated_cost: estimated_cost.clamp(1, 100),
207 runnable,
208 }
209 }
210
211 pub fn run(self) -> TaskExecutionOutcome {
215 (self.runnable)()
216 }
217}
218
219impl fmt::Debug for Task {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 f.debug_struct("Task")
222 .field("id", &self.id)
223 .field("priority", &self.priority)
224 .field("estimated_cost", &self.estimated_cost)
225 .field("runnable", &"<Box<dyn FnOnce() -> TaskExecutionOutcome>>")
226 .finish()
227 }
228}