Skip to main content

rustvello_postgres/orchestrator/
recovery.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3
4use rustvello_core::error::RustvelloResult;
5use rustvello_core::orchestrator::{
6    ActiveRunnerInfo, AtomicServiceExecution, OrchestratorRecovery,
7};
8use rustvello_proto::identifiers::{InvocationId, RunnerId};
9
10use super::PostgresOrchestrator;
11use crate::db::pg_err;
12
13#[async_trait]
14impl OrchestratorRecovery for PostgresOrchestrator {
15    async fn register_heartbeat(
16        &self,
17        runner_id: &RunnerId,
18        _can_run_atomic_service: bool,
19    ) -> RustvelloResult<()> {
20        let client = self.db.conn().await?;
21        let now = Utc::now();
22
23        client
24            .execute(
25                "INSERT INTO runner_heartbeats (runner_id, last_heartbeat) VALUES ($1, $2)
26                 ON CONFLICT (runner_id) DO UPDATE SET last_heartbeat = $2",
27                &[&runner_id.as_str(), &now],
28            )
29            .await
30            .map_err(pg_err)?;
31
32        Ok(())
33    }
34
35    async fn get_stale_pending_invocations(
36        &self,
37        max_pending_seconds: u64,
38    ) -> RustvelloResult<Vec<InvocationId>> {
39        let client = self.db.conn().await?;
40        let threshold = Utc::now()
41            - chrono::Duration::seconds(i64::try_from(max_pending_seconds).unwrap_or(i64::MAX));
42
43        let rows = client
44            .query(
45                "SELECT invocation_id FROM status_records
46                 WHERE status = 'PENDING' AND timestamp < $1",
47                &[&threshold],
48            )
49            .await
50            .map_err(pg_err)?;
51
52        Ok(rows
53            .iter()
54            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
55            .collect())
56    }
57
58    async fn get_stale_running_invocations(
59        &self,
60        runner_dead_after_seconds: u64,
61    ) -> RustvelloResult<Vec<InvocationId>> {
62        let client = self.db.conn().await?;
63        let threshold = Utc::now()
64            - chrono::Duration::seconds(
65                i64::try_from(runner_dead_after_seconds).unwrap_or(i64::MAX),
66            );
67
68        let rows = client
69            .query(
70                "SELECT sr.invocation_id FROM status_records sr
71                 LEFT JOIN runner_heartbeats rh ON sr.runner_id = rh.runner_id
72                 WHERE sr.status = 'RUNNING'
73                   AND (rh.last_heartbeat IS NULL OR rh.last_heartbeat < $1)",
74                &[&threshold],
75            )
76            .await
77            .map_err(pg_err)?;
78
79        Ok(rows
80            .iter()
81            .map(|r| InvocationId::from_string(r.get::<_, String>(0)))
82            .collect())
83    }
84
85    async fn get_active_runner_ids(&self, timeout_seconds: u64) -> RustvelloResult<Vec<RunnerId>> {
86        let client = self.db.conn().await?;
87        let threshold = Utc::now()
88            - chrono::Duration::seconds(i64::try_from(timeout_seconds).unwrap_or(i64::MAX));
89        let rows = client
90            .query(
91                "SELECT runner_id FROM runner_heartbeats WHERE last_heartbeat >= $1",
92                &[&threshold],
93            )
94            .await
95            .map_err(pg_err)?;
96        Ok(rows
97            .iter()
98            .map(|r| RunnerId::from_string(r.get::<_, String>(0)))
99            .collect())
100    }
101
102    async fn get_active_runners(
103        &self,
104        timeout_seconds: u64,
105        _can_run_atomic_service: Option<bool>,
106    ) -> RustvelloResult<Vec<ActiveRunnerInfo>> {
107        let client = self.db.conn().await?;
108        let threshold = Utc::now()
109            - chrono::Duration::seconds(i64::try_from(timeout_seconds).unwrap_or(i64::MAX));
110        let rows = client
111            .query(
112                "SELECT runner_id, last_heartbeat FROM runner_heartbeats WHERE last_heartbeat >= $1",
113                &[&threshold],
114            )
115            .await
116            .map_err(pg_err)?;
117        Ok(rows
118            .iter()
119            .map(|r| {
120                let ts: DateTime<Utc> = r.get(1);
121                ActiveRunnerInfo {
122                    runner_id: RunnerId::from_string(r.get::<_, String>(0)),
123                    creation_time: ts,
124                    last_heartbeat: ts,
125                    can_run_atomic_service: true,
126                    last_service_start: None,
127                    last_service_end: None,
128                }
129            })
130            .collect())
131    }
132
133    async fn record_atomic_service_execution(
134        &self,
135        _runner_id: &RunnerId,
136        _start: DateTime<Utc>,
137        _end: DateTime<Utc>,
138    ) -> RustvelloResult<()> {
139        Ok(())
140    }
141
142    async fn get_atomic_service_timeline(&self) -> RustvelloResult<Vec<AtomicServiceExecution>> {
143        Ok(Vec::new())
144    }
145}