Skip to main content

temporalio_common/
payload_visitor.rs

1//! Payload visitor infrastructure for applying codecs to proto messages.
2//!
3//! This module provides a visitor pattern for traversing proto messages and transforming
4//! payload fields. It's used to apply `PayloadCodec` encode/decode operations at the
5//! boundary between the SDK and Core.
6
7use crate::{
8    data_converters::{PayloadCodec, SerializationContextData},
9    protos::temporal::api::common::v1::{Payload, Payloads},
10};
11use futures::future::BoxFuture;
12
13/// Represents a payload field in a proto message.
14/// Payloads within the same field may be processed together by the codec.
15pub struct PayloadField<'a> {
16    /// The fully-qualified field path (e.g.,
17    /// `coresdk.workflow_commands.ScheduleActivity.arguments`)
18    pub path: &'static str,
19    /// The payload data
20    pub data: PayloadFieldData<'a>,
21}
22
23/// The payload data within a field, varying by field type.
24pub enum PayloadFieldData<'a> {
25    /// A singular [Payload] field
26    Single(&'a mut Payload),
27    /// A repeated [Payload] field
28    Repeated(&'a mut Vec<Payload>),
29    /// A [Payloads] message field
30    Payloads(&'a mut Payloads),
31}
32
33/// Async visitor for transforming payload fields.
34pub trait AsyncPayloadVisitor {
35    /// Visit a payload field, potentially transforming it.
36    fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()>;
37}
38
39/// Trait for messages that contain Payload fields (directly or transitively).
40/// Generated via codegen for all relevant proto message types.
41pub trait PayloadVisitable: Send {
42    /// Visit all payload fields in this message.
43    /// The visitor is called once per field, receiving the field's payload(s).
44    fn visit_payloads_mut<'a>(
45        &'a mut self,
46        visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
47    ) -> BoxFuture<'a, ()>;
48}
49
50/// Check if a field path represents search attributes that should not be encoded.
51/// Search attributes must remain server-readable for indexing.
52fn is_search_attributes_path(path: &str) -> bool {
53    // All search attributes go through the SearchAttributes message which has indexed_fields
54    path.contains("SearchAttributes.indexed_fields")
55}
56
57fn should_encode(path: &str) -> bool {
58    !is_search_attributes_path(path)
59}
60
61/// Visitor that encodes payloads using a codec.
62pub struct EncodeVisitor<'a> {
63    codec: &'a (dyn PayloadCodec + Send + Sync),
64    context: &'a SerializationContextData,
65}
66
67impl AsyncPayloadVisitor for EncodeVisitor<'_> {
68    fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
69        Box::pin(async move {
70            if !should_encode(field.path) {
71                return;
72            }
73            match field.data {
74                PayloadFieldData::Single(payload) => {
75                    let encoded = self
76                        .codec
77                        .encode(self.context, vec![std::mem::take(payload)])
78                        .await;
79                    if let Some(p) = encoded.into_iter().next() {
80                        *payload = p;
81                    }
82                }
83                PayloadFieldData::Repeated(payloads) => {
84                    *payloads = self
85                        .codec
86                        .encode(self.context, std::mem::take(payloads))
87                        .await;
88                }
89                PayloadFieldData::Payloads(payloads_msg) => {
90                    payloads_msg.payloads = self
91                        .codec
92                        .encode(self.context, std::mem::take(&mut payloads_msg.payloads))
93                        .await;
94                }
95            }
96        })
97    }
98}
99
100/// Visitor that decodes payloads using a codec.
101pub struct DecodeVisitor<'a> {
102    codec: &'a (dyn PayloadCodec + Send + Sync),
103    context: &'a SerializationContextData,
104}
105
106impl AsyncPayloadVisitor for DecodeVisitor<'_> {
107    fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
108        Box::pin(async move {
109            if !should_encode(field.path) {
110                return;
111            }
112            match field.data {
113                PayloadFieldData::Single(payload) => {
114                    let decoded = self
115                        .codec
116                        .decode(self.context, vec![std::mem::take(payload)])
117                        .await;
118                    if let Some(p) = decoded.into_iter().next() {
119                        *payload = p;
120                    }
121                }
122                PayloadFieldData::Repeated(payloads) => {
123                    *payloads = self
124                        .codec
125                        .decode(self.context, std::mem::take(payloads))
126                        .await;
127                }
128                PayloadFieldData::Payloads(payloads_msg) => {
129                    payloads_msg.payloads = self
130                        .codec
131                        .decode(self.context, std::mem::take(&mut payloads_msg.payloads))
132                        .await;
133                }
134            }
135        })
136    }
137}
138
139/// Encode all payloads in a message using the given codec.
140pub async fn encode_payloads<M: PayloadVisitable + Send>(
141    msg: &mut M,
142    codec: &(dyn PayloadCodec + Send + Sync),
143    context: &SerializationContextData,
144) {
145    let mut visitor = EncodeVisitor { codec, context };
146    msg.visit_payloads_mut(&mut visitor).await;
147}
148
149/// Decode all payloads in a message using the given codec.
150pub async fn decode_payloads<M: PayloadVisitable + Send>(
151    msg: &mut M,
152    codec: &(dyn PayloadCodec + Send + Sync),
153    context: &SerializationContextData,
154) {
155    let mut visitor = DecodeVisitor { codec, context };
156    msg.visit_payloads_mut(&mut visitor).await;
157}
158
159// Manual impl for Payload - visits itself as a single payload
160impl PayloadVisitable for Payload {
161    fn visit_payloads_mut<'a>(
162        &'a mut self,
163        visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
164    ) -> BoxFuture<'a, ()> {
165        Box::pin(async move {
166            visitor
167                .visit(PayloadField {
168                    path: "temporal.api.common.v1.Payload",
169                    data: PayloadFieldData::Single(self),
170                })
171                .await;
172        })
173    }
174}
175
176// Manual impl for Payloads - visits itself as a Payloads field
177impl PayloadVisitable for Payloads {
178    fn visit_payloads_mut<'a>(
179        &'a mut self,
180        visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
181    ) -> BoxFuture<'a, ()> {
182        Box::pin(async move {
183            visitor
184                .visit(PayloadField {
185                    path: "temporal.api.common.v1.Payloads",
186                    data: PayloadFieldData::Payloads(self),
187                })
188                .await;
189        })
190    }
191}
192
193// Include the generated PayloadVisitable implementations
194include!(concat!(env!("OUT_DIR"), "/payload_visitor_impl.rs"));
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use crate::{
200        data_converters::{DefaultFailureConverter, FailureConverter, PayloadConverter},
201        error::{ApplicationFailure, OutgoingError, OutgoingWorkflowError},
202        protos::{
203            coresdk::{
204                activity_result::{
205                    ActivityResolution, Success, activity_resolution::Status as ActivityStatus,
206                },
207                workflow_activation::{
208                    InitializeWorkflow, ResolveActivity, WorkflowActivation, WorkflowActivationJob,
209                    workflow_activation_job::Variant,
210                },
211                workflow_commands::{
212                    ContinueAsNewWorkflowExecution, ScheduleActivity, StartChildWorkflowExecution,
213                    UpsertWorkflowSearchAttributes, WorkflowCommand,
214                    workflow_command::Variant as CmdVariant,
215                },
216                workflow_completion::{
217                    WorkflowActivationCompletion, workflow_activation_completion::Status,
218                },
219            },
220            temporal::api::{
221                common::v1::{Memo, SearchAttributes},
222                failure::v1::failure::FailureInfo,
223                workflow::v1::WorkflowExecutionInfo,
224                workflowservice::v1::DescribeWorkflowExecutionResponse,
225            },
226        },
227    };
228    use futures::FutureExt;
229    use std::collections::HashMap;
230
231    struct MarkingCodec;
232    impl PayloadCodec for MarkingCodec {
233        fn encode(
234            &self,
235            _: &SerializationContextData,
236            payloads: Vec<Payload>,
237        ) -> BoxFuture<'static, Vec<Payload>> {
238            async move {
239                payloads
240                    .into_iter()
241                    .map(|mut p| {
242                        p.metadata.insert("encoded".to_string(), b"true".to_vec());
243                        p
244                    })
245                    .collect()
246            }
247            .boxed()
248        }
249
250        fn decode(
251            &self,
252            _: &SerializationContextData,
253            payloads: Vec<Payload>,
254        ) -> BoxFuture<'static, Vec<Payload>> {
255            async move {
256                payloads
257                    .into_iter()
258                    .map(|mut p| {
259                        p.metadata.insert("decoded".to_string(), b"true".to_vec());
260                        p
261                    })
262                    .collect()
263            }
264            .boxed()
265        }
266    }
267
268    struct PathRecordingVisitor {
269        visited_paths: Vec<String>,
270    }
271    impl PathRecordingVisitor {
272        fn new() -> Self {
273            Self {
274                visited_paths: Vec::new(),
275            }
276        }
277
278        fn paths(&self) -> Vec<String> {
279            self.visited_paths.clone()
280        }
281    }
282
283    impl AsyncPayloadVisitor for PathRecordingVisitor {
284        fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
285            let path = field.path.to_string();
286            self.visited_paths.push(path);
287            async move {}.boxed()
288        }
289    }
290
291    fn make_payload(data: &str) -> Payload {
292        Payload {
293            metadata: HashMap::new(),
294            data: data.as_bytes().to_vec(),
295            external_payloads: vec![],
296        }
297    }
298
299    fn is_encoded(p: &Payload) -> bool {
300        p.metadata.contains_key("encoded")
301    }
302
303    fn is_decoded(p: &Payload) -> bool {
304        p.metadata.contains_key("decoded")
305    }
306
307    #[tokio::test]
308    async fn test_direct_visitor_records_paths() {
309        let mut activation = WorkflowActivation {
310            run_id: "test-run".to_string(),
311            jobs: vec![WorkflowActivationJob {
312                variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
313                    workflow_type: "test-workflow".to_string(),
314                    arguments: vec![make_payload("input1")],
315                    headers: {
316                        let mut h = HashMap::new();
317                        h.insert("header-key".to_string(), make_payload("header-value"));
318                        h
319                    },
320                    memo: Some(Memo {
321                        fields: {
322                            let mut m = HashMap::new();
323                            m.insert("memo-key".to_string(), make_payload("memo-value"));
324                            m
325                        },
326                    }),
327                    ..Default::default()
328                })),
329            }],
330            ..Default::default()
331        };
332
333        let mut visitor = PathRecordingVisitor::new();
334        activation.visit_payloads_mut(&mut visitor).await;
335
336        let paths = visitor.paths();
337        assert!(
338            paths
339                .iter()
340                .any(|p| p.contains("InitializeWorkflow.arguments")),
341            "should visit arguments, got: {:?}",
342            paths
343        );
344        assert!(
345            paths
346                .iter()
347                .any(|p| p.contains("InitializeWorkflow.headers")),
348            "should visit headers, got: {:?}",
349            paths
350        );
351        assert!(
352            paths.iter().any(|p| p.contains("Memo.fields")),
353            "should visit memo fields, got: {:?}",
354            paths
355        );
356    }
357
358    #[tokio::test]
359    async fn test_encode_workflow_activation_completion_with_schedule_activity() {
360        let mut completion = WorkflowActivationCompletion {
361            run_id: "test-run".to_string(),
362            status: Some(Status::Successful(
363                crate::protos::coresdk::workflow_completion::Success {
364                    commands: vec![WorkflowCommand {
365                        variant: Some(CmdVariant::ScheduleActivity(ScheduleActivity {
366                            activity_id: "act-1".to_string(),
367                            activity_type: "test-activity".to_string(),
368                            arguments: vec![make_payload("arg1"), make_payload("arg2")],
369                            headers: {
370                                let mut h = HashMap::new();
371                                h.insert("header-key".to_string(), make_payload("header-value"));
372                                h
373                            },
374                            ..Default::default()
375                        })),
376                        user_metadata: None,
377                    }],
378                    ..Default::default()
379                },
380            )),
381        };
382
383        encode_payloads(
384            &mut completion,
385            &MarkingCodec,
386            &SerializationContextData::Workflow,
387        )
388        .await;
389
390        let status = completion.status.as_ref().unwrap();
391        let Status::Successful(success) = status else {
392            panic!("Expected successful status")
393        };
394        let cmd = &success.commands[0];
395        let CmdVariant::ScheduleActivity(schedule) = cmd.variant.as_ref().unwrap() else {
396            panic!("Expected ScheduleActivity")
397        };
398
399        assert!(is_encoded(&schedule.arguments[0]), "arg1 should be encoded");
400        assert!(is_encoded(&schedule.arguments[1]), "arg2 should be encoded");
401        assert!(
402            is_encoded(schedule.headers.get("header-key").unwrap()),
403            "header should be encoded"
404        );
405    }
406
407    #[tokio::test]
408    async fn test_decode_workflow_activation_with_initialize() {
409        let mut activation = WorkflowActivation {
410            run_id: "test-run".to_string(),
411            jobs: vec![WorkflowActivationJob {
412                variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
413                    workflow_type: "test-workflow".to_string(),
414                    arguments: vec![make_payload("input1"), make_payload("input2")],
415                    headers: {
416                        let mut h = HashMap::new();
417                        h.insert("header-key".to_string(), make_payload("header-value"));
418                        h
419                    },
420                    ..Default::default()
421                })),
422            }],
423            ..Default::default()
424        };
425
426        decode_payloads(
427            &mut activation,
428            &MarkingCodec,
429            &SerializationContextData::Workflow,
430        )
431        .await;
432
433        let job = &activation.jobs[0];
434        let Variant::InitializeWorkflow(init) = job.variant.as_ref().unwrap() else {
435            panic!("Expected InitializeWorkflow")
436        };
437
438        assert!(is_decoded(&init.arguments[0]), "arg1 should be decoded");
439        assert!(is_decoded(&init.arguments[1]), "arg2 should be decoded");
440        assert!(
441            is_decoded(init.headers.get("header-key").unwrap()),
442            "header should be decoded"
443        );
444    }
445
446    #[tokio::test]
447    async fn test_decode_workflow_activation_with_resolve_activity() {
448        let mut activation = WorkflowActivation {
449            run_id: "test-run".to_string(),
450            jobs: vec![WorkflowActivationJob {
451                variant: Some(Variant::ResolveActivity(ResolveActivity {
452                    seq: 1,
453                    result: Some(ActivityResolution {
454                        status: Some(ActivityStatus::Completed(Success {
455                            result: Some(make_payload("activity-result")),
456                        })),
457                    }),
458                    ..Default::default()
459                })),
460            }],
461            ..Default::default()
462        };
463
464        decode_payloads(
465            &mut activation,
466            &MarkingCodec,
467            &SerializationContextData::Workflow,
468        )
469        .await;
470
471        let job = &activation.jobs[0];
472        let Variant::ResolveActivity(resolve) = job.variant.as_ref().unwrap() else {
473            panic!("Expected ResolveActivity")
474        };
475        let ActivityStatus::Completed(success) =
476            resolve.result.as_ref().unwrap().status.as_ref().unwrap()
477        else {
478            panic!("Expected Completed status")
479        };
480
481        assert!(
482            is_decoded(success.result.as_ref().unwrap()),
483            "activity result should be decoded"
484        );
485    }
486
487    #[tokio::test]
488    async fn test_search_attributes_skipped_on_encode() {
489        // Test that search attributes are NOT encoded (they must remain server-readable)
490        let mut completion = WorkflowActivationCompletion {
491            run_id: "test-run".to_string(),
492            status: Some(Status::Successful(
493                crate::protos::coresdk::workflow_completion::Success {
494                    commands: vec![
495                        // UpsertWorkflowSearchAttributes command
496                        WorkflowCommand {
497                            variant: Some(CmdVariant::UpsertWorkflowSearchAttributes(
498                                UpsertWorkflowSearchAttributes {
499                                    search_attributes: Some(SearchAttributes {
500                                        indexed_fields: {
501                                            let mut sa = HashMap::new();
502                                            sa.insert(
503                                                "CustomField".to_string(),
504                                                make_payload("search-value"),
505                                            );
506                                            sa
507                                        },
508                                    }),
509                                },
510                            )),
511                            user_metadata: None,
512                        },
513                        // ContinueAsNewWorkflowExecution command
514                        WorkflowCommand {
515                            variant: Some(CmdVariant::ContinueAsNewWorkflowExecution(
516                                ContinueAsNewWorkflowExecution {
517                                    arguments: vec![make_payload("continue-arg")],
518                                    search_attributes: Some(SearchAttributes {
519                                        indexed_fields: {
520                                            let mut sa = HashMap::new();
521                                            sa.insert(
522                                                "CustomField".to_string(),
523                                                make_payload("continue-search-value"),
524                                            );
525                                            sa
526                                        },
527                                    }),
528                                    ..Default::default()
529                                },
530                            )),
531                            user_metadata: None,
532                        },
533                        // StartChildWorkflowExecution command
534                        WorkflowCommand {
535                            variant: Some(CmdVariant::StartChildWorkflowExecution(
536                                StartChildWorkflowExecution {
537                                    seq: 1,
538                                    workflow_type: "child-workflow".to_string(),
539                                    input: vec![make_payload("child-arg")],
540                                    search_attributes: Some(SearchAttributes {
541                                        indexed_fields: {
542                                            let mut sa = HashMap::new();
543                                            sa.insert(
544                                                "CustomField".to_string(),
545                                                make_payload("child-search-value"),
546                                            );
547                                            sa
548                                        },
549                                    }),
550                                    ..Default::default()
551                                },
552                            )),
553                            user_metadata: None,
554                        },
555                    ],
556                    ..Default::default()
557                },
558            )),
559        };
560
561        encode_payloads(
562            &mut completion,
563            &MarkingCodec,
564            &SerializationContextData::Workflow,
565        )
566        .await;
567
568        let status = completion.status.as_ref().unwrap();
569        let Status::Successful(success) = status else {
570            panic!("Expected successful status")
571        };
572
573        // UpsertWorkflowSearchAttributes - search attributes should NOT be encoded
574        let CmdVariant::UpsertWorkflowSearchAttributes(upsert) =
575            success.commands[0].variant.as_ref().unwrap()
576        else {
577            panic!("Expected UpsertWorkflowSearchAttributes")
578        };
579        let sa = upsert.search_attributes.as_ref().unwrap();
580        assert!(
581            !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
582            "search attributes should NOT be encoded"
583        );
584
585        // ContinueAsNewWorkflowExecution - arguments encoded, search attributes NOT
586        let CmdVariant::ContinueAsNewWorkflowExecution(continue_as_new) =
587            success.commands[1].variant.as_ref().unwrap()
588        else {
589            panic!("Expected ContinueAsNewWorkflowExecution")
590        };
591        assert!(
592            is_encoded(&continue_as_new.arguments[0]),
593            "arguments should be encoded"
594        );
595        let sa = continue_as_new.search_attributes.as_ref().unwrap();
596        assert!(
597            !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
598            "search attributes should NOT be encoded"
599        );
600
601        // StartChildWorkflowExecution - input encoded, search attributes NOT
602        let CmdVariant::StartChildWorkflowExecution(start_child) =
603            success.commands[2].variant.as_ref().unwrap()
604        else {
605            panic!("Expected StartChildWorkflowExecution")
606        };
607        assert!(is_encoded(&start_child.input[0]), "input should be encoded");
608        let sa = start_child.search_attributes.as_ref().unwrap();
609        assert!(
610            !is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
611            "search attributes should NOT be encoded"
612        );
613    }
614
615    #[tokio::test]
616    async fn test_search_attributes_skipped_on_decode() {
617        let mut response = DescribeWorkflowExecutionResponse {
618            workflow_execution_info: Some(WorkflowExecutionInfo {
619                memo: Some(Memo {
620                    fields: {
621                        let mut memo = HashMap::new();
622                        memo.insert("tracked".to_string(), make_payload("memo-value"));
623                        memo
624                    },
625                }),
626                search_attributes: Some(SearchAttributes {
627                    indexed_fields: {
628                        let mut sa = HashMap::new();
629                        sa.insert("CustomField".to_string(), make_payload("search-value"));
630                        sa
631                    },
632                }),
633                ..Default::default()
634            }),
635            ..Default::default()
636        };
637
638        decode_payloads(
639            &mut response,
640            &MarkingCodec,
641            &SerializationContextData::Workflow,
642        )
643        .await;
644
645        let info = response.workflow_execution_info.as_ref().unwrap();
646        assert!(
647            is_decoded(info.memo.as_ref().unwrap().fields.get("tracked").unwrap()),
648            "memo should be decoded"
649        );
650        assert!(
651            !is_decoded(
652                info.search_attributes
653                    .as_ref()
654                    .unwrap()
655                    .indexed_fields
656                    .get("CustomField")
657                    .unwrap()
658            ),
659            "search attributes should NOT be decoded"
660        );
661    }
662
663    #[tokio::test]
664    async fn test_encode_single_payload() {
665        let mut payload = make_payload("test-data");
666
667        encode_payloads(
668            &mut payload,
669            &MarkingCodec,
670            &SerializationContextData::Workflow,
671        )
672        .await;
673
674        assert!(is_encoded(&payload), "single payload should be encoded");
675    }
676
677    #[tokio::test]
678    async fn test_decode_single_payload() {
679        let mut payload = make_payload("test-data");
680
681        decode_payloads(
682            &mut payload,
683            &MarkingCodec,
684            &SerializationContextData::Workflow,
685        )
686        .await;
687
688        assert!(is_decoded(&payload), "single payload should be decoded");
689    }
690
691    #[tokio::test]
692    async fn test_encode_payloads_message() {
693        let mut payloads = Payloads {
694            payloads: vec![make_payload("p1"), make_payload("p2"), make_payload("p3")],
695        };
696
697        encode_payloads(
698            &mut payloads,
699            &MarkingCodec,
700            &SerializationContextData::Workflow,
701        )
702        .await;
703
704        for (i, p) in payloads.payloads.iter().enumerate() {
705            assert!(is_encoded(p), "payload {} should be encoded", i);
706        }
707    }
708
709    #[tokio::test]
710    async fn test_encode_failure_encodes_application_failure_details() {
711        let mut failure = DefaultFailureConverter.to_failure(
712            OutgoingError::Workflow(OutgoingWorkflowError::Application(Box::new(
713                ApplicationFailure::builder(anyhow::anyhow!("app boom"))
714                    .details(crate::data_converters::RawValue::new(vec![make_payload(
715                        "detail",
716                    )]))
717                    .build(),
718            ))),
719            &PayloadConverter::default(),
720            &SerializationContextData::Workflow,
721        );
722
723        encode_payloads(
724            &mut failure,
725            &MarkingCodec,
726            &SerializationContextData::Workflow,
727        )
728        .await;
729
730        let Some(FailureInfo::ApplicationFailureInfo(info)) = failure.failure_info else {
731            panic!("expected application failure info")
732        };
733        assert!(is_encoded(&info.details.unwrap().payloads[0]));
734    }
735}