1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[cfg(not(feature = "std"))]
10use alloc::{collections::BTreeMap as HashMap, format, string::String, vec::Vec};
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14use crate::content::{Role, SamplingContent, SamplingContentBlock};
15use crate::definitions::Tool;
16
17#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
23pub struct TaskMetadata {
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub ttl: Option<u64>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct Task {
32 #[serde(rename = "taskId")]
34 pub task_id: String,
35 pub status: TaskStatus,
37 #[serde(rename = "statusMessage", skip_serializing_if = "Option::is_none")]
39 pub status_message: Option<String>,
40 #[serde(rename = "createdAt")]
42 pub created_at: String,
43 #[serde(rename = "lastUpdatedAt")]
45 pub last_updated_at: String,
46 pub ttl: Option<u64>,
48 #[serde(rename = "pollInterval", skip_serializing_if = "Option::is_none")]
50 pub poll_interval: Option<u64>,
51}
52
53#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
55#[serde(rename_all = "snake_case")]
56pub enum TaskStatus {
57 Cancelled,
59 Completed,
61 Failed,
63 InputRequired,
65 Working,
67}
68
69impl core::fmt::Display for TaskStatus {
70 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
71 match self {
72 Self::Cancelled => f.write_str("cancelled"),
73 Self::Completed => f.write_str("completed"),
74 Self::Failed => f.write_str("failed"),
75 Self::InputRequired => f.write_str("input_required"),
76 Self::Working => f.write_str("working"),
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
83pub struct CreateTaskResult {
84 pub task: Task,
86 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
88 pub meta: Option<HashMap<String, Value>>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93pub struct ListTasksResult {
94 pub tasks: Vec<Task>,
96 #[serde(rename = "nextCursor", skip_serializing_if = "Option::is_none")]
98 pub next_cursor: Option<String>,
99 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
101 pub meta: Option<HashMap<String, Value>>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108pub struct RelatedTaskMetadata {
109 #[serde(rename = "taskId")]
111 pub task_id: String,
112}
113
114#[derive(Debug, Clone, PartialEq)]
124pub enum ElicitRequestParams {
125 Form(ElicitRequestFormParams),
127 Url(ElicitRequestURLParams),
129}
130
131impl ElicitRequestParams {
132 #[must_use]
134 pub fn form(message: impl Into<String>, requested_schema: Value) -> Self {
135 Self::Form(ElicitRequestFormParams {
136 message: message.into(),
137 requested_schema,
138 task: None,
139 meta: None,
140 })
141 }
142
143 #[must_use]
145 pub fn url(
146 message: impl Into<String>,
147 url: impl Into<String>,
148 elicitation_id: impl Into<String>,
149 ) -> Self {
150 Self::Url(ElicitRequestURLParams {
151 message: message.into(),
152 url: url.into(),
153 elicitation_id: elicitation_id.into(),
154 task: None,
155 meta: None,
156 })
157 }
158
159 #[must_use]
161 pub fn message(&self) -> &str {
162 match self {
163 Self::Form(p) => &p.message,
164 Self::Url(p) => &p.message,
165 }
166 }
167
168 #[must_use]
170 pub fn task(&self) -> Option<&TaskMetadata> {
171 match self {
172 Self::Form(p) => p.task.as_ref(),
173 Self::Url(p) => p.task.as_ref(),
174 }
175 }
176
177 #[must_use]
179 pub fn meta(&self) -> Option<&HashMap<String, Value>> {
180 match self {
181 Self::Form(p) => p.meta.as_ref(),
182 Self::Url(p) => p.meta.as_ref(),
183 }
184 }
185}
186
187impl Serialize for ElicitRequestParams {
188 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
189 match self {
190 Self::Form(params) => {
191 let mut value = serde_json::to_value(params).map_err(serde::ser::Error::custom)?;
193 if let Some(obj) = value.as_object_mut() {
194 obj.insert("mode".into(), Value::String("form".into()));
195 }
196 value.serialize(serializer)
197 }
198 Self::Url(params) => {
199 let mut value = serde_json::to_value(params).map_err(serde::ser::Error::custom)?;
201 if let Some(obj) = value.as_object_mut() {
202 obj.insert("mode".into(), Value::String("url".into()));
203 }
204 value.serialize(serializer)
205 }
206 }
207 }
208}
209
210impl<'de> Deserialize<'de> for ElicitRequestParams {
211 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
212 let value = Value::deserialize(deserializer)?;
213
214 match value.get("mode") {
215 None => {
216 let params: ElicitRequestFormParams =
217 serde_json::from_value(value).map_err(serde::de::Error::custom)?;
218 Ok(Self::Form(params))
219 }
220 Some(Value::String(mode)) if mode == "form" => {
221 let params: ElicitRequestFormParams =
222 serde_json::from_value(value).map_err(serde::de::Error::custom)?;
223 Ok(Self::Form(params))
224 }
225 Some(Value::String(mode)) if mode == "url" => {
226 let params: ElicitRequestURLParams =
227 serde_json::from_value(value).map_err(serde::de::Error::custom)?;
228 Ok(Self::Url(params))
229 }
230 Some(Value::String(mode)) => Err(serde::de::Error::custom(format!(
231 "unsupported elicitation mode `{mode}`"
232 ))),
233 Some(_) => Err(serde::de::Error::custom(
234 "elicitation mode must be a string when present",
235 )),
236 }
237 }
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
242pub struct ElicitRequestFormParams {
243 pub message: String,
245 #[serde(rename = "requestedSchema")]
247 pub requested_schema: Value,
248 #[serde(skip_serializing_if = "Option::is_none")]
250 pub task: Option<TaskMetadata>,
251 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
253 pub meta: Option<HashMap<String, Value>>,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
258pub struct ElicitRequestURLParams {
259 pub message: String,
261 pub url: String,
263 #[serde(rename = "elicitationId")]
265 pub elicitation_id: String,
266 #[serde(skip_serializing_if = "Option::is_none")]
268 pub task: Option<TaskMetadata>,
269 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
271 pub meta: Option<HashMap<String, Value>>,
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
276pub struct ElicitResult {
277 pub action: ElicitAction,
279 #[serde(skip_serializing_if = "Option::is_none")]
282 pub content: Option<Value>,
283 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
285 pub meta: Option<HashMap<String, Value>>,
286}
287
288#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
290#[serde(rename_all = "lowercase")]
291pub enum ElicitAction {
292 Accept,
294 Decline,
296 Cancel,
298}
299
300impl core::fmt::Display for ElicitAction {
301 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
302 match self {
303 Self::Accept => f.write_str("accept"),
304 Self::Decline => f.write_str("decline"),
305 Self::Cancel => f.write_str("cancel"),
306 }
307 }
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
314pub struct ElicitationCompleteNotification {
315 #[serde(rename = "elicitationId")]
317 pub elicitation_id: String,
318}
319
320#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
326pub struct CreateMessageRequest {
327 #[serde(default)]
329 pub messages: Vec<SamplingMessage>,
330 #[serde(rename = "maxTokens")]
332 pub max_tokens: u32,
333 #[serde(rename = "modelPreferences", skip_serializing_if = "Option::is_none")]
335 pub model_preferences: Option<ModelPreferences>,
336 #[serde(rename = "systemPrompt", skip_serializing_if = "Option::is_none")]
338 pub system_prompt: Option<String>,
339 #[serde(rename = "includeContext", skip_serializing_if = "Option::is_none")]
341 pub include_context: Option<IncludeContext>,
342 #[serde(skip_serializing_if = "Option::is_none")]
344 pub temperature: Option<f64>,
345 #[serde(rename = "stopSequences", skip_serializing_if = "Option::is_none")]
347 pub stop_sequences: Option<Vec<String>>,
348 #[serde(skip_serializing_if = "Option::is_none")]
350 pub task: Option<TaskMetadata>,
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub tools: Option<Vec<Tool>>,
354 #[serde(rename = "toolChoice", skip_serializing_if = "Option::is_none")]
356 pub tool_choice: Option<ToolChoice>,
357 #[serde(skip_serializing_if = "Option::is_none")]
359 pub metadata: Option<Value>,
360 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
362 pub meta: Option<HashMap<String, Value>>,
363}
364
365#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
370pub struct SamplingMessage {
371 pub role: Role,
373 pub content: SamplingContentBlock,
375 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
377 pub meta: Option<HashMap<String, Value>>,
378}
379
380impl SamplingMessage {
381 #[must_use]
383 pub fn user(text: impl Into<String>) -> Self {
384 Self {
385 role: Role::User,
386 content: SamplingContent::text(text).into(),
387 meta: None,
388 }
389 }
390
391 #[must_use]
393 pub fn assistant(text: impl Into<String>) -> Self {
394 Self {
395 role: Role::Assistant,
396 content: SamplingContent::text(text).into(),
397 meta: None,
398 }
399 }
400}
401
402#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
404pub struct ModelPreferences {
405 #[serde(skip_serializing_if = "Option::is_none")]
407 pub hints: Option<Vec<ModelHint>>,
408 #[serde(rename = "costPriority", skip_serializing_if = "Option::is_none")]
410 pub cost_priority: Option<f64>,
411 #[serde(rename = "speedPriority", skip_serializing_if = "Option::is_none")]
413 pub speed_priority: Option<f64>,
414 #[serde(
416 rename = "intelligencePriority",
417 skip_serializing_if = "Option::is_none"
418 )]
419 pub intelligence_priority: Option<f64>,
420}
421
422#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
426pub struct ModelHint {
427 #[serde(skip_serializing_if = "Option::is_none")]
429 pub name: Option<String>,
430}
431
432impl core::fmt::Display for IncludeContext {
433 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
434 match self {
435 Self::AllServers => f.write_str("allServers"),
436 Self::ThisServer => f.write_str("thisServer"),
437 Self::None => f.write_str("none"),
438 }
439 }
440}
441
442#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
446pub enum IncludeContext {
447 #[serde(rename = "allServers")]
449 AllServers,
450 #[serde(rename = "thisServer")]
452 ThisServer,
453 #[serde(rename = "none")]
455 None,
456}
457
458#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
462pub struct ToolChoice {
463 #[serde(skip_serializing_if = "Option::is_none")]
465 pub mode: Option<ToolChoiceMode>,
466}
467
468#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
470#[serde(rename_all = "lowercase")]
471pub enum ToolChoiceMode {
472 Auto,
474 None,
476 Required,
478}
479
480impl core::fmt::Display for ToolChoiceMode {
481 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
482 match self {
483 Self::Auto => f.write_str("auto"),
484 Self::None => f.write_str("none"),
485 Self::Required => f.write_str("required"),
486 }
487 }
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
495pub struct CreateMessageResult {
496 pub role: Role,
498 pub content: SamplingContentBlock,
500 pub model: String,
502 #[serde(rename = "stopReason", skip_serializing_if = "Option::is_none")]
504 pub stop_reason: Option<String>,
505 #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
507 pub meta: Option<HashMap<String, Value>>,
508}
509
510#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
518pub struct ClientCapabilities {
519 #[serde(skip_serializing_if = "Option::is_none")]
521 pub roots: Option<RootsCapabilities>,
522 #[serde(skip_serializing_if = "Option::is_none")]
524 pub sampling: Option<SamplingCapabilities>,
525 #[serde(skip_serializing_if = "Option::is_none")]
527 pub elicitation: Option<ElicitationCapabilities>,
528 #[serde(skip_serializing_if = "Option::is_none")]
533 pub tasks: Option<ClientTasksCapabilities>,
534 #[serde(skip_serializing_if = "Option::is_none")]
536 pub extensions: Option<HashMap<String, Value>>,
537 #[serde(skip_serializing_if = "Option::is_none")]
539 pub experimental: Option<HashMap<String, Value>>,
540}
541
542#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
551pub struct ElicitationCapabilities {
552 #[serde(skip_serializing_if = "Option::is_none")]
556 pub form: Option<ElicitationFormCapabilities>,
557 #[serde(skip_serializing_if = "Option::is_none")]
561 pub url: Option<ElicitationUrlCapabilities>,
562 #[serde(rename = "schemaValidation", skip_serializing_if = "Option::is_none")]
567 pub schema_validation: Option<bool>,
568}
569
570impl ElicitationCapabilities {
571 #[must_use]
573 pub fn full() -> Self {
574 Self {
575 form: Some(ElicitationFormCapabilities {}),
576 url: Some(ElicitationUrlCapabilities {}),
577 schema_validation: None,
578 }
579 }
580
581 #[must_use]
583 pub fn form_only() -> Self {
584 Self {
585 form: Some(ElicitationFormCapabilities {}),
586 url: None,
587 schema_validation: None,
588 }
589 }
590
591 #[must_use]
595 pub fn supports_form(&self) -> bool {
596 self.form.is_some() || (self.form.is_none() && self.url.is_none())
597 }
598
599 #[must_use]
601 pub fn supports_url(&self) -> bool {
602 self.url.is_some()
603 }
604
605 #[must_use]
607 pub fn with_schema_validation(mut self) -> Self {
608 self.schema_validation = Some(true);
609 self
610 }
611
612 #[must_use]
614 pub fn without_schema_validation(mut self) -> Self {
615 self.schema_validation = Some(false);
616 self
617 }
618}
619
620#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
622pub struct ElicitationFormCapabilities {}
623
624#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
626pub struct ElicitationUrlCapabilities {}
627
628#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
632pub struct SamplingCapabilities {
633 #[serde(skip_serializing_if = "Option::is_none")]
635 pub context: Option<HashMap<String, Value>>,
636 #[serde(skip_serializing_if = "Option::is_none")]
638 pub tools: Option<HashMap<String, Value>>,
639}
640
641#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
643pub struct RootsCapabilities {
644 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
646 pub list_changed: Option<bool>,
647}
648
649#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
653pub struct ClientTasksCapabilities {
654 #[serde(skip_serializing_if = "Option::is_none")]
656 pub list: Option<TasksListCapabilities>,
657 #[serde(skip_serializing_if = "Option::is_none")]
659 pub cancel: Option<TasksCancelCapabilities>,
660 #[serde(skip_serializing_if = "Option::is_none")]
662 pub requests: Option<ClientTasksRequestsCapabilities>,
663}
664
665#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
667pub struct ClientTasksRequestsCapabilities {
668 #[serde(skip_serializing_if = "Option::is_none")]
670 pub sampling: Option<TasksSamplingCapabilities>,
671 #[serde(skip_serializing_if = "Option::is_none")]
673 pub elicitation: Option<TasksElicitationCapabilities>,
674}
675
676#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
678pub struct TasksSamplingCapabilities {
679 #[serde(rename = "createMessage", skip_serializing_if = "Option::is_none")]
681 pub create_message: Option<TasksSamplingCreateMessageCapabilities>,
682}
683
684#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
686pub struct TasksSamplingCreateMessageCapabilities {}
687
688#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
690pub struct TasksElicitationCapabilities {
691 #[serde(skip_serializing_if = "Option::is_none")]
693 pub create: Option<TasksElicitationCreateCapabilities>,
694}
695
696#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
698pub struct TasksElicitationCreateCapabilities {}
699
700#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
705pub struct ServerCapabilities {
706 #[serde(skip_serializing_if = "Option::is_none")]
708 pub tools: Option<ToolsCapabilities>,
709 #[serde(skip_serializing_if = "Option::is_none")]
711 pub resources: Option<ResourcesCapabilities>,
712 #[serde(skip_serializing_if = "Option::is_none")]
714 pub prompts: Option<PromptsCapabilities>,
715 #[serde(skip_serializing_if = "Option::is_none")]
717 pub logging: Option<LoggingCapabilities>,
718 #[serde(skip_serializing_if = "Option::is_none")]
720 pub completions: Option<CompletionCapabilities>,
721 #[serde(skip_serializing_if = "Option::is_none")]
726 pub tasks: Option<ServerTasksCapabilities>,
727 #[serde(skip_serializing_if = "Option::is_none")]
729 pub extensions: Option<HashMap<String, Value>>,
730 #[serde(skip_serializing_if = "Option::is_none")]
732 pub experimental: Option<HashMap<String, Value>>,
733}
734
735#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
737pub struct ToolsCapabilities {
738 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
740 pub list_changed: Option<bool>,
741}
742
743#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
745pub struct ResourcesCapabilities {
746 #[serde(skip_serializing_if = "Option::is_none")]
748 pub subscribe: Option<bool>,
749 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
751 pub list_changed: Option<bool>,
752}
753
754#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
756pub struct PromptsCapabilities {
757 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
759 pub list_changed: Option<bool>,
760}
761
762#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
764pub struct LoggingCapabilities {}
765
766#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
768pub struct CompletionCapabilities {}
769
770#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
774pub struct ServerTasksCapabilities {
775 #[serde(skip_serializing_if = "Option::is_none")]
777 pub list: Option<TasksListCapabilities>,
778 #[serde(skip_serializing_if = "Option::is_none")]
780 pub cancel: Option<TasksCancelCapabilities>,
781 #[serde(skip_serializing_if = "Option::is_none")]
783 pub requests: Option<ServerTasksRequestsCapabilities>,
784}
785
786#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
788pub struct ServerTasksRequestsCapabilities {
789 #[serde(skip_serializing_if = "Option::is_none")]
791 pub tools: Option<TasksToolsCapabilities>,
792}
793
794#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
796pub struct TasksToolsCapabilities {
797 #[serde(skip_serializing_if = "Option::is_none")]
799 pub call: Option<TasksToolsCallCapabilities>,
800}
801
802#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
804pub struct TasksToolsCallCapabilities {}
805
806#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
808pub struct TasksListCapabilities {}
809
810#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
812pub struct TasksCancelCapabilities {}
813
814#[cfg(test)]
819mod tests {
820 use super::*;
821
822 #[test]
823 fn test_include_context_serde() {
824 let json = serde_json::to_string(&IncludeContext::ThisServer).unwrap();
826 assert_eq!(json, "\"thisServer\"");
827
828 let json = serde_json::to_string(&IncludeContext::AllServers).unwrap();
829 assert_eq!(json, "\"allServers\"");
830
831 let json = serde_json::to_string(&IncludeContext::None).unwrap();
832 assert_eq!(json, "\"none\"");
833
834 let parsed: IncludeContext = serde_json::from_str("\"thisServer\"").unwrap();
836 assert_eq!(parsed, IncludeContext::ThisServer);
837 }
838
839 #[test]
840 fn test_tool_choice_mode_optional() {
841 let tc = ToolChoice { mode: None };
843 let json = serde_json::to_string(&tc).unwrap();
844 assert_eq!(json, "{}");
845
846 let tc = ToolChoice {
848 mode: Some(ToolChoiceMode::Required),
849 };
850 let json = serde_json::to_string(&tc).unwrap();
851 assert!(json.contains("\"required\""));
852 }
853
854 #[test]
855 fn test_model_hint_name_optional() {
856 let hint = ModelHint { name: None };
857 let json = serde_json::to_string(&hint).unwrap();
858 assert_eq!(json, "{}");
859
860 let hint = ModelHint {
861 name: Some("claude".into()),
862 };
863 let json = serde_json::to_string(&hint).unwrap();
864 assert!(json.contains("\"claude\""));
865 }
866
867 #[test]
868 fn test_task_status_serde() {
869 let json = serde_json::to_string(&TaskStatus::InputRequired).unwrap();
870 assert_eq!(json, "\"input_required\"");
871
872 let json = serde_json::to_string(&TaskStatus::Working).unwrap();
873 assert_eq!(json, "\"working\"");
874 }
875
876 #[test]
877 fn test_create_message_request_default() {
878 let req = CreateMessageRequest {
880 messages: vec![SamplingMessage::user("hello")],
881 max_tokens: 100,
882 ..Default::default()
883 };
884 assert_eq!(req.messages.len(), 1);
885 assert_eq!(req.max_tokens, 100);
886 assert!(req.tools.is_none());
887 }
888
889 #[test]
890 fn test_sampling_message_content_single_or_array() {
891 let msg = SamplingMessage::user("hello");
893 let json = serde_json::to_string(&msg).unwrap();
894 assert!(json.contains("\"text\":\"hello\""));
896
897 let parsed: SamplingMessage = serde_json::from_str(&json).unwrap();
899 assert_eq!(parsed.content.as_text(), Some("hello"));
900
901 let json_array = r#"{"role":"user","content":[{"type":"text","text":"hello"},{"type":"text","text":"world"}]}"#;
903 let parsed: SamplingMessage = serde_json::from_str(json_array).unwrap();
904 match &parsed.content {
905 SamplingContentBlock::Multiple(v) => assert_eq!(v.len(), 2),
906 _ => panic!("Expected multiple content blocks"),
907 }
908 }
909
910 #[test]
911 fn test_server_capabilities_structure() {
912 let caps = ServerCapabilities {
913 tasks: Some(ServerTasksCapabilities {
914 list: Some(TasksListCapabilities {}),
915 cancel: Some(TasksCancelCapabilities {}),
916 requests: Some(ServerTasksRequestsCapabilities {
917 tools: Some(TasksToolsCapabilities {
918 call: Some(TasksToolsCallCapabilities {}),
919 }),
920 }),
921 }),
922 extensions: Some(HashMap::from([(
923 "trace".to_string(),
924 serde_json::json!({"version": "1"}),
925 )])),
926 ..Default::default()
927 };
928 let json = serde_json::to_string(&caps).unwrap();
929 let v: Value = serde_json::from_str(&json).unwrap();
930 assert!(v["tasks"]["requests"]["tools"]["call"].is_object());
932 assert!(v["extensions"]["trace"].is_object());
933 }
934
935 #[test]
937 fn test_elicit_action_serde() {
938 let cases = [
939 (ElicitAction::Accept, "\"accept\""),
940 (ElicitAction::Decline, "\"decline\""),
941 (ElicitAction::Cancel, "\"cancel\""),
942 ];
943 for (action, expected) in cases {
944 let json = serde_json::to_string(&action).unwrap();
945 assert_eq!(json, expected);
946 let parsed: ElicitAction = serde_json::from_str(expected).unwrap();
947 assert_eq!(parsed, action);
948 }
949 }
950
951 #[test]
952 fn test_elicit_result_round_trip() {
953 let result = ElicitResult {
954 action: ElicitAction::Accept,
955 content: Some(serde_json::json!({"name": "test"})),
956 meta: None,
957 };
958 let json = serde_json::to_string(&result).unwrap();
959 let parsed: ElicitResult = serde_json::from_str(&json).unwrap();
960 assert_eq!(parsed.action, ElicitAction::Accept);
961 assert!(parsed.content.is_some());
962
963 let decline = ElicitResult {
965 action: ElicitAction::Decline,
966 content: None,
967 meta: None,
968 };
969 let json = serde_json::to_string(&decline).unwrap();
970 assert!(!json.contains("\"content\""));
971 let parsed: ElicitResult = serde_json::from_str(&json).unwrap();
972 assert_eq!(parsed.action, ElicitAction::Decline);
973 assert!(parsed.content.is_none());
974 }
975
976 #[test]
978 fn test_server_capabilities_no_elicitation_or_sampling() {
979 let caps = ServerCapabilities::default();
980 let json = serde_json::to_string(&caps).unwrap();
981 assert!(!json.contains("elicitation"));
982 assert!(!json.contains("sampling"));
983
984 let caps = ServerCapabilities {
986 tools: Some(ToolsCapabilities {
987 list_changed: Some(true),
988 }),
989 resources: Some(ResourcesCapabilities {
990 subscribe: Some(true),
991 list_changed: Some(true),
992 }),
993 prompts: Some(PromptsCapabilities {
994 list_changed: Some(true),
995 }),
996 logging: Some(LoggingCapabilities {}),
997 completions: Some(CompletionCapabilities {}),
998 tasks: Some(ServerTasksCapabilities::default()),
999 extensions: Some(HashMap::from([(
1000 "trace".to_string(),
1001 serde_json::json!({"version": "1"}),
1002 )])),
1003 experimental: Some(HashMap::new()),
1004 };
1005 let json = serde_json::to_string(&caps).unwrap();
1006 assert!(!json.contains("elicitation"));
1007 assert!(!json.contains("sampling"));
1008 assert!(json.contains("extensions"));
1009 }
1010
1011 #[test]
1013 fn test_sampling_message_array_content_round_trip() {
1014 let json_array =
1015 r#"{"role":"user","content":[{"type":"text","text":"a"},{"type":"text","text":"b"}]}"#;
1016 let parsed: SamplingMessage = serde_json::from_str(json_array).unwrap();
1017 let re_serialized = serde_json::to_string(&parsed).unwrap();
1018 let re_parsed: Value = serde_json::from_str(&re_serialized).unwrap();
1019 assert!(re_parsed["content"].is_array());
1020 assert_eq!(re_parsed["content"].as_array().unwrap().len(), 2);
1021 }
1022
1023 #[test]
1025 fn test_tool_choice_mode_all_variants() {
1026 let cases = [
1027 (ToolChoiceMode::Auto, "\"auto\""),
1028 (ToolChoiceMode::None, "\"none\""),
1029 (ToolChoiceMode::Required, "\"required\""),
1030 ];
1031 for (mode, expected) in cases {
1032 let json = serde_json::to_string(&mode).unwrap();
1033 assert_eq!(json, expected);
1034 let parsed: ToolChoiceMode = serde_json::from_str(expected).unwrap();
1035 assert_eq!(parsed, mode);
1036 }
1037 }
1038
1039 #[test]
1041 fn test_elicit_request_params_form_without_mode() {
1042 let json = r#"{"message":"Enter name","requestedSchema":{"type":"object"}}"#;
1044 let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
1045 match &parsed {
1046 ElicitRequestParams::Form(params) => {
1047 assert_eq!(params.message, "Enter name");
1048 }
1049 ElicitRequestParams::Url(_) => panic!("expected Form variant"),
1050 }
1051 }
1052
1053 #[test]
1054 fn test_elicit_request_params_form_with_explicit_mode() {
1055 let json = r#"{"mode":"form","message":"Enter name","requestedSchema":{"type":"object"}}"#;
1056 let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
1057 match &parsed {
1058 ElicitRequestParams::Form(params) => {
1059 assert_eq!(params.message, "Enter name");
1060 }
1061 ElicitRequestParams::Url(_) => panic!("expected Form variant"),
1062 }
1063 }
1064
1065 #[test]
1066 fn test_elicit_request_params_url_mode() {
1067 let json = r#"{"mode":"url","message":"Authenticate","url":"https://example.com/auth","elicitationId":"e-123"}"#;
1068 let parsed: ElicitRequestParams = serde_json::from_str(json).unwrap();
1069 match &parsed {
1070 ElicitRequestParams::Url(params) => {
1071 assert_eq!(params.message, "Authenticate");
1072 assert_eq!(params.url, "https://example.com/auth");
1073 assert_eq!(params.elicitation_id, "e-123");
1074 }
1075 ElicitRequestParams::Form(_) => panic!("expected Url variant"),
1076 }
1077 }
1078
1079 #[test]
1080 fn test_elicit_request_params_rejects_unknown_mode() {
1081 let json =
1082 r#"{"mode":"unknown","message":"Enter name","requestedSchema":{"type":"object"}}"#;
1083 let err = serde_json::from_str::<ElicitRequestParams>(json).unwrap_err();
1084 assert!(err.to_string().contains("unsupported elicitation mode"));
1085 }
1086
1087 #[test]
1088 fn test_elicit_request_params_rejects_non_string_mode() {
1089 let json = r#"{"mode":true,"message":"Enter name","requestedSchema":{"type":"object"}}"#;
1090 let err = serde_json::from_str::<ElicitRequestParams>(json).unwrap_err();
1091 assert!(err.to_string().contains("mode must be a string"));
1092 }
1093
1094 #[test]
1095 fn test_elicit_request_params_form_round_trip() {
1096 let params = ElicitRequestParams::Form(ElicitRequestFormParams {
1097 message: "Enter details".into(),
1098 requested_schema: serde_json::json!({"type": "object", "properties": {"name": {"type": "string"}}}),
1099 task: None,
1100 meta: None,
1101 });
1102 let json = serde_json::to_string(¶ms).unwrap();
1103 let v: Value = serde_json::from_str(&json).unwrap();
1105 assert_eq!(v["mode"], "form");
1106 let parsed: ElicitRequestParams = serde_json::from_str(&json).unwrap();
1108 assert_eq!(parsed, params);
1109 }
1110
1111 #[test]
1112 fn test_elicit_request_params_url_round_trip() {
1113 let params = ElicitRequestParams::Url(ElicitRequestURLParams {
1114 message: "Please authenticate".into(),
1115 url: "https://example.com/oauth".into(),
1116 elicitation_id: "elicit-456".into(),
1117 task: None,
1118 meta: None,
1119 });
1120 let json = serde_json::to_string(¶ms).unwrap();
1121 let v: Value = serde_json::from_str(&json).unwrap();
1122 assert_eq!(v["mode"], "url");
1123 let parsed: ElicitRequestParams = serde_json::from_str(&json).unwrap();
1124 assert_eq!(parsed, params);
1125 }
1126
1127 #[test]
1129 fn test_task_status_all_variants() {
1130 let cases = [
1131 (TaskStatus::Cancelled, "\"cancelled\""),
1132 (TaskStatus::Completed, "\"completed\""),
1133 (TaskStatus::Failed, "\"failed\""),
1134 (TaskStatus::InputRequired, "\"input_required\""),
1135 (TaskStatus::Working, "\"working\""),
1136 ];
1137 for (status, expected) in cases {
1138 let json = serde_json::to_string(&status).unwrap();
1139 assert_eq!(json, expected);
1140 let parsed: TaskStatus = serde_json::from_str(expected).unwrap();
1141 assert_eq!(parsed, status);
1142 }
1143 }
1144}