workflow_graph_worker_sdk/
lib.rs1pub mod executor;
2
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6use workflow_graph_queue::traits::*;
7
8#[derive(Clone, Debug)]
10pub struct WorkerConfig {
11 pub server_url: String,
12 pub worker_id: String,
13 pub labels: Vec<String>,
14 pub lease_ttl: Duration,
15 pub poll_interval: Duration,
16 pub heartbeat_interval: Duration,
17 pub cancellation_check_interval: Duration,
18 pub log_batch_interval: Duration,
19}
20
21impl Default for WorkerConfig {
22 fn default() -> Self {
23 Self {
24 server_url: "http://localhost:3000".into(),
25 worker_id: uuid::Uuid::new_v4().to_string(),
26 labels: vec![],
27 lease_ttl: Duration::from_secs(30),
28 poll_interval: Duration::from_secs(2),
29 heartbeat_interval: Duration::from_secs(10),
30 cancellation_check_interval: Duration::from_secs(2),
31 log_batch_interval: Duration::from_millis(500),
32 }
33 }
34}
35
36#[derive(Serialize)]
37struct RegisterRequest {
38 worker_id: String,
39 labels: Vec<String>,
40}
41
42#[derive(Serialize)]
43struct ClaimRequest {
44 worker_id: String,
45 labels: Vec<String>,
46 lease_ttl_secs: u64,
47}
48
49#[derive(Deserialize)]
50struct ClaimResponse {
51 job: QueuedJob,
52 lease: Lease,
53}
54
55#[derive(Serialize)]
56struct CompleteRequest {
57 outputs: std::collections::HashMap<String, String>,
58}
59
60#[derive(Serialize)]
61struct FailRequest {
62 error: String,
63 retryable: bool,
64}
65
66pub struct Worker {
68 config: WorkerConfig,
69 client: reqwest::Client,
70}
71
72impl Worker {
73 pub fn new(config: WorkerConfig) -> Self {
74 Self {
75 config,
76 client: reqwest::Client::new(),
77 }
78 }
79
80 pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
83 self.register().await?;
84 println!(
85 "Worker {} registered with labels {:?}",
86 self.config.worker_id, self.config.labels
87 );
88
89 let shutdown = tokio::signal::ctrl_c();
90 tokio::pin!(shutdown);
91
92 loop {
93 tokio::select! {
94 biased;
95 _ = &mut shutdown => {
96 println!("Received shutdown signal, finishing current work...");
97 break;
98 }
99 result = self.poll_and_execute() => {
100 match result {
101 Ok(true) => {} Ok(false) => {
103 tokio::time::sleep(self.config.poll_interval).await;
105 }
106 Err(e) => {
107 eprintln!("Worker error: {e}");
108 tokio::time::sleep(self.config.poll_interval).await;
109 }
110 }
111 }
112 }
113 }
114
115 println!("Worker {} shutting down gracefully", self.config.worker_id);
116 Ok(())
117 }
118
119 async fn register(&self) -> Result<(), Box<dyn std::error::Error>> {
120 self.client
121 .post(format!("{}/api/workers/register", self.config.server_url))
122 .json(&RegisterRequest {
123 worker_id: self.config.worker_id.clone(),
124 labels: self.config.labels.clone(),
125 })
126 .send()
127 .await?;
128 Ok(())
129 }
130
131 async fn poll_and_execute(&self) -> Result<bool, Box<dyn std::error::Error>> {
133 let response = self
135 .client
136 .post(format!("{}/api/jobs/claim", self.config.server_url))
137 .json(&ClaimRequest {
138 worker_id: self.config.worker_id.clone(),
139 labels: self.config.labels.clone(),
140 lease_ttl_secs: self.config.lease_ttl.as_secs(),
141 })
142 .send()
143 .await?;
144
145 let claim: Option<ClaimResponse> = response.json().await?;
146 let Some(claim) = claim else {
147 return Ok(false);
148 };
149
150 println!(
151 "Claimed job {} (workflow {})",
152 claim.job.job_id, claim.job.workflow_id
153 );
154
155 self.execute_job(&claim.job, &claim.lease).await;
157
158 Ok(true)
159 }
160
161 async fn execute_job(&self, job: &QueuedJob, lease: &Lease) {
162 let lease_id = lease.lease_id.clone();
163 let workflow_id = job.workflow_id.clone();
164 let job_id = job.job_id.clone();
165
166 let hb_client = self.client.clone();
168 let hb_url = format!("{}/api/jobs/{}/heartbeat", self.config.server_url, lease_id);
169 let hb_interval = self.config.heartbeat_interval;
170 let hb_handle = tokio::spawn(async move {
171 loop {
172 tokio::time::sleep(hb_interval).await;
173 let res = hb_client.post(&hb_url).send().await;
174 if let Ok(resp) = res
175 && resp.status() == reqwest::StatusCode::CONFLICT
176 {
177 eprintln!("Lease expired, aborting heartbeat");
178 break;
179 }
180 }
181 });
182
183 let cancel_client = self.client.clone();
185 let cancel_url = format!(
186 "{}/api/jobs/{}/{}/cancelled",
187 self.config.server_url, workflow_id, job_id
188 );
189 let cancel_interval = self.config.cancellation_check_interval;
190 let cancel_token = tokio_util::sync::CancellationToken::new();
191 let cancel_token_clone = cancel_token.clone();
192 let cancel_handle = tokio::spawn(async move {
193 loop {
194 tokio::time::sleep(cancel_interval).await;
195 if let Ok(resp) = cancel_client.get(&cancel_url).send().await
196 && let Ok(cancelled) = resp.json::<bool>().await
197 && cancelled
198 {
199 cancel_token_clone.cancel();
200 break;
201 }
202 }
203 });
204
205 let result = executor::execute_job_streaming(
207 &job.command,
208 &self.client,
209 &format!("{}/api/jobs/{}/logs", self.config.server_url, lease_id),
210 &workflow_id,
211 &job_id,
212 self.config.log_batch_interval,
213 cancel_token,
214 )
215 .await;
216
217 hb_handle.abort();
219 cancel_handle.abort();
220
221 match result {
223 Ok(output) => {
224 let url = format!("{}/api/jobs/{}/complete", self.config.server_url, lease_id);
225 self.client
226 .post(&url)
227 .json(&CompleteRequest {
228 outputs: output.outputs,
229 })
230 .send()
231 .await
232 .ok();
233 println!("Job {} completed successfully", job.job_id);
234 }
235 Err(e) => {
236 let url = format!("{}/api/jobs/{}/fail", self.config.server_url, lease_id);
237 self.client
238 .post(&url)
239 .json(&FailRequest {
240 error: e.to_string(),
241 retryable: true,
242 })
243 .send()
244 .await
245 .ok();
246 eprintln!("Job {} failed: {e}", job.job_id);
247 }
248 }
249 }
250}