1use std::collections::{BTreeMap, HashMap};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use rmpv::Value;
7use rpc_runtime_activation::{
8 CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
9 ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
10 ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
11 decode_create_instance_response, decode_list_instances_response,
12 decode_release_instance_response, decode_resolve_instance_ids_response,
13 encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
14 encode_resolve_instance_ids_request,
15};
16use rpc_runtime_core::{
17 CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, Options,
18 RUNTIME_PROTOCOL_VERSION, Request, RequestId, Role, ServiceGuid,
19};
20use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
21use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
22use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
23use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
24
25#[derive(Clone)]
26pub struct RpcClient {
27 inner: Arc<ClientInner>,
28}
29
30pub const DEFAULT_AUTH_TOKEN_OPTION_KEY: &str = "tripley.auth.token";
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub struct RpcClientHandshakeConfig {
34 pub auth_token: Option<String>,
35 pub auth_option_key: String,
36}
37
38impl RpcClientHandshakeConfig {
39 pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
40 self.auth_token = Some(token.into());
41 self
42 }
43
44 pub fn with_auth_option_key(mut self, key: impl Into<String>) -> Self {
45 self.auth_option_key = key.into();
46 self
47 }
48
49 fn hello_options(&self) -> Options {
50 self.auth_token
51 .as_ref()
52 .map(|token| vec![(self.auth_option_key.clone(), Value::from(token.as_str()))])
53 .unwrap_or_default()
54 }
55}
56
57impl Default for RpcClientHandshakeConfig {
58 fn default() -> Self {
59 Self {
60 auth_token: None,
61 auth_option_key: DEFAULT_AUTH_TOKEN_OPTION_KEY.to_string(),
62 }
63 }
64}
65
66struct ClientInner {
67 sender: RpcSender,
68 next_request_id: AtomicU64,
69 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
70 notifications: broadcast::Sender<Notification>,
71}
72
73impl RpcClient {
74 pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
75 Self::connect_with_handshake_config(endpoint, config, RpcClientHandshakeConfig::default())
76 .await
77 }
78
79 pub async fn connect_with_handshake_config(
80 endpoint: IpcEndpoint,
81 config: FrameConfig,
82 handshake: RpcClientHandshakeConfig,
83 ) -> Result<Self, RuntimeError> {
84 let connection = IpcConnection::connect(endpoint, config)
85 .await
86 .map_err(|err| {
87 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
88 })?;
89 Self::from_connection_with_handshake_config(connection, handshake).await
90 }
91
92 pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
93 where
94 C: Into<RpcConnection>,
95 {
96 Self::from_connection_with_handshake_config(connection, RpcClientHandshakeConfig::default())
97 .await
98 }
99
100 pub async fn from_connection_with_handshake_config<C>(
101 connection: C,
102 handshake: RpcClientHandshakeConfig,
103 ) -> Result<Self, RuntimeError>
104 where
105 C: Into<RpcConnection>,
106 {
107 let (sender, mut receiver) = connection.into().split();
108 sender
109 .send_envelope(&Envelope::Hello(Hello {
110 protocol_version: RUNTIME_PROTOCOL_VERSION,
111 role: Role::Client,
112 capability_bits: client_capabilities(),
113 max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
114 options: handshake.hello_options(),
115 }))
116 .await
117 .map_err(|err| {
118 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
119 })?;
120
121 let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
122 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
123 })?
124 else {
125 return Err(RuntimeError::transport(
126 RuntimeErrorCode::InternalRuntimeError,
127 "server disconnected during handshake",
128 ));
129 };
130 let Envelope::HelloAck(ack) = envelope else {
131 return Err(RuntimeError::protocol(
132 RuntimeErrorCode::InvalidEnvelope,
133 "expected HELLO_ACK during handshake",
134 ));
135 };
136 if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
137 return Err(RuntimeError::protocol(
138 RuntimeErrorCode::UnsupportedProtocolVersion,
139 "server returned unsupported protocol version",
140 ));
141 }
142
143 let (notifications, _) = broadcast::channel(128);
144 let inner = Arc::new(ClientInner {
145 sender,
146 next_request_id: AtomicU64::new(1),
147 pending: Mutex::new(HashMap::new()),
148 notifications,
149 });
150 spawn_receive_loop(Arc::clone(&inner), receiver);
151 Ok(Self { inner })
152 }
153
154 pub async fn call(
155 &self,
156 instance_id: InstanceId,
157 method_id: MethodId,
158 payload: Value,
159 ) -> Result<Value, RuntimeError> {
160 let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
161 let (tx, rx) = oneshot::channel();
162 self.inner.pending.lock().await.insert(request_id, tx);
163
164 let send_result = self
165 .inner
166 .sender
167 .send_envelope(&Envelope::Request(Request {
168 request_id: RequestId::new(request_id),
169 instance_id,
170 method_id,
171 payload,
172 }))
173 .await;
174 if let Err(err) = send_result {
175 self.inner.pending.lock().await.remove(&request_id);
176 return Err(RuntimeError::transport(
177 RuntimeErrorCode::InternalRuntimeError,
178 err.to_string(),
179 ));
180 }
181
182 rx.await.map_err(|_| {
183 RuntimeError::transport(
184 RuntimeErrorCode::InternalRuntimeError,
185 "response channel closed before request completed",
186 )
187 })?
188 }
189
190 pub async fn call_timeout(
191 &self,
192 instance_id: InstanceId,
193 method_id: MethodId,
194 payload: Value,
195 timeout: Duration,
196 ) -> Result<Value, RuntimeError> {
197 tokio::time::timeout(timeout, self.call(instance_id, method_id, payload))
198 .await
199 .map_err(|_| {
200 RuntimeError::runtime(RuntimeErrorCode::RequestTimeout, "request timed out")
201 })?
202 }
203
204 pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
205 let response = self
206 .call(
207 activation_instance_id(),
208 MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
209 encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
210 instance_names: names,
211 }),
212 )
213 .await?;
214 Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
215 }
216
217 pub async fn create_instance(
218 &self,
219 service_guid: ServiceGuid,
220 create_payload: Option<Vec<u8>>,
221 options: BTreeMap<String, String>,
222 ) -> Result<InstanceId, RuntimeError> {
223 let response = self
224 .call(
225 activation_instance_id(),
226 MethodId::new(CREATE_INSTANCE_METHOD_ID),
227 encode_create_instance_request(&CreateInstanceRequest {
228 service_guid,
229 create_payload,
230 options,
231 }),
232 )
233 .await?;
234 Ok(decode_create_instance_response(&response)?.instance_id)
235 }
236
237 pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
238 let response = self
239 .call(
240 activation_instance_id(),
241 MethodId::new(RELEASE_INSTANCE_METHOD_ID),
242 encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
243 )
244 .await?;
245 decode_release_instance_response(&response)?;
246 Ok(())
247 }
248
249 pub async fn list_instances(
250 &self,
251 service_guid: Option<ServiceGuid>,
252 ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
253 let response = self
254 .call(
255 activation_instance_id(),
256 MethodId::new(LIST_INSTANCES_METHOD_ID),
257 encode_list_instances_request(&ListInstancesRequest { service_guid }),
258 )
259 .await?;
260 Ok(decode_list_instances_response(&response)?.instances)
261 }
262
263 pub fn subscribe_notifications(
264 &self,
265 instance_id_filter: Option<InstanceId>,
266 notification_id_filter: Option<u32>,
267 ) -> mpsc::UnboundedReceiver<Notification> {
268 let mut source = self.inner.notifications.subscribe();
269 let (tx, rx) = mpsc::unbounded_channel();
270 tokio::spawn(async move {
271 loop {
272 let Ok(notification) = source.recv().await else {
273 break;
274 };
275 let instance_matches = instance_id_filter
276 .is_none_or(|expected| notification.instance_id == Some(expected));
277 let notification_matches = notification_id_filter
278 .is_none_or(|expected| notification.notification_id.get() == expected);
279 if instance_matches && notification_matches && tx.send(notification).is_err() {
280 break;
281 }
282 }
283 });
284 rx
285 }
286
287 pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
288 self.inner
289 .sender
290 .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
291 reason_code: 0,
292 message: Some(message.into()),
293 }))
294 .await
295 .map_err(|err| {
296 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
297 })
298 }
299}
300
301fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
302 tokio::spawn(async move {
303 loop {
304 let envelope = match receiver.recv_envelope().await {
305 Ok(Some(envelope)) => envelope,
306 Ok(None) => {
307 fail_pending(
308 &inner,
309 RuntimeError::transport(
310 RuntimeErrorCode::InternalRuntimeError,
311 "server disconnected",
312 ),
313 )
314 .await;
315 break;
316 }
317 Err(err) => {
318 fail_pending(
319 &inner,
320 RuntimeError::transport(
321 RuntimeErrorCode::InternalRuntimeError,
322 err.to_string(),
323 ),
324 )
325 .await;
326 break;
327 }
328 };
329 match envelope {
330 Envelope::ResponseOk(response) => {
331 complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
332 }
333 Envelope::ResponseError(response) => {
334 complete_pending(
335 &inner,
336 response.request_id.get(),
337 Err(RuntimeError::new(
338 runtime_error_code(response.error_code),
339 error_kind(response.error_kind),
340 response.error_message.unwrap_or_default(),
341 )),
342 )
343 .await;
344 }
345 Envelope::Notification(notification) => {
346 let _ = inner.notifications.send(notification);
347 }
348 _ => {
349 fail_pending(
350 &inner,
351 RuntimeError::protocol(
352 RuntimeErrorCode::InvalidEnvelope,
353 "client received invalid envelope kind",
354 ),
355 )
356 .await;
357 break;
358 }
359 }
360 }
361 });
362}
363
364async fn complete_pending(
365 inner: &ClientInner,
366 request_id: u64,
367 result: Result<Value, RuntimeError>,
368) {
369 if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
370 let _ = sender.send(result);
371 }
372}
373
374async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
375 let pending = std::mem::take(&mut *inner.pending.lock().await);
376 for (_, sender) in pending {
377 let _ = sender.send(Err(error.clone()));
378 }
379}
380
381fn client_capabilities() -> CapabilityFlags {
382 CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
383 | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
384 | CapabilityFlags::SERVICE_ACTIVATION
385 | CapabilityFlags::GOODBYE
386}
387
388fn runtime_error_code(value: i32) -> RuntimeErrorCode {
389 match value {
390 1001 => RuntimeErrorCode::UnknownMessageKind,
391 1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
392 1003 => RuntimeErrorCode::InvalidEnvelope,
393 1004 => RuntimeErrorCode::InvalidRequestId,
394 1005 => RuntimeErrorCode::InvalidInstanceId,
395 1006 => RuntimeErrorCode::InstanceNotFound,
396 1007 => RuntimeErrorCode::MethodNotFound,
397 1008 => RuntimeErrorCode::NotificationNotFound,
398 1009 => RuntimeErrorCode::PayloadDecodeFailed,
399 1010 => RuntimeErrorCode::PayloadEncodeFailed,
400 1011 => RuntimeErrorCode::ServiceActivationNotSupported,
401 1012 => RuntimeErrorCode::ServiceGuidNotFound,
402 1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
403 1014 => RuntimeErrorCode::RequestTimeout,
404 1015 => RuntimeErrorCode::UnsupportedCapability,
405 1016 => RuntimeErrorCode::BusinessErrorDeclared,
406 1017 => RuntimeErrorCode::DuplicateRequestId,
407 1018 => RuntimeErrorCode::RequestCancelUnsupported,
408 1019 => RuntimeErrorCode::AccessDenied,
409 _ => RuntimeErrorCode::InternalRuntimeError,
410 }
411}
412
413fn error_kind(value: u8) -> ErrorKind {
414 match value {
415 1 => ErrorKind::Transport,
416 2 => ErrorKind::Protocol,
417 3 => ErrorKind::Runtime,
418 4 => ErrorKind::Business,
419 5 => ErrorKind::Timeout,
420 6 => ErrorKind::Cancelled,
421 _ => ErrorKind::Runtime,
422 }
423}