worldinterface_coordinator/
handler.rs1use std::sync::Arc;
8
9use actionqueue_executor_local::handler::{ExecutorContext, ExecutorHandler, HandlerOutput};
10use serde::Deserialize;
11use worldinterface_connector::ConnectorRegistry;
12use worldinterface_contextstore::ContextStore;
13use worldinterface_core::metrics::MetricsRecorder;
14use worldinterface_flowspec::payload::{CoordinatorPayload, StepPayload, TaskType};
15
16#[derive(Deserialize)]
18struct PayloadPeek {
19 task_type: TaskType,
20}
21
22pub struct FlowHandler<S: ContextStore> {
26 registry: Arc<ConnectorRegistry>,
27 store: Arc<S>,
28 metrics: Arc<dyn MetricsRecorder>,
29}
30
31impl<S: ContextStore> FlowHandler<S> {
32 pub fn new(
33 registry: Arc<ConnectorRegistry>,
34 store: Arc<S>,
35 metrics: Arc<dyn MetricsRecorder>,
36 ) -> Self {
37 Self { registry, store, metrics }
38 }
39}
40
41impl<S: ContextStore + 'static> ExecutorHandler for FlowHandler<S> {
42 fn execute(&self, ctx: ExecutorContext) -> HandlerOutput {
43 let peek: PayloadPeek = match serde_json::from_slice(&ctx.input.payload) {
45 Ok(p) => p,
46 Err(e) => {
47 return HandlerOutput::terminal_failure(format!(
48 "payload deserialization failed: {e}"
49 ));
50 }
51 };
52
53 match peek.task_type {
55 TaskType::Coordinator => {
56 let payload: CoordinatorPayload = match serde_json::from_slice(&ctx.input.payload) {
57 Ok(p) => p,
58 Err(e) => {
59 return HandlerOutput::terminal_failure(format!(
60 "coordinator payload deserialization failed: {e}"
61 ));
62 }
63 };
64 crate::coordinator::execute_coordinator(&ctx, &payload, self.store.as_ref())
65 }
66 TaskType::Step => {
67 let payload: StepPayload = match serde_json::from_slice(&ctx.input.payload) {
68 Ok(p) => p,
69 Err(e) => {
70 return HandlerOutput::terminal_failure(format!(
71 "step payload deserialization failed: {e}"
72 ));
73 }
74 };
75 crate::step::execute_step(
76 &ctx,
77 &payload,
78 &self.registry,
79 &self.store,
80 &*self.metrics,
81 )
82 }
83 }
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use actionqueue_core::ids::{AttemptId, RunId};
90 use actionqueue_core::task::safety::SafetyLevel;
91 use actionqueue_executor_local::handler::{AttemptMetadata, CancellationContext, HandlerInput};
92 use serde_json::json;
93 use worldinterface_connector::connectors::default_registry;
94 use worldinterface_contextstore::SqliteContextStore;
95 use worldinterface_core::flowspec::{ConnectorNode, NodeType};
96 use worldinterface_core::id::{FlowRunId, NodeId};
97 use worldinterface_core::metrics::NoopMetricsRecorder;
98 use worldinterface_flowspec::payload::{StepPayload, TaskType};
99
100 use super::*;
101
102 fn make_handler() -> (FlowHandler<SqliteContextStore>, Arc<SqliteContextStore>) {
103 let store = Arc::new(SqliteContextStore::in_memory().unwrap());
104 let registry = Arc::new(default_registry());
105 let metrics: Arc<dyn MetricsRecorder> = Arc::new(NoopMetricsRecorder);
106 let handler = FlowHandler::new(registry, Arc::clone(&store), metrics);
107 (handler, store)
108 }
109
110 fn make_ctx_with_payload(payload_bytes: Vec<u8>) -> ExecutorContext {
111 ExecutorContext {
112 input: HandlerInput {
113 run_id: RunId::new(),
114 attempt_id: AttemptId::new(),
115 payload: payload_bytes,
116 metadata: AttemptMetadata {
117 max_attempts: 3,
118 attempt_number: 1,
119 timeout_secs: None,
120 safety_level: SafetyLevel::Idempotent,
121 },
122 cancellation_context: CancellationContext::new(),
123 },
124 submission: None,
125 children: None,
126 }
127 }
128
129 #[test]
132 fn routes_step_payload() {
133 let (handler, store) = make_handler();
134 let fr = FlowRunId::new();
135 let node_id = NodeId::new();
136
137 let payload = StepPayload {
138 task_type: TaskType::Step,
139 flow_run_id: fr,
140 node_id,
141 node_type: NodeType::Connector(ConnectorNode {
142 connector: "delay".into(),
143 params: json!({"duration_ms": 10}),
144 idempotency_config: None,
145 }),
146 flow_params: None,
147 };
148 let bytes = serde_json::to_vec(&payload).unwrap();
149 let ctx = make_ctx_with_payload(bytes);
150
151 let result = handler.execute(ctx);
152 assert!(matches!(result, HandlerOutput::Success { .. }));
153 assert!(store.get(fr, node_id).unwrap().is_some());
154 }
155
156 #[test]
157 fn routes_coordinator_payload() {
158 let (handler, _store) = make_handler();
159
160 let fr = FlowRunId::new();
162 let node_id = NodeId::new();
163 let spec = worldinterface_core::flowspec::FlowSpec {
164 id: None,
165 name: None,
166 nodes: vec![worldinterface_core::flowspec::Node {
167 id: node_id,
168 label: None,
169 node_type: NodeType::Connector(ConnectorNode {
170 connector: "delay".into(),
171 params: json!({"duration_ms": 10}),
172 idempotency_config: None,
173 }),
174 }],
175 edges: vec![],
176 params: None,
177 };
178
179 let task_id = worldinterface_flowspec::id::derive_task_id(fr, node_id);
180 let mut node_task_map = std::collections::HashMap::new();
181 node_task_map.insert(node_id, task_id);
182
183 let payload = worldinterface_flowspec::payload::CoordinatorPayload {
184 task_type: TaskType::Coordinator,
185 flow_spec: spec,
186 flow_run_id: fr,
187 node_task_map,
188 dependencies: std::collections::HashMap::new(),
189 };
190 let bytes = serde_json::to_vec(&payload).unwrap();
191
192 let submission = Arc::new(TestSubmissionPort::new());
194 let ctx = ExecutorContext {
195 input: HandlerInput {
196 run_id: RunId::new(),
197 attempt_id: AttemptId::new(),
198 payload: bytes,
199 metadata: AttemptMetadata {
200 max_attempts: 1,
201 attempt_number: 1,
202 timeout_secs: None,
203 safety_level: SafetyLevel::Pure,
204 },
205 cancellation_context: CancellationContext::new(),
206 },
207 submission: Some(submission.clone()),
208 children: None,
209 };
210
211 let result = handler.execute(ctx);
212 assert!(matches!(result, HandlerOutput::Suspended { .. }));
214 assert_eq!(submission.submitted_count(), 1);
215 }
216
217 #[test]
218 fn rejects_corrupt_payload() {
219 let (handler, _) = make_handler();
220 let ctx = make_ctx_with_payload(b"not valid json".to_vec());
221 let result = handler.execute(ctx);
222 assert!(matches!(result, HandlerOutput::TerminalFailure { .. }));
223 }
224
225 #[test]
226 fn rejects_unknown_task_type() {
227 let (handler, _) = make_handler();
228 let ctx =
229 make_ctx_with_payload(serde_json::to_vec(&json!({"task_type": "unknown"})).unwrap());
230 let result = handler.execute(ctx);
231 assert!(matches!(result, HandlerOutput::TerminalFailure { .. }));
232 }
233
234 struct TestSubmissionPort {
236 submissions: std::sync::Mutex<
237 Vec<(actionqueue_core::task::task_spec::TaskSpec, Vec<actionqueue_core::ids::TaskId>)>,
238 >,
239 }
240
241 impl TestSubmissionPort {
242 fn new() -> Self {
243 Self { submissions: std::sync::Mutex::new(Vec::new()) }
244 }
245
246 fn submitted_count(&self) -> usize {
247 self.submissions.lock().unwrap().len()
248 }
249 }
250
251 impl actionqueue_executor_local::handler::TaskSubmissionPort for TestSubmissionPort {
252 fn submit(
253 &self,
254 task_spec: actionqueue_core::task::task_spec::TaskSpec,
255 dependencies: Vec<actionqueue_core::ids::TaskId>,
256 ) {
257 self.submissions.lock().unwrap().push((task_spec, dependencies));
258 }
259 }
260}