1use async_nats::header::HeaderMap;
2use async_nats::ConnectOptions;
3use bytes::Bytes;
4use nkeys::KeyPair;
5use serde_json::Value;
6use tokio::time::timeout;
7
8use crate::proof::now_iat_seconds;
9use crate::{EventDescriptor, RpcDescriptor, SessionAuth, TrellisClientError};
10
11pub struct ServiceConnectOptions<'a> {
13 pub servers: &'a str,
14 pub sentinel_creds_path: &'a str,
15 pub session_key_seed_base64url: &'a str,
16 pub timeout_ms: u64,
17}
18
19pub struct UserConnectOptions<'a> {
21 pub servers: &'a str,
22 pub sentinel_jwt: &'a str,
23 pub sentinel_seed: &'a str,
24 pub session_key_seed_base64url: &'a str,
25 pub binding_token: &'a str,
26 pub timeout_ms: u64,
27}
28
29pub struct TrellisClient {
31 nats: async_nats::Client,
32 auth: SessionAuth,
33 timeout_ms: u64,
34}
35
36impl TrellisClient {
37 pub fn nats(&self) -> &async_nats::Client {
39 &self.nats
40 }
41
42 pub fn auth(&self) -> &SessionAuth {
44 &self.auth
45 }
46
47 pub async fn connect_service(
49 opts: ServiceConnectOptions<'_>,
50 ) -> Result<Self, TrellisClientError> {
51 let auth = SessionAuth::from_seed_base64url(opts.session_key_seed_base64url)?;
52 let token = auth.nats_connect_token(now_iat_seconds());
53 let inbox_prefix = auth.inbox_prefix();
54
55 let nats = ConnectOptions::new()
56 .credentials(opts.sentinel_creds_path)?
57 .token(token)
58 .custom_inbox_prefix(inbox_prefix)
59 .connect(opts.servers)
60 .await
61 .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?;
62
63 Ok(Self {
64 nats,
65 auth,
66 timeout_ms: opts.timeout_ms,
67 })
68 }
69
70 pub async fn connect_user(opts: UserConnectOptions<'_>) -> Result<Self, TrellisClientError> {
72 let auth = SessionAuth::from_seed_base64url(opts.session_key_seed_base64url)?;
73 let token = auth.nats_connect_binding_token(opts.binding_token);
74 let inbox_prefix = auth.inbox_prefix();
75 let key_pair = std::sync::Arc::new(
76 KeyPair::from_seed(opts.sentinel_seed)
77 .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?,
78 );
79
80 let nats = ConnectOptions::with_jwt(opts.sentinel_jwt.to_string(), move |nonce| {
81 let key_pair = key_pair.clone();
82 async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) }
83 })
84 .token(token)
85 .custom_inbox_prefix(inbox_prefix)
86 .connect(opts.servers)
87 .await
88 .map_err(|error| TrellisClientError::NatsConnect(error.to_string()))?;
89
90 Ok(Self {
91 nats,
92 auth,
93 timeout_ms: opts.timeout_ms,
94 })
95 }
96
97 async fn request(
98 &self,
99 subject: &str,
100 payload: Bytes,
101 ) -> Result<async_nats::Message, TrellisClientError> {
102 let proof = self.auth.create_proof(subject, &payload);
103
104 let mut headers = HeaderMap::new();
105 headers.insert("session-key", self.auth.session_key.as_str());
106 headers.insert("proof", proof.as_str());
107
108 let future = self
109 .nats
110 .request_with_headers(subject.to_string(), headers, payload);
111 let message = timeout(std::time::Duration::from_millis(self.timeout_ms), future)
112 .await
113 .map_err(|_| TrellisClientError::Timeout)?
114 .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
115 Ok(message)
116 }
117
118 async fn request_json(&self, subject: &str, body: Value) -> Result<Value, TrellisClientError> {
119 let payload = Bytes::from(serde_json::to_vec(&body)?);
120 let message = self.request(subject, payload).await?;
121
122 if let Some(headers) = &message.headers {
123 if let Some(status) = headers.get("status") {
124 if status.as_str() == "error" {
125 let value: Value = serde_json::from_slice(&message.payload)?;
126 return Err(TrellisClientError::RpcError(value.to_string()));
127 }
128 }
129 }
130
131 Ok(serde_json::from_slice(&message.payload)?)
132 }
133
134 pub async fn request_json_value(
136 &self,
137 subject: &str,
138 body: &Value,
139 ) -> Result<Value, TrellisClientError> {
140 self.request_json(subject, body.clone()).await
141 }
142
143 pub async fn call<D>(&self, input: &D::Input) -> Result<D::Output, TrellisClientError>
145 where
146 D: RpcDescriptor,
147 {
148 let value = serde_json::to_value(input)?;
149 let response = self.request_json(D::SUBJECT, value).await?;
150 Ok(serde_json::from_value(response)?)
151 }
152
153 pub async fn publish<D>(&self, event: &D::Event) -> Result<(), TrellisClientError>
155 where
156 D: EventDescriptor,
157 {
158 let payload = Bytes::from(serde_json::to_vec(event)?);
159 self.nats
160 .publish(D::SUBJECT.to_string(), payload)
161 .await
162 .map_err(|error| TrellisClientError::NatsRequest(error.to_string()))?;
163 Ok(())
164 }
165}