Skip to main content

worldinterface_coordinator/
handler.rs

1//! Multiplexed handler — routes Coordinator and Step tasks.
2//!
3//! ActionQueue takes a single `ExecutorHandler`. This module provides
4//! `FlowHandler`, which inspects the task payload to determine the task
5//! type and routes to the appropriate execution path.
6
7use std::sync::Arc;
8
9use actionqueue_executor_local::handler::{ExecutorContext, ExecutorHandler, HandlerOutput};
10use serde::Deserialize;
11use worldinterface_connector::ConnectorRegistry;
12use worldinterface_contextstore::ContextStore;
13use worldinterface_core::metrics::MetricsRecorder;
14use worldinterface_flowspec::payload::{CoordinatorPayload, StepPayload, TaskType};
15
16/// Minimal struct for routing — only deserializes the task_type discriminator.
17#[derive(Deserialize)]
18struct PayloadPeek {
19    task_type: TaskType,
20}
21
22/// The top-level handler registered with ActionQueue.
23///
24/// Routes execution based on the `TaskType` discriminator in the payload.
25pub struct FlowHandler<S: ContextStore> {
26    registry: Arc<ConnectorRegistry>,
27    store: Arc<S>,
28    metrics: Arc<dyn MetricsRecorder>,
29}
30
31impl<S: ContextStore> FlowHandler<S> {
32    pub fn new(
33        registry: Arc<ConnectorRegistry>,
34        store: Arc<S>,
35        metrics: Arc<dyn MetricsRecorder>,
36    ) -> Self {
37        Self { registry, store, metrics }
38    }
39}
40
41impl<S: ContextStore + 'static> ExecutorHandler for FlowHandler<S> {
42    fn execute(&self, ctx: ExecutorContext) -> HandlerOutput {
43        // 1. Peek at the task_type discriminator
44        let peek: PayloadPeek = match serde_json::from_slice(&ctx.input.payload) {
45            Ok(p) => p,
46            Err(e) => {
47                return HandlerOutput::terminal_failure(format!(
48                    "payload deserialization failed: {e}"
49                ));
50            }
51        };
52
53        // 2. Route based on task_type
54        match peek.task_type {
55            TaskType::Coordinator => {
56                let payload: CoordinatorPayload = match serde_json::from_slice(&ctx.input.payload) {
57                    Ok(p) => p,
58                    Err(e) => {
59                        return HandlerOutput::terminal_failure(format!(
60                            "coordinator payload deserialization failed: {e}"
61                        ));
62                    }
63                };
64                crate::coordinator::execute_coordinator(&ctx, &payload, self.store.as_ref())
65            }
66            TaskType::Step => {
67                let payload: StepPayload = match serde_json::from_slice(&ctx.input.payload) {
68                    Ok(p) => p,
69                    Err(e) => {
70                        return HandlerOutput::terminal_failure(format!(
71                            "step payload deserialization failed: {e}"
72                        ));
73                    }
74                };
75                crate::step::execute_step(
76                    &ctx,
77                    &payload,
78                    &self.registry,
79                    &self.store,
80                    &*self.metrics,
81                )
82            }
83        }
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use actionqueue_core::ids::{AttemptId, RunId};
90    use actionqueue_core::task::safety::SafetyLevel;
91    use actionqueue_executor_local::handler::{AttemptMetadata, CancellationContext, HandlerInput};
92    use serde_json::json;
93    use worldinterface_connector::connectors::default_registry;
94    use worldinterface_contextstore::SqliteContextStore;
95    use worldinterface_core::flowspec::{ConnectorNode, NodeType};
96    use worldinterface_core::id::{FlowRunId, NodeId};
97    use worldinterface_core::metrics::NoopMetricsRecorder;
98    use worldinterface_flowspec::payload::{StepPayload, TaskType};
99
100    use super::*;
101
102    fn make_handler() -> (FlowHandler<SqliteContextStore>, Arc<SqliteContextStore>) {
103        let store = Arc::new(SqliteContextStore::in_memory().unwrap());
104        let registry = Arc::new(default_registry());
105        let metrics: Arc<dyn MetricsRecorder> = Arc::new(NoopMetricsRecorder);
106        let handler = FlowHandler::new(registry, Arc::clone(&store), metrics);
107        (handler, store)
108    }
109
110    fn make_ctx_with_payload(payload_bytes: Vec<u8>) -> ExecutorContext {
111        ExecutorContext {
112            input: HandlerInput {
113                run_id: RunId::new(),
114                attempt_id: AttemptId::new(),
115                payload: payload_bytes,
116                metadata: AttemptMetadata {
117                    max_attempts: 3,
118                    attempt_number: 1,
119                    timeout_secs: None,
120                    safety_level: SafetyLevel::Idempotent,
121                },
122                cancellation_context: CancellationContext::new(),
123            },
124            submission: None,
125            children: None,
126        }
127    }
128
129    // T-1: Multiplexed Handler Routing
130
131    #[test]
132    fn routes_step_payload() {
133        let (handler, store) = make_handler();
134        let fr = FlowRunId::new();
135        let node_id = NodeId::new();
136
137        let payload = StepPayload {
138            task_type: TaskType::Step,
139            flow_run_id: fr,
140            node_id,
141            node_type: NodeType::Connector(ConnectorNode {
142                connector: "delay".into(),
143                params: json!({"duration_ms": 10}),
144                idempotency_config: None,
145            }),
146            flow_params: None,
147        };
148        let bytes = serde_json::to_vec(&payload).unwrap();
149        let ctx = make_ctx_with_payload(bytes);
150
151        let result = handler.execute(ctx);
152        assert!(matches!(result, HandlerOutput::Success { .. }));
153        assert!(store.get(fr, node_id).unwrap().is_some());
154    }
155
156    #[test]
157    fn routes_coordinator_payload() {
158        let (handler, _store) = make_handler();
159
160        // A coordinator with no children snapshot suspends on first dispatch
161        let fr = FlowRunId::new();
162        let node_id = NodeId::new();
163        let spec = worldinterface_core::flowspec::FlowSpec {
164            id: None,
165            name: None,
166            nodes: vec![worldinterface_core::flowspec::Node {
167                id: node_id,
168                label: None,
169                node_type: NodeType::Connector(ConnectorNode {
170                    connector: "delay".into(),
171                    params: json!({"duration_ms": 10}),
172                    idempotency_config: None,
173                }),
174            }],
175            edges: vec![],
176            params: None,
177        };
178
179        let task_id = worldinterface_flowspec::id::derive_task_id(fr, node_id);
180        let mut node_task_map = std::collections::HashMap::new();
181        node_task_map.insert(node_id, task_id);
182
183        let payload = worldinterface_flowspec::payload::CoordinatorPayload {
184            task_type: TaskType::Coordinator,
185            flow_spec: spec,
186            flow_run_id: fr,
187            node_task_map,
188            dependencies: std::collections::HashMap::new(),
189        };
190        let bytes = serde_json::to_vec(&payload).unwrap();
191
192        // Need to provide a submission port for the coordinator
193        let submission = Arc::new(TestSubmissionPort::new());
194        let ctx = ExecutorContext {
195            input: HandlerInput {
196                run_id: RunId::new(),
197                attempt_id: AttemptId::new(),
198                payload: bytes,
199                metadata: AttemptMetadata {
200                    max_attempts: 1,
201                    attempt_number: 1,
202                    timeout_secs: None,
203                    safety_level: SafetyLevel::Pure,
204                },
205                cancellation_context: CancellationContext::new(),
206            },
207            submission: Some(submission.clone()),
208            children: None,
209        };
210
211        let result = handler.execute(ctx);
212        // Coordinator should submit the step and suspend
213        assert!(matches!(result, HandlerOutput::Suspended { .. }));
214        assert_eq!(submission.submitted_count(), 1);
215    }
216
217    #[test]
218    fn rejects_corrupt_payload() {
219        let (handler, _) = make_handler();
220        let ctx = make_ctx_with_payload(b"not valid json".to_vec());
221        let result = handler.execute(ctx);
222        assert!(matches!(result, HandlerOutput::TerminalFailure { .. }));
223    }
224
225    #[test]
226    fn rejects_unknown_task_type() {
227        let (handler, _) = make_handler();
228        let ctx =
229            make_ctx_with_payload(serde_json::to_vec(&json!({"task_type": "unknown"})).unwrap());
230        let result = handler.execute(ctx);
231        assert!(matches!(result, HandlerOutput::TerminalFailure { .. }));
232    }
233
234    /// Test submission port that records submissions.
235    struct TestSubmissionPort {
236        submissions: std::sync::Mutex<
237            Vec<(actionqueue_core::task::task_spec::TaskSpec, Vec<actionqueue_core::ids::TaskId>)>,
238        >,
239    }
240
241    impl TestSubmissionPort {
242        fn new() -> Self {
243            Self { submissions: std::sync::Mutex::new(Vec::new()) }
244        }
245
246        fn submitted_count(&self) -> usize {
247            self.submissions.lock().unwrap().len()
248        }
249    }
250
251    impl actionqueue_executor_local::handler::TaskSubmissionPort for TestSubmissionPort {
252        fn submit(
253            &self,
254            task_spec: actionqueue_core::task::task_spec::TaskSpec,
255            dependencies: Vec<actionqueue_core::ids::TaskId>,
256        ) {
257            self.submissions.lock().unwrap().push((task_spec, dependencies));
258        }
259    }
260}