std_modrpc/role_impls/
request_server.rs

1use crate::proto::{
2    RequestLazy,
3    RequestInitState,
4    RequestServerConfig,
5    Response,
6    ResponseGen,
7};
8use modrpc::RoleSetup;
9
10use crate::request_tracker::{RequestTracker, get_request_tracker};
11
12pub struct RequestServer<Req, Resp> {
13    name: &'static str,
14    worker_id: u16,
15    hooks: crate::RequestServerHooks<Req, Resp>,
16    tracker: RequestTracker,
17    tracing_enabled: bool,
18}
19
20pub struct RequestServerBuilder<Req, Resp> {
21    pub name: &'static str,
22    pub hooks: crate::RequestServerHooks<Req, Resp>,
23    pub stubs: crate::RequestServerStubs<Req, Resp>,
24    pub init: RequestInitState,
25}
26
27impl<
28    Req: mproto::Owned,
29    Resp: mproto::Owned,
30> RequestServerBuilder<Req, Resp> {
31    pub fn new(
32        name: &'static str,
33        hooks: crate::RequestServerHooks<Req, Resp>,
34        stubs: crate::RequestServerStubs<Req, Resp>,
35        _config: &RequestServerConfig,
36        init: RequestInitState,
37    ) -> Self {
38        Self { name, hooks, stubs, init }
39    }
40
41    pub fn create_handle(&self, setup: &RoleSetup) -> RequestServer<Req, Resp> {
42        let worker_id = setup.worker_id();
43        let tracker = get_request_tracker(setup);
44
45        RequestServer {
46            name: self.name,
47            worker_id,
48            hooks: self.hooks.clone(),
49            tracker,
50            tracing_enabled: true,
51        }
52    }
53
54    pub fn build_shared(self, /* todo */) {
55        // TODO insert self.hooks.response into the shared state's map of plane_id -> response_tx
56        // we'll need to change modrpc::PacketProcessor to lookup handlers based on
57        // (infra_id, topic) instead of (plane_id, topic) (and add infra_id to TransmitPacket).
58        // In PacketProcessor, if infra_id is non-zero, we'll look up the handler based on
59        // infra_id. Otherwise we'll look up based on plane_id.
60    }
61
62    pub fn build_replier(
63        self,
64        setup: &RoleSetup,
65        mut handler: impl AsyncFnMut(RequestContext<Resp>, Req::Lazy<'_>) + 'static,
66    ) {
67        let mut response_tx: modrpc::EventTx<Response<Resp>> = self.hooks.response;
68        self.stubs.request
69            .queued(setup, async move |source: modrpc::EndpointAddr, request: RequestLazy<Req>| {
70                let Ok(request_id) = request.request_id() else { return; };
71                let Ok(requester_worker) = request.worker() else { return; };
72                let Ok(payload) = request.payload() else { return; };
73
74                handler(
75                    RequestContext {
76                        source,
77                        reply: ResponseSender {
78                            response_event_sender: &mut response_tx,
79                            request_id,
80                            source: source,
81                            requester_worker,
82                        },
83                    },
84                    payload,
85                )
86                .await;
87            })
88            .load_balance();
89    }
90
91    pub fn build(
92        self,
93        setup: &RoleSetup,
94        mut handler: impl AsyncFnMut(modrpc::EndpointAddr, Req::Lazy<'_>) -> Resp + 'static,
95    ) {
96        let response_tx: modrpc::EventTx<Response<Resp>> = self.hooks.response;
97        self.stubs.request.queued(
98            setup,
99            async move |source: modrpc::EndpointAddr, request: RequestLazy<Req>| {
100                let Ok(request_id) = request.request_id() else { return; };
101                let Ok(requester_worker) = request.worker() else { return; };
102                let Ok(request_payload) = request.payload() else { return; };
103
104                let response = handler(source, request_payload).await;
105                response_tx.send(Response {
106                    request_id,
107                    requester: source.endpoint,
108                    requester_worker,
109                    payload: response,
110                })
111                .await;
112            },
113        )
114        .load_balance();
115    }
116
117    pub fn build_proxied(self, setup: &RoleSetup) {
118        self.stubs.request.proxy_load_balance(setup);
119    }
120}
121
122impl<Req, Resp> Clone for RequestServer<Req, Resp> {
123    fn clone(&self) -> Self {
124        Self {
125            name: self.name,
126            worker_id: self.worker_id,
127            hooks: self.hooks.clone(),
128            tracker: self.tracker.clone(),
129            tracing_enabled: self.tracing_enabled,
130        }
131    }
132}
133
134pub struct RequestContext<'a, R> {
135    pub source: modrpc::EndpointAddr,
136    pub reply: ResponseSender<'a, R>,
137}
138
139pub struct ResponseSender<'a, T> {
140    pub response_event_sender: &'a mut modrpc::EventTx<Response<T>>,
141    pub request_id: u32,
142    pub source: modrpc::EndpointAddr,
143    pub requester_worker: u16,
144}
145
146impl<T: mproto::Owned> ResponseSender<'_, T> {
147    #[inline]
148    pub async fn send(&mut self, response: impl mproto::Encode + mproto::Compatible<T>) {
149        self.response_event_sender.send(ResponseGen {
150            request_id: self.request_id,
151            requester: self.source.endpoint,
152            requester_worker: self.requester_worker,
153            payload: response,
154        }).await;
155    }
156}
157
158// Helpers to play nice with type inference for the very common situation where the response type
159// is a `Result`.
160impl<O: mproto::Owned, E: mproto::Owned> ResponseSender<'_, Result<O, E>> {
161    #[inline]
162    pub async fn send_ok(&mut self, response: impl mproto::Encode + mproto::Compatible<O>) {
163        self.response_event_sender.send(ResponseGen {
164            request_id: self.request_id,
165            requester: self.source.endpoint,
166            requester_worker: self.requester_worker,
167            payload: Ok::<_, E>(response),
168        }).await;
169    }
170
171    #[inline]
172    pub async fn send_err(&mut self, response: impl mproto::Encode + mproto::Compatible<E>) {
173        self.response_event_sender.send(ResponseGen {
174            request_id: self.request_id,
175            requester: self.source.endpoint,
176            requester_worker: self.requester_worker,
177            payload: Err::<O, _>(response),
178        }).await;
179    }
180}