1use std::collections::{BTreeMap, HashMap};
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use rmpv::Value;
8use rpc_runtime_activation::{
9 ACTIVATION_INSTANCE_ID_VALUE, ActivationMode, CREATE_INSTANCE_METHOD_ID,
10 CreateInstanceResponse, InstanceDescriptor, LIST_INSTANCES_METHOD_ID, ListInstancesResponse,
11 RELEASE_INSTANCE_METHOD_ID, RESOLVE_INSTANCE_IDS_METHOD_ID, ReleaseInstanceResponse,
12 ResolveInstanceIdsResponse, activation_instance_id, activation_service_guid,
13 decode_create_instance_request, decode_list_instances_request, decode_release_instance_request,
14 decode_resolve_instance_ids_request, encode_create_instance_response,
15 encode_list_instances_response, encode_release_instance_response,
16 encode_resolve_instance_ids_response,
17};
18use rpc_runtime_core::{
19 CapabilityFlags, Envelope, HelloAck, InstanceId, MethodId, Notification,
20 RUNTIME_PROTOCOL_VERSION, Request, RequestId, ResponseError, ResponseOk, Role, ServiceGuid,
21};
22use rpc_runtime_errors::{RuntimeError, RuntimeErrorCode};
23use rpc_runtime_transport::{RpcConnection, RpcListener, RpcReceiver, RpcSender};
24use tokio::sync::RwLock;
25
26pub type HandlerFuture = Pin<Box<dyn Future<Output = Result<Value, RuntimeError>> + Send>>;
27
28pub trait RpcServiceHandler: Send + Sync {
29 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture;
30}
31
32impl<F> RpcServiceHandler for F
33where
34 F: Send + Sync + 'static,
35 F: Fn(RpcCallContext, MethodId, Value) -> HandlerFuture,
36{
37 fn call(&self, ctx: RpcCallContext, method_id: MethodId, payload: Value) -> HandlerFuture {
38 self(ctx, method_id, payload)
39 }
40}
41
42pub type FactoryFuture =
43 Pin<Box<dyn Future<Output = Result<Arc<dyn RpcServiceHandler>, RuntimeError>> + Send>>;
44
45pub trait RpcServiceFactory: Send + Sync {
46 fn create(
47 &self,
48 ctx: RpcCallContext,
49 create_payload: Option<Vec<u8>>,
50 options: BTreeMap<String, String>,
51 ) -> FactoryFuture;
52}
53
54impl<F> RpcServiceFactory for F
55where
56 F: Send + Sync + 'static,
57 F: Fn(RpcCallContext, Option<Vec<u8>>, BTreeMap<String, String>) -> FactoryFuture,
58{
59 fn create<'a>(
60 &self,
61 ctx: RpcCallContext,
62 create_payload: Option<Vec<u8>>,
63 options: BTreeMap<String, String>,
64 ) -> FactoryFuture {
65 self(ctx, create_payload, options)
66 }
67}
68
69#[derive(Clone)]
70pub struct RpcCallContext {
71 connection_id: u64,
72 instance_id: InstanceId,
73 sender: RpcSender,
74}
75
76impl RpcCallContext {
77 pub fn connection_id(&self) -> u64 {
78 self.connection_id
79 }
80
81 pub fn instance_id(&self) -> InstanceId {
82 self.instance_id
83 }
84
85 pub async fn notify(
86 &self,
87 instance_id: Option<InstanceId>,
88 notification_id: u32,
89 payload: Value,
90 ) -> Result<(), RuntimeError> {
91 self.sender
92 .send_envelope(&Envelope::Notification(Notification {
93 instance_id,
94 notification_id: rpc_runtime_core::NotificationId::new(notification_id),
95 payload,
96 }))
97 .await
98 .map_err(|err| {
99 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
100 })
101 }
102
103 pub async fn notify_bound(
104 &self,
105 notification_id: u32,
106 payload: Value,
107 ) -> Result<(), RuntimeError> {
108 self.notify(Some(self.instance_id), notification_id, payload)
109 .await
110 }
111}
112
113#[derive(Clone)]
114pub struct RpcServer {
115 state: Arc<ServerState>,
116}
117
118pub struct RpcServerBuilder {
119 state: ServerState,
120}
121
122impl RpcServerBuilder {
123 pub fn new() -> Self {
124 let mut state = ServerState::new();
125 state.insert_activation_instance();
126 Self { state }
127 }
128
129 pub fn register_named_instance(
130 &mut self,
131 name: impl Into<String>,
132 service_guid: ServiceGuid,
133 methods: impl IntoIterator<Item = u32>,
134 handler: Arc<dyn RpcServiceHandler>,
135 ) -> InstanceId {
136 self.state.insert_instance(NewInstance {
137 service_guid,
138 name: Some(name.into()),
139 activation_mode: ActivationMode::NamedPrecreated,
140 releasable: false,
141 owner_connection_id: None,
142 methods: methods.into_iter().collect(),
143 handler,
144 })
145 }
146
147 pub fn register_singleton(
148 &mut self,
149 service_guid: ServiceGuid,
150 methods: impl IntoIterator<Item = u32>,
151 handler: Arc<dyn RpcServiceHandler>,
152 ) -> InstanceId {
153 self.state.insert_instance(NewInstance {
154 service_guid,
155 name: None,
156 activation_mode: ActivationMode::Singleton,
157 releasable: false,
158 owner_connection_id: None,
159 methods: methods.into_iter().collect(),
160 handler,
161 })
162 }
163
164 pub fn register_factory(
165 &mut self,
166 service_guid: ServiceGuid,
167 methods: impl IntoIterator<Item = u32>,
168 factory: Arc<dyn RpcServiceFactory>,
169 ) {
170 self.state.factories.insert(
171 service_guid.get(),
172 FactoryEntry {
173 methods: methods.into_iter().collect(),
174 factory,
175 },
176 );
177 }
178
179 pub fn build(self) -> RpcServer {
180 RpcServer {
181 state: Arc::new(self.state),
182 }
183 }
184}
185
186impl Default for RpcServerBuilder {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192impl RpcServer {
193 pub async fn serve_connection<C>(&self, connection: C) -> Result<(), RuntimeError>
194 where
195 C: Into<RpcConnection>,
196 {
197 let connection_id = self
198 .state
199 .next_connection_id
200 .fetch_add(1, Ordering::Relaxed);
201 let (sender, mut receiver) = connection.into().split();
202 self.perform_handshake(&sender, &mut receiver).await?;
203
204 while let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
205 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
206 })? {
207 match envelope {
208 Envelope::Request(request) => {
209 let state = Arc::clone(&self.state);
210 let sender = sender.clone();
211 tokio::spawn(async move {
212 let request_id = request.request_id;
213 let response =
214 dispatch_request(state, sender.clone(), connection_id, request).await;
215 let envelope = match response {
216 Ok(payload) => Envelope::ResponseOk(ResponseOk {
217 request_id,
218 payload,
219 }),
220 Err(error) => runtime_error_response(request_id, error),
221 };
222 let _ = sender.send_envelope(&envelope).await;
223 });
224 }
225 Envelope::Goodbye(_) => break,
226 _ => {
227 return Err(RuntimeError::protocol(
228 RuntimeErrorCode::InvalidEnvelope,
229 "server expected request envelope",
230 ));
231 }
232 }
233 }
234
235 self.state.cleanup_connection(connection_id).await;
236 Ok(())
237 }
238
239 pub async fn serve_listener<L>(&self, mut listener: L) -> Result<(), RuntimeError>
240 where
241 L: RpcListener + Send,
242 {
243 loop {
244 let connection = listener.accept().await.map_err(|err| {
245 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
246 })?;
247 let server = self.clone();
248 tokio::spawn(async move {
249 let _ = server.serve_connection(connection).await;
250 });
251 }
252 }
253
254 pub fn spawn_listener<L>(
255 &self,
256 listener: L,
257 ) -> tokio::task::JoinHandle<Result<(), RuntimeError>>
258 where
259 L: RpcListener + Send + 'static,
260 {
261 let server = self.clone();
262 tokio::spawn(async move { server.serve_listener(listener).await })
263 }
264
265 async fn perform_handshake(
266 &self,
267 sender: &RpcSender,
268 receiver: &mut RpcReceiver,
269 ) -> Result<(), RuntimeError> {
270 let Some(envelope) = receiver.recv_envelope().await.map_err(|err| {
271 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
272 })?
273 else {
274 return Err(RuntimeError::transport(
275 RuntimeErrorCode::InternalRuntimeError,
276 "client disconnected during handshake",
277 ));
278 };
279 let Envelope::Hello(hello) = envelope else {
280 return Err(RuntimeError::protocol(
281 RuntimeErrorCode::InvalidEnvelope,
282 "expected HELLO during handshake",
283 ));
284 };
285 if hello.protocol_version != RUNTIME_PROTOCOL_VERSION || hello.role != Role::Client {
286 return Err(RuntimeError::protocol(
287 RuntimeErrorCode::UnsupportedProtocolVersion,
288 "unsupported client handshake",
289 ));
290 }
291 sender
292 .send_envelope(&Envelope::HelloAck(HelloAck {
293 protocol_version: RUNTIME_PROTOCOL_VERSION,
294 accepted_capability_bits: server_capabilities() & hello.capability_bits,
295 max_message_size: hello.max_message_size,
296 options: Vec::new(),
297 }))
298 .await
299 .map_err(|err| {
300 RuntimeError::transport(RuntimeErrorCode::InternalRuntimeError, err.to_string())
301 })
302 }
303
304 pub async fn list_instances(&self) -> Vec<InstanceDescriptor> {
305 self.state.list_instances(None).await
306 }
307}
308
309async fn dispatch_request(
310 state: Arc<ServerState>,
311 sender: RpcSender,
312 connection_id: u64,
313 request: Request,
314) -> Result<Value, RuntimeError> {
315 if request.instance_id.get() == ACTIVATION_INSTANCE_ID_VALUE {
316 return dispatch_activation(state, sender, connection_id, request).await;
317 }
318
319 let instance = state.get_instance(request.instance_id).await?;
320 if !instance.methods.contains(&request.method_id.get()) {
321 return Err(RuntimeError::runtime(
322 RuntimeErrorCode::MethodNotFound,
323 format!("method id `{}` was not found", request.method_id.get()),
324 ));
325 }
326 let ctx = RpcCallContext {
327 connection_id,
328 instance_id: request.instance_id,
329 sender,
330 };
331 instance
332 .handler
333 .call(ctx, request.method_id, request.payload)
334 .await
335}
336
337async fn dispatch_activation(
338 state: Arc<ServerState>,
339 sender: RpcSender,
340 connection_id: u64,
341 request: Request,
342) -> Result<Value, RuntimeError> {
343 let ctx = RpcCallContext {
344 connection_id,
345 instance_id: request.instance_id,
346 sender,
347 };
348 match request.method_id.get() {
349 RESOLVE_INSTANCE_IDS_METHOD_ID => {
350 let request = decode_resolve_instance_ids_request(&request.payload)?;
351 let ids = state.resolve_instance_ids(&request.instance_names).await;
352 Ok(encode_resolve_instance_ids_response(
353 &ResolveInstanceIdsResponse { instance_ids: ids },
354 ))
355 }
356 CREATE_INSTANCE_METHOD_ID => {
357 let request = decode_create_instance_request(&request.payload)?;
358 let factory = state.get_factory(request.service_guid).ok_or_else(|| {
359 RuntimeError::runtime(
360 RuntimeErrorCode::ServiceGuidNotFound,
361 "service factory was not found",
362 )
363 })?;
364 let handler = factory
365 .factory
366 .create(ctx, request.create_payload, request.options)
367 .await?;
368 let instance_id = state
369 .insert_client_instance(
370 request.service_guid,
371 connection_id,
372 factory.methods.clone(),
373 handler,
374 )
375 .await;
376 Ok(encode_create_instance_response(&CreateInstanceResponse {
377 instance_id,
378 }))
379 }
380 RELEASE_INSTANCE_METHOD_ID => {
381 let request = decode_release_instance_request(&request.payload)?;
382 state
383 .release_instance(connection_id, request.instance_id)
384 .await?;
385 Ok(encode_release_instance_response(&ReleaseInstanceResponse))
386 }
387 LIST_INSTANCES_METHOD_ID => {
388 let request = decode_list_instances_request(&request.payload)?;
389 let instances = state.list_instances(request.service_guid).await;
390 Ok(encode_list_instances_response(&ListInstancesResponse {
391 instances,
392 }))
393 }
394 _ => Err(RuntimeError::runtime(
395 RuntimeErrorCode::MethodNotFound,
396 "activation method was not found",
397 )),
398 }
399}
400
401fn runtime_error_response(request_id: RequestId, error: RuntimeError) -> Envelope {
402 Envelope::ResponseError(ResponseError {
403 request_id,
404 error_code: error.code.as_i32(),
405 error_kind: error.kind.as_u8(),
406 error_message: Some(error.message),
407 error_details: Value::Nil,
408 })
409}
410
411fn server_capabilities() -> CapabilityFlags {
412 CapabilityFlags::SERVER_TO_CLIENT_NOTIFICATION
413 | CapabilityFlags::NAMED_INSTANCE_RESOLUTION
414 | CapabilityFlags::SERVICE_ACTIVATION
415 | CapabilityFlags::GOODBYE
416}
417
418struct ServerState {
419 next_connection_id: AtomicU64,
420 next_instance_id: AtomicU64,
421 instances: RwLock<HashMap<u64, InstanceEntry>>,
422 names: RwLock<HashMap<String, u64>>,
423 factories: HashMap<uuid::Uuid, FactoryEntry>,
424}
425
426impl ServerState {
427 fn new() -> Self {
428 Self {
429 next_connection_id: AtomicU64::new(1),
430 next_instance_id: AtomicU64::new(2),
431 instances: RwLock::new(HashMap::new()),
432 names: RwLock::new(HashMap::new()),
433 factories: HashMap::new(),
434 }
435 }
436
437 fn insert_activation_instance(&mut self) {
438 self.instances.get_mut().insert(
439 ACTIVATION_INSTANCE_ID_VALUE,
440 InstanceEntry {
441 instance_id: activation_instance_id(),
442 service_guid: activation_service_guid(),
443 instance_name: Some("rpc.runtime.Activation".to_string()),
444 activation_mode: ActivationMode::Singleton,
445 releasable: false,
446 owner_connection_id: None,
447 methods: vec![
448 RESOLVE_INSTANCE_IDS_METHOD_ID,
449 CREATE_INSTANCE_METHOD_ID,
450 RELEASE_INSTANCE_METHOD_ID,
451 LIST_INSTANCES_METHOD_ID,
452 ],
453 handler: Arc::new(ActivationMarker),
454 },
455 );
456 }
457
458 fn insert_instance(&mut self, instance: NewInstance) -> InstanceId {
459 let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
460 let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
461 if let Some(name) = &instance.name {
462 self.names.get_mut().insert(name.clone(), id);
463 }
464 self.instances.get_mut().insert(
465 id,
466 InstanceEntry {
467 instance_id,
468 service_guid: instance.service_guid,
469 instance_name: instance.name,
470 activation_mode: instance.activation_mode,
471 releasable: instance.releasable,
472 owner_connection_id: instance.owner_connection_id,
473 methods: instance.methods,
474 handler: instance.handler,
475 },
476 );
477 instance_id
478 }
479
480 async fn insert_client_instance(
481 &self,
482 service_guid: ServiceGuid,
483 connection_id: u64,
484 methods: Vec<u32>,
485 handler: Arc<dyn RpcServiceHandler>,
486 ) -> InstanceId {
487 let id = self.next_instance_id.fetch_add(1, Ordering::Relaxed);
488 let instance_id = InstanceId::new(id).expect("generated instance id is non-zero");
489 self.instances.write().await.insert(
490 id,
491 InstanceEntry {
492 instance_id,
493 service_guid,
494 instance_name: None,
495 activation_mode: ActivationMode::Instantiable,
496 releasable: true,
497 owner_connection_id: Some(connection_id),
498 methods,
499 handler,
500 },
501 );
502 instance_id
503 }
504
505 async fn get_instance(&self, instance_id: InstanceId) -> Result<InstanceEntry, RuntimeError> {
506 self.instances
507 .read()
508 .await
509 .get(&instance_id.get())
510 .cloned()
511 .ok_or_else(|| {
512 RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
513 })
514 }
515
516 fn get_factory(&self, service_guid: ServiceGuid) -> Option<FactoryEntry> {
517 self.factories.get(&service_guid.get()).cloned()
518 }
519
520 async fn resolve_instance_ids(&self, names: &[String]) -> Vec<u64> {
521 let index = self.names.read().await;
522 names
523 .iter()
524 .map(|name| index.get(name).copied().unwrap_or(0))
525 .collect()
526 }
527
528 async fn release_instance(
529 &self,
530 connection_id: u64,
531 instance_id: InstanceId,
532 ) -> Result<(), RuntimeError> {
533 let mut instances = self.instances.write().await;
534 let entry = instances.get(&instance_id.get()).ok_or_else(|| {
535 RuntimeError::runtime(RuntimeErrorCode::InstanceNotFound, "instance was not found")
536 })?;
537 if !entry.releasable {
538 return Err(RuntimeError::runtime(
539 RuntimeErrorCode::InstanceReleaseNotAllowed,
540 "instance is not releasable",
541 ));
542 }
543 if entry.owner_connection_id != Some(connection_id) {
544 return Err(RuntimeError::runtime(
545 RuntimeErrorCode::AccessDenied,
546 "instance is owned by another connection",
547 ));
548 }
549 instances.remove(&instance_id.get());
550 Ok(())
551 }
552
553 async fn cleanup_connection(&self, connection_id: u64) {
554 self.instances
555 .write()
556 .await
557 .retain(|_, entry| entry.owner_connection_id != Some(connection_id));
558 }
559
560 async fn list_instances(&self, service_guid: Option<ServiceGuid>) -> Vec<InstanceDescriptor> {
561 let mut values = self
562 .instances
563 .read()
564 .await
565 .values()
566 .filter(|entry| service_guid.is_none_or(|guid| guid == entry.service_guid))
567 .map(InstanceEntry::descriptor)
568 .collect::<Vec<_>>();
569 values.sort_by_key(|entry| entry.instance_id.get());
570 values
571 }
572}
573
574struct NewInstance {
575 service_guid: ServiceGuid,
576 name: Option<String>,
577 activation_mode: ActivationMode,
578 releasable: bool,
579 owner_connection_id: Option<u64>,
580 methods: Vec<u32>,
581 handler: Arc<dyn RpcServiceHandler>,
582}
583
584#[derive(Clone)]
585struct InstanceEntry {
586 instance_id: InstanceId,
587 service_guid: ServiceGuid,
588 instance_name: Option<String>,
589 activation_mode: ActivationMode,
590 releasable: bool,
591 owner_connection_id: Option<u64>,
592 methods: Vec<u32>,
593 handler: Arc<dyn RpcServiceHandler>,
594}
595
596impl InstanceEntry {
597 fn descriptor(&self) -> InstanceDescriptor {
598 InstanceDescriptor {
599 instance_id: self.instance_id,
600 instance_name: self.instance_name.clone(),
601 service_guid: self.service_guid,
602 activation_mode: self.activation_mode,
603 releasable: self.releasable,
604 }
605 }
606}
607
608#[derive(Clone)]
609struct FactoryEntry {
610 methods: Vec<u32>,
611 factory: Arc<dyn RpcServiceFactory>,
612}
613
614struct ActivationMarker;
615
616impl RpcServiceHandler for ActivationMarker {
617 fn call(&self, _: RpcCallContext, _: MethodId, _: Value) -> HandlerFuture {
618 Box::pin(async {
619 Err(RuntimeError::runtime(
620 RuntimeErrorCode::InternalRuntimeError,
621 "activation marker should not be dispatched directly",
622 ))
623 })
624 }
625}