Skip to main content

workflow_graph_worker_sdk/
lib.rs

1pub mod executor;
2
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6use workflow_graph_queue::traits::*;
7
8/// Configuration for a worker instance.
9#[derive(Clone, Debug)]
10pub struct WorkerConfig {
11    pub server_url: String,
12    pub worker_id: String,
13    pub labels: Vec<String>,
14    pub lease_ttl: Duration,
15    pub poll_interval: Duration,
16    pub heartbeat_interval: Duration,
17    pub cancellation_check_interval: Duration,
18    pub log_batch_interval: Duration,
19}
20
21impl Default for WorkerConfig {
22    fn default() -> Self {
23        Self {
24            server_url: "http://localhost:3000".into(),
25            worker_id: uuid::Uuid::new_v4().to_string(),
26            labels: vec![],
27            lease_ttl: Duration::from_secs(30),
28            poll_interval: Duration::from_secs(2),
29            heartbeat_interval: Duration::from_secs(10),
30            cancellation_check_interval: Duration::from_secs(2),
31            log_batch_interval: Duration::from_millis(500),
32        }
33    }
34}
35
36#[derive(Serialize)]
37struct RegisterRequest {
38    worker_id: String,
39    labels: Vec<String>,
40}
41
42#[derive(Serialize)]
43struct ClaimRequest {
44    worker_id: String,
45    labels: Vec<String>,
46    lease_ttl_secs: u64,
47}
48
49#[derive(Deserialize)]
50struct ClaimResponse {
51    job: QueuedJob,
52    lease: Lease,
53}
54
55#[derive(Serialize)]
56struct CompleteRequest {
57    outputs: std::collections::HashMap<String, String>,
58}
59
60#[derive(Serialize)]
61struct FailRequest {
62    error: String,
63    retryable: bool,
64}
65
66/// A worker that polls the server for jobs and executes them.
67pub struct Worker {
68    config: WorkerConfig,
69    client: reqwest::Client,
70}
71
72impl Worker {
73    pub fn new(config: WorkerConfig) -> Self {
74        Self {
75            config,
76            client: reqwest::Client::new(),
77        }
78    }
79
80    /// Run the worker loop: register, poll for jobs, execute, report results.
81    /// Handles SIGTERM/SIGINT for graceful shutdown — finishes the current job before exiting.
82    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
83        self.register().await?;
84        println!(
85            "Worker {} registered with labels {:?}",
86            self.config.worker_id, self.config.labels
87        );
88
89        let shutdown = tokio::signal::ctrl_c();
90        tokio::pin!(shutdown);
91
92        loop {
93            tokio::select! {
94                biased;
95                _ = &mut shutdown => {
96                    println!("Received shutdown signal, finishing current work...");
97                    break;
98                }
99                result = self.poll_and_execute() => {
100                    match result {
101                        Ok(true) => {} // executed a job, poll again immediately
102                        Ok(false) => {
103                            // no job available, wait before polling again
104                            tokio::time::sleep(self.config.poll_interval).await;
105                        }
106                        Err(e) => {
107                            eprintln!("Worker error: {e}");
108                            tokio::time::sleep(self.config.poll_interval).await;
109                        }
110                    }
111                }
112            }
113        }
114
115        println!("Worker {} shutting down gracefully", self.config.worker_id);
116        Ok(())
117    }
118
119    async fn register(&self) -> Result<(), Box<dyn std::error::Error>> {
120        self.client
121            .post(format!("{}/api/workers/register", self.config.server_url))
122            .json(&RegisterRequest {
123                worker_id: self.config.worker_id.clone(),
124                labels: self.config.labels.clone(),
125            })
126            .send()
127            .await?;
128        Ok(())
129    }
130
131    /// Poll for a job, execute it if available. Returns true if a job was executed.
132    async fn poll_and_execute(&self) -> Result<bool, Box<dyn std::error::Error>> {
133        // Claim a job
134        let response = self
135            .client
136            .post(format!("{}/api/jobs/claim", self.config.server_url))
137            .json(&ClaimRequest {
138                worker_id: self.config.worker_id.clone(),
139                labels: self.config.labels.clone(),
140                lease_ttl_secs: self.config.lease_ttl.as_secs(),
141            })
142            .send()
143            .await?;
144
145        let claim: Option<ClaimResponse> = response.json().await?;
146        let Some(claim) = claim else {
147            return Ok(false);
148        };
149
150        println!(
151            "Claimed job {} (workflow {})",
152            claim.job.job_id, claim.job.workflow_id
153        );
154
155        // Execute the job with concurrent heartbeat, log streaming, and cancellation checking
156        self.execute_job(&claim.job, &claim.lease).await;
157
158        Ok(true)
159    }
160
161    async fn execute_job(&self, job: &QueuedJob, lease: &Lease) {
162        let lease_id = lease.lease_id.clone();
163        let workflow_id = job.workflow_id.clone();
164        let job_id = job.job_id.clone();
165
166        // Spawn heartbeat task
167        let hb_client = self.client.clone();
168        let hb_url = format!("{}/api/jobs/{}/heartbeat", self.config.server_url, lease_id);
169        let hb_interval = self.config.heartbeat_interval;
170        let hb_handle = tokio::spawn(async move {
171            loop {
172                tokio::time::sleep(hb_interval).await;
173                let res = hb_client.post(&hb_url).send().await;
174                if let Ok(resp) = res
175                    && resp.status() == reqwest::StatusCode::CONFLICT
176                {
177                    eprintln!("Lease expired, aborting heartbeat");
178                    break;
179                }
180            }
181        });
182
183        // Spawn cancellation checker
184        let cancel_client = self.client.clone();
185        let cancel_url = format!(
186            "{}/api/jobs/{}/{}/cancelled",
187            self.config.server_url, workflow_id, job_id
188        );
189        let cancel_interval = self.config.cancellation_check_interval;
190        let cancel_token = tokio_util::sync::CancellationToken::new();
191        let cancel_token_clone = cancel_token.clone();
192        let cancel_handle = tokio::spawn(async move {
193            loop {
194                tokio::time::sleep(cancel_interval).await;
195                if let Ok(resp) = cancel_client.get(&cancel_url).send().await
196                    && let Ok(cancelled) = resp.json::<bool>().await
197                    && cancelled
198                {
199                    cancel_token_clone.cancel();
200                    break;
201                }
202            }
203        });
204
205        // Execute the command
206        let result = executor::execute_job_streaming(
207            &job.command,
208            &self.client,
209            &format!("{}/api/jobs/{}/logs", self.config.server_url, lease_id),
210            &workflow_id,
211            &job_id,
212            self.config.log_batch_interval,
213            cancel_token,
214        )
215        .await;
216
217        // Cancel background tasks
218        hb_handle.abort();
219        cancel_handle.abort();
220
221        // Report result
222        match result {
223            Ok(output) => {
224                let url = format!("{}/api/jobs/{}/complete", self.config.server_url, lease_id);
225                self.client
226                    .post(&url)
227                    .json(&CompleteRequest {
228                        outputs: output.outputs,
229                    })
230                    .send()
231                    .await
232                    .ok();
233                println!("Job {} completed successfully", job.job_id);
234            }
235            Err(e) => {
236                let url = format!("{}/api/jobs/{}/fail", self.config.server_url, lease_id);
237                self.client
238                    .post(&url)
239                    .json(&FailRequest {
240                        error: e.to_string(),
241                        retryable: true,
242                    })
243                    .send()
244                    .await
245                    .ok();
246                eprintln!("Job {} failed: {e}", job.job_id);
247            }
248        }
249    }
250}