Skip to main content

turul_a2a_types/
message.rs

1use serde::{Deserialize, Serialize};
2use turul_a2a_proto as pb;
3
4/// Ergonomic wrapper over proto `Part`.
5#[derive(Debug, Clone)]
6#[non_exhaustive]
7pub struct Part {
8    pub(crate) inner: pb::Part,
9}
10
11impl Part {
12    pub fn text(text: impl Into<String>) -> Self {
13        Self {
14            inner: pb::Part {
15                content: Some(pb::part::Content::Text(text.into())),
16                metadata: None,
17                filename: String::new(),
18                media_type: "text/plain".to_string(),
19            },
20        }
21    }
22
23    pub fn url(url: impl Into<String>, media_type: impl Into<String>) -> Self {
24        Self {
25            inner: pb::Part {
26                content: Some(pb::part::Content::Url(url.into())),
27                metadata: None,
28                filename: String::new(),
29                media_type: media_type.into(),
30            },
31        }
32    }
33
34    pub fn raw(data: Vec<u8>, media_type: impl Into<String>) -> Self {
35        Self {
36            inner: pb::Part {
37                content: Some(pb::part::Content::Raw(data)),
38                metadata: None,
39                filename: String::new(),
40                media_type: media_type.into(),
41            },
42        }
43    }
44
45    pub fn data(value: serde_json::Value) -> Self {
46        Self {
47            inner: pb::Part {
48                content: Some(pb::part::Content::Data(json_to_proto_value(value))),
49                metadata: None,
50                filename: String::new(),
51                media_type: "application/json".to_string(),
52            },
53        }
54    }
55
56    pub fn with_filename(mut self, filename: impl Into<String>) -> Self {
57        self.inner.filename = filename.into();
58        self
59    }
60
61    pub fn with_media_type(mut self, media_type: impl Into<String>) -> Self {
62        self.inner.media_type = media_type.into();
63        self
64    }
65
66    /// Returns the text content if this is a text part.
67    pub fn as_text(&self) -> Option<&str> {
68        match &self.inner.content {
69            Some(pb::part::Content::Text(t)) => Some(t.as_str()),
70            _ => None,
71        }
72    }
73
74    /// Returns the URL if this is a URL part.
75    pub fn as_url(&self) -> Option<&str> {
76        match &self.inner.content {
77            Some(pb::part::Content::Url(u)) => Some(u.as_str()),
78            _ => None,
79        }
80    }
81
82    /// Returns the raw bytes if this is a raw/binary part.
83    pub fn as_raw(&self) -> Option<&[u8]> {
84        match &self.inner.content {
85            Some(pb::part::Content::Raw(r)) => Some(r.as_slice()),
86            _ => None,
87        }
88    }
89
90    /// Raw JSON view of a Data part. No number normalization.
91    /// Returns `None` if this is not a Data part.
92    pub fn as_data(&self) -> Option<serde_json::Value> {
93        match &self.inner.content {
94            Some(pb::part::Content::Data(proto_struct)) => serde_json::to_value(proto_struct).ok(),
95            _ => None,
96        }
97    }
98
99    /// Deserialize a Data part into `T`, normalizing proto f64 integers first.
100    ///
101    /// Protobuf `Value` uses f64 for all numbers, so `25544` becomes `25544.0`.
102    /// This normalizes whole-number f64s back to integers before deserializing,
103    /// so `u32`/`i32`/`u8` fields work correctly.
104    ///
105    /// Returns `None` if not a Data part. Returns `Err` if deserialization fails.
106    pub fn parse_data<T: serde::de::DeserializeOwned>(
107        &self,
108    ) -> Option<Result<T, crate::error::A2aTypeError>> {
109        let json = self.as_data()?;
110        let normalized = normalize_proto_numbers_for_deser(json);
111        Some(
112            serde_json::from_value(normalized)
113                .map_err(|e| crate::error::A2aTypeError::Deserialization(e.to_string())),
114        )
115    }
116
117    pub fn as_proto(&self) -> &pb::Part {
118        &self.inner
119    }
120
121    pub fn into_proto(self) -> pb::Part {
122        self.inner
123    }
124}
125
126/// Wrap a proto Part without validation (content may be None).
127/// Use `TryFrom` when you need to validate content is present.
128impl From<pb::Part> for Part {
129    fn from(inner: pb::Part) -> Self {
130        Self { inner }
131    }
132}
133
134impl From<Part> for pb::Part {
135    fn from(part: Part) -> Self {
136        part.inner
137    }
138}
139
140impl Serialize for Part {
141    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
142        self.inner.serialize(serializer)
143    }
144}
145
146impl<'de> Deserialize<'de> for Part {
147    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
148        pb::Part::deserialize(deserializer).map(Self::from)
149    }
150}
151
152/// Ergonomic wrapper over proto `Message`.
153#[derive(Debug, Clone)]
154#[non_exhaustive]
155pub struct Message {
156    pub(crate) inner: pb::Message,
157}
158
159/// Role of a message sender.
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161#[non_exhaustive]
162pub enum Role {
163    User,
164    Agent,
165}
166
167impl From<Role> for pb::Role {
168    fn from(role: Role) -> Self {
169        match role {
170            Role::User => pb::Role::User,
171            Role::Agent => pb::Role::Agent,
172        }
173    }
174}
175
176impl TryFrom<pb::Role> for Role {
177    type Error = crate::error::A2aTypeError;
178
179    fn try_from(value: pb::Role) -> Result<Self, Self::Error> {
180        match value {
181            pb::Role::User => Ok(Self::User),
182            pb::Role::Agent => Ok(Self::Agent),
183            pb::Role::Unspecified => Err(crate::error::A2aTypeError::InvalidState),
184        }
185    }
186}
187
188impl Message {
189    pub fn new(message_id: impl Into<String>, role: Role, parts: Vec<Part>) -> Self {
190        Self {
191            inner: pb::Message {
192                message_id: message_id.into(),
193                role: pb::Role::from(role).into(),
194                parts: parts.into_iter().map(|p| p.inner).collect(),
195                context_id: String::new(),
196                task_id: String::new(),
197                metadata: None,
198                extensions: vec![],
199                reference_task_ids: vec![],
200            },
201        }
202    }
203
204    pub fn with_context_id(mut self, context_id: impl Into<String>) -> Self {
205        self.inner.context_id = context_id.into();
206        self
207    }
208
209    pub fn with_task_id(mut self, task_id: impl Into<String>) -> Self {
210        self.inner.task_id = task_id.into();
211        self
212    }
213
214    pub fn message_id(&self) -> &str {
215        &self.inner.message_id
216    }
217
218    /// Context id, or empty string if unset. Proto default for the
219    /// `context_id` field is the empty string, not `null`.
220    pub fn context_id(&self) -> &str {
221        &self.inner.context_id
222    }
223
224    /// Task id the message is bound to, or empty string if unset.
225    /// Callers making a fresh request leave this empty; continuations
226    /// set it to the existing task's id.
227    pub fn task_id(&self) -> &str {
228        &self.inner.task_id
229    }
230
231    /// Borrow the raw `Message.metadata` proto struct. Prefer
232    /// [`Self::metadata_keys`] for the common "which correlation fields
233    /// arrived" check; drop to this accessor when you need to read
234    /// specific values.
235    pub fn metadata(&self) -> Option<&pb::pbjson_types::Struct> {
236        self.inner.metadata.as_ref()
237    }
238
239    // `pb::pbjson_types` is re-exported from `turul_a2a_proto` so
240    // adopters don't need to depend on `pbjson_types` directly.
241
242    /// Sorted list of keys present on `Message.metadata`. Returns an
243    /// empty vec if `metadata` is unset. Convenient for log lines and
244    /// demo-level "which correlation fields did the caller supply"
245    /// inspection without exposing the values themselves.
246    pub fn metadata_keys(&self) -> Vec<String> {
247        let Some(s) = self.inner.metadata.as_ref() else {
248            return Vec::new();
249        };
250        let mut keys: Vec<String> = s.fields.keys().cloned().collect();
251        keys.sort();
252        keys
253    }
254
255    /// Returns individual text parts, preserving part boundaries.
256    /// This is the primary safe accessor for message text content.
257    /// Callers decide how to combine parts.
258    pub fn text_parts(&self) -> Vec<&str> {
259        self.inner
260            .parts
261            .iter()
262            .filter_map(|p| match &p.content {
263                Some(pb::part::Content::Text(t)) => Some(t.as_str()),
264                _ => None,
265            })
266            .collect()
267    }
268
269    /// Convenience: joins all text parts with a single space.
270    /// Use `text_parts()` when part boundaries matter (e.g., multi-part prompts).
271    pub fn joined_text(&self) -> String {
272        self.text_parts().join(" ")
273    }
274
275    /// Raw JSON data parts (no number normalization).
276    pub fn data_parts(&self) -> Vec<serde_json::Value> {
277        self.inner
278            .parts
279            .iter()
280            .filter_map(|p| match &p.content {
281                Some(pb::part::Content::Data(proto_struct)) => {
282                    serde_json::to_value(proto_struct).ok()
283                }
284                _ => None,
285            })
286            .collect()
287    }
288
289    /// Deserialize the first Data part into `T`, normalizing proto f64 integers.
290    ///
291    /// Returns `None` if no Data part exists. Returns `Err` if deserialization fails.
292    pub fn parse_first_data<T: serde::de::DeserializeOwned>(
293        &self,
294    ) -> Option<Result<T, crate::error::A2aTypeError>> {
295        for part in &self.inner.parts {
296            if let Some(pb::part::Content::Data(proto_struct)) = &part.content {
297                if let Ok(json) = serde_json::to_value(proto_struct) {
298                    let normalized = normalize_proto_numbers_for_deser(json);
299                    return Some(
300                        serde_json::from_value(normalized).map_err(|e| {
301                            crate::error::A2aTypeError::Deserialization(e.to_string())
302                        }),
303                    );
304                }
305            }
306        }
307        None
308    }
309
310    /// Deserialize from the first Data part, falling back to parsing the first
311    /// Text part as JSON.
312    ///
313    /// A2A protocol v0.3 clients (e.g., Strands `A2AAgent` via a2a-sdk) send
314    /// typed JSON as a text part (`{"kind":"text","text":"{...}"}`), while
315    /// Turul clients send it as a data part. This method handles both.
316    ///
317    /// Preference order: Data part (proto struct) → Text part (JSON string).
318    /// Returns `None` if no Data or parseable Text part exists.
319    pub fn parse_first_data_or_text<T: serde::de::DeserializeOwned>(
320        &self,
321    ) -> Option<Result<T, crate::error::A2aTypeError>> {
322        // Try Data part first (Turul clients)
323        if let Some(result) = self.parse_first_data() {
324            return Some(result);
325        }
326
327        // Fall back to first Text part parsed as JSON (a2a-sdk / Strands clients)
328        for part in &self.inner.parts {
329            if let Some(pb::part::Content::Text(text)) = &part.content {
330                if text.trim_start().starts_with('{') {
331                    return Some(
332                        serde_json::from_str(text).map_err(|e| {
333                            crate::error::A2aTypeError::Deserialization(e.to_string())
334                        }),
335                    );
336                }
337            }
338        }
339
340        None
341    }
342
343    pub fn as_proto(&self) -> &pb::Message {
344        &self.inner
345    }
346
347    pub fn into_proto(self) -> pb::Message {
348        self.inner
349    }
350}
351
352impl TryFrom<pb::Message> for Message {
353    type Error = crate::error::A2aTypeError;
354
355    fn try_from(inner: pb::Message) -> Result<Self, Self::Error> {
356        // Validate role is not UNSPECIFIED
357        let role_val = pb::Role::try_from(inner.role).unwrap_or(pb::Role::Unspecified);
358        if role_val == pb::Role::Unspecified {
359            return Err(crate::error::A2aTypeError::MissingField("role"));
360        }
361        Ok(Self { inner })
362    }
363}
364
365impl From<Message> for pb::Message {
366    fn from(msg: Message) -> Self {
367        msg.inner
368    }
369}
370
371impl Serialize for Message {
372    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
373        self.inner.serialize(serializer)
374    }
375}
376
377impl<'de> Deserialize<'de> for Message {
378    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
379        let proto = pb::Message::deserialize(deserializer)?;
380        Message::try_from(proto).map_err(serde::de::Error::custom)
381    }
382}
383
384/// Normalize proto f64 whole numbers to integers for typed deserialization.
385///
386/// Protobuf `Value.number_value` is always f64. This converts finite whole-number
387/// f64s back to JSON integers so serde can deserialize into u32/i32/u8/etc.
388/// Fractional values and out-of-range numbers are left unchanged.
389///
390/// Internal — callers use `Part::parse_data::<T>()` or `Message::parse_first_data::<T>()`.
391fn normalize_proto_numbers_for_deser(value: serde_json::Value) -> serde_json::Value {
392    match value {
393        serde_json::Value::Number(n) => {
394            if let Some(f) = n.as_f64() {
395                if f.is_finite() && f.fract() == 0.0 {
396                    if f >= 0.0 && f <= u64::MAX as f64 {
397                        return serde_json::Value::Number((f as u64).into());
398                    } else if f >= i64::MIN as f64 && f <= i64::MAX as f64 {
399                        return serde_json::Value::Number((f as i64).into());
400                    }
401                }
402            }
403            serde_json::Value::Number(n)
404        }
405        serde_json::Value::Array(arr) => serde_json::Value::Array(
406            arr.into_iter()
407                .map(normalize_proto_numbers_for_deser)
408                .collect(),
409        ),
410        serde_json::Value::Object(map) => serde_json::Value::Object(
411            map.into_iter()
412                .map(|(k, v)| (k, normalize_proto_numbers_for_deser(v)))
413                .collect(),
414        ),
415        other => other,
416    }
417}
418
419/// Convert serde_json::Value to pbjson_types::Value.
420fn json_to_proto_value(value: serde_json::Value) -> pbjson_types::Value {
421    match value {
422        serde_json::Value::Null => pbjson_types::Value {
423            kind: Some(pbjson_types::value::Kind::NullValue(0)),
424        },
425        serde_json::Value::Bool(b) => pbjson_types::Value {
426            kind: Some(pbjson_types::value::Kind::BoolValue(b)),
427        },
428        serde_json::Value::Number(n) => pbjson_types::Value {
429            kind: Some(pbjson_types::value::Kind::NumberValue(
430                n.as_f64().unwrap_or(0.0),
431            )),
432        },
433        serde_json::Value::String(s) => pbjson_types::Value {
434            kind: Some(pbjson_types::value::Kind::StringValue(s)),
435        },
436        serde_json::Value::Array(arr) => pbjson_types::Value {
437            kind: Some(pbjson_types::value::Kind::ListValue(
438                pbjson_types::ListValue {
439                    values: arr.into_iter().map(json_to_proto_value).collect(),
440                },
441            )),
442        },
443        serde_json::Value::Object(map) => pbjson_types::Value {
444            kind: Some(pbjson_types::value::Kind::StructValue(
445                pbjson_types::Struct {
446                    fields: map
447                        .into_iter()
448                        .map(|(k, v)| (k, json_to_proto_value(v)))
449                        .collect(),
450                },
451            )),
452        },
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn part_text_constructor() {
462        let part = Part::text("hello");
463        let proto = part.as_proto();
464        assert!(matches!(proto.content, Some(pb::part::Content::Text(ref s)) if s == "hello"));
465        assert_eq!(proto.media_type, "text/plain");
466    }
467
468    #[test]
469    fn part_url_constructor() {
470        let part =
471            Part::url("https://example.com/file.pdf", "application/pdf").with_filename("file.pdf");
472        let proto = part.as_proto();
473        assert!(
474            matches!(proto.content, Some(pb::part::Content::Url(ref u)) if u == "https://example.com/file.pdf")
475        );
476        assert_eq!(proto.filename, "file.pdf");
477        assert_eq!(proto.media_type, "application/pdf");
478    }
479
480    #[test]
481    fn part_raw_constructor() {
482        let part = Part::raw(vec![0x48, 0x65], "image/png");
483        let proto = part.as_proto();
484        assert!(matches!(proto.content, Some(pb::part::Content::Raw(ref b)) if b == &[0x48, 0x65]));
485    }
486
487    #[test]
488    fn part_data_constructor() {
489        let part = Part::data(serde_json::json!({"key": "val"}));
490        let proto = part.as_proto();
491        assert!(matches!(proto.content, Some(pb::part::Content::Data(_))));
492    }
493
494    #[test]
495    fn part_serde_round_trip() {
496        let part = Part::text("round-trip");
497        let json = serde_json::to_string(&part).unwrap();
498        let back: Part = serde_json::from_str(&json).unwrap();
499        assert!(matches!(
500            back.as_proto().content,
501            Some(pb::part::Content::Text(ref s)) if s == "round-trip"
502        ));
503    }
504
505    #[test]
506    fn message_constructor() {
507        let msg = Message::new("msg-1", Role::User, vec![Part::text("hello")]);
508        assert_eq!(msg.message_id(), "msg-1");
509        assert_eq!(msg.as_proto().role, i32::from(pb::Role::User));
510        assert_eq!(msg.as_proto().parts.len(), 1);
511    }
512
513    #[test]
514    fn message_with_context_and_task() {
515        let msg = Message::new("msg-2", Role::Agent, vec![])
516            .with_context_id("ctx-1")
517            .with_task_id("task-1");
518        assert_eq!(msg.as_proto().context_id, "ctx-1");
519        assert_eq!(msg.as_proto().task_id, "task-1");
520    }
521
522    #[test]
523    fn message_serde_round_trip() {
524        let msg = Message::new("msg-rt", Role::User, vec![Part::text("hi")]);
525        let json = serde_json::to_string(&msg).unwrap();
526        let back: Message = serde_json::from_str(&json).unwrap();
527        assert_eq!(back.message_id(), "msg-rt");
528    }
529
530    #[test]
531    fn role_conversions() {
532        assert_eq!(pb::Role::from(Role::User), pb::Role::User);
533        assert_eq!(pb::Role::from(Role::Agent), pb::Role::Agent);
534        assert_eq!(Role::try_from(pb::Role::User).unwrap(), Role::User);
535        assert_eq!(Role::try_from(pb::Role::Agent).unwrap(), Role::Agent);
536        assert!(Role::try_from(pb::Role::Unspecified).is_err());
537    }
538
539    #[test]
540    fn message_try_from_proto_rejects_unspecified_role() {
541        let proto_msg = pb::Message {
542            message_id: "m-1".to_string(),
543            role: pb::Role::Unspecified.into(),
544            parts: vec![],
545            context_id: String::new(),
546            task_id: String::new(),
547            metadata: None,
548            extensions: vec![],
549            reference_task_ids: vec![],
550        };
551        assert!(Message::try_from(proto_msg).is_err());
552    }
553
554    #[test]
555    fn message_try_from_proto_accepts_valid_role() {
556        let proto_msg = pb::Message {
557            message_id: "m-2".to_string(),
558            role: pb::Role::User.into(),
559            parts: vec![],
560            context_id: String::new(),
561            task_id: String::new(),
562            metadata: None,
563            extensions: vec![],
564            reference_task_ids: vec![],
565        };
566        let msg = Message::try_from(proto_msg).unwrap();
567        assert_eq!(msg.message_id(), "m-2");
568    }
569
570    #[test]
571    fn message_json_deserialization_rejects_unspecified_role() {
572        let json = r#"{"messageId":"m-bad","role":"ROLE_UNSPECIFIED","parts":[]}"#;
573        let result: Result<Message, _> = serde_json::from_str(json);
574        assert!(result.is_err());
575    }
576
577    // =========================================================
578    // Proto number normalization tests
579    // =========================================================
580
581    #[test]
582    fn as_data_returns_raw_json_without_normalization() {
583        let part = Part::data(serde_json::json!({"count": 25544}));
584        let json = part.as_data().unwrap();
585        // Proto f64: 25544 → 25544.0 in raw JSON
586        let count = json.get("count").unwrap();
587        assert!(
588            count.is_f64() || count.is_u64(),
589            "Raw JSON may be f64 from proto: {count}"
590        );
591    }
592
593    #[test]
594    fn parse_data_normalizes_integers_for_typed_deser() {
595        #[derive(serde::Deserialize)]
596        struct MyData {
597            count: u32,
598            name: String,
599        }
600
601        let part = Part::data(serde_json::json!({"count": 25544, "name": "test"}));
602        let result: MyData = part.parse_data().unwrap().unwrap();
603        assert_eq!(result.count, 25544);
604        assert_eq!(result.name, "test");
605    }
606
607    #[test]
608    fn parse_data_preserves_fractional_numbers() {
609        #[derive(serde::Deserialize)]
610        struct MyData {
611            ratio: f64,
612        }
613
614        let part = Part::data(serde_json::json!({"ratio": 1.5}));
615        let result: MyData = part.parse_data().unwrap().unwrap();
616        assert!((result.ratio - 1.5).abs() < f64::EPSILON);
617    }
618
619    #[test]
620    fn parse_data_handles_nested_structures() {
621        #[derive(serde::Deserialize)]
622        struct Inner {
623            value: u16,
624        }
625        #[derive(serde::Deserialize)]
626        struct Outer {
627            items: Vec<Inner>,
628        }
629
630        let part = Part::data(serde_json::json!({
631            "items": [{"value": 42}, {"value": 100}]
632        }));
633        let result: Outer = part.parse_data().unwrap().unwrap();
634        assert_eq!(result.items.len(), 2);
635        assert_eq!(result.items[0].value, 42);
636        assert_eq!(result.items[1].value, 100);
637    }
638
639    #[test]
640    fn parse_data_returns_none_for_non_data_part() {
641        let part = Part::text("hello");
642        assert!(part.parse_data::<serde_json::Value>().is_none());
643    }
644
645    #[test]
646    fn message_parse_first_data_works() {
647        #[derive(serde::Deserialize)]
648        struct Req {
649            id: u32,
650        }
651
652        let msg = Message::new(
653            "m-1",
654            Role::User,
655            vec![
656                Part::text("some text"),
657                Part::data(serde_json::json!({"id": 12345})),
658            ],
659        );
660
661        let result: Req = msg.parse_first_data().unwrap().unwrap();
662        assert_eq!(result.id, 12345);
663    }
664
665    #[test]
666    fn normalize_whole_numbers_to_integers() {
667        let input = serde_json::json!({"a": 25544.0, "b": 1.5, "c": -10.0});
668        let output = normalize_proto_numbers_for_deser(input);
669        assert!(output["a"].is_u64(), "25544.0 should become integer");
670        assert!(output["b"].is_f64(), "1.5 should stay f64");
671        assert!(output["c"].is_i64(), "-10.0 should become negative integer");
672    }
673
674    #[test]
675    fn parse_first_data_or_text_prefers_data() {
676        #[derive(serde::Deserialize)]
677        struct Req {
678            id: u32,
679        }
680
681        let msg = Message::new(
682            "m-1",
683            Role::User,
684            vec![
685                Part::text(r#"{"id": 99}"#),
686                Part::data(serde_json::json!({"id": 42})),
687            ],
688        );
689
690        // Data part (id=42) should be preferred over text part (id=99)
691        let result: Req = msg.parse_first_data_or_text().unwrap().unwrap();
692        assert_eq!(result.id, 42);
693    }
694
695    #[test]
696    fn parse_first_data_or_text_falls_back_to_text() {
697        // Simulates Strands A2AAgent sending JSON as a text part
698        #[derive(serde::Deserialize)]
699        struct Req {
700            skill: String,
701            version: String,
702        }
703
704        let msg = Message::new(
705            "m-1",
706            Role::User,
707            vec![Part::text(
708                r#"{"skill": "solar_elevation", "version": "1.0"}"#,
709            )],
710        );
711
712        let result: Req = msg.parse_first_data_or_text().unwrap().unwrap();
713        assert_eq!(result.skill, "solar_elevation");
714        assert_eq!(result.version, "1.0");
715    }
716
717    #[test]
718    fn parse_first_data_or_text_ignores_non_json_text() {
719        let msg = Message::new("m-1", Role::User, vec![Part::text("hello world")]);
720
721        assert!(
722            msg.parse_first_data_or_text::<serde_json::Value>()
723                .is_none()
724        );
725    }
726}