1extern crate alloc;
17
18use alloc::string::String;
19use alloc::sync::Arc;
20use alloc::vec::Vec;
21use core::marker::PhantomData;
22
23use zerodds_dcps::dds_type::{DdsType, RawBytes};
24use zerodds_dcps::participant::DomainParticipant;
25use zerodds_dcps::publisher::DataWriter;
26use zerodds_dcps::qos::{PublisherQos, SubscriberQos, TopicQos};
27use zerodds_dcps::subscriber::DataReader;
28
29use crate::common_types::{RemoteExceptionCode, ReplyHeader};
30use crate::error::{RpcError, RpcResult};
31use crate::qos_profile::RpcQos;
32use crate::requester::{InstanceClaim, InstanceRole, try_claim_instance};
33use crate::topic_naming::ServiceTopicNames;
34use crate::wire_codec::{decode_request_frame, encode_reply_frame};
35
36pub trait ReplierHandler<TIn, TOut>: Send + Sync {
45 fn handle(&self, request: TIn) -> Result<TOut, RemoteExceptionCode>;
49}
50
51pub struct FnHandler<F, TIn, TOut>
54where
55 F: Fn(TIn) -> Result<TOut, RemoteExceptionCode> + Send + Sync,
56{
57 f: F,
58 _phantom: PhantomData<fn() -> (TIn, TOut)>,
59}
60
61impl<F, TIn, TOut> FnHandler<F, TIn, TOut>
62where
63 F: Fn(TIn) -> Result<TOut, RemoteExceptionCode> + Send + Sync,
64{
65 pub fn new(f: F) -> Self {
67 Self {
68 f,
69 _phantom: PhantomData,
70 }
71 }
72}
73
74impl<F, TIn, TOut> ReplierHandler<TIn, TOut> for FnHandler<F, TIn, TOut>
75where
76 F: Fn(TIn) -> Result<TOut, RemoteExceptionCode> + Send + Sync,
77{
78 fn handle(&self, request: TIn) -> Result<TOut, RemoteExceptionCode> {
79 (self.f)(request)
80 }
81}
82
83pub struct Replier<TIn: DdsType, TOut: DdsType> {
89 service_name: String,
90 instance_name: String,
91 request_reader: DataReader<RawBytes>,
92 reply_writer: DataWriter<RawBytes>,
93 handler: Arc<dyn ReplierHandler<TIn, TOut>>,
94 handled_count: std::sync::atomic::AtomicU64,
95 error_count: std::sync::atomic::AtomicU64,
96 _claim: InstanceClaim,
97 _phantom: PhantomData<fn() -> (TIn, TOut)>,
98}
99
100impl<TIn: DdsType, TOut: DdsType> core::fmt::Debug for Replier<TIn, TOut> {
101 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
102 f.debug_struct("Replier")
103 .field("service", &self.service_name)
104 .field("instance", &self.instance_name)
105 .finish_non_exhaustive()
106 }
107}
108
109impl<TIn: DdsType + Send + 'static, TOut: DdsType + Send + 'static> Replier<TIn, TOut> {
110 pub fn new(
115 participant: &DomainParticipant,
116 service_name: &str,
117 qos: &RpcQos,
118 handler: Arc<dyn ReplierHandler<TIn, TOut>>,
119 ) -> RpcResult<Self> {
120 Self::with_instance(participant, service_name, "", qos, handler)
121 }
122
123 pub fn with_instance(
128 participant: &DomainParticipant,
129 service_name: &str,
130 instance_name: &str,
131 qos: &RpcQos,
132 handler: Arc<dyn ReplierHandler<TIn, TOut>>,
133 ) -> RpcResult<Self> {
134 let topics = ServiceTopicNames::new(service_name)?;
135 let claim = try_claim_instance(
136 participant,
137 InstanceRole::Replier,
138 service_name,
139 instance_name,
140 )?;
141 let request_topic = participant
142 .create_topic::<RawBytes>(&topics.request, TopicQos::default())
143 .map_err(|e| RpcError::Dcps(alloc::format!("create_topic request: {e:?}")))?;
144 let reply_topic = participant
145 .create_topic::<RawBytes>(&topics.reply, TopicQos::default())
146 .map_err(|e| RpcError::Dcps(alloc::format!("create_topic reply: {e:?}")))?;
147 let publisher = participant.create_publisher(PublisherQos::default());
148 let subscriber = participant.create_subscriber(SubscriberQos::default());
149 let request_reader = subscriber
150 .create_datareader::<RawBytes>(&request_topic, qos.request_reader_qos())
151 .map_err(|e| RpcError::Dcps(alloc::format!("create_datareader: {e:?}")))?;
152 let reply_writer = publisher
153 .create_datawriter::<RawBytes>(&reply_topic, qos.reply_writer_qos())
154 .map_err(|e| RpcError::Dcps(alloc::format!("create_datawriter: {e:?}")))?;
155 Ok(Self {
156 service_name: service_name.into(),
157 instance_name: instance_name.into(),
158 request_reader,
159 reply_writer,
160 handler,
161 handled_count: std::sync::atomic::AtomicU64::new(0),
162 error_count: std::sync::atomic::AtomicU64::new(0),
163 _claim: claim,
164 _phantom: PhantomData,
165 })
166 }
167
168 #[must_use]
170 pub fn service_name(&self) -> &str {
171 &self.service_name
172 }
173
174 #[must_use]
176 pub fn instance_name(&self) -> &str {
177 &self.instance_name
178 }
179
180 #[must_use]
182 pub fn handled_count(&self) -> u64 {
183 self.handled_count
184 .load(std::sync::atomic::Ordering::Acquire)
185 }
186
187 #[must_use]
189 pub fn error_count(&self) -> u64 {
190 self.error_count.load(std::sync::atomic::Ordering::Acquire)
191 }
192
193 pub fn tick(&self) -> usize {
196 let samples = match self.request_reader.take() {
197 Ok(s) => s,
198 Err(_) => return 0,
199 };
200 let mut processed = 0;
201 for raw in samples {
202 let bytes = raw.data;
203 let (header, payload) = match decode_request_frame(&bytes) {
204 Ok(t) => t,
205 Err(_) => continue,
206 };
207 if !self.instance_name.is_empty()
212 && !header.instance_name.is_empty()
213 && header.instance_name != self.instance_name
214 {
215 continue;
216 }
217 let request_id = header.request_id;
218 let req = match TIn::decode(payload) {
220 Ok(v) => v,
221 Err(_) => {
222 self.send_error_reply(request_id, RemoteExceptionCode::InvalidArgument);
223 continue;
224 }
225 };
226 match self.handler.handle(req) {
227 Ok(reply) => {
228 let mut user_buf = Vec::new();
229 if reply.encode(&mut user_buf).is_err() {
230 self.send_error_reply(request_id, RemoteExceptionCode::OutOfResources);
231 continue;
232 }
233 let header = ReplyHeader::new(request_id, RemoteExceptionCode::Ok);
234 let frame = encode_reply_frame(&header, &user_buf);
235 if self.reply_writer.write(&RawBytes::new(frame)).is_err() {
236 continue;
240 }
241 self.handled_count
242 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
243 }
244 Err(code) => {
245 self.send_error_reply(request_id, code);
246 self.error_count
247 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
248 }
249 }
250 processed += 1;
251 }
252 processed
253 }
254
255 fn send_error_reply(
256 &self,
257 request_id: crate::common_types::SampleIdentity,
258 code: RemoteExceptionCode,
259 ) {
260 let header = ReplyHeader::new(request_id, code);
261 let frame = encode_reply_frame(&header, &[]);
262 let _ = self.reply_writer.write(&RawBytes::new(frame));
263 }
264
265 #[doc(hidden)]
267 #[must_use]
268 pub fn __drain_reply_writer(&self) -> Vec<Vec<u8>> {
269 self.reply_writer.__drain_pending()
270 }
271
272 #[doc(hidden)]
274 pub fn __push_request_raw(&self, bytes: Vec<u8>) -> RpcResult<()> {
275 self.request_reader
276 .__push_raw(bytes)
277 .map_err(|e| RpcError::Dcps(alloc::format!("push raw: {e:?}")))
278 }
279}
280
281#[cfg(test)]
282#[allow(clippy::unwrap_used, clippy::expect_used)]
283mod tests {
284 use super::*;
285 use crate::common_types::{RequestHeader, SampleIdentity};
286 use crate::wire_codec::{decode_reply_frame, encode_request_frame};
287 use zerodds_dcps::dds_type::RawBytes;
288 use zerodds_dcps::factory::DomainParticipantFactory;
289 use zerodds_dcps::qos::DomainParticipantQos;
290
291 fn participant(domain: i32) -> DomainParticipant {
292 DomainParticipantFactory::instance()
293 .create_participant_offline(domain, DomainParticipantQos::default())
294 }
295
296 fn echo_handler() -> Arc<dyn ReplierHandler<RawBytes, RawBytes>> {
297 Arc::new(FnHandler::new(|req: RawBytes| -> Result<RawBytes, _> {
298 Ok(req)
299 }))
300 }
301
302 fn err_handler(code: RemoteExceptionCode) -> Arc<dyn ReplierHandler<RawBytes, RawBytes>> {
303 Arc::new(FnHandler::new(move |_req: RawBytes| Err(code)))
304 }
305
306 #[test]
307 fn replier_new_creates_endpoints() {
308 let p = participant(201);
309 let q = RpcQos::default_basic();
310 let r = Replier::<RawBytes, RawBytes>::new(&p, "Calc", &q, echo_handler()).unwrap();
311 assert_eq!(r.service_name(), "Calc");
312 assert_eq!(r.instance_name(), "");
313 assert_eq!(r.handled_count(), 0);
314 }
315
316 #[test]
317 fn replier_invalid_service_name_rejected() {
318 let p = participant(202);
319 let q = RpcQos::default_basic();
320 let err = Replier::<RawBytes, RawBytes>::new(&p, "", &q, echo_handler()).unwrap_err();
321 assert!(matches!(err, RpcError::InvalidServiceName(_)));
322 }
323
324 #[test]
325 fn replier_duplicate_instance_name_rejected() {
326 let p = participant(203);
327 let q = RpcQos::default_basic();
328 let _r1 =
329 Replier::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-A", &q, echo_handler())
330 .unwrap();
331 let err =
332 Replier::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-A", &q, echo_handler())
333 .unwrap_err();
334 assert!(matches!(err, RpcError::DuplicateInstanceName(_)));
335 }
336
337 #[test]
338 fn tick_with_no_requests_is_noop() {
339 let p = participant(204);
340 let q = RpcQos::default_basic();
341 let r = Replier::<RawBytes, RawBytes>::new(&p, "Calc", &q, echo_handler()).unwrap();
342 assert_eq!(r.tick(), 0);
343 assert_eq!(r.handled_count(), 0);
344 }
345
346 #[test]
347 fn tick_processes_request_and_writes_reply() {
348 let p = participant(205);
349 let q = RpcQos::default_basic();
350 let r = Replier::<RawBytes, RawBytes>::new(&p, "Calc", &q, echo_handler()).unwrap();
351 let id = SampleIdentity::new([1u8; 16], 42);
352 let req_header = RequestHeader::new(id, "");
353 let req_frame = encode_request_frame(&req_header, &[7u8, 8, 9]);
354 r.__push_request_raw(req_frame).unwrap();
355 assert_eq!(r.tick(), 1);
356 assert_eq!(r.handled_count(), 1);
357 let frames = r.__drain_reply_writer();
358 assert_eq!(frames.len(), 1);
359 let (reply_header, payload) = decode_reply_frame(&frames[0]).unwrap();
360 assert_eq!(reply_header.related_request_id, id);
361 assert_eq!(reply_header.remote_ex, RemoteExceptionCode::Ok);
362 assert_eq!(payload, &[7u8, 8, 9]);
363 }
364
365 #[test]
366 fn tick_propagates_handler_error_into_reply() {
367 let p = participant(206);
368 let q = RpcQos::default_basic();
369 let r = Replier::<RawBytes, RawBytes>::new(
370 &p,
371 "Calc",
372 &q,
373 err_handler(RemoteExceptionCode::InvalidArgument),
374 )
375 .unwrap();
376 let id = SampleIdentity::new([2u8; 16], 7);
377 let frame = encode_request_frame(&RequestHeader::new(id, ""), &[1, 2]);
378 r.__push_request_raw(frame).unwrap();
379 assert_eq!(r.tick(), 1);
380 assert_eq!(r.error_count(), 1);
381 assert_eq!(r.handled_count(), 0);
382 let replies = r.__drain_reply_writer();
383 let (h, payload) = decode_reply_frame(&replies[0]).unwrap();
384 assert_eq!(h.related_request_id, id);
385 assert_eq!(h.remote_ex, RemoteExceptionCode::InvalidArgument);
386 assert!(payload.is_empty());
387 }
388
389 #[test]
390 fn tick_drops_malformed_request_silently() {
391 let p = participant(207);
392 let q = RpcQos::default_basic();
393 let r = Replier::<RawBytes, RawBytes>::new(&p, "Calc", &q, echo_handler()).unwrap();
394 r.__push_request_raw(alloc::vec![0u8; 5]).unwrap(); assert_eq!(r.tick(), 0);
396 assert_eq!(r.handled_count(), 0);
397 assert!(r.__drain_reply_writer().is_empty());
398 }
399
400 #[test]
401 fn tick_filters_requests_for_other_instance_name() {
402 let p = participant(208);
403 let q = RpcQos::default_basic();
404 let r =
405 Replier::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-A", &q, echo_handler())
406 .unwrap();
407 let id = SampleIdentity::new([3u8; 16], 1);
408 let frame = encode_request_frame(&RequestHeader::new(id, "calc-B"), &[1]);
410 r.__push_request_raw(frame).unwrap();
411 assert_eq!(r.tick(), 0);
412 assert_eq!(r.handled_count(), 0);
413 }
414
415 #[test]
416 fn tick_handles_multiple_requests_in_one_call() {
417 let p = participant(209);
418 let q = RpcQos::default_basic();
419 let r = Replier::<RawBytes, RawBytes>::new(&p, "Calc", &q, echo_handler()).unwrap();
420 for i in 1..=5u64 {
421 let id = SampleIdentity::new([0xAB; 16], i);
422 let frame =
423 encode_request_frame(&RequestHeader::new(id, ""), &[u8::try_from(i).unwrap()]);
424 r.__push_request_raw(frame).unwrap();
425 }
426 assert_eq!(r.tick(), 5);
427 assert_eq!(r.handled_count(), 5);
428 let replies = r.__drain_reply_writer();
429 assert_eq!(replies.len(), 5);
430 }
431
432 #[test]
433 fn fn_handler_passthrough_works() {
434 let h = FnHandler::new(|x: RawBytes| Ok::<RawBytes, RemoteExceptionCode>(x));
435 let res = h.handle(RawBytes::new(alloc::vec![1, 2])).unwrap();
436 assert_eq!(res.data, alloc::vec![1, 2]);
437 }
438}