1use std::collections::HashMap;
40use std::time::Duration;
41
42use reqwest::{header::HeaderMap, Client, Method, StatusCode};
43use serde::{Deserialize, Serialize};
44use tokio::time::{sleep, Instant};
45
46use crate::agents::{bounded_read, AgentsError, AgentsResult, MAX_RESPONSE_SIZE};
47
48pub const TERMINAL_WORKFLOW_STATUSES: &[&str] = &["succeeded", "failed", "timed_out"];
50
51pub fn is_workflow_run_terminal(status: &str) -> bool {
54 TERMINAL_WORKFLOW_STATUSES.contains(&status)
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct Workflow {
65 pub id: String,
66 pub tenant_id: String,
67 pub name: String,
68 #[serde(default)]
69 pub description: Option<String>,
70 pub start_at: String,
71 pub states: HashMap<String, serde_json::Value>,
72 pub version: String,
73 pub created_at: u64,
74 pub updated_at: u64,
75}
76
77#[derive(Debug, Clone, Serialize, Default)]
78pub struct CreateWorkflowRequest {
79 pub name: String,
80 pub start_at: String,
81 pub states: HashMap<String, serde_json::Value>,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub description: Option<String>,
84}
85
86#[derive(Debug, Clone, Deserialize)]
87pub struct WorkflowListResponse {
88 #[serde(default)]
89 pub workflows: Vec<Workflow>,
90}
91
92#[derive(Debug, Clone, Deserialize)]
94pub struct StartRunResponse {
95 pub execution_id: String,
96 pub workflow_id: String,
97 pub status: String,
98}
99
100#[derive(Debug, Clone, Deserialize)]
104pub struct WorkflowExecution {
105 pub id: String,
106 pub workflow_id: String,
107 pub tenant_id: String,
108 pub status: String,
109 #[serde(default)]
110 pub current_state: Option<String>,
111 #[serde(default)]
112 pub input: serde_json::Value,
113 #[serde(default)]
114 pub output: Option<serde_json::Value>,
115 pub started_at: u64,
116 #[serde(default)]
117 pub ended_at: Option<u64>,
118 #[serde(default)]
119 pub error: Option<String>,
120}
121
122impl WorkflowExecution {
123 pub fn is_terminal(&self) -> bool {
124 is_workflow_run_terminal(&self.status)
125 }
126
127 pub fn succeeded(&self) -> bool {
128 self.status == "succeeded"
129 }
130}
131
132#[derive(Debug, Deserialize)]
135struct WorkflowEnvelope {
136 workflow: Workflow,
137}
138
139#[derive(Debug, Deserialize)]
140struct ExecutionEnvelope {
141 execution: WorkflowExecution,
142}
143
144pub struct WorkflowsClient {
147 base_url: String,
148 api_key: Option<String>,
149 tenant: Option<String>,
150 client: Client,
151}
152
153impl WorkflowsClient {
154 pub fn new(base_url: impl Into<String>) -> Self {
155 Self {
156 base_url: base_url.into().trim_end_matches('/').to_string(),
157 api_key: None,
158 tenant: None,
159 client: Client::builder()
160 .timeout(Duration::from_secs(120))
161 .build()
162 .expect("reqwest client"),
163 }
164 }
165
166 pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
167 self.api_key = Some(key.into());
168 self
169 }
170
171 pub fn with_tenant(mut self, tenant: impl Into<String>) -> Self {
172 self.tenant = Some(tenant.into());
173 self
174 }
175
176 fn headers(&self) -> HeaderMap {
177 let mut h = HeaderMap::new();
178 h.insert("Content-Type", "application/json".parse().unwrap());
179 if let Some(key) = &self.api_key {
180 if let Ok(val) = format!("Bearer {key}").parse() {
181 h.insert("Authorization", val);
182 }
183 }
184 if let Some(t) = &self.tenant {
185 if let Ok(val) = t.parse() {
186 h.insert("x-rapidapi-user", val);
187 }
188 }
189 h
190 }
191
192 async fn request<T: for<'de> Deserialize<'de>>(
193 &self,
194 method: Method,
195 path: &str,
196 body: Option<&impl Serialize>,
197 ) -> AgentsResult<Option<T>> {
198 let url = format!("{}{}", self.base_url, path);
199 let mut req = self.client.request(method, &url).headers(self.headers());
200 if let Some(b) = body {
201 req = req.json(b);
202 }
203 let resp = req.send().await?;
204 let status = resp.status();
205 if status == StatusCode::NO_CONTENT {
206 return Ok(None);
207 }
208 let bytes = bounded_read(resp, MAX_RESPONSE_SIZE).await?;
209 if !status.is_success() {
210 let body = String::from_utf8_lossy(&bytes).into_owned();
211 return Err(AgentsError::Status {
212 status: status.as_u16(),
213 body,
214 });
215 }
216 if bytes.is_empty() {
217 return Ok(None);
218 }
219 Ok(Some(serde_json::from_slice(&bytes)?))
220 }
221
222 pub async fn create(&self, req: CreateWorkflowRequest) -> AgentsResult<Workflow> {
229 let env: WorkflowEnvelope = self
230 .request::<WorkflowEnvelope>(Method::POST, "/v1/workflows", Some(&req))
231 .await?
232 .ok_or_else(|| {
233 AgentsError::InvalidInput("server returned empty body for create".into())
234 })?;
235 Ok(env.workflow)
236 }
237
238 pub async fn list(&self) -> AgentsResult<WorkflowListResponse> {
240 self.request::<WorkflowListResponse>(Method::GET, "/v1/workflows", Option::<&()>::None)
241 .await
242 .map(|o| o.unwrap_or(WorkflowListResponse { workflows: vec![] }))
243 }
244
245 pub async fn get(&self, workflow_id: &str) -> AgentsResult<Workflow> {
247 let env: WorkflowEnvelope = self
248 .request::<WorkflowEnvelope>(
249 Method::GET,
250 &format!("/v1/workflows/{workflow_id}"),
251 Option::<&()>::None,
252 )
253 .await?
254 .ok_or_else(|| {
255 AgentsError::InvalidInput("server returned empty body for get".into())
256 })?;
257 Ok(env.workflow)
258 }
259
260 pub async fn delete(&self, workflow_id: &str) -> AgentsResult<()> {
262 let _: Option<serde_json::Value> = self
263 .request(
264 Method::DELETE,
265 &format!("/v1/workflows/{workflow_id}"),
266 Option::<&()>::None,
267 )
268 .await?;
269 Ok(())
270 }
271
272 pub async fn start_run(
279 &self,
280 workflow_id: &str,
281 input: Option<serde_json::Value>,
282 ) -> AgentsResult<StartRunResponse> {
283 #[derive(Serialize)]
284 struct Body {
285 #[serde(skip_serializing_if = "Option::is_none")]
286 input: Option<serde_json::Value>,
287 }
288 self.request::<StartRunResponse>(
289 Method::POST,
290 &format!("/v1/workflows/{workflow_id}/runs"),
291 Some(&Body { input }),
292 )
293 .await?
294 .ok_or_else(|| AgentsError::InvalidInput("server returned empty body for start_run".into()))
295 }
296
297 pub async fn get_run(
299 &self,
300 workflow_id: &str,
301 execution_id: &str,
302 ) -> AgentsResult<WorkflowExecution> {
303 let env: ExecutionEnvelope = self
304 .request::<ExecutionEnvelope>(
305 Method::GET,
306 &format!("/v1/workflows/{workflow_id}/runs/{execution_id}"),
307 Option::<&()>::None,
308 )
309 .await?
310 .ok_or_else(|| {
311 AgentsError::InvalidInput("server returned empty body for get_run".into())
312 })?;
313 Ok(env.execution)
314 }
315
316 pub async fn wait_for_run(
324 &self,
325 workflow_id: &str,
326 execution_id: &str,
327 timeout: Option<Duration>,
328 poll_interval: Option<Duration>,
329 ) -> AgentsResult<WorkflowExecution> {
330 let timeout = timeout.unwrap_or_else(|| Duration::from_secs(90));
331 let poll_interval = poll_interval.unwrap_or_else(|| Duration::from_secs(1));
332 let deadline = Instant::now() + timeout;
333 loop {
334 let execution = self.get_run(workflow_id, execution_id).await?;
335 if execution.is_terminal() {
336 return Ok(execution);
337 }
338 if Instant::now() >= deadline {
339 return Err(AgentsError::InvalidInput(format!(
340 "workflow run {execution_id} did not terminate within {timeout:?} \
341 (last status: {})",
342 execution.status
343 )));
344 }
345 sleep(poll_interval).await;
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_is_terminal() {
356 assert!(is_workflow_run_terminal("succeeded"));
357 assert!(is_workflow_run_terminal("failed"));
358 assert!(is_workflow_run_terminal("timed_out"));
359 assert!(!is_workflow_run_terminal("running"));
360 assert!(!is_workflow_run_terminal("queued"));
361 }
362
363 #[test]
364 fn test_execution_terminal_helpers() {
365 let succeeded = WorkflowExecution {
366 id: "wfr_1".into(),
367 workflow_id: "wf_1".into(),
368 tenant_id: "t1".into(),
369 status: "succeeded".into(),
370 current_state: None,
371 input: serde_json::json!(null),
372 output: Some(serde_json::json!({"ok": true})),
373 started_at: 100,
374 ended_at: Some(110),
375 error: None,
376 };
377 assert!(succeeded.is_terminal());
378 assert!(succeeded.succeeded());
379
380 let running = WorkflowExecution {
381 id: "wfr_2".into(),
382 workflow_id: "wf_1".into(),
383 tenant_id: "t1".into(),
384 status: "running".into(),
385 current_state: Some("Compute".into()),
386 input: serde_json::json!(null),
387 output: None,
388 started_at: 100,
389 ended_at: None,
390 error: None,
391 };
392 assert!(!running.is_terminal());
393 assert!(!running.succeeded());
394 }
395
396 #[test]
397 fn test_workflow_deserialize() {
398 let json = serde_json::json!({
399 "id": "wf_1",
400 "tenant_id": "t1",
401 "name": "triage",
402 "start_at": "Compute",
403 "states": { "Compute": { "type": "Succeed" } },
404 "version": "1.0",
405 "created_at": 100,
406 "updated_at": 200
407 });
408 let wf: Workflow = serde_json::from_value(json).unwrap();
409 assert_eq!(wf.id, "wf_1");
410 assert_eq!(wf.start_at, "Compute");
411 assert!(wf.states.contains_key("Compute"));
412 assert!(wf.description.is_none());
413 }
414
415 #[test]
416 fn test_envelope_unwrap() {
417 let json = serde_json::json!({
418 "workflow": {
419 "id": "wf_1",
420 "tenant_id": "t1",
421 "name": "triage",
422 "start_at": "S",
423 "states": { "S": { "type": "Succeed" } },
424 "version": "1.0",
425 "created_at": 1,
426 "updated_at": 2,
427 }
428 });
429 let env: WorkflowEnvelope = serde_json::from_value(json).unwrap();
430 assert_eq!(env.workflow.id, "wf_1");
431 }
432
433 #[test]
434 fn test_client_construction() {
435 let c = WorkflowsClient::new("http://localhost:3000/")
436 .with_api_key("k")
437 .with_tenant("t");
438 let h = c.headers();
439 assert_eq!(h.get("Authorization").unwrap(), "Bearer k");
440 assert_eq!(h.get("x-rapidapi-user").unwrap(), "t");
441 assert_eq!(c.base_url, "http://localhost:3000");
442 }
443}