1use std::future::Future;
2use std::marker::PhantomData;
3
4use futures_util::stream::{self, BoxStream};
5use futures_util::StreamExt;
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use serde_json::{json, Value};
8
9use crate::TrellisClientError;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(rename_all = "lowercase")]
13pub enum OperationState {
14 Pending,
15 Running,
16 Completed,
17 Failed,
18 Cancelled,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
22pub struct OperationRefData {
23 pub id: String,
24 pub service: String,
25 pub operation: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
29#[serde(rename_all = "camelCase")]
30pub struct OperationSnapshot<TProgress = Value, TOutput = Value> {
31 pub revision: u64,
32 pub state: OperationState,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub progress: Option<TProgress>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 pub output: Option<TOutput>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "camelCase")]
41struct AcceptedEnvelope<TProgress = Value, TOutput = Value> {
42 kind: String,
43 #[serde(rename = "ref")]
44 operation_ref: OperationRefData,
45 snapshot: OperationSnapshot<TProgress, TOutput>,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49#[serde(rename_all = "camelCase")]
50struct SnapshotFrame<TProgress = Value, TOutput = Value> {
51 kind: String,
52 snapshot: OperationSnapshot<TProgress, TOutput>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56#[serde(tag = "type", rename_all = "lowercase")]
57pub enum OperationEvent<TProgress = Value, TOutput = Value> {
58 Accepted {
59 snapshot: OperationSnapshot<TProgress, TOutput>,
60 },
61 Started {
62 snapshot: OperationSnapshot<TProgress, TOutput>,
63 },
64 Progress {
65 snapshot: OperationSnapshot<TProgress, TOutput>,
66 },
67 Completed {
68 snapshot: OperationSnapshot<TProgress, TOutput>,
69 },
70 Failed {
71 snapshot: OperationSnapshot<TProgress, TOutput>,
72 },
73 Cancelled {
74 snapshot: OperationSnapshot<TProgress, TOutput>,
75 },
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79#[serde(rename_all = "camelCase")]
80struct EventFrame<TProgress = Value, TOutput = Value> {
81 kind: String,
82 event: OperationEvent<TProgress, TOutput>,
83}
84
85pub trait OperationDescriptor {
86 type Input: Serialize;
87 type Progress: DeserializeOwned + Send + 'static;
88 type Output: DeserializeOwned + Send + 'static;
89
90 const KEY: &'static str;
91 const SUBJECT: &'static str;
92 const CALLER_CAPABILITIES: &'static [&'static str];
93 const READ_CAPABILITIES: &'static [&'static str];
94 const CANCEL_CAPABILITIES: &'static [&'static str];
95 const CANCELABLE: bool;
96}
97
98#[doc(hidden)]
99pub trait OperationTransport {
100 fn request_json_value<'a>(
101 &'a self,
102 subject: String,
103 body: Value,
104 ) -> impl Future<Output = Result<Value, TrellisClientError>> + Send + 'a;
105
106 fn watch_json_value<'a>(
107 &'a self,
108 subject: String,
109 body: Value,
110 ) -> impl Future<
111 Output = Result<BoxStream<'a, Result<Value, TrellisClientError>>, TrellisClientError>,
112 > + Send
113 + 'a;
114}
115
116pub struct OperationInvoker<'a, T, D> {
117 transport: &'a T,
118 _descriptor: PhantomData<D>,
119}
120
121pub struct OperationRef<'a, T, D> {
122 transport: &'a T,
123 data: OperationRefData,
124 _descriptor: PhantomData<D>,
125}
126
127fn is_terminal_state(state: &OperationState) -> bool {
128 matches!(
129 state,
130 OperationState::Completed | OperationState::Failed | OperationState::Cancelled
131 )
132}
133
134impl<'a, T, D> OperationInvoker<'a, T, D> {
135 pub fn new(transport: &'a T) -> Self {
136 Self {
137 transport,
138 _descriptor: PhantomData,
139 }
140 }
141}
142
143impl<'a, T, D> OperationInvoker<'a, T, D>
144where
145 T: OperationTransport,
146 D: OperationDescriptor,
147 D::Progress: Send,
148 D::Output: Send,
149{
150 pub async fn start(
151 &self,
152 input: &D::Input,
153 ) -> Result<OperationRef<'a, T, D>, TrellisClientError> {
154 let body = serde_json::to_value(input)?;
155 let response = self
156 .transport
157 .request_json_value(D::SUBJECT.to_string(), body)
158 .await?;
159 let accepted: AcceptedEnvelope<D::Progress, D::Output> = serde_json::from_value(response)?;
160 if accepted.kind != "accepted" {
161 return Err(TrellisClientError::OperationProtocol(format!(
162 "expected accepted envelope, got '{}'",
163 accepted.kind
164 )));
165 }
166 Ok(OperationRef {
167 transport: self.transport,
168 data: accepted.operation_ref,
169 _descriptor: PhantomData,
170 })
171 }
172}
173
174impl<'a, T, D> OperationRef<'a, T, D> {
175 pub fn id(&self) -> &str {
176 &self.data.id
177 }
178
179 pub fn service(&self) -> &str {
180 &self.data.service
181 }
182
183 pub fn operation(&self) -> &str {
184 &self.data.operation
185 }
186}
187
188impl<'a, T, D> OperationRef<'a, T, D>
189where
190 T: OperationTransport,
191 D: OperationDescriptor,
192{
193 pub async fn get(
194 &self,
195 ) -> Result<OperationSnapshot<D::Progress, D::Output>, TrellisClientError> {
196 let body = json!({
197 "action": "get",
198 "operationId": self.id(),
199 });
200 let response = self
201 .transport
202 .request_json_value(control_subject(D::SUBJECT), body)
203 .await?;
204 let frame: SnapshotFrame<D::Progress, D::Output> = serde_json::from_value(response)?;
205 if frame.kind != "snapshot" {
206 return Err(TrellisClientError::OperationProtocol(format!(
207 "expected snapshot frame, got '{}'",
208 frame.kind
209 )));
210 }
211 Ok(frame.snapshot)
212 }
213
214 pub async fn wait(
215 &self,
216 ) -> Result<OperationSnapshot<D::Progress, D::Output>, TrellisClientError> {
217 let body = json!({
218 "action": "wait",
219 "operationId": self.id(),
220 });
221 let response = self
222 .transport
223 .request_json_value(control_subject(D::SUBJECT), body)
224 .await?;
225 let frame: SnapshotFrame<D::Progress, D::Output> = serde_json::from_value(response)?;
226 if frame.kind != "snapshot" {
227 return Err(TrellisClientError::OperationProtocol(format!(
228 "expected snapshot frame, got '{}'",
229 frame.kind
230 )));
231 }
232 if !is_terminal_state(&frame.snapshot.state) {
233 return Err(TrellisClientError::OperationProtocol(
234 "wait returned non-terminal snapshot".to_string(),
235 ));
236 }
237 Ok(frame.snapshot)
238 }
239
240 pub async fn cancel(
241 &self,
242 ) -> Result<OperationSnapshot<D::Progress, D::Output>, TrellisClientError> {
243 let body = json!({
244 "action": "cancel",
245 "operationId": self.id(),
246 });
247 let response = self
248 .transport
249 .request_json_value(control_subject(D::SUBJECT), body)
250 .await?;
251 let frame: SnapshotFrame<D::Progress, D::Output> = serde_json::from_value(response)?;
252 if frame.kind != "snapshot" {
253 return Err(TrellisClientError::OperationProtocol(format!(
254 "expected snapshot frame, got '{}'",
255 frame.kind
256 )));
257 }
258 Ok(frame.snapshot)
259 }
260
261 pub async fn watch(
262 &self,
263 ) -> Result<
264 BoxStream<'a, Result<OperationEvent<D::Progress, D::Output>, TrellisClientError>>,
265 TrellisClientError,
266 > {
267 let control = control_subject(D::SUBJECT);
268 let body = json!({
269 "action": "watch",
270 "operationId": self.id(),
271 });
272 let response = self.transport.watch_json_value(control, body).await?;
273 Ok(Box::pin(stream::try_unfold(
274 (response, false),
275 |(mut response, done)| async move {
276 if done {
277 return Ok(None);
278 }
279
280 loop {
281 match response.next().await {
282 Some(frame) => {
283 let event = match frame {
284 Ok(value) => {
285 match decode_watch_frame::<D::Progress, D::Output>(value) {
286 Ok(Some(event)) => event,
287 Ok(None) => continue,
288 Err(error) => return Err(error),
289 }
290 }
291 Err(error) => return Err(error),
292 };
293
294 let terminal = is_terminal_event(&event);
295 return Ok(Some((event, (response, terminal))));
296 }
297 None => return Ok(None),
298 }
299 }
300 },
301 )))
302 }
303}
304
305fn decode_watch_frame<TProgress: DeserializeOwned, TOutput: DeserializeOwned>(
306 value: Value,
307) -> Result<Option<OperationEvent<TProgress, TOutput>>, TrellisClientError> {
308 if value.get("kind").and_then(Value::as_str) == Some("keepalive") {
309 return Ok(None);
310 }
311
312 let kind = value.get("kind").and_then(Value::as_str).ok_or_else(|| {
313 TrellisClientError::OperationProtocol("expected watch frame kind".to_string())
314 })?;
315
316 match kind {
317 "snapshot" => {
318 let frame: SnapshotFrame<TProgress, TOutput> = serde_json::from_value(value)?;
319 Ok(Some(snapshot_to_event(frame.snapshot)))
320 }
321 "event" => {
322 let frame: EventFrame<TProgress, TOutput> = serde_json::from_value(value)?;
323 Ok(Some(frame.event))
324 }
325 _ => Err(TrellisClientError::OperationProtocol(
326 "expected snapshot/event/keepalive frame".to_string(),
327 )),
328 }
329}
330
331fn snapshot_to_event<TProgress, TOutput>(
332 snapshot: OperationSnapshot<TProgress, TOutput>,
333) -> OperationEvent<TProgress, TOutput> {
334 match snapshot.state {
335 OperationState::Pending => OperationEvent::Accepted { snapshot },
336 OperationState::Running => OperationEvent::Started { snapshot },
337 OperationState::Completed => OperationEvent::Completed { snapshot },
338 OperationState::Failed => OperationEvent::Failed { snapshot },
339 OperationState::Cancelled => OperationEvent::Cancelled { snapshot },
340 }
341}
342
343fn is_terminal_event<TProgress, TOutput>(event: &OperationEvent<TProgress, TOutput>) -> bool {
344 matches!(
345 event,
346 OperationEvent::Completed { .. }
347 | OperationEvent::Failed { .. }
348 | OperationEvent::Cancelled { .. }
349 )
350}
351
352pub fn control_subject(subject: &str) -> String {
353 format!("{subject}.control")
354}