1use serde::{Deserialize, Serialize};
2use turul_a2a_proto as pb;
3
4#[derive(Debug, Clone)]
6#[non_exhaustive]
7pub struct Part {
8 pub(crate) inner: pb::Part,
9}
10
11impl Part {
12 pub fn text(text: impl Into<String>) -> Self {
13 Self {
14 inner: pb::Part {
15 content: Some(pb::part::Content::Text(text.into())),
16 metadata: None,
17 filename: String::new(),
18 media_type: "text/plain".to_string(),
19 },
20 }
21 }
22
23 pub fn url(url: impl Into<String>, media_type: impl Into<String>) -> Self {
24 Self {
25 inner: pb::Part {
26 content: Some(pb::part::Content::Url(url.into())),
27 metadata: None,
28 filename: String::new(),
29 media_type: media_type.into(),
30 },
31 }
32 }
33
34 pub fn raw(data: Vec<u8>, media_type: impl Into<String>) -> Self {
35 Self {
36 inner: pb::Part {
37 content: Some(pb::part::Content::Raw(data)),
38 metadata: None,
39 filename: String::new(),
40 media_type: media_type.into(),
41 },
42 }
43 }
44
45 pub fn data(value: serde_json::Value) -> Self {
46 Self {
47 inner: pb::Part {
48 content: Some(pb::part::Content::Data(json_to_proto_value(value))),
49 metadata: None,
50 filename: String::new(),
51 media_type: "application/json".to_string(),
52 },
53 }
54 }
55
56 pub fn with_filename(mut self, filename: impl Into<String>) -> Self {
57 self.inner.filename = filename.into();
58 self
59 }
60
61 pub fn with_media_type(mut self, media_type: impl Into<String>) -> Self {
62 self.inner.media_type = media_type.into();
63 self
64 }
65
66 pub fn as_text(&self) -> Option<&str> {
68 match &self.inner.content {
69 Some(pb::part::Content::Text(t)) => Some(t.as_str()),
70 _ => None,
71 }
72 }
73
74 pub fn as_url(&self) -> Option<&str> {
76 match &self.inner.content {
77 Some(pb::part::Content::Url(u)) => Some(u.as_str()),
78 _ => None,
79 }
80 }
81
82 pub fn as_raw(&self) -> Option<&[u8]> {
84 match &self.inner.content {
85 Some(pb::part::Content::Raw(r)) => Some(r.as_slice()),
86 _ => None,
87 }
88 }
89
90 pub fn as_data(&self) -> Option<serde_json::Value> {
93 match &self.inner.content {
94 Some(pb::part::Content::Data(proto_struct)) => serde_json::to_value(proto_struct).ok(),
95 _ => None,
96 }
97 }
98
99 pub fn parse_data<T: serde::de::DeserializeOwned>(
107 &self,
108 ) -> Option<Result<T, crate::error::A2aTypeError>> {
109 let json = self.as_data()?;
110 let normalized = normalize_proto_numbers_for_deser(json);
111 Some(
112 serde_json::from_value(normalized)
113 .map_err(|e| crate::error::A2aTypeError::Deserialization(e.to_string())),
114 )
115 }
116
117 pub fn as_proto(&self) -> &pb::Part {
118 &self.inner
119 }
120
121 pub fn into_proto(self) -> pb::Part {
122 self.inner
123 }
124}
125
126impl From<pb::Part> for Part {
129 fn from(inner: pb::Part) -> Self {
130 Self { inner }
131 }
132}
133
134impl From<Part> for pb::Part {
135 fn from(part: Part) -> Self {
136 part.inner
137 }
138}
139
140impl Serialize for Part {
141 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
142 self.inner.serialize(serializer)
143 }
144}
145
146impl<'de> Deserialize<'de> for Part {
147 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
148 pb::Part::deserialize(deserializer).map(Self::from)
149 }
150}
151
152#[derive(Debug, Clone)]
154#[non_exhaustive]
155pub struct Message {
156 pub(crate) inner: pb::Message,
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161#[non_exhaustive]
162pub enum Role {
163 User,
164 Agent,
165}
166
167impl From<Role> for pb::Role {
168 fn from(role: Role) -> Self {
169 match role {
170 Role::User => pb::Role::User,
171 Role::Agent => pb::Role::Agent,
172 }
173 }
174}
175
176impl TryFrom<pb::Role> for Role {
177 type Error = crate::error::A2aTypeError;
178
179 fn try_from(value: pb::Role) -> Result<Self, Self::Error> {
180 match value {
181 pb::Role::User => Ok(Self::User),
182 pb::Role::Agent => Ok(Self::Agent),
183 pb::Role::Unspecified => Err(crate::error::A2aTypeError::InvalidState),
184 }
185 }
186}
187
188impl Message {
189 pub fn new(message_id: impl Into<String>, role: Role, parts: Vec<Part>) -> Self {
190 Self {
191 inner: pb::Message {
192 message_id: message_id.into(),
193 role: pb::Role::from(role).into(),
194 parts: parts.into_iter().map(|p| p.inner).collect(),
195 context_id: String::new(),
196 task_id: String::new(),
197 metadata: None,
198 extensions: vec![],
199 reference_task_ids: vec![],
200 },
201 }
202 }
203
204 pub fn with_context_id(mut self, context_id: impl Into<String>) -> Self {
205 self.inner.context_id = context_id.into();
206 self
207 }
208
209 pub fn with_task_id(mut self, task_id: impl Into<String>) -> Self {
210 self.inner.task_id = task_id.into();
211 self
212 }
213
214 pub fn message_id(&self) -> &str {
215 &self.inner.message_id
216 }
217
218 pub fn context_id(&self) -> &str {
221 &self.inner.context_id
222 }
223
224 pub fn task_id(&self) -> &str {
228 &self.inner.task_id
229 }
230
231 pub fn metadata(&self) -> Option<&pb::pbjson_types::Struct> {
236 self.inner.metadata.as_ref()
237 }
238
239 pub fn metadata_keys(&self) -> Vec<String> {
247 let Some(s) = self.inner.metadata.as_ref() else {
248 return Vec::new();
249 };
250 let mut keys: Vec<String> = s.fields.keys().cloned().collect();
251 keys.sort();
252 keys
253 }
254
255 pub fn text_parts(&self) -> Vec<&str> {
259 self.inner
260 .parts
261 .iter()
262 .filter_map(|p| match &p.content {
263 Some(pb::part::Content::Text(t)) => Some(t.as_str()),
264 _ => None,
265 })
266 .collect()
267 }
268
269 pub fn joined_text(&self) -> String {
272 self.text_parts().join(" ")
273 }
274
275 pub fn data_parts(&self) -> Vec<serde_json::Value> {
277 self.inner
278 .parts
279 .iter()
280 .filter_map(|p| match &p.content {
281 Some(pb::part::Content::Data(proto_struct)) => {
282 serde_json::to_value(proto_struct).ok()
283 }
284 _ => None,
285 })
286 .collect()
287 }
288
289 pub fn parse_first_data<T: serde::de::DeserializeOwned>(
293 &self,
294 ) -> Option<Result<T, crate::error::A2aTypeError>> {
295 for part in &self.inner.parts {
296 if let Some(pb::part::Content::Data(proto_struct)) = &part.content {
297 if let Ok(json) = serde_json::to_value(proto_struct) {
298 let normalized = normalize_proto_numbers_for_deser(json);
299 return Some(
300 serde_json::from_value(normalized).map_err(|e| {
301 crate::error::A2aTypeError::Deserialization(e.to_string())
302 }),
303 );
304 }
305 }
306 }
307 None
308 }
309
310 pub fn parse_first_data_or_text<T: serde::de::DeserializeOwned>(
320 &self,
321 ) -> Option<Result<T, crate::error::A2aTypeError>> {
322 if let Some(result) = self.parse_first_data() {
324 return Some(result);
325 }
326
327 for part in &self.inner.parts {
329 if let Some(pb::part::Content::Text(text)) = &part.content {
330 if text.trim_start().starts_with('{') {
331 return Some(
332 serde_json::from_str(text).map_err(|e| {
333 crate::error::A2aTypeError::Deserialization(e.to_string())
334 }),
335 );
336 }
337 }
338 }
339
340 None
341 }
342
343 pub fn as_proto(&self) -> &pb::Message {
344 &self.inner
345 }
346
347 pub fn into_proto(self) -> pb::Message {
348 self.inner
349 }
350}
351
352impl TryFrom<pb::Message> for Message {
353 type Error = crate::error::A2aTypeError;
354
355 fn try_from(inner: pb::Message) -> Result<Self, Self::Error> {
356 let role_val = pb::Role::try_from(inner.role).unwrap_or(pb::Role::Unspecified);
358 if role_val == pb::Role::Unspecified {
359 return Err(crate::error::A2aTypeError::MissingField("role"));
360 }
361 Ok(Self { inner })
362 }
363}
364
365impl From<Message> for pb::Message {
366 fn from(msg: Message) -> Self {
367 msg.inner
368 }
369}
370
371impl Serialize for Message {
372 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
373 self.inner.serialize(serializer)
374 }
375}
376
377impl<'de> Deserialize<'de> for Message {
378 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
379 let proto = pb::Message::deserialize(deserializer)?;
380 Message::try_from(proto).map_err(serde::de::Error::custom)
381 }
382}
383
384fn normalize_proto_numbers_for_deser(value: serde_json::Value) -> serde_json::Value {
392 match value {
393 serde_json::Value::Number(n) => {
394 if let Some(f) = n.as_f64() {
395 if f.is_finite() && f.fract() == 0.0 {
396 if f >= 0.0 && f <= u64::MAX as f64 {
397 return serde_json::Value::Number((f as u64).into());
398 } else if f >= i64::MIN as f64 && f <= i64::MAX as f64 {
399 return serde_json::Value::Number((f as i64).into());
400 }
401 }
402 }
403 serde_json::Value::Number(n)
404 }
405 serde_json::Value::Array(arr) => serde_json::Value::Array(
406 arr.into_iter()
407 .map(normalize_proto_numbers_for_deser)
408 .collect(),
409 ),
410 serde_json::Value::Object(map) => serde_json::Value::Object(
411 map.into_iter()
412 .map(|(k, v)| (k, normalize_proto_numbers_for_deser(v)))
413 .collect(),
414 ),
415 other => other,
416 }
417}
418
419fn json_to_proto_value(value: serde_json::Value) -> pbjson_types::Value {
421 match value {
422 serde_json::Value::Null => pbjson_types::Value {
423 kind: Some(pbjson_types::value::Kind::NullValue(0)),
424 },
425 serde_json::Value::Bool(b) => pbjson_types::Value {
426 kind: Some(pbjson_types::value::Kind::BoolValue(b)),
427 },
428 serde_json::Value::Number(n) => pbjson_types::Value {
429 kind: Some(pbjson_types::value::Kind::NumberValue(
430 n.as_f64().unwrap_or(0.0),
431 )),
432 },
433 serde_json::Value::String(s) => pbjson_types::Value {
434 kind: Some(pbjson_types::value::Kind::StringValue(s)),
435 },
436 serde_json::Value::Array(arr) => pbjson_types::Value {
437 kind: Some(pbjson_types::value::Kind::ListValue(
438 pbjson_types::ListValue {
439 values: arr.into_iter().map(json_to_proto_value).collect(),
440 },
441 )),
442 },
443 serde_json::Value::Object(map) => pbjson_types::Value {
444 kind: Some(pbjson_types::value::Kind::StructValue(
445 pbjson_types::Struct {
446 fields: map
447 .into_iter()
448 .map(|(k, v)| (k, json_to_proto_value(v)))
449 .collect(),
450 },
451 )),
452 },
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn part_text_constructor() {
462 let part = Part::text("hello");
463 let proto = part.as_proto();
464 assert!(matches!(proto.content, Some(pb::part::Content::Text(ref s)) if s == "hello"));
465 assert_eq!(proto.media_type, "text/plain");
466 }
467
468 #[test]
469 fn part_url_constructor() {
470 let part =
471 Part::url("https://example.com/file.pdf", "application/pdf").with_filename("file.pdf");
472 let proto = part.as_proto();
473 assert!(
474 matches!(proto.content, Some(pb::part::Content::Url(ref u)) if u == "https://example.com/file.pdf")
475 );
476 assert_eq!(proto.filename, "file.pdf");
477 assert_eq!(proto.media_type, "application/pdf");
478 }
479
480 #[test]
481 fn part_raw_constructor() {
482 let part = Part::raw(vec![0x48, 0x65], "image/png");
483 let proto = part.as_proto();
484 assert!(matches!(proto.content, Some(pb::part::Content::Raw(ref b)) if b == &[0x48, 0x65]));
485 }
486
487 #[test]
488 fn part_data_constructor() {
489 let part = Part::data(serde_json::json!({"key": "val"}));
490 let proto = part.as_proto();
491 assert!(matches!(proto.content, Some(pb::part::Content::Data(_))));
492 }
493
494 #[test]
495 fn part_serde_round_trip() {
496 let part = Part::text("round-trip");
497 let json = serde_json::to_string(&part).unwrap();
498 let back: Part = serde_json::from_str(&json).unwrap();
499 assert!(matches!(
500 back.as_proto().content,
501 Some(pb::part::Content::Text(ref s)) if s == "round-trip"
502 ));
503 }
504
505 #[test]
506 fn message_constructor() {
507 let msg = Message::new("msg-1", Role::User, vec![Part::text("hello")]);
508 assert_eq!(msg.message_id(), "msg-1");
509 assert_eq!(msg.as_proto().role, i32::from(pb::Role::User));
510 assert_eq!(msg.as_proto().parts.len(), 1);
511 }
512
513 #[test]
514 fn message_with_context_and_task() {
515 let msg = Message::new("msg-2", Role::Agent, vec![])
516 .with_context_id("ctx-1")
517 .with_task_id("task-1");
518 assert_eq!(msg.as_proto().context_id, "ctx-1");
519 assert_eq!(msg.as_proto().task_id, "task-1");
520 }
521
522 #[test]
523 fn message_serde_round_trip() {
524 let msg = Message::new("msg-rt", Role::User, vec![Part::text("hi")]);
525 let json = serde_json::to_string(&msg).unwrap();
526 let back: Message = serde_json::from_str(&json).unwrap();
527 assert_eq!(back.message_id(), "msg-rt");
528 }
529
530 #[test]
531 fn role_conversions() {
532 assert_eq!(pb::Role::from(Role::User), pb::Role::User);
533 assert_eq!(pb::Role::from(Role::Agent), pb::Role::Agent);
534 assert_eq!(Role::try_from(pb::Role::User).unwrap(), Role::User);
535 assert_eq!(Role::try_from(pb::Role::Agent).unwrap(), Role::Agent);
536 assert!(Role::try_from(pb::Role::Unspecified).is_err());
537 }
538
539 #[test]
540 fn message_try_from_proto_rejects_unspecified_role() {
541 let proto_msg = pb::Message {
542 message_id: "m-1".to_string(),
543 role: pb::Role::Unspecified.into(),
544 parts: vec![],
545 context_id: String::new(),
546 task_id: String::new(),
547 metadata: None,
548 extensions: vec![],
549 reference_task_ids: vec![],
550 };
551 assert!(Message::try_from(proto_msg).is_err());
552 }
553
554 #[test]
555 fn message_try_from_proto_accepts_valid_role() {
556 let proto_msg = pb::Message {
557 message_id: "m-2".to_string(),
558 role: pb::Role::User.into(),
559 parts: vec![],
560 context_id: String::new(),
561 task_id: String::new(),
562 metadata: None,
563 extensions: vec![],
564 reference_task_ids: vec![],
565 };
566 let msg = Message::try_from(proto_msg).unwrap();
567 assert_eq!(msg.message_id(), "m-2");
568 }
569
570 #[test]
571 fn message_json_deserialization_rejects_unspecified_role() {
572 let json = r#"{"messageId":"m-bad","role":"ROLE_UNSPECIFIED","parts":[]}"#;
573 let result: Result<Message, _> = serde_json::from_str(json);
574 assert!(result.is_err());
575 }
576
577 #[test]
582 fn as_data_returns_raw_json_without_normalization() {
583 let part = Part::data(serde_json::json!({"count": 25544}));
584 let json = part.as_data().unwrap();
585 let count = json.get("count").unwrap();
587 assert!(
588 count.is_f64() || count.is_u64(),
589 "Raw JSON may be f64 from proto: {count}"
590 );
591 }
592
593 #[test]
594 fn parse_data_normalizes_integers_for_typed_deser() {
595 #[derive(serde::Deserialize)]
596 struct MyData {
597 count: u32,
598 name: String,
599 }
600
601 let part = Part::data(serde_json::json!({"count": 25544, "name": "test"}));
602 let result: MyData = part.parse_data().unwrap().unwrap();
603 assert_eq!(result.count, 25544);
604 assert_eq!(result.name, "test");
605 }
606
607 #[test]
608 fn parse_data_preserves_fractional_numbers() {
609 #[derive(serde::Deserialize)]
610 struct MyData {
611 ratio: f64,
612 }
613
614 let part = Part::data(serde_json::json!({"ratio": 1.5}));
615 let result: MyData = part.parse_data().unwrap().unwrap();
616 assert!((result.ratio - 1.5).abs() < f64::EPSILON);
617 }
618
619 #[test]
620 fn parse_data_handles_nested_structures() {
621 #[derive(serde::Deserialize)]
622 struct Inner {
623 value: u16,
624 }
625 #[derive(serde::Deserialize)]
626 struct Outer {
627 items: Vec<Inner>,
628 }
629
630 let part = Part::data(serde_json::json!({
631 "items": [{"value": 42}, {"value": 100}]
632 }));
633 let result: Outer = part.parse_data().unwrap().unwrap();
634 assert_eq!(result.items.len(), 2);
635 assert_eq!(result.items[0].value, 42);
636 assert_eq!(result.items[1].value, 100);
637 }
638
639 #[test]
640 fn parse_data_returns_none_for_non_data_part() {
641 let part = Part::text("hello");
642 assert!(part.parse_data::<serde_json::Value>().is_none());
643 }
644
645 #[test]
646 fn message_parse_first_data_works() {
647 #[derive(serde::Deserialize)]
648 struct Req {
649 id: u32,
650 }
651
652 let msg = Message::new(
653 "m-1",
654 Role::User,
655 vec![
656 Part::text("some text"),
657 Part::data(serde_json::json!({"id": 12345})),
658 ],
659 );
660
661 let result: Req = msg.parse_first_data().unwrap().unwrap();
662 assert_eq!(result.id, 12345);
663 }
664
665 #[test]
666 fn normalize_whole_numbers_to_integers() {
667 let input = serde_json::json!({"a": 25544.0, "b": 1.5, "c": -10.0});
668 let output = normalize_proto_numbers_for_deser(input);
669 assert!(output["a"].is_u64(), "25544.0 should become integer");
670 assert!(output["b"].is_f64(), "1.5 should stay f64");
671 assert!(output["c"].is_i64(), "-10.0 should become negative integer");
672 }
673
674 #[test]
675 fn parse_first_data_or_text_prefers_data() {
676 #[derive(serde::Deserialize)]
677 struct Req {
678 id: u32,
679 }
680
681 let msg = Message::new(
682 "m-1",
683 Role::User,
684 vec![
685 Part::text(r#"{"id": 99}"#),
686 Part::data(serde_json::json!({"id": 42})),
687 ],
688 );
689
690 let result: Req = msg.parse_first_data_or_text().unwrap().unwrap();
692 assert_eq!(result.id, 42);
693 }
694
695 #[test]
696 fn parse_first_data_or_text_falls_back_to_text() {
697 #[derive(serde::Deserialize)]
699 struct Req {
700 skill: String,
701 version: String,
702 }
703
704 let msg = Message::new(
705 "m-1",
706 Role::User,
707 vec![Part::text(
708 r#"{"skill": "solar_elevation", "version": "1.0"}"#,
709 )],
710 );
711
712 let result: Req = msg.parse_first_data_or_text().unwrap().unwrap();
713 assert_eq!(result.skill, "solar_elevation");
714 assert_eq!(result.version, "1.0");
715 }
716
717 #[test]
718 fn parse_first_data_or_text_ignores_non_json_text() {
719 let msg = Message::new("m-1", Role::User, vec![Part::text("hello world")]);
720
721 assert!(
722 msg.parse_first_data_or_text::<serde_json::Value>()
723 .is_none()
724 );
725 }
726}