std_modrpc/role_impls/
request_client.rs

1use modrpc::RoleSetup;
2
3use crate::{
4    proto::{
5        RequestClientConfig, RequestGen, RequestInitState, RequestLazy, Response, ResponseLazy,
6    },
7    request_tracker::{RequestTracker, get_request_tracker},
8};
9
10pub use sealed::ResponseWaiter;
11
12mod sealed {
13    use crate::{proto::Response, request_tracker::PendingRequestSubscription};
14
15    pub struct ResponseWaiter<Resp> {
16        pending_request: PendingRequestSubscription,
17        _phantom: core::marker::PhantomData<Resp>,
18    }
19
20    impl<Resp: for<'d> mproto::Decode<'d>> ResponseWaiter<Resp> {
21        pub async fn wait(self) -> mproto::DecodeResult<Resp> {
22            let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
23            let response_packet = self.pending_request.wait().await;
24
25            let response = mproto::decode_value(&response_packet.as_ref()[packet_header_len..])?;
26            let response: Response<Resp> = response;
27
28            Ok(response.payload)
29        }
30    }
31
32    pub fn wait_response_then_decode<Resp>(
33        pending_request: PendingRequestSubscription,
34    ) -> ResponseWaiter<Resp>
35    where
36        Resp: for<'d> mproto::Decode<'d>,
37    {
38        ResponseWaiter {
39            pending_request,
40            _phantom: core::marker::PhantomData,
41        }
42    }
43
44    // One day
45    /*pub type ResponseWaiter<Resp: for<'d> mproto::Decode<'d>>
46        = impl Future<Output = mproto::DecodeResult<Resp>>;
47
48    #[define_opaque(ResponseWaiter)]
49    pub fn wait_response_then_decode<Resp>(pending_request: PendingRequestSubscription)
50        -> ResponseWaiter<Resp>
51        where Resp: for<'d> mproto::Decode<'d>
52    {
53        async move {
54            let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
55            let response_packet = pending_request.wait().await;
56
57            let response = mproto::decode_value(&response_packet.as_ref()[packet_header_len..])?;
58            let response: Response<Resp> = response;
59
60            Ok(response.payload)
61        }
62    }*/
63}
64
65pub struct RequestClient<Req, Resp> {
66    name: &'static str,
67    rt: modrpc::RuntimeHandle,
68    worker_id: u16,
69    hooks: crate::RequestClientHooks<Req, Resp>,
70    tracker: RequestTracker,
71    spawner: modrpc::RoleSpawner,
72}
73
74impl<Req: mproto::Owned, Resp: mproto::Owned> RequestClient<Req, Resp> {
75    pub async fn call<LikeReq>(&self, payload: LikeReq) -> Resp
76    where
77        LikeReq: mproto::Compatible<Req>,
78    {
79        let pending_request = self.tracker.client_start_request();
80        self.hooks
81            .request
82            .send(RequestGen::<LikeReq> {
83                worker: self.worker_id,
84                request_id: pending_request.request_id(),
85                payload,
86            })
87            .await;
88
89        let response_buf = pending_request.await;
90        let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
91        let response: Response<Resp> =
92            mproto::decode_value(&response_buf.as_ref()[header_len..]).unwrap();
93
94        response.payload
95    }
96
97    pub fn subscribe(
98        &self,
99        request_subscription: impl AsyncFnMut(modrpc::EndpointAddr, Req::Lazy<'_>, ResponseWaiter<Resp>)
100        + Clone
101        + 'static,
102    ) {
103        // Lazily create the inter-worker topic subscription
104        let local_worker_context = self
105            .rt
106            .local_worker_context()
107            .expect("modrpc::RequestClient::subscribe local worker context");
108        modrpc::add_topic_subscription(
109            local_worker_context,
110            "todo",
111            self.hooks.request.plane_id(),
112            self.hooks.request.topic(),
113        );
114
115        let spawner = self.spawner.clone();
116        self.tracker.subscribe(
117            self.hooks.request.plane_id(),
118            self.hooks.request.topic(),
119            Box::new(move |request_packet: modrpc::Packet, pending_request| {
120                let mut request_subscription = request_subscription.clone();
121                spawner.spawn(async move {
122                    let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
123
124                    let Ok(header) = mproto::decode_value(&request_packet.as_ref()[..header_len])
125                    else {
126                        return;
127                    };
128                    let header: modrpc::TransmitPacket = header;
129
130                    let Ok(request) = mproto::decode_value(&request_packet.as_ref()[header_len..])
131                    else {
132                        return;
133                    };
134                    let request: RequestLazy<Req> = request;
135
136                    let Ok(request_payload) = request.payload() else {
137                        return;
138                    };
139
140                    request_subscription(
141                        header.source,
142                        request_payload,
143                        sealed::wait_response_then_decode(pending_request),
144                    )
145                    .await;
146                });
147            }),
148        );
149    }
150}
151
152pub struct RequestClientBuilder<Req, Resp> {
153    pub name: &'static str,
154    pub hooks: crate::RequestClientHooks<Req, Resp>,
155    pub stubs: crate::RequestClientStubs<Req, Resp>,
156    pub init: RequestInitState,
157}
158
159impl<Req: mproto::Owned, Resp: mproto::Owned> RequestClientBuilder<Req, Resp> {
160    pub fn new(
161        name: &'static str,
162        hooks: crate::RequestClientHooks<Req, Resp>,
163        stubs: crate::RequestClientStubs<Req, Resp>,
164        _config: &RequestClientConfig,
165        init: RequestInitState,
166    ) -> Self {
167        Self {
168            name,
169            hooks,
170            stubs,
171            init,
172        }
173    }
174
175    pub fn create_handle(&self, setup: &RoleSetup) -> RequestClient<Req, Resp> {
176        let worker_id = setup.worker_id();
177        let tracker = get_request_tracker(setup);
178
179        RequestClient {
180            name: self.name,
181            rt: setup.worker_context().rt().clone(),
182            worker_id,
183            hooks: self.hooks.clone(),
184            tracker,
185            spawner: setup.role_spawner().clone(),
186        }
187    }
188
189    pub fn build(self, setup: &RoleSetup) {
190        let local_addr = setup.endpoint_addr().endpoint;
191        let local_worker_id = setup.worker_id();
192        let plane_id = setup.plane_id();
193
194        // The .local() request/response handlers below are actually subscriptions, but we defer
195        // creating the inter-worker topic subscription until the user actually creates a
196        // subscription on this request because the inter-worker subscription is relatively
197        // expensive.
198
199        let tracker = get_request_tracker(setup);
200        let request_topic = self.hooks.request.topic();
201        self.stubs
202            .request
203            .inline_untyped(setup, move |source, packet| {
204                tracker.handle_request::<Req>(
205                    plane_id,
206                    request_topic,
207                    source.endpoint,
208                    packet.clone(),
209                );
210            })
211            .local();
212
213        let tracker = get_request_tracker(setup);
214        self.stubs
215            .response
216            .clone()
217            .inline_untyped(setup, move |_source, packet| {
218                tracker.handle_response::<Resp>(
219                    plane_id,
220                    local_addr,
221                    local_worker_id,
222                    packet.clone(),
223                );
224            })
225            .local();
226
227        self.stubs
228            .response
229            .clone()
230            .route_to_worker(setup, move |_source, packet| {
231                let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
232                let Ok(response) = mproto::decode_value::<ResponseLazy<Resp>>(
233                    &packet.as_ref()[packet_header_len..],
234                ) else {
235                    return None;
236                };
237                let Ok(requester) = response.requester() else {
238                    return None;
239                };
240                let Ok(requester_worker) = response.requester_worker() else {
241                    return None;
242                };
243
244                probius::trace_label("route_to_worker");
245                probius::trace_branch(|| {
246                    if requester == local_addr && requester_worker != local_worker_id {
247                        probius::trace_metric("redirect", 1);
248                        // This response is for a locally generated request at a different worker -
249                        // route to the correct worker.
250                        Some(modrpc::WorkerId(requester_worker))
251                    } else {
252                        probius::trace_metric("no-redirect", 1);
253                        None
254                    }
255                })
256            });
257    }
258}
259
260impl<Req, Resp> Clone for RequestClient<Req, Resp> {
261    fn clone(&self) -> Self {
262        Self {
263            name: self.name,
264            rt: self.rt.clone(),
265            worker_id: self.worker_id,
266            hooks: self.hooks.clone(),
267            tracker: self.tracker.clone(),
268            spawner: self.spawner.clone(),
269        }
270    }
271}