1#![warn(missing_docs)]
2#![warn(clippy::missing_docs_in_private_items)]
3
4use std::{collections::HashMap, fmt::Debug, future::Future, io::Read};
7
8use error_stack::Report;
9#[cfg(feature = "opentelemetry")]
10use opentelemetry::sdk::propagation::TraceContextPropagator;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use thiserror::Error;
13#[cfg(feature = "tracing")]
14use tracing::{event, Level};
15use uuid::Uuid;
16
17#[cfg(feature = "stats")]
18pub mod stats;
19
20#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub struct SubtaskId {
23 pub job: Uuid,
25 pub stage: u16,
27 pub task: u32,
29 pub try_num: u16,
31}
32
33impl std::fmt::Display for SubtaskId {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(
36 f,
37 "{}-{:03}-{:05}-{:02}",
38 self.job, self.stage, self.task, self.try_num
39 )
40 }
41}
42
43#[cfg(feature = "opentelemetry")]
44pub fn get_trace_context() -> HashMap<String, String> {
46 use opentelemetry::propagation::TextMapPropagator;
47 use tracing_opentelemetry::OpenTelemetrySpanExt;
48
49 let span = tracing::Span::current();
50 let context = span.context();
51 let propagator = TraceContextPropagator::new();
52 let mut fields = HashMap::new();
53 propagator.inject_context(&context, &mut fields);
54 fields
55}
56
57pub fn propagate_tracing_context(trace_context: &HashMap<String, String>) {
59 #![allow(unused_variables)]
60 #[cfg(feature = "opentelemetry")]
61 {
62 use opentelemetry::propagation::TextMapPropagator;
63 use tracing_opentelemetry::OpenTelemetrySpanExt;
64
65 let propagator = TraceContextPropagator::new();
66 let context = propagator.extract(trace_context);
67 let span = tracing::Span::current();
68 span.set_parent(context);
69 }
70}
71
72#[derive(Debug)]
73pub struct WorkerInput<T> {
75 pub task_id: SubtaskId,
77 #[cfg(feature = "opentelemetry")]
78 pub trace_context: std::collections::HashMap<String, String>,
81 pub input: T,
83 pub cancelled: tokio::sync::watch::Receiver<bool>,
87}
88
89#[derive(Debug)]
90#[cfg_attr(feature = "worker-side", derive(Serialize))]
91#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
92pub struct WorkerInputPayload<T> {
94 pub task_id: SubtaskId,
96 #[cfg(feature = "opentelemetry")]
97 #[serde(default)]
98 pub trace_context: std::collections::HashMap<String, String>,
101 pub input: T,
103}
104
105#[cfg(feature = "spawner-side")]
106impl<T> WorkerInputPayload<T> {
107 pub fn new(task_id: SubtaskId, input: T) -> Self {
109 #[cfg(feature = "opentelemetry")]
110 let trace_context = get_trace_context();
111
112 Self {
113 task_id,
114 #[cfg(feature = "opentelemetry")]
115 trace_context,
116 input,
117 }
118 }
119}
120
121#[cfg(feature = "worker-side")]
122impl<T: DeserializeOwned + 'static> WorkerInputPayload<T> {
123 pub fn propagate_tracing_context(&self) {
125 #[cfg(feature = "opentelemetry")]
126 {
127 propagate_tracing_context(&self.trace_context);
128 }
129 }
130
131 pub fn parse(input: impl Read) -> Result<Self, serde_json::Error> {
133 let input: WorkerInputPayload<T> = serde_json::from_reader(input)?;
134
135 input.propagate_tracing_context();
136
137 Ok(input)
138 }
139}
140
141#[cfg_attr(feature = "worker-side", derive(Serialize))]
142#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
143#[derive(Debug)]
144pub struct WorkerError {
146 pub retryable: bool,
148 pub error: String,
150}
151
152impl WorkerError {
153 #[cfg(feature = "worker-side")]
155 pub fn from_error(retryable: bool, error: impl Debug) -> Self {
156 Self {
157 retryable,
158 error: format!("{:?}", error),
159 }
160 }
161}
162
163#[cfg(feature = "worker-side")]
164impl<E: std::error::Error> From<E> for WorkerError {
165 fn from(e: E) -> Self {
168 Self {
169 retryable: true,
170 error: format!("{:?}", e),
171 }
172 }
173}
174
175#[cfg_attr(feature = "worker-side", derive(Serialize))]
176#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
177#[derive(Debug)]
178pub struct WorkerOutput<T: Debug> {
180 pub result: WorkerResult<T>,
182 #[cfg(feature = "stats")]
183 pub stats: Option<crate::stats::Statistics>,
185}
186
187#[cfg(feature = "spawner-side")]
188impl<T: Debug + DeserializeOwned> WorkerOutput<T> {
189 pub fn from_output_payload(data: &[u8]) -> WorkerOutput<T> {
191 match serde_json::from_slice::<WorkerOutput<T>>(data) {
192 Ok(output) => output,
193 Err(e) => WorkerOutput {
194 result: WorkerResult::Err(WorkerError::from_error(false, e)),
195 #[cfg(feature = "stats")]
196 stats: None,
197 },
198 }
199 }
200}
201
202#[cfg_attr(feature = "worker-side", derive(Serialize))]
203#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
204#[derive(Debug)]
205#[serde(tag = "type", content = "data", rename_all = "snake_case")]
206pub enum WorkerResult<T: Debug> {
208 Ok(T),
210 Err(WorkerError),
212}
213
214impl<T: Debug> WorkerResult<T> {
215 pub fn into_result(self) -> Result<T, WorkerError> {
217 match self {
218 WorkerResult::Ok(r) => Ok(r),
219 WorkerResult::Err(e) => Err(e),
220 }
221 }
222}
223
224impl<T: Debug, E: std::error::Error> From<Result<T, E>> for WorkerResult<T> {
225 fn from(res: Result<T, E>) -> Self {
226 match res {
227 Ok(r) => WorkerResult::Ok(r),
228 Err(e) => WorkerResult::Err(WorkerError::from(e)),
229 }
230 }
231}
232
233impl<T: Debug> From<WorkerResult<T>> for Result<T, WorkerError> {
234 fn from(r: WorkerResult<T>) -> Result<T, WorkerError> {
235 match r {
236 WorkerResult::Ok(r) => Ok(r),
237 WorkerResult::Err(e) => Err(e),
238 }
239 }
240}
241
242#[derive(Debug, Error)]
244pub enum WrapperError {
245 #[error("Error initializing worker")]
247 Initializing,
248 #[error("Failed to read input payload")]
250 ReadInput,
251 #[error("Unexpected input payload format")]
253 UnexpectedInput,
254 #[error("Failed to write output payload")]
256 WriteOutput,
257}
258
259impl WrapperError {
260 pub fn retryable(&self) -> bool {
262 match self {
263 WrapperError::Initializing => false,
264 WrapperError::ReadInput => true,
265 WrapperError::UnexpectedInput => false,
266 WrapperError::WriteOutput => true,
267 }
268 }
269}
270
271pub async fn run_worker<INPUT, F, FUT, OUTPUT, ERR>(
275 input: WorkerInputPayload<INPUT>,
276 f: F,
277) -> Result<WorkerOutput<OUTPUT>, Report<WrapperError>>
278where
279 F: FnOnce(WorkerInput<INPUT>) -> FUT,
280 FUT: Future<Output = Result<OUTPUT, ERR>>,
281 ERR: std::error::Error,
282 INPUT: Debug + DeserializeOwned + Send + 'static,
283 OUTPUT: Debug + Serialize + Send + 'static,
284{
285 #[cfg(feature = "stats")]
286 let stats = stats::track_system_stats();
287
288 let (cancel_tx, cancelled_rx) = tokio::sync::watch::channel(false);
289
290 tokio::task::spawn(async move {
291 let sig = tokio::signal::ctrl_c().await;
292 match sig {
293 Ok(_) => {
294 cancel_tx.send(true).ok();
295 }
296 Err(_) => {
297 #[cfg(feature = "tracing")]
298 event!(Level::WARN, "Failed to listen for SIGINT");
299 }
300 };
301 });
302
303 let input = WorkerInput {
304 task_id: input.task_id,
305 #[cfg(feature = "opentelemetry")]
306 trace_context: input.trace_context,
307 input: input.input,
308 cancelled: cancelled_rx,
309 };
310
311 let job_result = trace_result("running worker", f(input).await);
312
313 let result = WorkerOutput {
314 result: job_result.into(),
315 #[cfg(feature = "stats")]
316 stats: stats.finish().await,
317 };
318
319 Ok(result)
320}
321
322pub fn trace_result<R: Debug, E: Debug>(message: &str, result: Result<R, E>) -> Result<R, E> {
324 match result {
325 Ok(o) => {
326 #[cfg(feature = "tracing")]
327 event!(Level::INFO, output=?o, "{}", message);
328 Ok(o)
329 }
330 Err(e) => {
331 #[cfg(feature = "tracing")]
332 event!(Level::ERROR, error=?e, "{}", message);
333 Err(e)
334 }
335 }
336}