rustvello_core/
invocation.rs1use std::marker::PhantomData;
10use std::sync::Arc;
11use std::time::Duration;
12
13use serde::de::DeserializeOwned;
14
15use rustvello_proto::identifiers::InvocationId;
16use rustvello_proto::status::InvocationStatus;
17
18use crate::error::{RustvelloError, RustvelloResult};
19use crate::orchestrator::Orchestrator;
20use crate::state_backend::StateBackend;
21
22pub struct InvocationHandle<R: DeserializeOwned = String> {
40 invocation_id: InvocationId,
41 orchestrator: Arc<dyn Orchestrator>,
42 state_backend: Arc<dyn StateBackend>,
43 _result_type: PhantomData<R>,
44}
45
46impl<R: DeserializeOwned> InvocationHandle<R> {
47 pub fn new(
49 invocation_id: InvocationId,
50 orchestrator: Arc<dyn Orchestrator>,
51 state_backend: Arc<dyn StateBackend>,
52 ) -> Self {
53 Self {
54 invocation_id,
55 orchestrator,
56 state_backend,
57 _result_type: PhantomData,
58 }
59 }
60
61 pub fn invocation_id(&self) -> &InvocationId {
63 &self.invocation_id
64 }
65
66 pub async fn status(&self) -> RustvelloResult<InvocationStatus> {
68 let record = self
69 .orchestrator
70 .get_invocation_status(&self.invocation_id)
71 .await?;
72 Ok(record.status)
73 }
74
75 pub async fn is_done(&self) -> RustvelloResult<bool> {
77 Ok(self.status().await?.is_terminal())
78 }
79
80 pub async fn result(&self) -> RustvelloResult<R> {
84 let status = self.status().await?;
85 match status {
86 InvocationStatus::Success => {
87 let raw = self
88 .state_backend
89 .get_result(&self.invocation_id)
90 .await?
91 .ok_or_else(|| RustvelloError::Internal {
92 message: format!(
93 "invocation {} has SUCCESS status but no stored result",
94 self.invocation_id
95 ),
96 })?;
97 serde_json::from_str(&raw).map_err(|e| RustvelloError::Serialization {
98 message: e.to_string(),
99 })
100 }
101 InvocationStatus::Failed => {
102 let err = self.state_backend.get_error(&self.invocation_id).await?;
103 Err(RustvelloError::runner_err(err.map_or_else(
104 || "unknown error".to_string(),
105 |e| e.to_string(),
106 )))
107 }
108 other => Err(RustvelloError::Internal {
109 message: format!(
110 "invocation {} is not finished (status: {})",
111 self.invocation_id, other
112 ),
113 }),
114 }
115 }
116
117 pub async fn wait(&self, poll_interval: Duration) -> RustvelloResult<R> {
124 loop {
125 if self.is_done().await? {
126 return self.result().await;
127 }
128 tokio::time::sleep(poll_interval).await;
129 }
130 }
131
132 pub async fn wait_timeout(
136 &self,
137 timeout: Duration,
138 poll_interval: Duration,
139 ) -> RustvelloResult<R> {
140 tokio::time::timeout(timeout, self.wait(poll_interval))
141 .await
142 .map_err(|_| {
143 RustvelloError::runner_err(format!(
144 "timeout waiting for invocation {}",
145 self.invocation_id
146 ))
147 })?
148 }
149
150 pub fn into_untyped(self) -> InvocationHandle<String> {
152 InvocationHandle {
153 invocation_id: self.invocation_id,
154 orchestrator: self.orchestrator,
155 state_backend: self.state_backend,
156 _result_type: PhantomData,
157 }
158 }
159}
160
161impl<R: DeserializeOwned> std::fmt::Debug for InvocationHandle<R> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("InvocationHandle")
164 .field("invocation_id", &self.invocation_id)
165 .finish()
166 }
167}
168
169pub struct SyncInvocation<R> {
181 invocation_id: InvocationId,
182 status: InvocationStatus,
183 result: Result<R, RustvelloError>,
184}
185
186impl<R> SyncInvocation<R> {
187 pub fn success(invocation_id: InvocationId, result: R) -> Self {
189 Self {
190 invocation_id,
191 status: InvocationStatus::Success,
192 result: Ok(result),
193 }
194 }
195
196 pub fn failed(invocation_id: InvocationId, error: RustvelloError) -> Self {
198 Self {
199 invocation_id,
200 status: InvocationStatus::Failed,
201 result: Err(error),
202 }
203 }
204
205 pub fn invocation_id(&self) -> &InvocationId {
207 &self.invocation_id
208 }
209
210 pub fn status(&self) -> InvocationStatus {
212 self.status
213 }
214
215 pub fn is_done(&self) -> bool {
217 true
218 }
219}
220
221impl<R> std::fmt::Debug for SyncInvocation<R> {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct("SyncInvocation")
224 .field("invocation_id", &self.invocation_id)
225 .field("status", &self.status)
226 .finish()
227 }
228}
229
230#[non_exhaustive]
243pub enum Invocation<R: DeserializeOwned> {
244 Sync(SyncInvocation<R>),
246 Distributed(InvocationHandle<R>),
248}
249
250impl<R: DeserializeOwned> Invocation<R> {
251 pub fn invocation_id(&self) -> &InvocationId {
253 match self {
254 Self::Sync(s) => s.invocation_id(),
255 Self::Distributed(d) => d.invocation_id(),
256 }
257 }
258
259 pub async fn status(&self) -> RustvelloResult<InvocationStatus> {
261 match self {
262 Self::Sync(s) => Ok(s.status()),
263 Self::Distributed(d) => d.status().await,
264 }
265 }
266
267 pub async fn is_done(&self) -> RustvelloResult<bool> {
269 match self {
270 Self::Sync(s) => Ok(s.is_done()),
271 Self::Distributed(d) => d.is_done().await,
272 }
273 }
274
275 pub async fn result(self) -> RustvelloResult<R> {
280 match self {
281 Self::Sync(s) => s.result,
282 Self::Distributed(d) => d.result().await,
283 }
284 }
285
286 pub async fn wait(self, poll_interval: Duration) -> RustvelloResult<R> {
288 match self {
289 Self::Sync(s) => s.result,
290 Self::Distributed(d) => d.wait(poll_interval).await,
291 }
292 }
293
294 pub async fn wait_timeout(
296 self,
297 timeout: Duration,
298 poll_interval: Duration,
299 ) -> RustvelloResult<R> {
300 match self {
301 Self::Sync(s) => s.result,
302 Self::Distributed(d) => d.wait_timeout(timeout, poll_interval).await,
303 }
304 }
305
306 pub fn is_sync(&self) -> bool {
308 matches!(self, Self::Sync(_))
309 }
310
311 pub fn is_distributed(&self) -> bool {
313 matches!(self, Self::Distributed(_))
314 }
315}
316
317impl<R: DeserializeOwned> std::fmt::Debug for Invocation<R> {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 match self {
320 Self::Sync(s) => f.debug_tuple("Invocation::Sync").field(s).finish(),
321 Self::Distributed(d) => f.debug_tuple("Invocation::Distributed").field(d).finish(),
322 }
323 }
324}