smelter_worker/
lib.rs

1#![warn(missing_docs)]
2#![warn(clippy::missing_docs_in_private_items)]
3
4//! Common code used for communicating between the job manager and worker tasks.
5
6use std::{collections::HashMap, fmt::Debug, future::Future, io::Read};
7
8use error_stack::Report;
9#[cfg(feature = "opentelemetry")]
10use opentelemetry::sdk::propagation::TraceContextPropagator;
11use serde::{de::DeserializeOwned, Deserialize, Serialize};
12use thiserror::Error;
13#[cfg(feature = "tracing")]
14use tracing::{event, Level};
15use uuid::Uuid;
16
17#[cfg(feature = "stats")]
18pub mod stats;
19
20/// The ID for a subtask, which uniquely identifies it within a [Job].
21#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
22pub struct SubtaskId {
23    /// The ID of the job.
24    pub job: Uuid,
25    /// Which stage the subtask is running on.
26    pub stage: u16,
27    /// The index of the task within that stage.
28    pub task: u32,
29    /// Which retry of this task is being executed.
30    pub try_num: u16,
31}
32
33impl std::fmt::Display for SubtaskId {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(
36            f,
37            "{}-{:03}-{:05}-{:02}",
38            self.job, self.stage, self.task, self.try_num
39        )
40    }
41}
42
43#[cfg(feature = "opentelemetry")]
44/// Encode the current trace context so that it can be passed across process lines.
45pub fn get_trace_context() -> HashMap<String, String> {
46    use opentelemetry::propagation::TextMapPropagator;
47    use tracing_opentelemetry::OpenTelemetrySpanExt;
48
49    let span = tracing::Span::current();
50    let context = span.context();
51    let propagator = TraceContextPropagator::new();
52    let mut fields = HashMap::new();
53    propagator.inject_context(&context, &mut fields);
54    fields
55}
56
57/// If tracing is enabled, propagate the trace context from the spawner into the current span.
58pub fn propagate_tracing_context(trace_context: &HashMap<String, String>) {
59    #![allow(unused_variables)]
60    #[cfg(feature = "opentelemetry")]
61    {
62        use opentelemetry::propagation::TextMapPropagator;
63        use tracing_opentelemetry::OpenTelemetrySpanExt;
64
65        let propagator = TraceContextPropagator::new();
66        let context = propagator.extract(trace_context);
67        let span = tracing::Span::current();
68        span.set_parent(context);
69    }
70}
71
72#[derive(Debug)]
73/// The input payload that a worker will read when starting
74pub struct WorkerInput<T> {
75    /// The ID for this task
76    pub task_id: SubtaskId,
77    #[cfg(feature = "opentelemetry")]
78    /// Propagated trace context so that this worker can show up as a child of the parent span
79    /// from the spawner.
80    pub trace_context: std::collections::HashMap<String, String>,
81    /// Worker-specific input data
82    pub input: T,
83    /// A signal that will change to true if the task has cancelled. It is safe to skip reading the
84    /// value if you are waiting for a change notification; the channel will never be changed to
85    /// `false`.
86    pub cancelled: tokio::sync::watch::Receiver<bool>,
87}
88
89#[derive(Debug)]
90#[cfg_attr(feature = "worker-side", derive(Serialize))]
91#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
92/// The input payload that a worker will read when starting
93pub struct WorkerInputPayload<T> {
94    /// The ID for this task
95    pub task_id: SubtaskId,
96    #[cfg(feature = "opentelemetry")]
97    #[serde(default)]
98    /// Propagated trace context so that this worker can show up as a child of the parent span
99    /// from the spawner.
100    pub trace_context: std::collections::HashMap<String, String>,
101    /// Worker-specific input data
102    pub input: T,
103}
104
105#[cfg(feature = "spawner-side")]
106impl<T> WorkerInputPayload<T> {
107    /// Create a new [WorkerInputPayload]
108    pub fn new(task_id: SubtaskId, input: T) -> Self {
109        #[cfg(feature = "opentelemetry")]
110        let trace_context = get_trace_context();
111
112        Self {
113            task_id,
114            #[cfg(feature = "opentelemetry")]
115            trace_context,
116            input,
117        }
118    }
119}
120
121#[cfg(feature = "worker-side")]
122impl<T: DeserializeOwned + 'static> WorkerInputPayload<T> {
123    /// Propagate the trace context from the spawner into the current span.
124    pub fn propagate_tracing_context(&self) {
125        #[cfg(feature = "opentelemetry")]
126        {
127            propagate_tracing_context(&self.trace_context);
128        }
129    }
130
131    /// Parse a [WorkerInput] and propagate the trace context, if present.
132    pub fn parse(input: impl Read) -> Result<Self, serde_json::Error> {
133        let input: WorkerInputPayload<T> = serde_json::from_reader(input)?;
134
135        input.propagate_tracing_context();
136
137        Ok(input)
138    }
139}
140
141#[cfg_attr(feature = "worker-side", derive(Serialize))]
142#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
143#[derive(Debug)]
144/// A serializable error returned from a worker
145pub struct WorkerError {
146    /// Whether the task can be retried after this error, or not
147    pub retryable: bool,
148    /// The serialized error
149    pub error: String,
150}
151
152impl WorkerError {
153    /// Create a [WorkerError] from any [Error](std:error::Error).
154    #[cfg(feature = "worker-side")]
155    pub fn from_error(retryable: bool, error: impl Debug) -> Self {
156        Self {
157            retryable,
158            error: format!("{:?}", error),
159        }
160    }
161}
162
163#[cfg(feature = "worker-side")]
164impl<E: std::error::Error> From<E> for WorkerError {
165    /// Convert any [Error](std::error::Error) into a [WorkerError].
166    /// This sets [WorkerError#retryable] to `true`.
167    fn from(e: E) -> Self {
168        Self {
169            retryable: true,
170            error: format!("{:?}", e),
171        }
172    }
173}
174
175#[cfg_attr(feature = "worker-side", derive(Serialize))]
176#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
177#[derive(Debug)]
178/// The output payload that the worker writes when a task finishes.
179pub struct WorkerOutput<T: Debug> {
180    /// The result of the task code
181    pub result: WorkerResult<T>,
182    #[cfg(feature = "stats")]
183    /// OS-level statistics about the task
184    pub stats: Option<crate::stats::Statistics>,
185}
186
187#[cfg(feature = "spawner-side")]
188impl<T: Debug + DeserializeOwned> WorkerOutput<T> {
189    /// Deserialize a WorkerOutput from a worker's output payload
190    pub fn from_output_payload(data: &[u8]) -> WorkerOutput<T> {
191        match serde_json::from_slice::<WorkerOutput<T>>(data) {
192            Ok(output) => output,
193            Err(e) => WorkerOutput {
194                result: WorkerResult::Err(WorkerError::from_error(false, e)),
195                #[cfg(feature = "stats")]
196                stats: None,
197            },
198        }
199    }
200}
201
202#[cfg_attr(feature = "worker-side", derive(Serialize))]
203#[cfg_attr(feature = "spawner-side", derive(Deserialize))]
204#[derive(Debug)]
205#[serde(tag = "type", content = "data", rename_all = "snake_case")]
206/// The result of a worker task
207pub enum WorkerResult<T: Debug> {
208    /// The result data from the worker when it succeeded
209    Ok(T),
210    /// The error from the worker when it failed
211    Err(WorkerError),
212}
213
214impl<T: Debug> WorkerResult<T> {
215    /// Convert a [WorkerResult] into a [std::result::Result]
216    pub fn into_result(self) -> Result<T, WorkerError> {
217        match self {
218            WorkerResult::Ok(r) => Ok(r),
219            WorkerResult::Err(e) => Err(e),
220        }
221    }
222}
223
224impl<T: Debug, E: std::error::Error> From<Result<T, E>> for WorkerResult<T> {
225    fn from(res: Result<T, E>) -> Self {
226        match res {
227            Ok(r) => WorkerResult::Ok(r),
228            Err(e) => WorkerResult::Err(WorkerError::from(e)),
229        }
230    }
231}
232
233impl<T: Debug> From<WorkerResult<T>> for Result<T, WorkerError> {
234    fn from(r: WorkerResult<T>) -> Result<T, WorkerError> {
235        match r {
236            WorkerResult::Ok(r) => Ok(r),
237            WorkerResult::Err(e) => Err(e),
238        }
239    }
240}
241
242/// An error that the worker wrapper framework may encounter
243#[derive(Debug, Error)]
244pub enum WrapperError {
245    /// Failed when initializing the worker environment
246    #[error("Error initializing worker")]
247    Initializing,
248    /// Failed to read the input payload
249    #[error("Failed to read input payload")]
250    ReadInput,
251    /// The input payload could not be serialized into the structure that the worker expected
252    #[error("Unexpected input payload format")]
253    UnexpectedInput,
254    /// Failed to write the output payload
255    #[error("Failed to write output payload")]
256    WriteOutput,
257}
258
259impl WrapperError {
260    /// Whether the error indicates a failure that could possibly be retried, or not
261    pub fn retryable(&self) -> bool {
262        match self {
263            WrapperError::Initializing => false,
264            WrapperError::ReadInput => true,
265            WrapperError::UnexpectedInput => false,
266            WrapperError::WriteOutput => true,
267        }
268    }
269}
270
271/// Run the worker and return its output, ready for writing back to the output payload location.
272/// Usually you will want to call the equivalent `run_worker` function in the platform-specific
273/// module instead.
274pub async fn run_worker<INPUT, F, FUT, OUTPUT, ERR>(
275    input: WorkerInputPayload<INPUT>,
276    f: F,
277) -> Result<WorkerOutput<OUTPUT>, Report<WrapperError>>
278where
279    F: FnOnce(WorkerInput<INPUT>) -> FUT,
280    FUT: Future<Output = Result<OUTPUT, ERR>>,
281    ERR: std::error::Error,
282    INPUT: Debug + DeserializeOwned + Send + 'static,
283    OUTPUT: Debug + Serialize + Send + 'static,
284{
285    #[cfg(feature = "stats")]
286    let stats = stats::track_system_stats();
287
288    let (cancel_tx, cancelled_rx) = tokio::sync::watch::channel(false);
289
290    tokio::task::spawn(async move {
291        let sig = tokio::signal::ctrl_c().await;
292        match sig {
293            Ok(_) => {
294                cancel_tx.send(true).ok();
295            }
296            Err(_) => {
297                #[cfg(feature = "tracing")]
298                event!(Level::WARN, "Failed to listen for SIGINT");
299            }
300        };
301    });
302
303    let input = WorkerInput {
304        task_id: input.task_id,
305        #[cfg(feature = "opentelemetry")]
306        trace_context: input.trace_context,
307        input: input.input,
308        cancelled: cancelled_rx,
309    };
310
311    let job_result = trace_result("running worker", f(input).await);
312
313    let result = WorkerOutput {
314        result: job_result.into(),
315        #[cfg(feature = "stats")]
316        stats: stats.finish().await,
317    };
318
319    Ok(result)
320}
321
322/// Trace the result of an operation.
323pub fn trace_result<R: Debug, E: Debug>(message: &str, result: Result<R, E>) -> Result<R, E> {
324    match result {
325        Ok(o) => {
326            #[cfg(feature = "tracing")]
327            event!(Level::INFO, output=?o, "{}", message);
328            Ok(o)
329        }
330        Err(e) => {
331            #[cfg(feature = "tracing")]
332            event!(Level::ERROR, error=?e, "{}", message);
333            Err(e)
334        }
335    }
336}