1use async_nats::header::HeaderMap;
2use async_nats::ConnectOptions;
3use bytes::Bytes;
4use futures_util::stream::{self, BoxStream};
5use futures_util::StreamExt;
6use nkeys::KeyPair;
7use serde_json::Value;
8use tokio::time::timeout;
9
10use crate::operations::{OperationDescriptor, OperationInvoker, OperationTransport};
11use crate::proof::now_iat_seconds;
12use crate::{EventDescriptor, RpcDescriptor, SessionAuth, TrellisClientError};
13
14pub struct ServiceConnectOptions<'a> {
16 pub servers: &'a str,
17 pub sentinel_creds_path: &'a str,
18 pub session_key_seed_base64url: &'a str,
19 pub timeout_ms: u64,
20}
21
22pub struct UserConnectOptions<'a> {
24 pub servers: &'a str,
25 pub sentinel_jwt: &'a str,
26 pub sentinel_seed: &'a str,
27 pub session_key_seed_base64url: &'a str,
28 pub binding_token: &'a str,
29 pub timeout_ms: u64,
30}
31
32pub struct TrellisClient {
34 nats: async_nats::Client,
35 auth: SessionAuth,
36 timeout_ms: u64,
37}
38
39impl TrellisClient {
40 #[cfg(test)]
41 fn new(nats: async_nats::Client, auth: SessionAuth, timeout_ms: u64) -> Self {
42 Self {
43 nats,
44 auth,
45 timeout_ms,
46 }
47 }
48
49 pub fn nats(&self) -> &async_nats::Client {
51 &self.nats
52 }
53
54 pub fn auth(&self) -> &SessionAuth {
56 &self.auth
57 }
58
59 pub async fn connect_service(
61 opts: ServiceConnectOptions<'_>,
62 ) -> Result<Self, TrellisClientError> {
63 let auth = SessionAuth::from_seed_base64url(opts.session_key_seed_base64url)?;
64 let token = auth.nats_connect_token(now_iat_seconds());
65 let inbox_prefix = auth.inbox_prefix();
66
67 let nats = ConnectOptions::new()
68 .credentials(opts.sentinel_creds_path)?
69 .token(token)
70 .custom_inbox_prefix(inbox_prefix)
71 .connect(opts.servers)
72 .await
73 .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?;
74
75 Ok(Self {
76 nats,
77 auth,
78 timeout_ms: opts.timeout_ms,
79 })
80 }
81
82 pub async fn connect_user(opts: UserConnectOptions<'_>) -> Result<Self, TrellisClientError> {
84 let auth = SessionAuth::from_seed_base64url(opts.session_key_seed_base64url)?;
85 let token = auth.nats_connect_binding_token(opts.binding_token);
86 let inbox_prefix = auth.inbox_prefix();
87 let key_pair = std::sync::Arc::new(
88 KeyPair::from_seed(opts.sentinel_seed)
89 .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?,
90 );
91
92 let nats = ConnectOptions::with_jwt(opts.sentinel_jwt.to_string(), move |nonce| {
93 let key_pair = key_pair.clone();
94 async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) }
95 })
96 .token(token)
97 .custom_inbox_prefix(inbox_prefix)
98 .connect(opts.servers)
99 .await
100 .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?;
101
102 Ok(Self {
103 nats,
104 auth,
105 timeout_ms: opts.timeout_ms,
106 })
107 }
108
109 async fn request(
110 &self,
111 subject: &str,
112 payload: Bytes,
113 ) -> Result<async_nats::Message, TrellisClientError> {
114 let proof = self.auth.create_proof(subject, &payload);
115
116 let mut headers = HeaderMap::new();
117 headers.insert("session-key", self.auth.session_key.as_str());
118 headers.insert("proof", proof.as_str());
119
120 let future = self
121 .nats
122 .request_with_headers(subject.to_string(), headers, payload);
123 let message = timeout(std::time::Duration::from_millis(self.timeout_ms), future)
124 .await
125 .map_err(|_| TrellisClientError::Timeout)?
126 .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
127 Ok(message)
128 }
129
130 async fn request_json(&self, subject: &str, body: Value) -> Result<Value, TrellisClientError> {
131 let payload = Bytes::from(serde_json::to_vec(&body)?);
132 let message = self.request(subject, payload).await?;
133
134 decode_json_message(message)
135 }
136
137 pub async fn request_json_value(
139 &self,
140 subject: &str,
141 body: &Value,
142 ) -> Result<Value, TrellisClientError> {
143 self.request_json(subject, body.clone()).await
144 }
145
146 pub async fn call<D>(&self, input: &D::Input) -> Result<D::Output, TrellisClientError>
148 where
149 D: RpcDescriptor,
150 {
151 let value = serde_json::to_value(input)?;
152 let response = self.request_json(D::SUBJECT, value).await?;
153 Ok(serde_json::from_value(response)?)
154 }
155
156 pub async fn publish<D>(&self, event: &D::Event) -> Result<(), TrellisClientError>
158 where
159 D: EventDescriptor,
160 {
161 let payload = Bytes::from(serde_json::to_vec(event)?);
162 self.nats
163 .publish(D::SUBJECT.to_string(), payload)
164 .await
165 .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
166 Ok(())
167 }
168
169 pub fn operation<D>(&self) -> OperationInvoker<'_, Self, D>
171 where
172 D: OperationDescriptor,
173 {
174 OperationInvoker::new(self)
175 }
176}
177
178impl OperationTransport for TrellisClient {
179 async fn request_json_value(
180 &self,
181 subject: String,
182 body: Value,
183 ) -> Result<Value, TrellisClientError> {
184 TrellisClient::request_json_value(self, &subject, &body).await
185 }
186
187 async fn watch_json_value<'a>(
188 &'a self,
189 subject: String,
190 body: Value,
191 ) -> Result<BoxStream<'a, Result<Value, TrellisClientError>>, TrellisClientError> {
192 let payload = Bytes::from(serde_json::to_vec(&body)?);
193 let proof = self.auth.create_proof(&subject, &payload);
194
195 let mut headers = HeaderMap::new();
196 headers.insert("session-key", self.auth.session_key.as_str());
197 headers.insert("proof", proof.as_str());
198
199 let inbox = self.nats.new_inbox();
200 let subscriber = timeout(
201 std::time::Duration::from_millis(self.timeout_ms),
202 self.nats.subscribe(inbox.clone()),
203 )
204 .await
205 .map_err(|_| TrellisClientError::Timeout)?
206 .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
207
208 timeout(
209 std::time::Duration::from_millis(self.timeout_ms),
210 self.nats
211 .publish_with_reply_and_headers(subject, inbox, headers, payload),
212 )
213 .await
214 .map_err(|_| TrellisClientError::Timeout)?
215 .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
216
217 let stream = stream::try_unfold((subscriber, false), |(mut subscriber, done)| async move {
218 if done {
219 return Ok(None);
220 }
221
222 match subscriber.next().await {
223 Some(message) => {
224 let event = decode_watch_message(message)?;
225 let terminal = is_terminal_event(&event);
226 Ok(Some((event, (subscriber, terminal))))
227 }
228 None => Ok(None),
229 }
230 });
231
232 Ok(Box::pin(stream) as BoxStream<'a, Result<Value, TrellisClientError>>)
233 }
234}
235
236fn decode_json_message(message: async_nats::Message) -> Result<Value, TrellisClientError> {
237 if let Some(headers) = &message.headers {
238 if headers
239 .get("status")
240 .is_some_and(|status| status.as_str() == "error")
241 {
242 let value: Value = serde_json::from_slice(&message.payload)?;
243 return Err(TrellisClientError::RpcError(value.to_string()));
244 }
245 }
246
247 Ok(serde_json::from_slice(&message.payload)?)
248}
249
250fn decode_watch_message(message: async_nats::Message) -> Result<Value, TrellisClientError> {
251 decode_json_message(message)
252}
253
254fn is_terminal_event(event: &Value) -> bool {
255 matches!(
256 event.get("type").and_then(Value::as_str),
257 Some("completed" | "failed" | "cancelled")
258 )
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use bytes::Bytes;
265 use futures_util::StreamExt;
266 use serde::{Deserialize, Serialize};
267 use serde_json::{json, Value};
268 use std::process::Command;
269 use std::time::{Duration, SystemTime, UNIX_EPOCH};
270
271 use crate::control_subject;
272 use crate::operations::OperationEvent;
273
274 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
275 struct RefundInput {
276 charge_id: String,
277 }
278
279 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
280 struct RefundProgress {
281 message: String,
282 }
283
284 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
285 struct RefundOutput {
286 refund_id: String,
287 }
288
289 struct RefundOperation;
290
291 impl OperationDescriptor for RefundOperation {
292 type Input = RefundInput;
293 type Progress = RefundProgress;
294 type Output = RefundOutput;
295
296 const KEY: &'static str = "Billing.Refund";
297 const SUBJECT: &'static str = "operations.v1.Billing.Refund";
298 const CALLER_CAPABILITIES: &'static [&'static str] = &["billing.refund"];
299 const READ_CAPABILITIES: &'static [&'static str] = &["billing.read"];
300 const CANCEL_CAPABILITIES: &'static [&'static str] = &["billing.cancel"];
301 const CANCELABLE: bool = true;
302 }
303
304 struct RuntimeContainer {
305 runtime: String,
306 name: String,
307 }
308
309 impl Drop for RuntimeContainer {
310 fn drop(&mut self) {
311 let _ = Command::new(&self.runtime)
312 .args(["rm", "-f", &self.name])
313 .output();
314 }
315 }
316
317 fn detect_runtime() -> Option<&'static str> {
318 for runtime in ["podman", "docker"] {
319 let status = Command::new(runtime).arg("--version").status().ok()?;
320 if status.success() {
321 return Some(runtime);
322 }
323 }
324 None
325 }
326
327 fn run_command(runtime: &str, args: &[&str]) -> String {
328 let output = Command::new(runtime)
329 .args(args)
330 .output()
331 .expect("runtime command should execute");
332 if !output.status.success() {
333 panic!(
334 "runtime command failed: {} {}\nstdout: {}\nstderr: {}",
335 runtime,
336 args.join(" "),
337 String::from_utf8_lossy(&output.stdout),
338 String::from_utf8_lossy(&output.stderr),
339 );
340 }
341 String::from_utf8(output.stdout)
342 .expect("stdout should be utf-8")
343 .trim()
344 .to_string()
345 }
346
347 fn start_nats_container() -> (RuntimeContainer, String) {
348 let runtime = detect_runtime().expect("podman or docker runtime is required");
349 let now = SystemTime::now()
350 .duration_since(UNIX_EPOCH)
351 .expect("system clock should be after unix epoch")
352 .as_nanos();
353 let name = format!("trellis-client-watch-it-{}-{}", std::process::id(), now);
354
355 run_command(
356 runtime,
357 &[
358 "run",
359 "-d",
360 "--rm",
361 "--name",
362 &name,
363 "-p",
364 "127.0.0.1::4222",
365 "docker.io/library/nats:2.10-alpine",
366 ],
367 );
368
369 let mapping = run_command(runtime, &["port", &name, "4222/tcp"]);
370 let host_port = mapping
371 .split(':')
372 .next_back()
373 .expect("port mapping should include ':'")
374 .trim()
375 .to_string();
376 let server = format!("127.0.0.1:{}", host_port);
377
378 (
379 RuntimeContainer {
380 runtime: runtime.to_string(),
381 name,
382 },
383 server,
384 )
385 }
386
387 async fn connect_with_retry(server: &str) -> async_nats::Client {
388 let mut last_error = None;
389 for _ in 0..30 {
390 match async_nats::connect(server).await {
391 Ok(client) => return client,
392 Err(error) => {
393 last_error = Some(error.to_string());
394 tokio::time::sleep(Duration::from_millis(100)).await;
395 }
396 }
397 }
398
399 panic!(
400 "failed to connect to nats server {}: {}",
401 server,
402 last_error.unwrap_or_else(|| "unknown error".to_string())
403 );
404 }
405
406 fn test_auth() -> SessionAuth {
407 SessionAuth::from_seed_base64url(&crate::proof::base64url_encode(&[7u8; 32]))
408 .expect("session auth")
409 }
410
411 #[tokio::test]
412 #[ignore = "needs podman/docker runtime"]
413 async fn watch_stream_uses_reply_subject_and_stops_after_terminal_event() {
414 let (_container, server) = start_nats_container();
415
416 let service_client = connect_with_retry(&server).await;
417 let requester_client = connect_with_retry(&server).await;
418 let auth = test_auth();
419 let client = TrellisClient::new(requester_client, auth, 2_000);
420
421 let mut start_sub = service_client
422 .subscribe(RefundOperation::SUBJECT.to_string())
423 .await
424 .expect("subscribe start subject");
425 let mut control_sub = service_client
426 .subscribe(control_subject(RefundOperation::SUBJECT))
427 .await
428 .expect("subscribe control subject");
429
430 let service_for_start = service_client.clone();
431 let start_task = tokio::spawn(async move {
432 if let Some(msg) = start_sub.next().await {
433 let body: Value = serde_json::from_slice(&msg.payload).expect("start request json");
434 assert_eq!(body["charge_id"], "ch_123");
435 let accepted = json!({
436 "kind": "accepted",
437 "ref": {
438 "id": "op_123",
439 "service": "billing",
440 "operation": "Billing.Refund"
441 },
442 "snapshot": {
443 "revision": 1,
444 "state": "pending"
445 }
446 });
447 let reply = msg.reply.as_ref().expect("start reply subject").clone();
448 service_for_start
449 .publish(
450 reply,
451 Bytes::from(serde_json::to_vec(&accepted).expect("serialize accepted")),
452 )
453 .await
454 .expect("publish accepted reply");
455 }
456 });
457
458 let service_for_control = service_client.clone();
459 let control_task = tokio::spawn(async move {
460 if let Some(msg) = control_sub.next().await {
461 let body: Value =
462 serde_json::from_slice(&msg.payload).expect("control request json");
463 assert_eq!(body["action"], "watch");
464 assert_eq!(body["operationId"], "op_123");
465
466 let reply = msg.reply.as_ref().expect("watch reply subject").clone();
467 let frames = [
468 json!({
469 "kind": "snapshot",
470 "snapshot": {
471 "revision": 2,
472 "state": "running",
473 "progress": {
474 "message": "working"
475 }
476 }
477 }),
478 json!({
479 "kind": "event",
480 "event": {
481 "type": "progress",
482 "snapshot": {
483 "revision": 3,
484 "state": "running",
485 "progress": {
486 "message": "almost there"
487 }
488 }
489 }
490 }),
491 json!({"kind": "keepalive"}),
492 json!({
493 "kind": "event",
494 "event": {
495 "type": "completed",
496 "snapshot": {
497 "revision": 4,
498 "state": "completed",
499 "output": {
500 "refund_id": "rf_123"
501 }
502 }
503 }
504 }),
505 json!({
506 "kind": "event",
507 "event": {
508 "type": "progress",
509 "snapshot": {
510 "revision": 5,
511 "state": "running",
512 "progress": {
513 "message": "ignored"
514 }
515 }
516 }
517 }),
518 ];
519
520 for frame in frames {
521 service_for_control
522 .publish(
523 reply.clone(),
524 Bytes::from(serde_json::to_vec(&frame).expect("serialize frame")),
525 )
526 .await
527 .expect("publish watch frame");
528 }
529 }
530 });
531
532 tokio::time::sleep(Duration::from_millis(100)).await;
533
534 let operation = client
535 .operation::<RefundOperation>()
536 .start(&RefundInput {
537 charge_id: "ch_123".to_string(),
538 })
539 .await
540 .expect("start should succeed");
541 let stream = operation.watch().await.expect("watch should succeed");
542 let events: Vec<_> = stream.collect().await;
543
544 assert_eq!(events.len(), 3);
545 assert!(matches!(events[0], Ok(OperationEvent::Started { .. })));
546 assert!(matches!(events[1], Ok(OperationEvent::Progress { .. })));
547 assert!(matches!(events[2], Ok(OperationEvent::Completed { .. })));
548
549 start_task.await.expect("start task should complete");
550 control_task.await.expect("control task should complete");
551 }
552}