1use std::{fmt::Debug, future::Future, marker::PhantomData, sync::Arc, time::Duration};
2
3use kube::{runtime::controller::Action, Resource};
4use serde::de::DeserializeOwned;
5use tokio::task::JoinSet;
6use tracing::{debug, error, instrument};
7
8use crate::{record_resource_metadata, Reconciler};
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq)]
12pub enum TaskStatus {
13 Completed,
15 InProgress,
17}
18
19pub trait Task<
21 CONTEXT: Send + Sync,
22 ERROR: std::error::Error + Send + Sync,
23 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
24>: Send + Sync
25{
26 fn name(&self) -> &str;
28
29 fn run(
36 &self,
37 res: Arc<RESOURCE>,
38 ctx: Arc<CONTEXT>,
39 ) -> impl Future<Output = Result<TaskStatus, ERROR>> + Send;
40}
41
42pub struct Dag<
44 CONTEXT: Send + Sync,
45 ERROR: std::error::Error + Send + Sync,
46 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
47 TASK: Task<CONTEXT, ERROR, RESOURCE>,
48> {
49 action: Action,
50 requeue_delay: Duration,
51 tasks: Vec<Vec<Arc<TASK>>>,
52 _ctx: PhantomData<CONTEXT>,
53 _err: PhantomData<ERROR>,
54 _res: PhantomData<RESOURCE>,
55}
56
57impl<
58 CONTEXT: Send + Sync + 'static,
59 ERROR: std::error::Error + Send + Sync + 'static,
60 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync + 'static,
61 TASK: Task<CONTEXT, ERROR, RESOURCE> + 'static,
62 > Dag<CONTEXT, ERROR, RESOURCE, TASK>
63{
64 #[instrument(
69 fields(
70 resource.api_version = %RESOURCE::api_version(&()),
71 resource.name,
72 resource.namespace,
73 ),
74 skip(self, res, ctx)
75 )]
76 pub async fn run(&self, res: Arc<RESOURCE>, ctx: Arc<CONTEXT>) -> Result<Action, ERROR> {
77 record_resource_metadata!(res.meta());
78 let mut idx = 0;
79 while idx < self.tasks.len() {
80 let mut completed = 0;
81 let tasks = &self.tasks[idx];
82 let mut handles = JoinSet::new();
83 for task in tasks {
84 let task = task.clone();
85 let res = res.clone();
86 let ctx = ctx.clone();
87 let name = task.name().to_string();
88 debug!("starting task `{name}`");
89 handles.spawn(async move { task.run(res, ctx).await.map(|status| (name, status)) });
90 }
91 while let Some(res) = handles.join_next().await {
92 match res {
93 Ok(Ok((name, TaskStatus::Completed))) => {
94 debug!("task `{name}` completed");
95 completed += 1;
96 }
97 Ok(Ok((task, TaskStatus::InProgress))) => {
98 debug!("task `{task}` still in progress");
99 }
100 Ok(Err(err)) => return Err(err),
101 Err(err) => {
102 error!("failed to wait for task run: {err}");
103 }
104 }
105 }
106 if completed == tasks.len() {
107 idx += 1;
108 } else {
109 break;
110 }
111 }
112 if idx == self.tasks.len() {
113 Ok(self.action.clone())
114 } else {
115 Ok(Action::requeue(self.requeue_delay))
116 }
117 }
118}
119
120impl<
121 CONTEXT: Send + Sync,
122 ERROR: std::error::Error + Send + Sync,
123 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
124 TASK: Task<CONTEXT, ERROR, RESOURCE>,
125 > Default for Dag<CONTEXT, ERROR, RESOURCE, TASK>
126{
127 fn default() -> Self {
128 Self {
129 action: Action::await_change(),
130 requeue_delay: Duration::from_secs(15),
131 tasks: vec![],
132 _ctx: PhantomData,
133 _err: PhantomData,
134 _res: PhantomData,
135 }
136 }
137}
138
139#[derive(Default)]
141pub struct DagBuilder<
142 CONTEXT: Send + Sync,
143 ERROR: std::error::Error + Send + Sync,
144 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
145 TASK: Task<CONTEXT, ERROR, RESOURCE>,
146>(Dag<CONTEXT, ERROR, RESOURCE, TASK>);
147
148impl<
149 CONTEXT: Send + Sync,
150 ERROR: std::error::Error + Send + Sync,
151 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
152 TASK: Task<CONTEXT, ERROR, RESOURCE>,
153 > DagBuilder<CONTEXT, ERROR, RESOURCE, TASK>
154{
155 pub fn new() -> Self {
157 Self(Default::default())
158 }
159
160 pub fn action(mut self, action: Action) -> Self {
162 self.0.action = action;
163 self
164 }
165
166 pub fn requeue_delay(mut self, delay: Duration) -> Self {
168 self.0.requeue_delay = delay;
169 self
170 }
171
172 pub fn start_with<TASKS: IntoIterator<Item = TASK>>(
174 mut self,
175 tasks: TASKS,
176 ) -> DagBuilderThen<CONTEXT, ERROR, RESOURCE, TASK> {
177 self.0.tasks = vec![tasks.into_iter().map(Arc::new).collect()];
178 DagBuilderThen(self.0)
179 }
180}
181
182pub struct DagBuilderThen<
184 CONTEXT: Send + Sync,
185 ERROR: std::error::Error + Send + Sync,
186 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
187 TASK: Task<CONTEXT, ERROR, RESOURCE>,
188>(Dag<CONTEXT, ERROR, RESOURCE, TASK>);
189
190impl<
191 CONTEXT: Send + Sync,
192 ERROR: std::error::Error + Send + Sync,
193 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
194 TASK: Task<CONTEXT, ERROR, RESOURCE>,
195 > DagBuilderThen<CONTEXT, ERROR, RESOURCE, TASK>
196{
197 pub fn action(mut self, action: Action) -> Self {
199 self.0.action = action;
200 self
201 }
202
203 pub fn build(self) -> Dag<CONTEXT, ERROR, RESOURCE, TASK> {
205 self.0
206 }
207
208 pub fn requeue_delay(mut self, delay: Duration) -> Self {
210 self.0.requeue_delay = delay;
211 self
212 }
213
214 pub fn then<TASKS: IntoIterator<Item = TASK>>(mut self, tasks: TASKS) -> Self {
216 self.0.tasks.push(tasks.into_iter().map(Arc::new).collect());
217 self
218 }
219}
220
221#[derive(Default)]
223pub struct DagReconciler<
224 CONTEXT: Send + Sync,
225 ERROR: std::error::Error + Send + Sync,
226 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync,
227 TASK: Task<CONTEXT, ERROR, RESOURCE>,
228> {
229 on_create_or_update: Option<Dag<CONTEXT, ERROR, RESOURCE, TASK>>,
230 on_delete: Option<Dag<CONTEXT, ERROR, RESOURCE, TASK>>,
231}
232
233impl<
234 CONTEXT: Send + Sync + 'static,
235 ERROR: std::error::Error + Send + Sync + 'static,
236 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync + 'static,
237 TASK: Task<CONTEXT, ERROR, RESOURCE> + 'static,
238 > DagReconciler<CONTEXT, ERROR, RESOURCE, TASK>
239{
240 pub fn new() -> Self {
242 Self {
243 on_create_or_update: None,
244 on_delete: None,
245 }
246 }
247
248 pub fn on_create_or_update(mut self, dag: Dag<CONTEXT, ERROR, RESOURCE, TASK>) -> Self {
250 self.on_create_or_update = Some(dag);
251 self
252 }
253
254 pub fn on_delete(mut self, dag: Dag<CONTEXT, ERROR, RESOURCE, TASK>) -> Self {
256 self.on_delete = Some(dag);
257 self
258 }
259}
260
261impl<
262 CONTEXT: Send + Sync + 'static,
263 ERROR: std::error::Error + Send + Sync + 'static,
264 RESOURCE: Clone + Debug + DeserializeOwned + Resource<DynamicType = ()> + Send + Sync + 'static,
265 TASK: Task<CONTEXT, ERROR, RESOURCE> + 'static,
266 > Reconciler<CONTEXT, ERROR, RESOURCE> for DagReconciler<CONTEXT, ERROR, RESOURCE, TASK>
267{
268 #[instrument(
269 fields(
270 resource.api_version = %RESOURCE::api_version(&()),
271 resource.name,
272 resource.namespace,
273 ),
274 skip(self, res, ctx)
275 )]
276 async fn reconcile_creation_or_update(
277 &self,
278 res: Arc<RESOURCE>,
279 ctx: Arc<CONTEXT>,
280 ) -> Result<Action, ERROR> {
281 record_resource_metadata!(res.meta());
282 if let Some(dag) = &self.on_create_or_update {
283 dag.run(res, ctx).await
284 } else {
285 Ok(Action::await_change())
286 }
287 }
288
289 #[instrument(
290 fields(
291 resource.api_version = %RESOURCE::api_version(&()),
292 resource.name,
293 resource.namespace,
294 ),
295 skip(self, res, ctx)
296 )]
297 async fn reconcile_deletion(
298 &self,
299 res: Arc<RESOURCE>,
300 ctx: Arc<CONTEXT>,
301 ) -> Result<Action, ERROR> {
302 record_resource_metadata!(res.meta());
303 if let Some(dag) = &self.on_delete {
304 dag.run(res, ctx).await
305 } else {
306 Ok(Action::await_change())
307 }
308 }
309}
310
311#[cfg(test)]
312mod test {
313 use std::convert::Infallible;
314
315 use k8s_openapi::api::core::v1::Namespace;
316
317 use super::*;
318
319 struct DummyTask(TaskStatus);
320
321 impl Task<(), Infallible, Namespace> for DummyTask {
322 fn name(&self) -> &str {
323 "dummy"
324 }
325
326 async fn run(&self, _res: Arc<Namespace>, _ctx: Arc<()>) -> Result<TaskStatus, Infallible> {
327 Ok(self.0)
328 }
329 }
330
331 mod dag {
332 use super::*;
333
334 mod run {
335 use super::*;
336
337 #[tokio::test]
338 async fn when_first_group_in_progress() {
339 let delay = Duration::from_secs(10);
340 let dag = DagBuilder::new()
341 .requeue_delay(delay)
342 .start_with([
343 DummyTask(TaskStatus::InProgress),
344 DummyTask(TaskStatus::Completed),
345 ])
346 .then([DummyTask(TaskStatus::InProgress)])
347 .build();
348 let action = dag
349 .run(Arc::new(Namespace::default()), Arc::new(()))
350 .await
351 .unwrap();
352 assert_eq!(action, Action::requeue(delay));
353 }
354
355 #[tokio::test]
356 async fn when_second_group_in_progress() {
357 let delay = Duration::from_secs(10);
358 let dag = DagBuilder::new()
359 .requeue_delay(delay)
360 .start_with([
361 DummyTask(TaskStatus::Completed),
362 DummyTask(TaskStatus::Completed),
363 ])
364 .then([DummyTask(TaskStatus::InProgress)])
365 .build();
366 let action = dag
367 .run(Arc::new(Namespace::default()), Arc::new(()))
368 .await
369 .unwrap();
370 assert_eq!(action, Action::requeue(delay));
371 }
372
373 #[tokio::test]
374 async fn when_completed() {
375 let exepcted = Action::await_change();
376 let dag = DagBuilder::new()
377 .action(exepcted.clone())
378 .start_with([
379 DummyTask(TaskStatus::Completed),
380 DummyTask(TaskStatus::Completed),
381 ])
382 .then([DummyTask(TaskStatus::Completed)])
383 .build();
384 let action = dag
385 .run(Arc::new(Namespace::default()), Arc::new(()))
386 .await
387 .unwrap();
388 assert_eq!(action, exepcted);
389 }
390 }
391 }
392}