Skip to main content

rustvello_sqlite/orchestrator/
recovery.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5
6use rustvello_core::error::RustvelloResult;
7use rustvello_core::orchestrator::OrchestratorRecovery;
8use rustvello_core::orchestrator::{ActiveRunnerInfo, AtomicServiceExecution};
9use rustvello_proto::identifiers::{InvocationId, RunnerId};
10
11use crate::db::{blocking, lock_err, parse_timestamp, sql_err};
12
13use super::SqliteOrchestrator;
14
15#[async_trait]
16impl OrchestratorRecovery for SqliteOrchestrator {
17    async fn register_heartbeat(
18        &self,
19        runner_id: &RunnerId,
20        can_run_atomic_service: bool,
21    ) -> RustvelloResult<()> {
22        let db = Arc::clone(&self.db);
23        let runner_id = runner_id.clone();
24        blocking(move || {
25
26            let conn = db.conn.lock().map_err(lock_err)?;
27            let now = Utc::now().to_rfc3339();
28            let can_run = if can_run_atomic_service { 1i32 } else { 0i32 };
29
30            conn.execute(
31                "INSERT INTO runner_heartbeats (runner_id, creation_time, last_heartbeat, can_run_atomic_service)
32                 VALUES (?1, ?2, ?3, ?4)
33                 ON CONFLICT(runner_id) DO UPDATE SET last_heartbeat = ?3, can_run_atomic_service = ?4",
34                rusqlite::params![runner_id.as_str(), &now, &now, can_run],
35            )
36            .map_err(sql_err)?;
37
38            Ok(())
39
40        })
41        .await
42    }
43
44    async fn get_stale_pending_invocations(
45        &self,
46        max_pending_seconds: u64,
47    ) -> RustvelloResult<Vec<InvocationId>> {
48        let db = Arc::clone(&self.db);
49        blocking(move || {
50            let conn = db.conn.lock().map_err(lock_err)?;
51            let threshold = (Utc::now()
52                - chrono::Duration::seconds(
53                    i64::try_from(max_pending_seconds).unwrap_or(i64::MAX),
54                ))
55            .to_rfc3339();
56
57            let mut stmt = conn
58                .prepare(
59                    "SELECT invocation_id FROM status_records
60                     WHERE status = 'PENDING' AND timestamp < ?1",
61                )
62                .map_err(sql_err)?;
63
64            let ids: Vec<InvocationId> = stmt
65                .query_map([&threshold], |row| {
66                    let id: String = row.get(0)?;
67                    Ok(InvocationId::from_string(id))
68                })
69                .map_err(sql_err)?
70                .collect::<Result<Vec<_>, _>>()
71                .map_err(sql_err)?;
72
73            Ok(ids)
74        })
75        .await
76    }
77
78    async fn get_stale_running_invocations(
79        &self,
80        runner_dead_after_seconds: u64,
81    ) -> RustvelloResult<Vec<InvocationId>> {
82        let db = Arc::clone(&self.db);
83        blocking(move || {
84            let conn = db.conn.lock().map_err(lock_err)?;
85            let threshold = (Utc::now()
86                - chrono::Duration::seconds(
87                    i64::try_from(runner_dead_after_seconds).unwrap_or(i64::MAX),
88                ))
89            .to_rfc3339();
90
91            let mut stmt = conn
92                .prepare(
93                    "SELECT sr.invocation_id FROM status_records sr
94                     LEFT JOIN runner_heartbeats rh ON sr.runner_id = rh.runner_id
95                     WHERE sr.status = 'RUNNING'
96                       AND (rh.last_heartbeat IS NULL OR rh.last_heartbeat < ?1)",
97                )
98                .map_err(sql_err)?;
99
100            let ids: Vec<InvocationId> = stmt
101                .query_map([&threshold], |row| {
102                    let id: String = row.get(0)?;
103                    Ok(InvocationId::from_string(id))
104                })
105                .map_err(sql_err)?
106                .collect::<Result<Vec<_>, _>>()
107                .map_err(sql_err)?;
108
109            Ok(ids)
110        })
111        .await
112    }
113
114    async fn get_active_runner_ids(&self, timeout_seconds: u64) -> RustvelloResult<Vec<RunnerId>> {
115        let db = Arc::clone(&self.db);
116        blocking(move || {
117            let conn = db.conn.lock().map_err(lock_err)?;
118            let threshold = (Utc::now()
119                - chrono::Duration::seconds(i64::try_from(timeout_seconds).unwrap_or(i64::MAX)))
120            .to_rfc3339();
121
122            let mut stmt = conn
123                .prepare("SELECT runner_id FROM runner_heartbeats WHERE last_heartbeat >= ?1")
124                .map_err(sql_err)?;
125
126            let ids: Vec<RunnerId> = stmt
127                .query_map([&threshold], |row| {
128                    let id: String = row.get(0)?;
129                    Ok(RunnerId::from_string(id))
130                })
131                .map_err(sql_err)?
132                .collect::<Result<Vec<_>, _>>()
133                .map_err(sql_err)?;
134
135            Ok(ids)
136        })
137        .await
138    }
139
140    async fn get_active_runners(
141        &self,
142        timeout_seconds: u64,
143        can_run_atomic_service: Option<bool>,
144    ) -> RustvelloResult<Vec<ActiveRunnerInfo>> {
145        let db = Arc::clone(&self.db);
146        blocking(move || {
147
148            let conn = db.conn.lock().map_err(lock_err)?;
149            let threshold = (Utc::now()
150                - chrono::Duration::seconds(i64::try_from(timeout_seconds).unwrap_or(i64::MAX)))
151            .to_rfc3339();
152
153            let sql = match can_run_atomic_service {
154                Some(true) => {
155                    "SELECT runner_id, creation_time, last_heartbeat, can_run_atomic_service, last_service_start, last_service_end
156                     FROM runner_heartbeats WHERE last_heartbeat >= ?1 AND can_run_atomic_service = 1"
157                }
158                Some(false) => {
159                    "SELECT runner_id, creation_time, last_heartbeat, can_run_atomic_service, last_service_start, last_service_end
160                     FROM runner_heartbeats WHERE last_heartbeat >= ?1 AND can_run_atomic_service = 0"
161                }
162                None => {
163                    "SELECT runner_id, creation_time, last_heartbeat, can_run_atomic_service, last_service_start, last_service_end
164                     FROM runner_heartbeats WHERE last_heartbeat >= ?1"
165                }
166            };
167
168            let mut stmt = conn.prepare(sql).map_err(sql_err)?;
169            let runners: Vec<ActiveRunnerInfo> = stmt
170                .query_map([&threshold], |row| {
171                    let runner_id: String = row.get(0)?;
172                    let creation_time_str: String = row.get(1)?;
173                    let last_heartbeat_str: String = row.get(2)?;
174                    let can_run: i32 = row.get(3)?;
175                    let last_service_start_str: Option<String> = row.get(4)?;
176                    let last_service_end_str: Option<String> = row.get(5)?;
177                    Ok((
178                        runner_id,
179                        creation_time_str,
180                        last_heartbeat_str,
181                        can_run,
182                        last_service_start_str,
183                        last_service_end_str,
184                    ))
185                })
186                .map_err(sql_err)?
187                .collect::<Result<Vec<_>, _>>()
188                .map_err(sql_err)?
189                .into_iter()
190                .filter_map(|(rid, ct, lh, can_run, lss, lse)| {
191                    let creation_time = parse_timestamp(&ct).ok()?;
192                    let last_heartbeat = parse_timestamp(&lh).ok()?;
193                    let last_service_start = lss.as_deref().and_then(|s| parse_timestamp(s).ok());
194                    let last_service_end = lse.as_deref().and_then(|s| parse_timestamp(s).ok());
195                    Some(ActiveRunnerInfo {
196                        runner_id: RunnerId::from_string(rid),
197                        creation_time,
198                        last_heartbeat,
199                        can_run_atomic_service: can_run != 0,
200                        last_service_start,
201                        last_service_end,
202                    })
203                })
204                .collect();
205
206            Ok(runners)
207
208        })
209        .await
210    }
211
212    async fn record_atomic_service_execution(
213        &self,
214        runner_id: &RunnerId,
215        start: DateTime<Utc>,
216        end: DateTime<Utc>,
217    ) -> RustvelloResult<()> {
218        let db = Arc::clone(&self.db);
219        let runner_id = runner_id.clone();
220        blocking(move || {
221
222            let conn = db.conn.lock().map_err(lock_err)?;
223            let start_str = start.to_rfc3339();
224            let end_str = end.to_rfc3339();
225
226            conn.execute(
227                "UPDATE runner_heartbeats SET last_service_start = ?1, last_service_end = ?2 WHERE runner_id = ?3",
228                rusqlite::params![&start_str, &end_str, runner_id.as_str()],
229            )
230            .map_err(sql_err)?;
231
232            Ok(())
233
234        })
235        .await
236    }
237
238    async fn get_atomic_service_timeline(&self) -> RustvelloResult<Vec<AtomicServiceExecution>> {
239        Ok(Vec::new())
240    }
241}