std_modrpc/role_impls/
request_client.rs

1use modrpc::RoleSetup;
2
3use crate::{
4    proto::{
5        RequestClientConfig,
6        RequestGen,
7        RequestLazy,
8        RequestInitState,
9        Response,
10        ResponseLazy,
11    },
12    request_tracker::{RequestTracker, get_request_tracker},
13};
14
15pub use sealed::ResponseWaiter;
16
17mod sealed {
18    use crate::{
19        proto::Response,
20        request_tracker::PendingRequestSubscription,
21    };
22
23    pub struct ResponseWaiter<Resp> {
24        pending_request: PendingRequestSubscription,
25        _phantom: core::marker::PhantomData<Resp>,
26    }
27
28    impl<Resp: for<'d> mproto::Decode<'d>> ResponseWaiter<Resp> {
29        pub async fn wait(self) -> mproto::DecodeResult<Resp> {
30            let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
31            let response_packet = self.pending_request.wait().await;
32
33            let response = mproto::decode_value(&response_packet.as_ref()[packet_header_len..])?;
34            let response: Response<Resp> = response;
35
36            Ok(response.payload)
37        }
38    }
39
40    pub fn wait_response_then_decode<Resp>(pending_request: PendingRequestSubscription)
41        -> ResponseWaiter<Resp>
42        where Resp: for<'d> mproto::Decode<'d>
43    {
44        ResponseWaiter {
45            pending_request,
46            _phantom: core::marker::PhantomData,
47        }
48    }
49
50    // One day
51    /*pub type ResponseWaiter<Resp: for<'d> mproto::Decode<'d>>
52        = impl Future<Output = mproto::DecodeResult<Resp>>;
53
54    #[define_opaque(ResponseWaiter)]
55    pub fn wait_response_then_decode<Resp>(pending_request: PendingRequestSubscription)
56        -> ResponseWaiter<Resp>
57        where Resp: for<'d> mproto::Decode<'d>
58    {
59        async move {
60            let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
61            let response_packet = pending_request.wait().await;
62
63            let response = mproto::decode_value(&response_packet.as_ref()[packet_header_len..])?;
64            let response: Response<Resp> = response;
65
66            Ok(response.payload)
67        }
68    }*/
69}
70
71pub struct RequestClient<Req, Resp> {
72    name: &'static str,
73    rt: modrpc::RuntimeHandle,
74    worker_id: u16,
75    hooks: crate::RequestClientHooks<Req, Resp>,
76    tracker: RequestTracker,
77    spawner: modrpc::RoleSpawner,
78}
79
80impl<
81    Req: mproto::Owned,
82    Resp: mproto::Owned,
83> RequestClient<Req, Resp> {
84    pub async fn call<LikeReq>(&self, payload: LikeReq) -> Resp
85        where LikeReq: mproto::Compatible<Req>
86    {
87        let pending_request = self.tracker.client_start_request();
88        self.hooks.request.send(RequestGen::<LikeReq> {
89            worker: self.worker_id,
90            request_id: pending_request.request_id(),
91            payload,
92        })
93        .await;
94
95        let response_buf = pending_request.await;
96        let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
97        let response: Response<Resp>
98            = mproto::decode_value(&response_buf.as_ref()[header_len..]).unwrap();
99
100        response.payload
101    }
102
103    pub fn subscribe(
104        &self,
105        request_subscription:
106            impl AsyncFnMut(
107                modrpc::EndpointAddr,
108                Req::Lazy<'_>,
109                ResponseWaiter<Resp>,
110            ) + Clone + 'static,
111    )
112    where
113    {
114        // Lazily create the inter-worker topic subscription
115        let local_worker_context = self.rt.local_worker_context()
116            .expect("modrpc::RequestClient::subscribe local worker context");
117        modrpc::add_topic_subscription(
118            local_worker_context,
119            "todo",
120            self.hooks.request.plane_id(),
121            self.hooks.request.topic(),
122        );
123
124        let spawner = self.spawner.clone();
125        self.tracker.subscribe(
126            self.hooks.request.plane_id(),
127            self.hooks.request.topic(),
128            Box::new(move |request_packet: modrpc::Packet, pending_request| {
129                let mut request_subscription = request_subscription.clone();
130                spawner.spawn(async move {
131                    let header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
132
133                    let Ok(header) =
134                        mproto::decode_value(&request_packet.as_ref()[..header_len])
135                    else { return; };
136                    let header: modrpc::TransmitPacket = header;
137
138                    let Ok(request) =
139                        mproto::decode_value(&request_packet.as_ref()[header_len..])
140                    else { return; };
141                    let request: RequestLazy<Req> = request;
142
143                    let Ok(request_payload) = request.payload() else { return; };
144
145                    request_subscription(
146                        header.source,
147                        request_payload,
148                        sealed::wait_response_then_decode(pending_request),
149                    )
150                    .await;
151                });
152            }),
153        );
154    }
155}
156
157pub struct RequestClientBuilder<Req, Resp> {
158    pub name: &'static str,
159    pub hooks: crate::RequestClientHooks<Req, Resp>,
160    pub stubs: crate::RequestClientStubs<Req, Resp>,
161    pub init: RequestInitState,
162}
163
164impl<
165    Req: mproto::Owned,
166    Resp: mproto::Owned,
167> RequestClientBuilder<Req, Resp> {
168    pub fn new(
169        name: &'static str,
170        hooks: crate::RequestClientHooks<Req, Resp>,
171        stubs: crate::RequestClientStubs<Req, Resp>,
172        _config: &RequestClientConfig,
173        init: RequestInitState,
174    ) -> Self {
175        Self { name, hooks, stubs, init }
176    }
177
178    pub fn create_handle(&self, setup: &RoleSetup) -> RequestClient<Req, Resp> {
179        let worker_id = setup.worker_id();
180        let tracker = get_request_tracker(setup);
181
182        RequestClient {
183            name: self.name,
184            rt: setup.worker_context().rt().clone(),
185            worker_id,
186            hooks: self.hooks.clone(),
187            tracker,
188            spawner: setup.role_spawner().clone(),
189        }
190    }
191
192    pub fn build(
193        self,
194        setup: &RoleSetup,
195    ) {
196        let local_addr = setup.endpoint_addr().endpoint;
197        let local_worker_id = setup.worker_id();
198        let plane_id = setup.plane_id();
199
200        // The .local() request/response handlers below are actually subscriptions, but we defer
201        // creating the inter-worker topic subscription until the user actually creates a
202        // subscription on this request because the inter-worker subscription is relatively
203        // expensive.
204
205        let tracker = get_request_tracker(setup);
206        let request_topic = self.hooks.request.topic();
207        self.stubs.request
208            .inline_untyped(setup, move |source, packet| {
209                tracker.handle_request::<Req>(plane_id, request_topic, source.endpoint, packet.clone());
210            })
211            .local();
212
213        let tracker = get_request_tracker(setup);
214        self.stubs.response.clone()
215            .inline_untyped(setup, move |_source, packet| {
216                tracker.handle_response::<Resp>(plane_id, local_addr, local_worker_id, packet.clone());
217            })
218            .local();
219
220        self.stubs.response.clone()
221            .route_to_worker(setup, move |_source, packet| {
222                let packet_header_len = <modrpc::TransmitPacket as mproto::BaseLen>::BASE_LEN;
223                let Ok(response) =
224                    mproto::decode_value::<ResponseLazy<Resp>>(&packet.as_ref()[packet_header_len..])
225                else {
226                    return None;
227                };
228                let Ok(requester) = response.requester() else {
229                    return None;
230                };
231                let Ok(requester_worker) = response.requester_worker() else {
232                    return None;
233                };
234
235                probius::trace_label("route_to_worker");
236                probius::trace_branch(|| {
237                    if requester == local_addr && requester_worker != local_worker_id {
238                        probius::trace_metric("redirect", 1);
239                        // This response is for a locally generated request at a different worker -
240                        // route to the correct worker.
241                        Some(modrpc::WorkerId(requester_worker))
242                    } else {
243                        probius::trace_metric("no-redirect", 1);
244                        None
245                    }
246                })
247            });
248    }
249}
250
251impl<Req, Resp> Clone for RequestClient<Req, Resp> {
252    fn clone(&self) -> Self {
253        Self {
254            name: self.name,
255            rt: self.rt.clone(),
256            worker_id: self.worker_id,
257            hooks: self.hooks.clone(),
258            tracker: self.tracker.clone(),
259            spawner: self.spawner.clone(),
260        }
261    }
262}
263