1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::content::{Role, SamplingContent, SamplingContentBlock};
11use crate::definitions::Tool;
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
19pub struct TaskMetadata {
20 #[serde(skip_serializing_if = "Option::is_none")]
22 pub ttl: Option<u64>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct Task {
28 #[serde(rename = "taskId")]
30 pub task_id: String,
31 pub status: TaskStatus,
33 #[serde(rename = "statusMessage", skip_serializing_if = "Option::is_none")]
35 pub status_message: Option<String>,
36 #[serde(rename = "createdAt")]
38 pub created_at: String,
39 #[serde(rename = "lastUpdatedAt")]
41 pub last_updated_at: String,
42 pub ttl: Option<u64>,
44 #[serde(rename = "pollInterval", skip_serializing_if = "Option::is_none")]
46 pub poll_interval: Option<u64>,
47}
48
49#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
51#[serde(rename_all = "snake_case")]
52pub enum TaskStatus {
53 Cancelled,
55 Completed,
57 Failed,
59 InputRequired,
61 Working,
63}
64
65impl std::fmt::Display for TaskStatus {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 Self::Cancelled => f.write_str("cancelled"),
69 Self::Completed => f.write_str("completed"),
70 Self::Failed => f.write_str("failed"),
71 Self::InputRequired => f.write_str("input_required"),
72 Self::Working => f.write_str("working"),
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
79pub struct CreateTaskResult {
80 pub task: Task,
82 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
84 pub meta: Option<HashMap<String, Value>>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
89pub struct ListTasksResult {
90 pub tasks: Vec<Task>,
92 #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
94 pub next_cursor: Option<String>,
95 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
97 pub meta: Option<HashMap<String, Value>>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
104pub struct RelatedTaskMetadata {
105 #[serde(rename = "taskId")]
107 pub task_id: String,
108}
109
110#[derive(Debug, Clone, PartialEq)]
120pub enum ElicitRequestParams {
121 Form(ElicitRequestFormParams),
123 Url(ElicitRequestURLParams),
125}
126
127impl Serialize for ElicitRequestParams {
128 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
129 match self {
130 Self::Form(params) => {
131 let mut value = serde_json::to_value(params).map_err(serde::ser::Error::custom)?;
133 if let Some(obj) = value.as_object_mut() {
134 obj.insert("mode".into(), Value::String("form".into()));
135 }
136 value.serialize(serializer)
137 }
138 Self::Url(params) => {
139 let mut value = serde_json::to_value(params).map_err(serde::ser::Error::custom)?;
141 if let Some(obj) = value.as_object_mut() {
142 obj.insert("mode".into(), Value::String("url".into()));
143 }
144 value.serialize(serializer)
145 }
146 }
147 }
148}
149
150impl<'de> Deserialize<'de> for ElicitRequestParams {
151 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
152 let value = Value::deserialize(deserializer)?;
153 let mode = value.get("mode").and_then(|v| v.as_str()).unwrap_or("form");
154
155 match mode {
156 "url" => {
157 let params: ElicitRequestURLParams =
158 serde_json::from_value(value).map_err(serde::de::Error::custom)?;
159 Ok(Self::Url(params))
160 }
161 _ => {
162 let params: ElicitRequestFormParams =
164 serde_json::from_value(value).map_err(serde::de::Error::custom)?;
165 Ok(Self::Form(params))
166 }
167 }
168 }
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
173pub struct ElicitRequestFormParams {
174 pub message: String,
176 #[serde(rename = "requestedSchema")]
178 pub requested_schema: Value,
179 #[serde(skip_serializing_if = "Option::is_none")]
181 pub task: Option<TaskMetadata>,
182 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
184 pub meta: Option<HashMap<String, Value>>,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
189pub struct ElicitRequestURLParams {
190 pub message: String,
192 pub url: String,
194 #[serde(rename = "elicitationId")]
196 pub elicitation_id: String,
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub task: Option<TaskMetadata>,
200 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
202 pub meta: Option<HashMap<String, Value>>,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
207pub struct ElicitResult {
208 pub action: ElicitAction,
210 #[serde(skip_serializing_if = "Option::is_none")]
213 pub content: Option<Value>,
214 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
216 pub meta: Option<HashMap<String, Value>>,
217}
218
219#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
221#[serde(rename_all = "lowercase")]
222pub enum ElicitAction {
223 Accept,
225 Decline,
227 Cancel,
229}
230
231impl std::fmt::Display for ElicitAction {
232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
233 match self {
234 Self::Accept => f.write_str("accept"),
235 Self::Decline => f.write_str("decline"),
236 Self::Cancel => f.write_str("cancel"),
237 }
238 }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
245pub struct ElicitationCompleteNotification {
246 #[serde(rename = "elicitationId")]
248 pub elicitation_id: String,
249}
250
251#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
257pub struct CreateMessageRequest {
258 #[serde(default)]
260 pub messages: Vec<SamplingMessage>,
261 #[serde(rename = "maxTokens")]
263 pub max_tokens: u32,
264 #[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
266 pub model_preferences: Option<ModelPreferences>,
267 #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
269 pub system_prompt: Option<String>,
270 #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
272 pub include_context: Option<IncludeContext>,
273 #[serde(skip_serializing_if = "Option::is_none")]
275 pub temperature: Option<f64>,
276 #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
278 pub stop_sequences: Option<Vec<String>>,
279 #[serde(skip_serializing_if = "Option::is_none")]
281 pub task: Option<TaskMetadata>,
282 #[serde(skip_serializing_if = "Option::is_none")]
284 pub tools: Option<Vec<Tool>>,
285 #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
287 pub tool_choice: Option<ToolChoice>,
288 #[serde(skip_serializing_if = "Option::is_none")]
290 pub metadata: Option<Value>,
291 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
293 pub meta: Option<HashMap<String, Value>>,
294}
295
296#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
301pub struct SamplingMessage {
302 pub role: Role,
304 pub content: SamplingContentBlock,
306 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
308 pub meta: Option<HashMap<String, Value>>,
309}
310
311impl SamplingMessage {
312 #[must_use]
314 pub fn user(text: impl Into<String>) -> Self {
315 Self {
316 role: Role::User,
317 content: SamplingContent::text(text).into(),
318 meta: None,
319 }
320 }
321
322 #[must_use]
324 pub fn assistant(text: impl Into<String>) -> Self {
325 Self {
326 role: Role::Assistant,
327 content: SamplingContent::text(text).into(),
328 meta: None,
329 }
330 }
331}
332
333#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
335pub struct ModelPreferences {
336 #[serde(skip_serializing_if = "Option::is_none")]
338 pub hints: Option<Vec<ModelHint>>,
339 #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
341 pub cost_priority: Option<f64>,
342 #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
344 pub speed_priority: Option<f64>,
345 #[serde(
347 rename = "intelligencePriority",
348 skip_serializing_if = "Option::is_none"
349 )]
350 pub intelligence_priority: Option<f64>,
351}
352
353#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
357pub struct ModelHint {
358 #[serde(skip_serializing_if = "Option::is_none")]
360 pub name: Option<String>,
361}
362
363impl std::fmt::Display for IncludeContext {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 match self {
366 Self::AllServers => f.write_str("allServers"),
367 Self::ThisServer => f.write_str("thisServer"),
368 Self::None => f.write_str("none"),
369 }
370 }
371}
372
373#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
377pub enum IncludeContext {
378 #[serde(rename = "allServers")]
380 AllServers,
381 #[serde(rename = "thisServer")]
383 ThisServer,
384 #[serde(rename = "none")]
386 None,
387}
388
389#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
393pub struct ToolChoice {
394 #[serde(skip_serializing_if = "Option::is_none")]
396 pub mode: Option<ToolChoiceMode>,
397}
398
399#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
401#[serde(rename_all = "lowercase")]
402pub enum ToolChoiceMode {
403 Auto,
405 None,
407 Required,
409}
410
411impl std::fmt::Display for ToolChoiceMode {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 match self {
414 Self::Auto => f.write_str("auto"),
415 Self::None => f.write_str("none"),
416 Self::Required => f.write_str("required"),
417 }
418 }
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
426pub struct CreateMessageResult {
427 pub role: Role,
429 pub content: SamplingContentBlock,
431 pub model: String,
433 #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
435 pub stop_reason: Option<String>,
436 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
438 pub meta: Option<HashMap<String, Value>>,
439}
440
441#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
447pub struct ClientCapabilities {
448 #[serde(skip_serializing_if = "Option::is_none")]
450 pub elicitation: Option<ElicitationCapabilities>,
451 #[serde(skip_serializing_if = "Option::is_none")]
453 pub sampling: Option<SamplingCapabilities>,
454 #[serde(skip_serializing_if = "Option::is_none")]
456 pub roots: Option<RootsCapabilities>,
457 #[serde(skip_serializing_if = "Option::is_none")]
459 pub tasks: Option<ClientTaskCapabilities>,
460 #[serde(skip_serializing_if = "Option::is_none")]
462 pub experimental: Option<HashMap<String, Value>>,
463}
464
465#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
467pub struct ElicitationCapabilities {
468 #[serde(skip_serializing_if = "Option::is_none")]
470 pub form: Option<HashMap<String, Value>>,
471 #[serde(skip_serializing_if = "Option::is_none")]
473 pub url: Option<HashMap<String, Value>>,
474}
475
476#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
478pub struct SamplingCapabilities {
479 #[serde(skip_serializing_if = "Option::is_none")]
481 pub context: Option<HashMap<String, Value>>,
482 #[serde(skip_serializing_if = "Option::is_none")]
484 pub tools: Option<HashMap<String, Value>>,
485}
486
487#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
489pub struct RootsCapabilities {
490 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
492 pub list_changed: Option<bool>,
493}
494
495#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
497pub struct ClientTaskCapabilities {
498 #[serde(skip_serializing_if = "Option::is_none")]
500 pub list: Option<HashMap<String, Value>>,
501 #[serde(skip_serializing_if = "Option::is_none")]
503 pub cancel: Option<HashMap<String, Value>>,
504 #[serde(skip_serializing_if = "Option::is_none")]
506 pub requests: Option<ClientTaskRequests>,
507}
508
509#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
511pub struct ClientTaskRequests {
512 #[serde(skip_serializing_if = "Option::is_none")]
514 pub sampling: Option<ClientTaskSamplingRequests>,
515 #[serde(skip_serializing_if = "Option::is_none")]
517 pub elicitation: Option<ClientTaskElicitationRequests>,
518}
519
520#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
522pub struct ClientTaskSamplingRequests {
523 #[serde(rename = "createMessage", skip_serializing_if = "Option::is_none")]
525 pub create_message: Option<HashMap<String, Value>>,
526}
527
528#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
530pub struct ClientTaskElicitationRequests {
531 #[serde(skip_serializing_if = "Option::is_none")]
533 pub create: Option<HashMap<String, Value>>,
534}
535
536#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
541pub struct ServerCapabilities {
542 #[serde(skip_serializing_if = "Option::is_none")]
544 pub tools: Option<ToolCapabilities>,
545 #[serde(skip_serializing_if = "Option::is_none")]
547 pub resources: Option<ResourceCapabilities>,
548 #[serde(skip_serializing_if = "Option::is_none")]
550 pub prompts: Option<PromptCapabilities>,
551 #[serde(skip_serializing_if = "Option::is_none")]
553 pub logging: Option<HashMap<String, Value>>,
554 #[serde(skip_serializing_if = "Option::is_none")]
556 pub completions: Option<HashMap<String, Value>>,
557 #[serde(skip_serializing_if = "Option::is_none")]
559 pub tasks: Option<ServerTaskCapabilities>,
560 #[serde(skip_serializing_if = "Option::is_none")]
562 pub experimental: Option<HashMap<String, Value>>,
563}
564
565#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
567pub struct ToolCapabilities {
568 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
570 pub list_changed: Option<bool>,
571}
572
573#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
575pub struct ResourceCapabilities {
576 #[serde(skip_serializing_if = "Option::is_none")]
578 pub subscribe: Option<bool>,
579 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
581 pub list_changed: Option<bool>,
582}
583
584#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
586pub struct PromptCapabilities {
587 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
589 pub list_changed: Option<bool>,
590}
591
592#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
596pub struct ServerTaskCapabilities {
597 #[serde(skip_serializing_if = "Option::is_none")]
599 pub list: Option<HashMap<String, Value>>,
600 #[serde(skip_serializing_if = "Option::is_none")]
602 pub cancel: Option<HashMap<String, Value>>,
603 #[serde(skip_serializing_if = "Option::is_none")]
605 pub requests: Option<ServerTaskRequests>,
606}
607
608#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
612pub struct ServerTaskRequests {
613 #[serde(skip_serializing_if = "Option::is_none")]
615 pub tools: Option<ServerTaskToolRequests>,
616}
617
618#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
620pub struct ServerTaskToolRequests {
621 #[serde(skip_serializing_if = "Option::is_none")]
623 pub call: Option<HashMap<String, Value>>,
624}
625
626#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_include_context_serde() {
636 let json = serde_json::to_string(&IncludeContext::ThisServer).unwrap();
638 assert_eq!(json, "\"thisServer\"");
639
640 let json = serde_json::to_string(&IncludeContext::AllServers).unwrap();
641 assert_eq!(json, "\"allServers\"");
642
643 let json = serde_json::to_string(&IncludeContext::None).unwrap();
644 assert_eq!(json, "\"none\"");
645
646 let parsed: IncludeContext = serde_json::from_str("\"thisServer\"").unwrap();
648 assert_eq!(parsed, IncludeContext::ThisServer);
649 }
650
651 #[test]
652 fn test_tool_choice_mode_optional() {
653 let tc = ToolChoice { mode: None };
655 let json = serde_json::to_string(&tc).unwrap();
656 assert_eq!(json, "{}");
657
658 let tc = ToolChoice {
660 mode: Some(ToolChoiceMode::Required),
661 };
662 let json = serde_json::to_string(&tc).unwrap();
663 assert!(json.contains("\"required\""));
664 }
665
666 #[test]
667 fn test_model_hint_name_optional() {
668 let hint = ModelHint { name: None };
669 let json = serde_json::to_string(&hint).unwrap();
670 assert_eq!(json, "{}");
671
672 let hint = ModelHint {
673 name: Some("claude".into()),
674 };
675 let json = serde_json::to_string(&hint).unwrap();
676 assert!(json.contains("\"claude\""));
677 }
678
679 #[test]
680 fn test_task_status_serde() {
681 let json = serde_json::to_string(&TaskStatus::InputRequired).unwrap();
682 assert_eq!(json, "\"input_required\"");
683
684 let json = serde_json::to_string(&TaskStatus::Working).unwrap();
685 assert_eq!(json, "\"working\"");
686 }
687
688 #[test]
689 fn test_create_message_request_default() {
690 let req = CreateMessageRequest {
692 messages: vec![SamplingMessage::user("hello")],
693 max_tokens: 100,
694 ..Default::default()
695 };
696 assert_eq!(req.messages.len(), 1);
697 assert_eq!(req.max_tokens, 100);
698 assert!(req.tools.is_none());
699 }
700
701 #[test]
702 fn test_sampling_message_content_single_or_array() {
703 let msg = SamplingMessage::user("hello");
705 let json = serde_json::to_string(&msg).unwrap();
706 assert!(json.contains("\"text\":\"hello\""));
708
709 let parsed: SamplingMessage = serde_json::from_str(&json).unwrap();
711 assert_eq!(parsed.content.as_text(), Some("hello"));
712
713 let json_array = r#"{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}"#;
715 let parsed: SamplingMessage = serde_json::from_str(json_array).unwrap();
716 match &parsed.content {
717 SamplingContentBlock::Multiple(v) => assert_eq!(v.len(), 2),
718 _ => panic!("Expected multiple content blocks"),
719 }
720 }
721
722 #[test]
723 fn test_server_capabilities_structure() {
724 let caps = ServerCapabilities {
725 tasks: Some(ServerTaskCapabilities {
726 list: Some(HashMap::new()),
727 cancel: Some(HashMap::new()),
728 requests: Some(ServerTaskRequests {
729 tools: Some(ServerTaskToolRequests {
730 call: Some(HashMap::new()),
731 }),
732 }),
733 }),
734 ..Default::default()
735 };
736 let json = serde_json::to_string(&caps).unwrap();
737 let v: Value = serde_json::from_str(&json).unwrap();
738 assert!(v["tasks"]["requests"]["tools"]["call"].is_object());
740 }
741
742 #[test]
744 fn test_elicit_action_serde() {
745 let cases = [
746 (ElicitAction::Accept, "\"accept\""),
747 (ElicitAction::Decline, "\"decline\""),
748 (ElicitAction::Cancel, "\"cancel\""),
749 ];
750 for (action, expected) in cases {
751 let json = serde_json::to_string(&action).unwrap();
752 assert_eq!(json, expected);
753 let parsed: ElicitAction = serde_json::from_str(expected).unwrap();
754 assert_eq!(parsed, action);
755 }
756 }
757
758 #[test]
759 fn test_elicit_result_round_trip() {
760 let result = ElicitResult {
761 action: ElicitAction::Accept,
762 content: Some(serde_json::json!({"name": "test"})),
763 meta: None,
764 };
765 let json = serde_json::to_string(&result).unwrap();
766 let parsed: ElicitResult = serde_json::from_str(&json).unwrap();
767 assert_eq!(parsed.action, ElicitAction::Accept);
768 assert!(parsed.content.is_some());
769
770 let decline = ElicitResult {
772 action: ElicitAction::Decline,
773 content: None,
774 meta: None,
775 };
776 let json = serde_json::to_string(&decline).unwrap();
777 assert!(!json.contains("\"content\""));
778 let parsed: ElicitResult = serde_json::from_str(&json).unwrap();
779 assert_eq!(parsed.action, ElicitAction::Decline);
780 assert!(parsed.content.is_none());
781 }
782
783 #[test]
785 fn test_server_capabilities_no_elicitation_or_sampling() {
786 let caps = ServerCapabilities::default();
787 let json = serde_json::to_string(&caps).unwrap();
788 assert!(!json.contains("elicitation"));
789 assert!(!json.contains("sampling"));
790
791 let caps = ServerCapabilities {
793 tools: Some(ToolCapabilities {
794 list_changed: Some(true),
795 }),
796 resources: Some(ResourceCapabilities {
797 subscribe: Some(true),
798 list_changed: Some(true),
799 }),
800 prompts: Some(PromptCapabilities {
801 list_changed: Some(true),
802 }),
803 logging: Some(HashMap::new()),
804 completions: Some(HashMap::new()),
805 tasks: Some(ServerTaskCapabilities::default()),
806 experimental: Some(HashMap::new()),
807 };
808 let json = serde_json::to_string(&caps).unwrap();
809 assert!(!json.contains("elicitation"));
810 assert!(!json.contains("sampling"));
811 }
812
813 #[test]
815 fn test_sampling_message_array_content_round_trip() {
816 let json_array =
817 r#"{"role":"user","content":[{"type":"text","text":"a"},{"type":"text","text":"b"}]}"#;
818 let parsed: SamplingMessage = serde_json::from_str(json_array).unwrap();
819 let re_serialized = serde_json::to_string(&parsed).unwrap();
820 let re_parsed: Value = serde_json::from_str(&re_serialized).unwrap();
821 assert!(re_parsed["content"].is_array());
822 assert_eq!(re_parsed["content"].as_array().unwrap().len(), 2);
823 }
824
825 #[test]
827 fn test_tool_choice_mode_all_variants() {
828 let cases = [
829 (ToolChoiceMode::Auto, "\"auto\""),
830 (ToolChoiceMode::None, "\"none\""),
831 (ToolChoiceMode::Required, "\"required\""),
832 ];
833 for (mode, expected) in cases {
834 let json = serde_json::to_string(&mode).unwrap();
835 assert_eq!(json, expected);
836 let parsed: ToolChoiceMode = serde_json::from_str(expected).unwrap();
837 assert_eq!(parsed, mode);
838 }
839 }
840
841 #[test]
843 fn test_elicit_request_params_form_without_mode() {
844 let json = r#"{"message":"Enter name","requestedSchema":{"type":"object"}}"#;
846 let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
847 match &parsed {
848 ElicitRequestParams::Form(params) => {
849 assert_eq!(params.message, "Enter name");
850 }
851 ElicitRequestParams::Url(_) => panic!("expected Form variant"),
852 }
853 }
854
855 #[test]
856 fn test_elicit_request_params_form_with_explicit_mode() {
857 let json = r#"{"mode":"form","message":"Enter name","requestedSchema":{"type":"object"}}"#;
858 let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
859 match &parsed {
860 ElicitRequestParams::Form(params) => {
861 assert_eq!(params.message, "Enter name");
862 }
863 ElicitRequestParams::Url(_) => panic!("expected Form variant"),
864 }
865 }
866
867 #[test]
868 fn test_elicit_request_params_url_mode() {
869 let json = r#"{"mode":"url","message":"Authenticate","url":"https://example.com/auth","elicitationId":"e-123"}"#;
870 let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
871 match &parsed {
872 ElicitRequestParams::Url(params) => {
873 assert_eq!(params.message, "Authenticate");
874 assert_eq!(params.url, "https://example.com/auth");
875 assert_eq!(params.elicitation_id, "e-123");
876 }
877 ElicitRequestParams::Form(_) => panic!("expected Url variant"),
878 }
879 }
880
881 #[test]
882 fn test_elicit_request_params_form_round_trip() {
883 let params = ElicitRequestParams::Form(ElicitRequestFormParams {
884 message: "Enter details".into(),
885 requested_schema: serde_json::json!({"type": "object", "properties": {"name": {"type": "string"}}}),
886 task: None,
887 meta: None,
888 });
889 let json = serde_json::to_string(¶ms).unwrap();
890 let v: Value = serde_json::from_str(&json).unwrap();
892 assert_eq!(v["mode"], "form");
893 let parsed: ElicitRequestParams = serde_json::from_str(&json).unwrap();
895 assert_eq!(parsed, params);
896 }
897
898 #[test]
899 fn test_elicit_request_params_url_round_trip() {
900 let params = ElicitRequestParams::Url(ElicitRequestURLParams {
901 message: "Please authenticate".into(),
902 url: "https://example.com/oauth".into(),
903 elicitation_id: "elicit-456".into(),
904 task: None,
905 meta: None,
906 });
907 let json = serde_json::to_string(¶ms).unwrap();
908 let v: Value = serde_json::from_str(&json).unwrap();
909 assert_eq!(v["mode"], "url");
910 let parsed: ElicitRequestParams = serde_json::from_str(&json).unwrap();
911 assert_eq!(parsed, params);
912 }
913
914 #[test]
916 fn test_task_status_all_variants() {
917 let cases = [
918 (TaskStatus::Cancelled, "\"cancelled\""),
919 (TaskStatus::Completed, "\"completed\""),
920 (TaskStatus::Failed, "\"failed\""),
921 (TaskStatus::InputRequired, "\"input_required\""),
922 (TaskStatus::Working, "\"working\""),
923 ];
924 for (status, expected) in cases {
925 let json = serde_json::to_string(&status).unwrap();
926 assert_eq!(json, expected);
927 let parsed: TaskStatus = serde_json::from_str(expected).unwrap();
928 assert_eq!(parsed, status);
929 }
930 }
931}