1use std::collections::{BTreeMap, HashMap};
2use std::sync::Arc;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::time::Duration;
5
6use rmpv::Value;
7use rpc_runtime_activation::{
8 CREATE_INSTANCE_METHOD_ID, CreateInstanceRequest, LIST_INSTANCES_METHOD_ID,
9 ListInstancesRequest, RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID,
10 ReleaseInstanceRequest, ResolveInstanceIdsRequest, activation_instance_id,
11 decode_create_instance_response, decode_list_instances_response,
12 decode_release_instance_response, decode_resolve_instance_ids_response,
13 encode_create_instance_request, encode_list_instances_request, encode_release_instance_request,
14 encode_resolve_instance_ids_request,
15};
16use rpc_runtime_core::{
17 CapabilityFlags, Envelope, Hello, InstanceId, MethodId, Notification, RUNTIME_PROTOCOL_VERSION,
18 Request, RequestId, Role, ServiceGuid,
19};
20use rpc_runtime_errors::{ErrorKind, RuntimeError, RuntimeErrorCode};
21use rpc_runtime_transport::{RpcConnection, RpcReceiver, RpcSender};
22use rpc_runtime_transport_ipc::{FrameConfig, IpcConnection, IpcEndpoint};
23use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
24
25#[derive(Clone)]
26pub struct RpcClient {
27 inner: Arc<ClientInner>,
28}
29
30struct ClientInner {
31 sender: RpcSender,
32 next_request_id: AtomicU64,
33 pending: Mutex<HashMap<u64, oneshot::Sender<Result<Value, RuntimeError>>>>,
34 notifications: broadcast::Sender<Notification>,
35}
36
37impl RpcClient {
38 pub async fn connect(endpoint: IpcEndpoint, config: FrameConfig) -> Result<Self, RuntimeError> {
39 let connection = IpcConnection::connect(endpoint, config)
40 .await
41 .map_err(|err| {
42 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
43 })?;
44 Self::from_connection(connection).await
45 }
46
47 pub async fn from_connection<C>(connection: C) -> Result<Self, RuntimeError>
48 where
49 C: Into<RpcConnection>,
50 {
51 let (sender, mut receiver) = connection.into().split();
52 sender
53 .send_envelope(&Envelope::Hello(Hello {
54 protocol_version: RUNTIME_PROTOCOL_VERSION,
55 role: Role::Client,
56 capability_bits: client_capabilities(),
57 max_message_size: rpc_runtime_codec_msgpack::DEFAULT_MAX_MESSAGE_SIZE as u64,
58 options: Vec::new(),
59 }))
60 .await
61 .map_err(|err| {
62 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
63 })?;
64
65 let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
66 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
67 })?
68 else {
69 return Err(RuntimeError::transport(
70 RuntimeErrorCode::InternalRuntimeError,
71 "server disconnected during handshake",
72 ));
73 };
74 let Envelope::HelloAck(ack) = envelope else {
75 return Err(RuntimeError::protocol(
76 RuntimeErrorCode::InvalidEnvelope,
77 "expected HELLO_ACK during handshake",
78 ));
79 };
80 if ack.protocol_version != RUNTIME_PROTOCOL_VERSION {
81 return Err(RuntimeError::protocol(
82 RuntimeErrorCode::UnsupportedProtocolVersion,
83 "server returned unsupported protocol version",
84 ));
85 }
86
87 let (notifications, _) = broadcast::channel(128);
88 let inner = Arc::new(ClientInner {
89 sender,
90 next_request_id: AtomicU64::new(1),
91 pending: Mutex::new(HashMap::new()),
92 notifications,
93 });
94 spawn_receive_loop(Arc::clone(&inner), receiver);
95 Ok(Self { inner })
96 }
97
98 pub async fn call(
99 &self,
100 instance_id: InstanceId,
101 method_id: MethodId,
102 payload: Value,
103 ) -> Result<Value, RuntimeError> {
104 let request_id = self.inner.next_request_id.fetch_add(1, Ordering::Relaxed);
105 let (tx, rx) = oneshot::channel();
106 self.inner.pending.lock().await.insert(request_id, tx);
107
108 let send_result = self
109 .inner
110 .sender
111 .send_envelope(&Envelope::Request(Request {
112 request_id: RequestId::new(request_id),
113 instance_id,
114 method_id,
115 payload,
116 }))
117 .await;
118 if let Err(err) = send_result {
119 self.inner.pending.lock().await.remove(&request_id);
120 return Err(RuntimeError::transport(
121 RuntimeErrorCode::InternalRuntimeError,
122 err.to_string(),
123 ));
124 }
125
126 rx.await.map_err(|_| {
127 RuntimeError::transport(
128 RuntimeErrorCode::InternalRuntimeError,
129 "response channel closed before request completed",
130 )
131 })?
132 }
133
134 pub async fn call_timeout(
135 &self,
136 instance_id: InstanceId,
137 method_id: MethodId,
138 payload: Value,
139 timeout: Duration,
140 ) -> Result<Value, RuntimeError> {
141 tokio::time::timeout(timeout, self.call(instance_id, method_id, payload))
142 .await
143 .map_err(|_| {
144 RuntimeError::runtime(RuntimeErrorCode::RequestTimeout, "request timed out")
145 })?
146 }
147
148 pub async fn resolve_instance_ids(&self, names: Vec<String>) -> Result<Vec<u64>, RuntimeError> {
149 let response = self
150 .call(
151 activation_instance_id(),
152 MethodId::new(RESOLVE_INSTANCE_IDS_METHOD_ID),
153 encode_resolve_instance_ids_request(&ResolveInstanceIdsRequest {
154 instance_names: names,
155 }),
156 )
157 .await?;
158 Ok(decode_resolve_instance_ids_response(&response)?.instance_ids)
159 }
160
161 pub async fn create_instance(
162 &self,
163 service_guid: ServiceGuid,
164 create_payload: Option<Vec<u8>>,
165 options: BTreeMap<String, String>,
166 ) -> Result<InstanceId, RuntimeError> {
167 let response = self
168 .call(
169 activation_instance_id(),
170 MethodId::new(CREATE_INSTANCE_METHOD_ID),
171 encode_create_instance_request(&CreateInstanceRequest {
172 service_guid,
173 create_payload,
174 options,
175 }),
176 )
177 .await?;
178 Ok(decode_create_instance_response(&response)?.instance_id)
179 }
180
181 pub async fn release_instance(&self, instance_id: InstanceId) -> Result<(), RuntimeError> {
182 let response = self
183 .call(
184 activation_instance_id(),
185 MethodId::new(RELEASE_INSTANCE_METHOD_ID),
186 encode_release_instance_request(&ReleaseInstanceRequest { instance_id }),
187 )
188 .await?;
189 decode_release_instance_response(&response)?;
190 Ok(())
191 }
192
193 pub async fn list_instances(
194 &self,
195 service_guid: Option<ServiceGuid>,
196 ) -> Result<Vec<rpc_runtime_activation::InstanceDescriptor>, RuntimeError> {
197 let response = self
198 .call(
199 activation_instance_id(),
200 MethodId::new(LIST_INSTANCES_METHOD_ID),
201 encode_list_instances_request(&ListInstancesRequest { service_guid }),
202 )
203 .await?;
204 Ok(decode_list_instances_response(&response)?.instances)
205 }
206
207 pub fn subscribe_notifications(
208 &self,
209 instance_id_filter: Option<InstanceId>,
210 notification_id_filter: Option<u32>,
211 ) -> mpsc::UnboundedReceiver<Notification> {
212 let mut source = self.inner.notifications.subscribe();
213 let (tx, rx) = mpsc::unbounded_channel();
214 tokio::spawn(async move {
215 loop {
216 let Ok(notification) = source.recv().await else {
217 break;
218 };
219 let instance_matches = instance_id_filter
220 .is_none_or(|expected| notification.instance_id == Some(expected));
221 let notification_matches = notification_id_filter
222 .is_none_or(|expected| notification.notification_id.get() == expected);
223 if instance_matches && notification_matches && tx.send(notification).is_err() {
224 break;
225 }
226 }
227 });
228 rx
229 }
230
231 pub async fn goodbye(&self, message: impl Into<String>) -> Result<(), RuntimeError> {
232 self.inner
233 .sender
234 .send_envelope(&Envelope::Goodbye(rpc_runtime_core::Goodbye {
235 reason_code: 0,
236 message: Some(message.into()),
237 }))
238 .await
239 .map_err(|err| {
240 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
241 })
242 }
243}
244
245fn spawn_receive_loop(inner: Arc<ClientInner>, mut receiver: RpcReceiver) {
246 tokio::spawn(async move {
247 loop {
248 let envelope = match receiver.recv_envelope().await {
249 Ok(Some(envelope)) => envelope,
250 Ok(None) => {
251 fail_pending(
252 &inner,
253 RuntimeError::transport(
254 RuntimeErrorCode::InternalRuntimeError,
255 "server disconnected",
256 ),
257 )
258 .await;
259 break;
260 }
261 Err(err) => {
262 fail_pending(
263 &inner,
264 RuntimeError::transport(
265 RuntimeErrorCode::InternalRuntimeError,
266 err.to_string(),
267 ),
268 )
269 .await;
270 break;
271 }
272 };
273 match envelope {
274 Envelope::ResponseOk(response) => {
275 complete_pending(&inner, response.request_id.get(), Ok(response.payload)).await;
276 }
277 Envelope::ResponseError(response) => {
278 complete_pending(
279 &inner,
280 response.request_id.get(),
281 Err(RuntimeError::new(
282 runtime_error_code(response.error_code),
283 error_kind(response.error_kind),
284 response.error_message.unwrap_or_default(),
285 )),
286 )
287 .await;
288 }
289 Envelope::Notification(notification) => {
290 let _ = inner.notifications.send(notification);
291 }
292 _ => {
293 fail_pending(
294 &inner,
295 RuntimeError::protocol(
296 RuntimeErrorCode::InvalidEnvelope,
297 "client received invalid envelope kind",
298 ),
299 )
300 .await;
301 break;
302 }
303 }
304 }
305 });
306}
307
308async fn complete_pending(
309 inner: &ClientInner,
310 request_id: u64,
311 result: Result<Value, RuntimeError>,
312) {
313 if let Some(sender) = inner.pending.lock().await.remove(&request_id) {
314 let _ = sender.send(result);
315 }
316}
317
318async fn fail_pending(inner: &ClientInner, error: RuntimeError) {
319 let pending = std::mem::take(&mut *inner.pending.lock().await);
320 for (_, sender) in pending {
321 let _ = sender.send(Err(error.clone()));
322 }
323}
324
325fn client_capabilities() -> CapabilityFlags {
326 CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
327 | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
328 | CapabilityFlags::SERVICE_ACTIVATION
329 | CapabilityFlags::GOODBYE
330}
331
332fn runtime_error_code(value: i32) -> RuntimeErrorCode {
333 match value {
334 1001 => RuntimeErrorCode::UnknownMessageKind,
335 1002 => RuntimeErrorCode::UnsupportedProtocolVersion,
336 1003 => RuntimeErrorCode::InvalidEnvelope,
337 1004 => RuntimeErrorCode::InvalidRequestId,
338 1005 => RuntimeErrorCode::InvalidInstanceId,
339 1006 => RuntimeErrorCode::InstanceNotFound,
340 1007 => RuntimeErrorCode::MethodNotFound,
341 1008 => RuntimeErrorCode::NotificationNotFound,
342 1009 => RuntimeErrorCode::PayloadDecodeFailed,
343 1010 => RuntimeErrorCode::PayloadEncodeFailed,
344 1011 => RuntimeErrorCode::ServiceActivationNotSupported,
345 1012 => RuntimeErrorCode::ServiceGuidNotFound,
346 1013 => RuntimeErrorCode::InstanceReleaseNotAllowed,
347 1014 => RuntimeErrorCode::RequestTimeout,
348 1015 => RuntimeErrorCode::UnsupportedCapability,
349 1016 => RuntimeErrorCode::BusinessErrorDeclared,
350 1017 => RuntimeErrorCode::DuplicateRequestId,
351 1018 => RuntimeErrorCode::RequestCancelUnsupported,
352 1019 => RuntimeErrorCode::AccessDenied,
353 _ => RuntimeErrorCode::InternalRuntimeError,
354 }
355}
356
357fn error_kind(value: u8) -> ErrorKind {
358 match value {
359 1 => ErrorKind::Transport,
360 2 => ErrorKind::Protocol,
361 3 => ErrorKind::Runtime,
362 4 => ErrorKind::Business,
363 5 => ErrorKind::Timeout,
364 6 => ErrorKind::Cancelled,
365 _ => ErrorKind::Runtime,
366 }
367}