1use crate::{
8 data_converters::{PayloadCodec, SerializationContextData},
9 protos::temporal::api::common::v1::{Payload, Payloads},
10};
11use futures::future::BoxFuture;
12
13pub struct PayloadField<'a> {
16 pub path: &'static str,
19 pub data: PayloadFieldData<'a>,
21}
22
23pub enum PayloadFieldData<'a> {
25 Single(&'a mut Payload),
27 Repeated(&'a mut Vec<Payload>),
29 Payloads(&'a mut Payloads),
31}
32
33pub trait AsyncPayloadVisitor {
35 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()>;
37}
38
39pub trait PayloadVisitable: Send {
42 fn visit_payloads_mut<'a>(
45 &'a mut self,
46 visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
47 ) -> BoxFuture<'a, ()>;
48}
49
50fn is_search_attributes_path(path: &str) -> bool {
53 path.contains("SearchAttributes.indexed_fields")
55}
56
57fn should_encode(path: &str) -> bool {
58 !is_search_attributes_path(path)
59}
60
61pub struct EncodeVisitor<'a> {
63 codec: &'a (dyn PayloadCodec + Send + Sync),
64 context: &'a SerializationContextData,
65}
66
67impl AsyncPayloadVisitor for EncodeVisitor<'_> {
68 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
69 Box::pin(async move {
70 if !should_encode(field.path) {
71 return;
72 }
73 match field.data {
74 PayloadFieldData::Single(payload) => {
75 let encoded = self
76 .codec
77 .encode(self.context, vec![std::mem::take(payload)])
78 .await;
79 if let Some(p) = encoded.into_iter().next() {
80 *payload = p;
81 }
82 }
83 PayloadFieldData::Repeated(payloads) => {
84 *payloads = self
85 .codec
86 .encode(self.context, std::mem::take(payloads))
87 .await;
88 }
89 PayloadFieldData::Payloads(payloads_msg) => {
90 payloads_msg.payloads = self
91 .codec
92 .encode(self.context, std::mem::take(&mut payloads_msg.payloads))
93 .await;
94 }
95 }
96 })
97 }
98}
99
100pub struct DecodeVisitor<'a> {
102 codec: &'a (dyn PayloadCodec + Send + Sync),
103 context: &'a SerializationContextData,
104}
105
106impl AsyncPayloadVisitor for DecodeVisitor<'_> {
107 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
108 Box::pin(async move {
109 if !should_encode(field.path) {
110 return;
111 }
112 match field.data {
113 PayloadFieldData::Single(payload) => {
114 let decoded = self
115 .codec
116 .decode(self.context, vec![std::mem::take(payload)])
117 .await;
118 if let Some(p) = decoded.into_iter().next() {
119 *payload = p;
120 }
121 }
122 PayloadFieldData::Repeated(payloads) => {
123 *payloads = self
124 .codec
125 .decode(self.context, std::mem::take(payloads))
126 .await;
127 }
128 PayloadFieldData::Payloads(payloads_msg) => {
129 payloads_msg.payloads = self
130 .codec
131 .decode(self.context, std::mem::take(&mut payloads_msg.payloads))
132 .await;
133 }
134 }
135 })
136 }
137}
138
139pub async fn encode_payloads<M: PayloadVisitable + Send>(
141 msg: &mut M,
142 codec: &(dyn PayloadCodec + Send + Sync),
143 context: &SerializationContextData,
144) {
145 let mut visitor = EncodeVisitor { codec, context };
146 msg.visit_payloads_mut(&mut visitor).await;
147}
148
149pub async fn decode_payloads<M: PayloadVisitable + Send>(
151 msg: &mut M,
152 codec: &(dyn PayloadCodec + Send + Sync),
153 context: &SerializationContextData,
154) {
155 let mut visitor = DecodeVisitor { codec, context };
156 msg.visit_payloads_mut(&mut visitor).await;
157}
158
159impl PayloadVisitable for Payload {
161 fn visit_payloads_mut<'a>(
162 &'a mut self,
163 visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
164 ) -> BoxFuture<'a, ()> {
165 Box::pin(async move {
166 visitor
167 .visit(PayloadField {
168 path: "temporal.api.common.v1.Payload",
169 data: PayloadFieldData::Single(self),
170 })
171 .await;
172 })
173 }
174}
175
176impl PayloadVisitable for Payloads {
178 fn visit_payloads_mut<'a>(
179 &'a mut self,
180 visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
181 ) -> BoxFuture<'a, ()> {
182 Box::pin(async move {
183 visitor
184 .visit(PayloadField {
185 path: "temporal.api.common.v1.Payloads",
186 data: PayloadFieldData::Payloads(self),
187 })
188 .await;
189 })
190 }
191}
192
193include!(concat!(env!("OUT_DIR"), "/payload_visitor_impl.rs"));
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::{
200 data_converters::{DefaultFailureConverter, FailureConverter, PayloadConverter},
201 error::{ApplicationFailure, OutgoingError, OutgoingWorkflowError},
202 protos::{
203 coresdk::{
204 activity_result::{
205 ActivityResolution, Success, activity_resolution::Status as ActivityStatus,
206 },
207 workflow_activation::{
208 InitializeWorkflow, ResolveActivity, WorkflowActivation, WorkflowActivationJob,
209 workflow_activation_job::Variant,
210 },
211 workflow_commands::{
212 ContinueAsNewWorkflowExecution, ScheduleActivity, StartChildWorkflowExecution,
213 UpsertWorkflowSearchAttributes, WorkflowCommand,
214 workflow_command::Variant as CmdVariant,
215 },
216 workflow_completion::{
217 WorkflowActivationCompletion, workflow_activation_completion::Status,
218 },
219 },
220 temporal::api::{
221 common::v1::{Memo, SearchAttributes},
222 failure::v1::failure::FailureInfo,
223 workflow::v1::WorkflowExecutionInfo,
224 workflowservice::v1::DescribeWorkflowExecutionResponse,
225 },
226 },
227 };
228 use futures::FutureExt;
229 use std::collections::HashMap;
230
231 struct MarkingCodec;
232 impl PayloadCodec for MarkingCodec {
233 fn encode(
234 &self,
235 _: &SerializationContextData,
236 payloads: Vec<Payload>,
237 ) -> BoxFuture<'static, Vec<Payload>> {
238 async move {
239 payloads
240 .into_iter()
241 .map(|mut p| {
242 p.metadata.insert("encoded".to_string(), b"true".to_vec());
243 p
244 })
245 .collect()
246 }
247 .boxed()
248 }
249
250 fn decode(
251 &self,
252 _: &SerializationContextData,
253 payloads: Vec<Payload>,
254 ) -> BoxFuture<'static, Vec<Payload>> {
255 async move {
256 payloads
257 .into_iter()
258 .map(|mut p| {
259 p.metadata.insert("decoded".to_string(), b"true".to_vec());
260 p
261 })
262 .collect()
263 }
264 .boxed()
265 }
266 }
267
268 struct PathRecordingVisitor {
269 visited_paths: Vec<String>,
270 }
271 impl PathRecordingVisitor {
272 fn new() -> Self {
273 Self {
274 visited_paths: Vec::new(),
275 }
276 }
277
278 fn paths(&self) -> Vec<String> {
279 self.visited_paths.clone()
280 }
281 }
282
283 impl AsyncPayloadVisitor for PathRecordingVisitor {
284 fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
285 let path = field.path.to_string();
286 self.visited_paths.push(path);
287 async move {}.boxed()
288 }
289 }
290
291 fn make_payload(data: &str) -> Payload {
292 Payload {
293 metadata: HashMap::new(),
294 data: data.as_bytes().to_vec(),
295 external_payloads: vec![],
296 }
297 }
298
299 fn is_encoded(p: &Payload) -> bool {
300 p.metadata.contains_key("encoded")
301 }
302
303 fn is_decoded(p: &Payload) -> bool {
304 p.metadata.contains_key("decoded")
305 }
306
307 #[tokio::test]
308 async fn test_direct_visitor_records_paths() {
309 let mut activation = WorkflowActivation {
310 run_id: "test-run".to_string(),
311 jobs: vec![WorkflowActivationJob {
312 variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
313 workflow_type: "test-workflow".to_string(),
314 arguments: vec![make_payload("input1")],
315 headers: {
316 let mut h = HashMap::new();
317 h.insert("header-key".to_string(), make_payload("header-value"));
318 h
319 },
320 memo: Some(Memo {
321 fields: {
322 let mut m = HashMap::new();
323 m.insert("memo-key".to_string(), make_payload("memo-value"));
324 m
325 },
326 }),
327 ..Default::default()
328 })),
329 }],
330 ..Default::default()
331 };
332
333 let mut visitor = PathRecordingVisitor::new();
334 activation.visit_payloads_mut(&mut visitor).await;
335
336 let paths = visitor.paths();
337 assert!(
338 paths
339 .iter()
340 .any(|p| p.contains("InitializeWorkflow.arguments")),
341 "should visit arguments, got: {:?}",
342 paths
343 );
344 assert!(
345 paths
346 .iter()
347 .any(|p| p.contains("InitializeWorkflow.headers")),
348 "should visit headers, got: {:?}",
349 paths
350 );
351 assert!(
352 paths.iter().any(|p| p.contains("Memo.fields")),
353 "should visit memo fields, got: {:?}",
354 paths
355 );
356 }
357
358 #[tokio::test]
359 async fn test_encode_workflow_activation_completion_with_schedule_activity() {
360 let mut completion = WorkflowActivationCompletion {
361 run_id: "test-run".to_string(),
362 status: Some(Status::Successful(
363 crate::protos::coresdk::workflow_completion::Success {
364 commands: vec![WorkflowCommand {
365 variant: Some(CmdVariant::ScheduleActivity(ScheduleActivity {
366 activity_id: "act-1".to_string(),
367 activity_type: "test-activity".to_string(),
368 arguments: vec![make_payload("arg1"), make_payload("arg2")],
369 headers: {
370 let mut h = HashMap::new();
371 h.insert("header-key".to_string(), make_payload("header-value"));
372 h
373 },
374 ..Default::default()
375 })),
376 user_metadata: None,
377 }],
378 ..Default::default()
379 },
380 )),
381 };
382
383 encode_payloads(
384 &mut completion,
385 &MarkingCodec,
386 &SerializationContextData::Workflow,
387 )
388 .await;
389
390 let status = completion.status.as_ref().unwrap();
391 let Status::Successful(success) = status else {
392 panic!("Expected successful status")
393 };
394 let cmd = &success.commands[0];
395 let CmdVariant::ScheduleActivity(schedule) = cmd.variant.as_ref().unwrap() else {
396 panic!("Expected ScheduleActivity")
397 };
398
399 assert!(is_encoded(&schedule.arguments[0]), "arg1 should be encoded");
400 assert!(is_encoded(&schedule.arguments[1]), "arg2 should be encoded");
401 assert!(
402 is_encoded(schedule.headers.get("header-key").unwrap()),
403 "header should be encoded"
404 );
405 }
406
407 #[tokio::test]
408 async fn test_decode_workflow_activation_with_initialize() {
409 let mut activation = WorkflowActivation {
410 run_id: "test-run".to_string(),
411 jobs: vec![WorkflowActivationJob {
412 variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
413 workflow_type: "test-workflow".to_string(),
414 arguments: vec![make_payload("input1"), make_payload("input2")],
415 headers: {
416 let mut h = HashMap::new();
417 h.insert("header-key".to_string(), make_payload("header-value"));
418 h
419 },
420 ..Default::default()
421 })),
422 }],
423 ..Default::default()
424 };
425
426 decode_payloads(
427 &mut activation,
428 &MarkingCodec,
429 &SerializationContextData::Workflow,
430 )
431 .await;
432
433 let job = &activation.jobs[0];
434 let Variant::InitializeWorkflow(init) = job.variant.as_ref().unwrap() else {
435 panic!("Expected InitializeWorkflow")
436 };
437
438 assert!(is_decoded(&init.arguments[0]), "arg1 should be decoded");
439 assert!(is_decoded(&init.arguments[1]), "arg2 should be decoded");
440 assert!(
441 is_decoded(init.headers.get("header-key").unwrap()),
442 "header should be decoded"
443 );
444 }
445
446 #[tokio::test]
447 async fn test_decode_workflow_activation_with_resolve_activity() {
448 let mut activation = WorkflowActivation {
449 run_id: "test-run".to_string(),
450 jobs: vec![WorkflowActivationJob {
451 variant: Some(Variant::ResolveActivity(ResolveActivity {
452 seq: 1,
453 result: Some(ActivityResolution {
454 status: Some(ActivityStatus::Completed(Success {
455 result: Some(make_payload("activity-result")),
456 })),
457 }),
458 ..Default::default()
459 })),
460 }],
461 ..Default::default()
462 };
463
464 decode_payloads(
465 &mut activation,
466 &MarkingCodec,
467 &SerializationContextData::Workflow,
468 )
469 .await;
470
471 let job = &activation.jobs[0];
472 let Variant::ResolveActivity(resolve) = job.variant.as_ref().unwrap() else {
473 panic!("Expected ResolveActivity")
474 };
475 let ActivityStatus::Completed(success) =
476 resolve.result.as_ref().unwrap().status.as_ref().unwrap()
477 else {
478 panic!("Expected Completed status")
479 };
480
481 assert!(
482 is_decoded(success.result.as_ref().unwrap()),
483 "activity result should be decoded"
484 );
485 }
486
487 #[tokio::test]
488 async fn test_search_attributes_skipped_on_encode() {
489 let mut completion = WorkflowActivationCompletion {
491 run_id: "test-run".to_string(),
492 status: Some(Status::Successful(
493 crate::protos::coresdk::workflow_completion::Success {
494 commands: vec![
495 WorkflowCommand {
497 variant: Some(CmdVariant::UpsertWorkflowSearchAttributes(
498 UpsertWorkflowSearchAttributes {
499 search_attributes: Some(SearchAttributes {
500 indexed_fields: {
501 let mut sa = HashMap::new();
502 sa.insert(
503 "CustomField".to_string(),
504 make_payload("search-value"),
505 );
506 sa
507 },
508 }),
509 },
510 )),
511 user_metadata: None,
512 },
513 WorkflowCommand {
515 variant: Some(CmdVariant::ContinueAsNewWorkflowExecution(
516 ContinueAsNewWorkflowExecution {
517 arguments: vec![make_payload("continue-arg")],
518 search_attributes: Some(SearchAttributes {
519 indexed_fields: {
520 let mut sa = HashMap::new();
521 sa.insert(
522 "CustomField".to_string(),
523 make_payload("continue-search-value"),
524 );
525 sa
526 },
527 }),
528 ..Default::default()
529 },
530 )),
531 user_metadata: None,
532 },
533 WorkflowCommand {
535 variant: Some(CmdVariant::StartChildWorkflowExecution(
536 StartChildWorkflowExecution {
537 seq: 1,
538 workflow_type: "child-workflow".to_string(),
539 input: vec![make_payload("child-arg")],
540 search_attributes: Some(SearchAttributes {
541 indexed_fields: {
542 let mut sa = HashMap::new();
543 sa.insert(
544 "CustomField".to_string(),
545 make_payload("child-search-value"),
546 );
547 sa
548 },
549 }),
550 ..Default::default()
551 },
552 )),
553 user_metadata: None,
554 },
555 ],
556 ..Default::default()
557 },
558 )),
559 };
560
561 encode_payloads(
562 &mut completion,
563 &MarkingCodec,
564 &SerializationContextData::Workflow,
565 )
566 .await;
567
568 let status = completion.status.as_ref().unwrap();
569 let Status::Successful(success) = status else {
570 panic!("Expected successful status")
571 };
572
573 let CmdVariant::UpsertWorkflowSearchAttributes(upsert) =
575 success.commands[0].variant.as_ref().unwrap()
576 else {
577 panic!("Expected UpsertWorkflowSearchAttributes")
578 };
579 let sa = upsert.search_attributes.as_ref().unwrap();
580 assert!(
581 !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
582 "search attributes should NOT be encoded"
583 );
584
585 let CmdVariant::ContinueAsNewWorkflowExecution(continue_as_new) =
587 success.commands[1].variant.as_ref().unwrap()
588 else {
589 panic!("Expected ContinueAsNewWorkflowExecution")
590 };
591 assert!(
592 is_encoded(&continue_as_new.arguments[0]),
593 "arguments should be encoded"
594 );
595 let sa = continue_as_new.search_attributes.as_ref().unwrap();
596 assert!(
597 !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
598 "search attributes should NOT be encoded"
599 );
600
601 let CmdVariant::StartChildWorkflowExecution(start_child) =
603 success.commands[2].variant.as_ref().unwrap()
604 else {
605 panic!("Expected StartChildWorkflowExecution")
606 };
607 assert!(is_encoded(&start_child.input[0]), "input should be encoded");
608 let sa = start_child.search_attributes.as_ref().unwrap();
609 assert!(
610 !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
611 "search attributes should NOT be encoded"
612 );
613 }
614
615 #[tokio::test]
616 async fn test_search_attributes_skipped_on_decode() {
617 let mut response = DescribeWorkflowExecutionResponse {
618 workflow_execution_info: Some(WorkflowExecutionInfo {
619 memo: Some(Memo {
620 fields: {
621 let mut memo = HashMap::new();
622 memo.insert("tracked".to_string(), make_payload("memo-value"));
623 memo
624 },
625 }),
626 search_attributes: Some(SearchAttributes {
627 indexed_fields: {
628 let mut sa = HashMap::new();
629 sa.insert("CustomField".to_string(), make_payload("search-value"));
630 sa
631 },
632 }),
633 ..Default::default()
634 }),
635 ..Default::default()
636 };
637
638 decode_payloads(
639 &mut response,
640 &MarkingCodec,
641 &SerializationContextData::Workflow,
642 )
643 .await;
644
645 let info = response.workflow_execution_info.as_ref().unwrap();
646 assert!(
647 is_decoded(info.memo.as_ref().unwrap().fields.get("tracked").unwrap()),
648 "memo should be decoded"
649 );
650 assert!(
651 !is_decoded(
652 info.search_attributes
653 .as_ref()
654 .unwrap()
655 .indexed_fields
656 .get("CustomField")
657 .unwrap()
658 ),
659 "search attributes should NOT be decoded"
660 );
661 }
662
663 #[tokio::test]
664 async fn test_encode_single_payload() {
665 let mut payload = make_payload("test-data");
666
667 encode_payloads(
668 &mut payload,
669 &MarkingCodec,
670 &SerializationContextData::Workflow,
671 )
672 .await;
673
674 assert!(is_encoded(&payload), "single payload should be encoded");
675 }
676
677 #[tokio::test]
678 async fn test_decode_single_payload() {
679 let mut payload = make_payload("test-data");
680
681 decode_payloads(
682 &mut payload,
683 &MarkingCodec,
684 &SerializationContextData::Workflow,
685 )
686 .await;
687
688 assert!(is_decoded(&payload), "single payload should be decoded");
689 }
690
691 #[tokio::test]
692 async fn test_encode_payloads_message() {
693 let mut payloads = Payloads {
694 payloads: vec![make_payload("p1"), make_payload("p2"), make_payload("p3")],
695 };
696
697 encode_payloads(
698 &mut payloads,
699 &MarkingCodec,
700 &SerializationContextData::Workflow,
701 )
702 .await;
703
704 for (i, p) in payloads.payloads.iter().enumerate() {
705 assert!(is_encoded(p), "payload {} should be encoded", i);
706 }
707 }
708
709 #[tokio::test]
710 async fn test_encode_failure_encodes_application_failure_details() {
711 let mut failure = DefaultFailureConverter.to_failure(
712 OutgoingError::Workflow(OutgoingWorkflowError::Application(Box::new(
713 ApplicationFailure::builder(anyhow::anyhow!("app boom"))
714 .details(crate::data_converters::RawValue::new(vec![make_payload(
715 "detail",
716 )]))
717 .build(),
718 ))),
719 &PayloadConverter::default(),
720 &SerializationContextData::Workflow,
721 );
722
723 encode_payloads(
724 &mut failure,
725 &MarkingCodec,
726 &SerializationContextData::Workflow,
727 )
728 .await;
729
730 let Some(FailureInfo::ApplicationFailureInfo(info)) = failure.failure_info else {
731 panic!("expected application failure info")
732 };
733 assert!(is_encoded(&info.details.unwrap().payloads[0]));
734 }
735}