Skip to main content

trellis_client/
client.rs

1use async_nats::header::HeaderMap;
2use async_nats::ConnectOptions;
3use bytes::Bytes;
4use futures_util::stream::{self, BoxStream};
5use futures_util::StreamExt;
6use nkeys::KeyPair;
7use serde_json::Value;
8use tokio::time::timeout;
9
10use crate::operations::{OperationDescriptor, OperationInvoker, OperationTransport};
11use crate::proof::now_iat_seconds;
12use crate::{EventDescriptor, RpcDescriptor, SessionAuth, TrellisClientError};
13
14/// Connection options for a Trellis service/session-key principal.
15pub struct ServiceConnectOptions<'a> {
16    pub servers: &'a str,
17    pub sentinel_creds_path: &'a str,
18    pub session_key_seed_base64url: &'a str,
19    pub timeout_ms: u64,
20}
21
22/// Connection options for a user/binding-token principal.
23pub struct UserConnectOptions<'a> {
24    pub servers: &'a str,
25    pub sentinel_jwt: &'a str,
26    pub sentinel_seed: &'a str,
27    pub session_key_seed_base64url: &'a str,
28    pub binding_token: &'a str,
29    pub timeout_ms: u64,
30}
31
32/// A low-level Trellis client over NATS request/reply and publish primitives.
33pub struct TrellisClient {
34    nats: async_nats::Client,
35    auth: SessionAuth,
36    timeout_ms: u64,
37}
38
39impl TrellisClient {
40    #[cfg(test)]
41    fn new(nats: async_nats::Client, auth: SessionAuth, timeout_ms: u64) -> Self {
42        Self {
43            nats,
44            auth,
45            timeout_ms,
46        }
47    }
48
49    /// Expose the underlying NATS client for advanced use.
50    pub fn nats(&self) -> &async_nats::Client {
51        &self.nats
52    }
53
54    /// Return the session auth helper used by this client.
55    pub fn auth(&self) -> &SessionAuth {
56        &self.auth
57    }
58
59    /// Connect using sentinel credentials plus an `iat`-based service token.
60    pub async fn connect_service(
61        opts: ServiceConnectOptions<'_>,
62    ) -> Result<Self, TrellisClientError> {
63        let auth = SessionAuth::from_seed_base64url(opts.session_key_seed_base64url)?;
64        let token = auth.nats_connect_token(now_iat_seconds());
65        let inbox_prefix = auth.inbox_prefix();
66
67        let nats = ConnectOptions::new()
68            .credentials(opts.sentinel_creds_path)?
69            .token(token)
70            .custom_inbox_prefix(inbox_prefix)
71            .connect(opts.servers)
72            .await
73            .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?;
74
75        Ok(Self {
76            nats,
77            auth,
78            timeout_ms: opts.timeout_ms,
79        })
80    }
81
82    /// Connect using a previously issued binding token.
83    pub async fn connect_user(opts: UserConnectOptions<'_>) -> Result<Self, TrellisClientError> {
84        let auth = SessionAuth::from_seed_base64url(opts.session_key_seed_base64url)?;
85        let token = auth.nats_connect_binding_token(opts.binding_token);
86        let inbox_prefix = auth.inbox_prefix();
87        let key_pair = std::sync::Arc::new(
88            KeyPair::from_seed(opts.sentinel_seed)
89                .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?,
90        );
91
92        let nats = ConnectOptions::with_jwt(opts.sentinel_jwt.to_string(), move |nonce| {
93            let key_pair = key_pair.clone();
94            async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) }
95        })
96        .token(token)
97        .custom_inbox_prefix(inbox_prefix)
98        .connect(opts.servers)
99        .await
100        .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?;
101
102        Ok(Self {
103            nats,
104            auth,
105            timeout_ms: opts.timeout_ms,
106        })
107    }
108
109    async fn request(
110        &self,
111        subject: &str,
112        payload: Bytes,
113    ) -> Result<async_nats::Message, TrellisClientError> {
114        let proof = self.auth.create_proof(subject, &payload);
115
116        let mut headers = HeaderMap::new();
117        headers.insert("session-key", self.auth.session_key.as_str());
118        headers.insert("proof", proof.as_str());
119
120        let future = self
121            .nats
122            .request_with_headers(subject.to_string(), headers, payload);
123        let message = timeout(std::time::Duration::from_millis(self.timeout_ms), future)
124            .await
125            .map_err(|_| TrellisClientError::Timeout)?
126            .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
127        Ok(message)
128    }
129
130    async fn request_json(&self, subject: &str, body: Value) -> Result<Value, TrellisClientError> {
131        let payload = Bytes::from(serde_json::to_vec(&body)?);
132        let message = self.request(subject, payload).await?;
133
134        decode_json_message(message)
135    }
136
137    /// Call a raw subject with a JSON value payload.
138    pub async fn request_json_value(
139        &self,
140        subject: &str,
141        body: &Value,
142    ) -> Result<Value, TrellisClientError> {
143        self.request_json(subject, body.clone()).await
144    }
145
146    /// Call one descriptor-backed RPC.
147    pub async fn call<D>(&self, input: &D::Input) -> Result<D::Output, TrellisClientError>
148    where
149        D: RpcDescriptor,
150    {
151        let value = serde_json::to_value(input)?;
152        let response = self.request_json(D::SUBJECT, value).await?;
153        Ok(serde_json::from_value(response)?)
154    }
155
156    /// Publish one descriptor-backed event.
157    pub async fn publish<D>(&self, event: &D::Event) -> Result<(), TrellisClientError>
158    where
159        D: EventDescriptor,
160    {
161        let payload = Bytes::from(serde_json::to_vec(event)?);
162        self.nats
163            .publish(D::SUBJECT.to_string(), payload)
164            .await
165            .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
166        Ok(())
167    }
168
169    /// Start or control one descriptor-backed operation.
170    pub fn operation<D>(&self) -> OperationInvoker<'_, Self, D>
171    where
172        D: OperationDescriptor,
173    {
174        OperationInvoker::new(self)
175    }
176}
177
178impl OperationTransport for TrellisClient {
179    async fn request_json_value(
180        &self,
181        subject: String,
182        body: Value,
183    ) -> Result<Value, TrellisClientError> {
184        TrellisClient::request_json_value(self, &subject, &body).await
185    }
186
187    async fn watch_json_value<'a>(
188        &'a self,
189        subject: String,
190        body: Value,
191    ) -> Result<BoxStream<'a, Result<Value, TrellisClientError>>, TrellisClientError> {
192        let payload = Bytes::from(serde_json::to_vec(&body)?);
193        let proof = self.auth.create_proof(&subject, &payload);
194
195        let mut headers = HeaderMap::new();
196        headers.insert("session-key", self.auth.session_key.as_str());
197        headers.insert("proof", proof.as_str());
198
199        let inbox = self.nats.new_inbox();
200        let subscriber = timeout(
201            std::time::Duration::from_millis(self.timeout_ms),
202            self.nats.subscribe(inbox.clone()),
203        )
204        .await
205        .map_err(|_| TrellisClientError::Timeout)?
206        .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
207
208        timeout(
209            std::time::Duration::from_millis(self.timeout_ms),
210            self.nats
211                .publish_with_reply_and_headers(subject, inbox, headers, payload),
212        )
213        .await
214        .map_err(|_| TrellisClientError::Timeout)?
215        .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
216
217        let stream = stream::try_unfold((subscriber, false), |(mut subscriber, done)| async move {
218            if done {
219                return Ok(None);
220            }
221
222            match subscriber.next().await {
223                Some(message) => {
224                    let event = decode_watch_message(message)?;
225                    let terminal = is_terminal_event(&event);
226                    Ok(Some((event, (subscriber, terminal))))
227                }
228                None => Ok(None),
229            }
230        });
231
232        Ok(Box::pin(stream) as BoxStream<'a, Result<Value, TrellisClientError>>)
233    }
234}
235
236fn decode_json_message(message: async_nats::Message) -> Result<Value, TrellisClientError> {
237    if let Some(headers) = &message.headers {
238        if headers
239            .get("status")
240            .is_some_and(|status| status.as_str() == "error")
241        {
242            let value: Value = serde_json::from_slice(&message.payload)?;
243            return Err(TrellisClientError::RpcError(value.to_string()));
244        }
245    }
246
247    Ok(serde_json::from_slice(&message.payload)?)
248}
249
250fn decode_watch_message(message: async_nats::Message) -> Result<Value, TrellisClientError> {
251    decode_json_message(message)
252}
253
254fn is_terminal_event(event: &Value) -> bool {
255    matches!(
256        event.get("type").and_then(Value::as_str),
257        Some("completed" | "failed" | "cancelled")
258    )
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use bytes::Bytes;
265    use futures_util::StreamExt;
266    use serde::{Deserialize, Serialize};
267    use serde_json::{json, Value};
268    use std::process::Command;
269    use std::time::{Duration, SystemTime, UNIX_EPOCH};
270
271    use crate::control_subject;
272    use crate::operations::OperationEvent;
273
274    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
275    struct RefundInput {
276        charge_id: String,
277    }
278
279    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
280    struct RefundProgress {
281        message: String,
282    }
283
284    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
285    struct RefundOutput {
286        refund_id: String,
287    }
288
289    struct RefundOperation;
290
291    impl OperationDescriptor for RefundOperation {
292        type Input = RefundInput;
293        type Progress = RefundProgress;
294        type Output = RefundOutput;
295
296        const KEY: &'static str = "Billing.Refund";
297        const SUBJECT: &'static str = "operations.v1.Billing.Refund";
298        const CALLER_CAPABILITIES: &'static [&'static str] = &["billing.refund"];
299        const READ_CAPABILITIES: &'static [&'static str] = &["billing.read"];
300        const CANCEL_CAPABILITIES: &'static [&'static str] = &["billing.cancel"];
301        const CANCELABLE: bool = true;
302    }
303
304    struct RuntimeContainer {
305        runtime: String,
306        name: String,
307    }
308
309    impl Drop for RuntimeContainer {
310        fn drop(&mut self) {
311            let _ = Command::new(&self.runtime)
312                .args(["rm", "-f", &self.name])
313                .output();
314        }
315    }
316
317    fn detect_runtime() -> Option<&'static str> {
318        for runtime in ["podman", "docker"] {
319            let status = Command::new(runtime).arg("--version").status().ok()?;
320            if status.success() {
321                return Some(runtime);
322            }
323        }
324        None
325    }
326
327    fn run_command(runtime: &str, args: &[&str]) -> String {
328        let output = Command::new(runtime)
329            .args(args)
330            .output()
331            .expect("runtime command should execute");
332        if !output.status.success() {
333            panic!(
334                "runtime command failed: {} {}\nstdout: {}\nstderr: {}",
335                runtime,
336                args.join(" "),
337                String::from_utf8_lossy(&output.stdout),
338                String::from_utf8_lossy(&output.stderr),
339            );
340        }
341        String::from_utf8(output.stdout)
342            .expect("stdout should be utf-8")
343            .trim()
344            .to_string()
345    }
346
347    fn start_nats_container() -> (RuntimeContainer, String) {
348        let runtime = detect_runtime().expect("podman or docker runtime is required");
349        let now = SystemTime::now()
350            .duration_since(UNIX_EPOCH)
351            .expect("system clock should be after unix epoch")
352            .as_nanos();
353        let name = format!("trellis-client-watch-it-{}-{}", std::process::id(), now);
354
355        run_command(
356            runtime,
357            &[
358                "run",
359                "-d",
360                "--rm",
361                "--name",
362                &name,
363                "-p",
364                "127.0.0.1::4222",
365                "docker.io/library/nats:2.10-alpine",
366            ],
367        );
368
369        let mapping = run_command(runtime, &["port", &name, "4222/tcp"]);
370        let host_port = mapping
371            .split(':')
372            .next_back()
373            .expect("port mapping should include ':'")
374            .trim()
375            .to_string();
376        let server = format!("127.0.0.1:{}", host_port);
377
378        (
379            RuntimeContainer {
380                runtime: runtime.to_string(),
381                name,
382            },
383            server,
384        )
385    }
386
387    async fn connect_with_retry(server: &str) -> async_nats::Client {
388        let mut last_error = None;
389        for _ in 0..30 {
390            match async_nats::connect(server).await {
391                Ok(client) => return client,
392                Err(error) => {
393                    last_error = Some(error.to_string());
394                    tokio::time::sleep(Duration::from_millis(100)).await;
395                }
396            }
397        }
398
399        panic!(
400            "failed to connect to nats server {}: {}",
401            server,
402            last_error.unwrap_or_else(|| "unknown error".to_string())
403        );
404    }
405
406    fn test_auth() -> SessionAuth {
407        SessionAuth::from_seed_base64url(&crate::proof::base64url_encode(&[7u8; 32]))
408            .expect("session auth")
409    }
410
411    #[tokio::test]
412    #[ignore = "needs podman/docker runtime"]
413    async fn watch_stream_uses_reply_subject_and_stops_after_terminal_event() {
414        let (_container, server) = start_nats_container();
415
416        let service_client = connect_with_retry(&server).await;
417        let requester_client = connect_with_retry(&server).await;
418        let auth = test_auth();
419        let client = TrellisClient::new(requester_client, auth, 2_000);
420
421        let mut start_sub = service_client
422            .subscribe(RefundOperation::SUBJECT.to_string())
423            .await
424            .expect("subscribe start subject");
425        let mut control_sub = service_client
426            .subscribe(control_subject(RefundOperation::SUBJECT))
427            .await
428            .expect("subscribe control subject");
429
430        let service_for_start = service_client.clone();
431        let start_task = tokio::spawn(async move {
432            if let Some(msg) = start_sub.next().await {
433                let body: Value = serde_json::from_slice(&msg.payload).expect("start request json");
434                assert_eq!(body["charge_id"], "ch_123");
435                let accepted = json!({
436                    "kind": "accepted",
437                    "ref": {
438                        "id": "op_123",
439                        "service": "billing",
440                        "operation": "Billing.Refund"
441                    },
442                    "snapshot": {
443                        "revision": 1,
444                        "state": "pending"
445                    }
446                });
447                let reply = msg.reply.as_ref().expect("start reply subject").clone();
448                service_for_start
449                    .publish(
450                        reply,
451                        Bytes::from(serde_json::to_vec(&accepted).expect("serialize accepted")),
452                    )
453                    .await
454                    .expect("publish accepted reply");
455            }
456        });
457
458        let service_for_control = service_client.clone();
459        let control_task = tokio::spawn(async move {
460            if let Some(msg) = control_sub.next().await {
461                let body: Value =
462                    serde_json::from_slice(&msg.payload).expect("control request json");
463                assert_eq!(body["action"], "watch");
464                assert_eq!(body["operationId"], "op_123");
465
466                let reply = msg.reply.as_ref().expect("watch reply subject").clone();
467                let frames = [
468                    json!({
469                        "kind": "snapshot",
470                        "snapshot": {
471                            "revision": 2,
472                            "state": "running",
473                            "progress": {
474                                "message": "working"
475                            }
476                        }
477                    }),
478                    json!({
479                        "kind": "event",
480                        "event": {
481                            "type": "progress",
482                            "snapshot": {
483                                "revision": 3,
484                                "state": "running",
485                                "progress": {
486                                    "message": "almost there"
487                                }
488                            }
489                        }
490                    }),
491                    json!({"kind": "keepalive"}),
492                    json!({
493                        "kind": "event",
494                        "event": {
495                            "type": "completed",
496                            "snapshot": {
497                                "revision": 4,
498                                "state": "completed",
499                                "output": {
500                                    "refund_id": "rf_123"
501                                }
502                            }
503                        }
504                    }),
505                    json!({
506                        "kind": "event",
507                        "event": {
508                            "type": "progress",
509                            "snapshot": {
510                                "revision": 5,
511                                "state": "running",
512                                "progress": {
513                                    "message": "ignored"
514                                }
515                            }
516                        }
517                    }),
518                ];
519
520                for frame in frames {
521                    service_for_control
522                        .publish(
523                            reply.clone(),
524                            Bytes::from(serde_json::to_vec(&frame).expect("serialize frame")),
525                        )
526                        .await
527                        .expect("publish watch frame");
528                }
529            }
530        });
531
532        tokio::time::sleep(Duration::from_millis(100)).await;
533
534        let operation = client
535            .operation::<RefundOperation>()
536            .start(&RefundInput {
537                charge_id: "ch_123".to_string(),
538            })
539            .await
540            .expect("start should succeed");
541        let stream = operation.watch().await.expect("watch should succeed");
542        let events: Vec<_> = stream.collect().await;
543
544        assert_eq!(events.len(), 3);
545        assert!(matches!(events[0], Ok(OperationEvent::Started { .. })));
546        assert!(matches!(events[1], Ok(OperationEvent::Progress { .. })));
547        assert!(matches!(events[2], Ok(OperationEvent::Completed { .. })));
548
549        start_task.await.expect("start task should complete");
550        control_task.await.expect("control task should complete");
551    }
552}