1use crate::protos::temporal::api::{common::v1::Payload, failure::v1::Failure};
5use futures::{FutureExt, future::BoxFuture};
6use std::{collections::HashMap, sync::Arc};
7
8#[derive(Clone)]
11pub struct DataConverter {
12 payload_converter: PayloadConverter,
13 #[allow(dead_code)] failure_converter: Arc<dyn FailureConverter + Send + Sync>,
15 codec: Arc<dyn PayloadCodec + Send + Sync>,
16}
17
18impl std::fmt::Debug for DataConverter {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("DataConverter")
21 .field("payload_converter", &self.payload_converter)
22 .finish_non_exhaustive()
23 }
24}
25impl DataConverter {
26 pub fn new(
28 payload_converter: PayloadConverter,
29 failure_converter: impl FailureConverter + Send + Sync + 'static,
30 codec: impl PayloadCodec + Send + Sync + 'static,
31 ) -> Self {
32 Self {
33 payload_converter,
34 failure_converter: Arc::new(failure_converter),
35 codec: Arc::new(codec),
36 }
37 }
38
39 pub async fn to_payload<T: TemporalSerializable + 'static>(
41 &self,
42 data: &SerializationContextData,
43 val: &T,
44 ) -> Result<Payload, PayloadConversionError> {
45 let context = SerializationContext {
46 data,
47 converter: &self.payload_converter,
48 };
49 let payload = self.payload_converter.to_payload(&context, val)?;
50 let encoded = self.codec.encode(data, vec![payload]).await;
51 encoded
52 .into_iter()
53 .next()
54 .ok_or(PayloadConversionError::WrongEncoding)
55 }
56
57 pub async fn from_payload<T: TemporalDeserializable + 'static>(
59 &self,
60 data: &SerializationContextData,
61 payload: Payload,
62 ) -> Result<T, PayloadConversionError> {
63 let context = SerializationContext {
64 data,
65 converter: &self.payload_converter,
66 };
67 let decoded = self.codec.decode(data, vec![payload]).await;
68 let payload = decoded
69 .into_iter()
70 .next()
71 .ok_or(PayloadConversionError::WrongEncoding)?;
72 self.payload_converter.from_payload(&context, payload)
73 }
74
75 pub async fn to_payloads<T: TemporalSerializable + 'static>(
77 &self,
78 data: &SerializationContextData,
79 val: &T,
80 ) -> Result<Vec<Payload>, PayloadConversionError> {
81 let context = SerializationContext {
82 data,
83 converter: &self.payload_converter,
84 };
85 let payloads = self.payload_converter.to_payloads(&context, val)?;
86 Ok(self.codec.encode(data, payloads).await)
87 }
88
89 pub async fn from_payloads<T: TemporalDeserializable + 'static>(
91 &self,
92 data: &SerializationContextData,
93 payloads: Vec<Payload>,
94 ) -> Result<T, PayloadConversionError> {
95 let context = SerializationContext {
96 data,
97 converter: &self.payload_converter,
98 };
99 let decoded = self.codec.decode(data, payloads).await;
100 self.payload_converter.from_payloads(&context, decoded)
101 }
102
103 pub fn payload_converter(&self) -> &PayloadConverter {
105 &self.payload_converter
106 }
107
108 pub fn codec(&self) -> &(dyn PayloadCodec + Send + Sync) {
110 self.codec.as_ref()
111 }
112}
113
114#[derive(Clone, Copy, Debug, PartialEq, Eq)]
116pub enum SerializationContextData {
117 Workflow,
119 Activity,
121 Nexus,
123 None,
125}
126
127#[derive(Clone, Copy)]
130pub struct SerializationContext<'a> {
131 pub data: &'a SerializationContextData,
133 pub converter: &'a PayloadConverter,
135}
136#[derive(Clone)]
138pub enum PayloadConverter {
139 Serde(Arc<dyn ErasedSerdePayloadConverter>),
141 UseWrappers,
143 Composite(Arc<CompositePayloadConverter>),
145}
146
147impl std::fmt::Debug for PayloadConverter {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 PayloadConverter::Serde(_) => write!(f, "PayloadConverter::Serde(...)"),
151 PayloadConverter::UseWrappers => write!(f, "PayloadConverter::UseWrappers"),
152 PayloadConverter::Composite(_) => write!(f, "PayloadConverter::Composite(...)"),
153 }
154 }
155}
156impl PayloadConverter {
157 pub fn serde_json() -> Self {
159 Self::Serde(Arc::new(SerdeJsonPayloadConverter))
160 }
161 }
163
164impl Default for PayloadConverter {
165 fn default() -> Self {
166 Self::Composite(Arc::new(CompositePayloadConverter {
167 converters: vec![Self::UseWrappers, Self::serde_json()],
168 }))
169 }
170}
171
172#[derive(Debug)]
174pub enum PayloadConversionError {
175 WrongEncoding,
177 EncodingError(Box<dyn std::error::Error + Send + Sync>),
179}
180
181impl std::fmt::Display for PayloadConversionError {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 match self {
184 PayloadConversionError::WrongEncoding => write!(f, "Wrong encoding"),
185 PayloadConversionError::EncodingError(err) => write!(f, "Encoding error: {}", err),
186 }
187 }
188}
189
190impl std::error::Error for PayloadConversionError {
191 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
192 match self {
193 PayloadConversionError::WrongEncoding => None,
194 PayloadConversionError::EncodingError(err) => Some(err.as_ref()),
195 }
196 }
197}
198
199pub trait FailureConverter {
201 fn to_failure(
203 &self,
204 error: Box<dyn std::error::Error>,
205 payload_converter: &PayloadConverter,
206 context: &SerializationContextData,
207 ) -> Result<Failure, PayloadConversionError>;
208
209 fn to_error(
211 &self,
212 failure: Failure,
213 payload_converter: &PayloadConverter,
214 context: &SerializationContextData,
215 ) -> Result<Box<dyn std::error::Error>, PayloadConversionError>;
216}
217pub struct DefaultFailureConverter;
219pub trait PayloadCodec {
221 fn encode(
223 &self,
224 context: &SerializationContextData,
225 payloads: Vec<Payload>,
226 ) -> BoxFuture<'static, Vec<Payload>>;
227 fn decode(
229 &self,
230 context: &SerializationContextData,
231 payloads: Vec<Payload>,
232 ) -> BoxFuture<'static, Vec<Payload>>;
233}
234
235impl<T: PayloadCodec> PayloadCodec for Arc<T> {
236 fn encode(
237 &self,
238 context: &SerializationContextData,
239 payloads: Vec<Payload>,
240 ) -> BoxFuture<'static, Vec<Payload>> {
241 (**self).encode(context, payloads)
242 }
243 fn decode(
244 &self,
245 context: &SerializationContextData,
246 payloads: Vec<Payload>,
247 ) -> BoxFuture<'static, Vec<Payload>> {
248 (**self).decode(context, payloads)
249 }
250}
251
252pub struct DefaultPayloadCodec;
254
255pub trait TemporalSerializable {
260 fn as_serde(&self) -> Result<&dyn erased_serde::Serialize, PayloadConversionError> {
262 Err(PayloadConversionError::WrongEncoding)
263 }
264 fn to_payload(&self, _: &SerializationContext<'_>) -> Result<Payload, PayloadConversionError> {
266 Err(PayloadConversionError::WrongEncoding)
267 }
268 fn to_payloads(
270 &self,
271 ctx: &SerializationContext<'_>,
272 ) -> Result<Vec<Payload>, PayloadConversionError> {
273 Ok(vec![self.to_payload(ctx)?])
274 }
275}
276
277pub trait TemporalDeserializable: Sized {
282 fn from_serde(
284 _: &dyn ErasedSerdePayloadConverter,
285 _ctx: &SerializationContext<'_>,
286 _: Payload,
287 ) -> Result<Self, PayloadConversionError> {
288 Err(PayloadConversionError::WrongEncoding)
289 }
290 fn from_payload(
292 ctx: &SerializationContext<'_>,
293 payload: Payload,
294 ) -> Result<Self, PayloadConversionError> {
295 let _ = (ctx, payload);
296 Err(PayloadConversionError::WrongEncoding)
297 }
298 fn from_payloads(
300 ctx: &SerializationContext<'_>,
301 payloads: Vec<Payload>,
302 ) -> Result<Self, PayloadConversionError> {
303 if payloads.len() != 1 {
304 return Err(PayloadConversionError::WrongEncoding);
305 }
306 Self::from_payload(ctx, payloads.into_iter().next().unwrap())
307 }
308}
309
310#[derive(Clone, Debug, Default)]
312pub struct RawValue {
313 pub payloads: Vec<Payload>,
315}
316impl RawValue {
317 pub fn empty() -> Self {
320 Self {
321 payloads: vec![Payload::default()],
322 }
323 }
324
325 pub fn new(payloads: Vec<Payload>) -> Self {
327 Self { payloads }
328 }
329
330 pub fn from_value<T: TemporalSerializable + 'static>(
332 value: &T,
333 converter: &PayloadConverter,
334 ) -> RawValue {
335 RawValue::new(vec![
336 converter
337 .to_payload(
338 &SerializationContext {
339 data: &SerializationContextData::None,
340 converter,
341 },
342 value,
343 )
344 .unwrap(),
345 ])
346 }
347
348 pub fn to_value<T: TemporalDeserializable + 'static>(self, converter: &PayloadConverter) -> T {
350 converter
351 .from_payload(
352 &SerializationContext {
353 data: &SerializationContextData::None,
354 converter,
355 },
356 self.payloads.into_iter().next().unwrap(),
357 )
358 .unwrap()
359 }
360}
361
362impl TemporalSerializable for RawValue {
363 fn to_payload(&self, _: &SerializationContext<'_>) -> Result<Payload, PayloadConversionError> {
364 Ok(self.payloads.first().cloned().unwrap_or_default())
365 }
366 fn to_payloads(
367 &self,
368 _: &SerializationContext<'_>,
369 ) -> Result<Vec<Payload>, PayloadConversionError> {
370 Ok(self.payloads.clone())
371 }
372}
373
374impl TemporalDeserializable for RawValue {
375 fn from_payload(
376 _: &SerializationContext<'_>,
377 p: Payload,
378 ) -> Result<Self, PayloadConversionError> {
379 Ok(RawValue { payloads: vec![p] })
380 }
381 fn from_payloads(
382 _: &SerializationContext<'_>,
383 payloads: Vec<Payload>,
384 ) -> Result<Self, PayloadConversionError> {
385 Ok(RawValue { payloads })
386 }
387}
388
389pub trait GenericPayloadConverter {
391 fn to_payload<T: TemporalSerializable + 'static>(
393 &self,
394 context: &SerializationContext<'_>,
395 val: &T,
396 ) -> Result<Payload, PayloadConversionError>;
397 #[allow(clippy::wrong_self_convention)]
399 fn from_payload<T: TemporalDeserializable + 'static>(
400 &self,
401 context: &SerializationContext<'_>,
402 payload: Payload,
403 ) -> Result<T, PayloadConversionError>;
404 fn to_payloads<T: TemporalSerializable + 'static>(
406 &self,
407 context: &SerializationContext<'_>,
408 val: &T,
409 ) -> Result<Vec<Payload>, PayloadConversionError> {
410 Ok(vec![self.to_payload(context, val)?])
411 }
412 #[allow(clippy::wrong_self_convention)]
414 fn from_payloads<T: TemporalDeserializable + 'static>(
415 &self,
416 context: &SerializationContext<'_>,
417 payloads: Vec<Payload>,
418 ) -> Result<T, PayloadConversionError> {
419 if payloads.len() != 1 {
420 return Err(PayloadConversionError::WrongEncoding);
421 }
422 self.from_payload(context, payloads.into_iter().next().unwrap())
423 }
424}
425
426impl GenericPayloadConverter for PayloadConverter {
427 fn to_payload<T: TemporalSerializable + 'static>(
428 &self,
429 context: &SerializationContext<'_>,
430 val: &T,
431 ) -> Result<Payload, PayloadConversionError> {
432 let mut payloads = self.to_payloads(context, val)?;
433 if payloads.len() != 1 {
434 return Err(PayloadConversionError::WrongEncoding);
435 }
436 Ok(payloads.pop().unwrap())
437 }
438
439 fn from_payload<T: TemporalDeserializable + 'static>(
440 &self,
441 context: &SerializationContext<'_>,
442 payload: Payload,
443 ) -> Result<T, PayloadConversionError> {
444 self.from_payloads(context, vec![payload])
445 }
446
447 fn to_payloads<T: TemporalSerializable + 'static>(
448 &self,
449 context: &SerializationContext<'_>,
450 val: &T,
451 ) -> Result<Vec<Payload>, PayloadConversionError> {
452 match self {
453 PayloadConverter::Serde(pc) => Ok(vec![pc.to_payload(context.data, val.as_serde()?)?]),
454 PayloadConverter::UseWrappers => T::to_payloads(val, context),
455 PayloadConverter::Composite(composite) => {
456 for converter in &composite.converters {
457 match converter.to_payloads(context, val) {
458 Ok(payloads) => return Ok(payloads),
459 Err(PayloadConversionError::WrongEncoding) => continue,
460 Err(e) => return Err(e),
461 }
462 }
463 Err(PayloadConversionError::WrongEncoding)
464 }
465 }
466 }
467
468 fn from_payloads<T: TemporalDeserializable + 'static>(
469 &self,
470 context: &SerializationContext<'_>,
471 payloads: Vec<Payload>,
472 ) -> Result<T, PayloadConversionError> {
473 if payloads.is_empty() && std::any::TypeId::of::<T>() == std::any::TypeId::of::<()>() {
475 let boxed: Box<dyn std::any::Any> = Box::new(());
476 return Ok(*boxed.downcast::<T>().unwrap());
477 }
478
479 match self {
480 PayloadConverter::Serde(pc) => {
481 if payloads.len() != 1 {
482 return Err(PayloadConversionError::WrongEncoding);
483 }
484 T::from_serde(pc.as_ref(), context, payloads.into_iter().next().unwrap())
485 }
486 PayloadConverter::UseWrappers => T::from_payloads(context, payloads),
487 PayloadConverter::Composite(composite) => {
488 for converter in &composite.converters {
489 match converter.from_payloads(context, payloads.clone()) {
490 Ok(val) => return Ok(val),
491 Err(PayloadConversionError::WrongEncoding) => continue,
492 Err(e) => return Err(e),
493 }
494 }
495 Err(PayloadConversionError::WrongEncoding)
496 }
497 }
498 }
499}
500
501impl<T> TemporalSerializable for T
503where
504 T: serde::Serialize,
505{
506 fn as_serde(&self) -> Result<&dyn erased_serde::Serialize, PayloadConversionError> {
507 Ok(self)
508 }
509}
510impl<T> TemporalDeserializable for T
511where
512 T: serde::de::DeserializeOwned,
513{
514 fn from_serde(
515 pc: &dyn ErasedSerdePayloadConverter,
516 context: &SerializationContext<'_>,
517 payload: Payload,
518 ) -> Result<Self, PayloadConversionError>
519 where
520 Self: Sized,
521 {
522 let mut de = pc.from_payload(context.data, payload)?;
523 erased_serde::deserialize(&mut de)
524 .map_err(|e| PayloadConversionError::EncodingError(Box::new(e)))
525 }
526}
527
528struct SerdeJsonPayloadConverter;
529impl ErasedSerdePayloadConverter for SerdeJsonPayloadConverter {
530 fn to_payload(
531 &self,
532 _: &SerializationContextData,
533 value: &dyn erased_serde::Serialize,
534 ) -> Result<Payload, PayloadConversionError> {
535 let as_json = serde_json::to_vec(value)
536 .map_err(|e| PayloadConversionError::EncodingError(e.into()))?;
537 Ok(Payload {
538 metadata: {
539 let mut hm = HashMap::new();
540 hm.insert("encoding".to_string(), b"json/plain".to_vec());
541 hm
542 },
543 data: as_json,
544 external_payloads: vec![],
545 })
546 }
547
548 fn from_payload(
549 &self,
550 _: &SerializationContextData,
551 payload: Payload,
552 ) -> Result<Box<dyn erased_serde::Deserializer<'static>>, PayloadConversionError> {
553 let encoding = payload.metadata.get("encoding").map(|v| v.as_slice());
554 if encoding != Some(b"json/plain".as_slice()) {
555 return Err(PayloadConversionError::WrongEncoding);
556 }
557 let json_v: serde_json::Value = serde_json::from_slice(&payload.data)
558 .map_err(|e| PayloadConversionError::EncodingError(Box::new(e)))?;
559 Ok(Box::new(<dyn erased_serde::Deserializer>::erase(json_v)))
560 }
561}
562pub trait ErasedSerdePayloadConverter: Send + Sync {
564 fn to_payload(
566 &self,
567 context: &SerializationContextData,
568 value: &dyn erased_serde::Serialize,
569 ) -> Result<Payload, PayloadConversionError>;
570 #[allow(clippy::wrong_self_convention)]
572 fn from_payload(
573 &self,
574 context: &SerializationContextData,
575 payload: Payload,
576 ) -> Result<Box<dyn erased_serde::Deserializer<'static>>, PayloadConversionError>;
577}
578
579pub struct ProstSerializable<T: prost::Message>(pub T);
584impl<T> TemporalSerializable for ProstSerializable<T>
585where
586 T: prost::Message + Default + 'static,
587{
588 fn to_payload(&self, _: &SerializationContext<'_>) -> Result<Payload, PayloadConversionError> {
589 let as_proto = prost::Message::encode_to_vec(&self.0);
590 Ok(Payload {
591 metadata: {
592 let mut hm = HashMap::new();
593 hm.insert("encoding".to_string(), b"binary/protobuf".to_vec());
594 hm
595 },
596 data: as_proto,
597 external_payloads: vec![],
598 })
599 }
600}
601impl<T> TemporalDeserializable for ProstSerializable<T>
602where
603 T: prost::Message + Default + 'static,
604{
605 fn from_payload(
606 _: &SerializationContext<'_>,
607 p: Payload,
608 ) -> Result<Self, PayloadConversionError>
609 where
610 Self: Sized,
611 {
612 let encoding = p.metadata.get("encoding").map(|v| v.as_slice());
613 if encoding != Some(b"binary/protobuf".as_slice()) {
614 return Err(PayloadConversionError::WrongEncoding);
615 }
616 T::decode(p.data.as_slice())
617 .map(ProstSerializable)
618 .map_err(|e| PayloadConversionError::EncodingError(Box::new(e)))
619 }
620}
621
622#[derive(Clone)]
624pub struct CompositePayloadConverter {
625 converters: Vec<PayloadConverter>,
626}
627
628impl Default for DataConverter {
629 fn default() -> Self {
630 Self::new(
631 PayloadConverter::default(),
632 DefaultFailureConverter,
633 DefaultPayloadCodec,
634 )
635 }
636}
637impl FailureConverter for DefaultFailureConverter {
638 fn to_failure(
639 &self,
640 _: Box<dyn std::error::Error>,
641 _: &PayloadConverter,
642 _: &SerializationContextData,
643 ) -> Result<Failure, PayloadConversionError> {
644 todo!()
645 }
646 fn to_error(
647 &self,
648 _: Failure,
649 _: &PayloadConverter,
650 _: &SerializationContextData,
651 ) -> Result<Box<dyn std::error::Error>, PayloadConversionError> {
652 todo!()
653 }
654}
655impl PayloadCodec for DefaultPayloadCodec {
656 fn encode(
657 &self,
658 _: &SerializationContextData,
659 payloads: Vec<Payload>,
660 ) -> BoxFuture<'static, Vec<Payload>> {
661 async move { payloads }.boxed()
662 }
663 fn decode(
664 &self,
665 _: &SerializationContextData,
666 payloads: Vec<Payload>,
667 ) -> BoxFuture<'static, Vec<Payload>> {
668 async move { payloads }.boxed()
669 }
670}
671
672macro_rules! impl_multi_args {
675 ($name:ident; $count:expr; $($idx:tt: $ty:ident),+) => {
676 #[doc = concat!("Wrapper for ", stringify!($count), " typed arguments, enabling multi-arg serialization.")]
677 #[derive(Clone, Debug, PartialEq, Eq)]
678 pub struct $name<$($ty),+>($(pub $ty),+);
679
680 impl<$($ty),+> TemporalSerializable for $name<$($ty),+>
681 where
682 $($ty: TemporalSerializable + 'static),+
683 {
684 fn to_payload(&self, _: &SerializationContext<'_>) -> Result<Payload, PayloadConversionError> {
685 Err(PayloadConversionError::WrongEncoding)
686 }
687 fn to_payloads(
688 &self,
689 ctx: &SerializationContext<'_>,
690 ) -> Result<Vec<Payload>, PayloadConversionError> {
691 Ok(vec![$(ctx.converter.to_payload(ctx, &self.$idx)?),+])
692 }
693 }
694
695 #[allow(non_snake_case)]
696 impl<$($ty),+> From<($($ty),+,)> for $name<$($ty),+> {
697 fn from(t: ($($ty),+,)) -> Self {
698 $name($(t.$idx),+)
699 }
700 }
701
702 impl<$($ty),+> TemporalDeserializable for $name<$($ty),+>
703 where
704 $($ty: TemporalDeserializable + 'static),+
705 {
706 fn from_payload(_: &SerializationContext<'_>, _: Payload) -> Result<Self, PayloadConversionError> {
707 Err(PayloadConversionError::WrongEncoding)
708 }
709 fn from_payloads(
710 ctx: &SerializationContext<'_>,
711 payloads: Vec<Payload>,
712 ) -> Result<Self, PayloadConversionError> {
713 if payloads.len() != $count {
714 return Err(PayloadConversionError::WrongEncoding);
715 }
716 let mut iter = payloads.into_iter();
717 Ok($name(
718 $(ctx.converter.from_payload::<$ty>(ctx, iter.next().unwrap())?),+
719 ))
720 }
721 }
722 };
723}
724
725impl_multi_args!(MultiArgs2; 2; 0: A, 1: B);
726impl_multi_args!(MultiArgs3; 3; 0: A, 1: B, 2: C);
727impl_multi_args!(MultiArgs4; 4; 0: A, 1: B, 2: C, 3: D);
728impl_multi_args!(MultiArgs5; 5; 0: A, 1: B, 2: C, 3: D, 4: E);
729impl_multi_args!(MultiArgs6; 6; 0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[test]
736 fn test_empty_payloads_as_unit_type() {
737 let converter = PayloadConverter::default();
738 let ctx = SerializationContext {
739 data: &SerializationContextData::Workflow,
740 converter: &converter,
741 };
742
743 let empty_payloads: Vec<Payload> = vec![];
744 let result: Result<(), _> = converter.from_payloads(&ctx, empty_payloads);
745
746 assert!(result.is_ok(), "Empty payloads should deserialize as ()");
747 }
748
749 #[test]
750 fn multi_args_round_trip() {
751 let converter = PayloadConverter::default();
752 let ctx = SerializationContext {
753 data: &SerializationContextData::Workflow,
754 converter: &converter,
755 };
756
757 let args = MultiArgs2("hello".to_string(), 42i32);
758 let payloads = converter.to_payloads(&ctx, &args).unwrap();
759 assert_eq!(payloads.len(), 2);
760
761 let result: MultiArgs2<String, i32> = converter.from_payloads(&ctx, payloads).unwrap();
762 assert_eq!(result, args);
763 }
764
765 #[test]
766 fn multi_args_from_tuple() {
767 let args: MultiArgs2<String, i32> = ("hello".to_string(), 42i32).into();
768 assert_eq!(args, MultiArgs2("hello".to_string(), 42));
769 }
770}