simple_kube_controller/
dag.rs

1use std::{fmt::Debug, future::Future, marker::PhantomData, sync::Arc, time::Duration};
2
3use kube::{runtime::controller::Action, Resource};
4use serde::de::DeserializeOwned;
5use tokio::task::JoinSet;
6use tracing::{debug, error, instrument};
7
8use crate::{record_resource_metadata, Reconciler};
9
10/// A task status.
11#[derive(Clone, Copy, Debug, Eq, PartialEq)]
12pub enum TaskStatus {
13    /// The task is completed.
14    Completed,
15    /// The task is still in progress.
16    InProgress,
17}
18
19/// A task.
20pub trait Task<
21    CONTEXT: Send + Sync,
22    ERROR: std::error::Error + Send + Sync,
23    RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
24>: Send + Sync
25{
26    /// Name of the task.
27    fn name(&self) -> &str;
28
29    /// Run the task.
30    /// Arguments:
31    ///   - `res`: the resource
32    ///   - `ctx`: the context of the application
33    ///
34    /// It should return [`TaskStatus::Completed`] if the task is done and [`TaskStatus::InProgress`] if it is not.
35    fn run(
36        &self,
37        res: Arc<RESOURCE>,
38        ctx: Arc<CONTEXT>,
39    ) -> impl Future<Output = Result<TaskStatus, ERROR>> + Send;
40}
41
42/// A DAG.
43pub struct Dag<
44    CONTEXT: Send + Sync,
45    ERROR: std::error::Error + Send + Sync,
46    RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
47    TASK: Task<CONTEXT, ERROR, RESOURCE>,
48> {
49    action: Action,
50    requeue_delay: Duration,
51    tasks: Vec<Vec<Arc<TASK>>>,
52    _ctx: PhantomData<CONTEXT>,
53    _err: PhantomData<ERROR>,
54    _res: PhantomData<RESOURCE>,
55}
56
57impl<
58        CONTEXT: Send + Sync + 'static,
59        ERROR: std::error::Error + Send + Sync + 'static,
60        RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync + 'static,
61        TASK: Task<CONTEXT, ERROR, RESOURCE> + 'static,
62    > Dag<CONTEXT, ERROR, RESOURCE, TASK>
63{
64    /// Run a DAG.
65    /// Arguments:
66    ///   - `res`: the resource
67    ///   - `ctx`: the context of the application
68    #[instrument(
69        fields(
70            resource.api_version = %RESOURCE::api_version(&()),
71            resource.name,
72            resource.namespace,
73        ),
74        skip(self, res, ctx)
75    )]
76    pub async fn run(&self, res: Arc<RESOURCE>, ctx: Arc<CONTEXT>) -> Result<Action, ERROR> {
77        record_resource_metadata!(res.meta());
78        let mut idx = 0;
79        while idx < self.tasks.len() {
80            let mut completed = 0;
81            let tasks = &self.tasks[idx];
82            let mut handles = JoinSet::new();
83            for task in tasks {
84                let task = task.clone();
85                let res = res.clone();
86                let ctx = ctx.clone();
87                let name = task.name().to_string();
88                debug!("starting task `{name}`");
89                handles.spawn(async move { task.run(res, ctx).await.map(|status| (name, status)) });
90            }
91            while let Some(res) = handles.join_next().await {
92                match res {
93                    Ok(Ok((name, TaskStatus::Completed))) => {
94                        debug!("task `{name}` completed");
95                        completed += 1;
96                    }
97                    Ok(Ok((task, TaskStatus::InProgress))) => {
98                        debug!("task `{task}` still in progress");
99                    }
100                    Ok(Err(err)) => return Err(err),
101                    Err(err) => {
102                        error!("failed to wait for task run: {err}");
103                    }
104                }
105            }
106            if completed == tasks.len() {
107                idx += 1;
108            } else {
109                break;
110            }
111        }
112        if idx == self.tasks.len() {
113            Ok(self.action.clone())
114        } else {
115            Ok(Action::requeue(self.requeue_delay))
116        }
117    }
118}
119
120impl<
121        CONTEXT: Send + Sync,
122        ERROR: std::error::Error + Send + Sync,
123        RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
124        TASK: Task<CONTEXT, ERROR, RESOURCE>,
125    > Default for Dag<CONTEXT, ERROR, RESOURCE, TASK>
126{
127    fn default() -> Self {
128        Self {
129            action: Action::await_change(),
130            requeue_delay: Duration::from_secs(15),
131            tasks: vec![],
132            _ctx: PhantomData,
133            _err: PhantomData,
134            _res: PhantomData,
135        }
136    }
137}
138
139/// A DAG builder.
140#[derive(Default)]
141pub struct DagBuilder<
142    CONTEXT: Send + Sync,
143    ERROR: std::error::Error + Send + Sync,
144    RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
145    TASK: Task<CONTEXT, ERROR, RESOURCE>,
146>(Dag<CONTEXT, ERROR, RESOURCE, TASK>);
147
148impl<
149        CONTEXT: Send + Sync,
150        ERROR: std::error::Error + Send + Sync,
151        RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
152        TASK: Task<CONTEXT, ERROR, RESOURCE>,
153    > DagBuilder<CONTEXT, ERROR, RESOURCE, TASK>
154{
155    /// Create a DAG builder.
156    pub fn new() -> Self {
157        Self(Default::default())
158    }
159
160    /// Define the final action of the DAG (when all tasks are completed).
161    pub fn action(mut self, action: Action) -> Self {
162        self.0.action = action;
163        self
164    }
165
166    /// Define the delay between two reconciliations.
167    pub fn requeue_delay(mut self, delay: Duration) -> Self {
168        self.0.requeue_delay = delay;
169        self
170    }
171
172    /// Define the first group tasks to run. The tasks will be run asynchronously.
173    pub fn start_with<TASKS: IntoIterator<Item = TASK>>(
174        mut self,
175        tasks: TASKS,
176    ) -> DagBuilderThen<CONTEXT, ERROR, RESOURCE, TASK> {
177        self.0.tasks = vec![tasks.into_iter().map(Arc::new).collect()];
178        DagBuilderThen(self.0)
179    }
180}
181
182/// A DAG builder with the first group of tasks already configured.
183pub struct DagBuilderThen<
184    CONTEXT: Send + Sync,
185    ERROR: std::error::Error + Send + Sync,
186    RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
187    TASK: Task<CONTEXT, ERROR, RESOURCE>,
188>(Dag<CONTEXT, ERROR, RESOURCE, TASK>);
189
190impl<
191        CONTEXT: Send + Sync,
192        ERROR: std::error::Error + Send + Sync,
193        RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
194        TASK: Task<CONTEXT, ERROR, RESOURCE>,
195    > DagBuilderThen<CONTEXT, ERROR, RESOURCE, TASK>
196{
197    /// Define the final action of the DAG (when all tasks are completed).
198    pub fn action(mut self, action: Action) -> Self {
199        self.0.action = action;
200        self
201    }
202
203    /// Build the DAG.
204    pub fn build(self) -> Dag<CONTEXT, ERROR, RESOURCE, TASK> {
205        self.0
206    }
207
208    /// Define the delay between two reconciliations.
209    pub fn requeue_delay(mut self, delay: Duration) -> Self {
210        self.0.requeue_delay = delay;
211        self
212    }
213
214    /// Define the next group of tasks to run. It will be started when all tasks of the current group are completed.
215    pub fn then<TASKS: IntoIterator<Item = TASK>>(mut self, tasks: TASKS) -> Self {
216        self.0.tasks.push(tasks.into_iter().map(Arc::new).collect());
217        self
218    }
219}
220
221/// A DAG reconciler.
222#[derive(Default)]
223pub struct DagReconciler<
224    CONTEXT: Send + Sync,
225    ERROR: std::error::Error + Send + Sync,
226    RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
227    TASK: Task<CONTEXT, ERROR, RESOURCE>,
228> {
229    on_create_or_update: Option<Dag<CONTEXT, ERROR, RESOURCE, TASK>>,
230    on_delete: Option<Dag<CONTEXT, ERROR, RESOURCE, TASK>>,
231}
232
233impl<
234        CONTEXT: Send + Sync + 'static,
235        ERROR: std::error::Error + Send + Sync + 'static,
236        RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync + 'static,
237        TASK: Task<CONTEXT, ERROR, RESOURCE> + 'static,
238    > DagReconciler<CONTEXT, ERROR, RESOURCE, TASK>
239{
240    /// Create a new DAG reconciler.
241    pub fn new() -> Self {
242        Self {
243            on_create_or_update: None,
244            on_delete: None,
245        }
246    }
247
248    /// Define the DAG to run when a resource is created/updated.
249    pub fn on_create_or_update(mut self, dag: Dag<CONTEXT, ERROR, RESOURCE, TASK>) -> Self {
250        self.on_create_or_update = Some(dag);
251        self
252    }
253
254    /// Define the DAG to run when a resource is deleted.
255    pub fn on_delete(mut self, dag: Dag<CONTEXT, ERROR, RESOURCE, TASK>) -> Self {
256        self.on_delete = Some(dag);
257        self
258    }
259}
260
261impl<
262        CONTEXT: Send + Sync + 'static,
263        ERROR: std::error::Error + Send + Sync + 'static,
264        RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync + 'static,
265        TASK: Task<CONTEXT, ERROR, RESOURCE> + 'static,
266    > Reconciler<CONTEXT, ERROR, RESOURCE> for DagReconciler<CONTEXT, ERROR, RESOURCE, TASK>
267{
268    #[instrument(
269        fields(
270            resource.api_version = %RESOURCE::api_version(&()),
271            resource.name,
272            resource.namespace,
273        ),
274        skip(self, res, ctx)
275    )]
276    async fn reconcile_creation_or_update(
277        &self,
278        res: Arc<RESOURCE>,
279        ctx: Arc<CONTEXT>,
280    ) -> Result<Action, ERROR> {
281        record_resource_metadata!(res.meta());
282        if let Some(dag) = &self.on_create_or_update {
283            dag.run(res, ctx).await
284        } else {
285            Ok(Action::await_change())
286        }
287    }
288
289    #[instrument(
290        fields(
291            resource.api_version = %RESOURCE::api_version(&()),
292            resource.name,
293            resource.namespace,
294        ),
295        skip(self, res, ctx)
296    )]
297    async fn reconcile_deletion(
298        &self,
299        res: Arc<RESOURCE>,
300        ctx: Arc<CONTEXT>,
301    ) -> Result<Action, ERROR> {
302        record_resource_metadata!(res.meta());
303        if let Some(dag) = &self.on_delete {
304            dag.run(res, ctx).await
305        } else {
306            Ok(Action::await_change())
307        }
308    }
309}
310
311#[cfg(test)]
312mod test {
313    use std::convert::Infallible;
314
315    use k8s_openapi::api::core::v1::Namespace;
316
317    use super::*;
318
319    struct DummyTask(TaskStatus);
320
321    impl Task<(), Infallible, Namespace> for DummyTask {
322        fn name(&self) -> &str {
323            "dummy"
324        }
325
326        async fn run(&self, _res: Arc<Namespace>, _ctx: Arc<()>) -> Result<TaskStatus, Infallible> {
327            Ok(self.0)
328        }
329    }
330
331    mod dag {
332        use super::*;
333
334        mod run {
335            use super::*;
336
337            #[tokio::test]
338            async fn when_first_group_in_progress() {
339                let delay = Duration::from_secs(10);
340                let dag = DagBuilder::new()
341                    .requeue_delay(delay)
342                    .start_with([
343                        DummyTask(TaskStatus::InProgress),
344                        DummyTask(TaskStatus::Completed),
345                    ])
346                    .then([DummyTask(TaskStatus::InProgress)])
347                    .build();
348                let action = dag
349                    .run(Arc::new(Namespace::default()), Arc::new(()))
350                    .await
351                    .unwrap();
352                assert_eq!(action, Action::requeue(delay));
353            }
354
355            #[tokio::test]
356            async fn when_second_group_in_progress() {
357                let delay = Duration::from_secs(10);
358                let dag = DagBuilder::new()
359                    .requeue_delay(delay)
360                    .start_with([
361                        DummyTask(TaskStatus::Completed),
362                        DummyTask(TaskStatus::Completed),
363                    ])
364                    .then([DummyTask(TaskStatus::InProgress)])
365                    .build();
366                let action = dag
367                    .run(Arc::new(Namespace::default()), Arc::new(()))
368                    .await
369                    .unwrap();
370                assert_eq!(action, Action::requeue(delay));
371            }
372
373            #[tokio::test]
374            async fn when_completed() {
375                let exepcted = Action::await_change();
376                let dag = DagBuilder::new()
377                    .action(exepcted.clone())
378                    .start_with([
379                        DummyTask(TaskStatus::Completed),
380                        DummyTask(TaskStatus::Completed),
381                    ])
382                    .then([DummyTask(TaskStatus::Completed)])
383                    .build();
384                let action = dag
385                    .run(Arc::new(Namespace::default()), Arc::new(()))
386                    .await
387                    .unwrap();
388                assert_eq!(action, exepcted);
389            }
390        }
391    }
392}