1use serde::{Deserialize, Serialize};
11use std::fmt;
12use thiserror::Error;
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16#[serde(untagged)]
17pub enum InputValue {
18 String(String),
20 Bool(bool),
22 Number(i64),
24 Float(f64),
26}
27
28impl fmt::Display for InputValue {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 InputValue::String(s) => write!(f, "{}", s),
32 InputValue::Bool(b) => write!(f, "{}", b),
33 InputValue::Number(n) => write!(f, "{}", n),
34 InputValue::Float(fl) => write!(f, "{}", fl),
35 }
36 }
37}
38
39impl From<String> for InputValue {
40 fn from(s: String) -> Self {
41 InputValue::String(s)
42 }
43}
44
45impl From<&str> for InputValue {
46 fn from(s: &str) -> Self {
47 InputValue::String(s.to_string())
48 }
49}
50
51impl From<bool> for InputValue {
52 fn from(b: bool) -> Self {
53 InputValue::Bool(b)
54 }
55}
56
57impl From<i64> for InputValue {
58 fn from(n: i64) -> Self {
59 InputValue::Number(n)
60 }
61}
62
63impl From<f64> for InputValue {
64 fn from(f: f64) -> Self {
65 InputValue::Float(f)
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
71pub struct InputRequest {
72 pub id: String,
74 pub input_type: InputType,
76 pub title: String,
78 pub description: String,
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub default: Option<InputValue>,
83 #[serde(default)]
85 pub required: bool,
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub validation: Option<ValidationRule>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
93#[serde(tag = "type")]
94pub enum InputType {
95 String {
97 #[serde(skip_serializing_if = "Option::is_none")]
99 password: Option<bool>,
100 #[serde(skip_serializing_if = "Option::is_none")]
102 min_length: Option<usize>,
103 #[serde(skip_serializing_if = "Option::is_none")]
105 max_length: Option<usize>,
106 },
107 PickString {
109 options: Vec<String>,
111 #[serde(default)]
113 multiple: bool,
114 },
115 Number {
117 #[serde(skip_serializing_if = "Option::is_none")]
119 min: Option<i64>,
120 #[serde(skip_serializing_if = "Option::is_none")]
122 max: Option<i64>,
123 },
124 Bool {
126 #[serde(skip_serializing_if = "Option::is_none")]
128 true_label: Option<String>,
129 #[serde(skip_serializing_if = "Option::is_none")]
131 false_label: Option<String>,
132 },
133 FilePath {
135 #[serde(default)]
137 must_exist: bool,
138 #[serde(skip_serializing_if = "Option::is_none")]
140 filter: Option<String>,
141 },
142 Command {
144 command: String,
146 #[serde(default)]
148 args: Vec<String>,
149 },
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154#[serde(tag = "type")]
155pub enum ValidationRule {
156 Regex {
158 pattern: String,
160 #[serde(skip_serializing_if = "Option::is_none")]
162 message: Option<String>,
163 },
164 Custom {
166 function: String,
168 #[serde(default)]
170 params: std::collections::HashMap<String, serde_json::Value>,
171 },
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
176pub struct InputResponse {
177 pub id: String,
179 pub value: InputValue,
181 #[serde(default)]
183 pub cancelled: bool,
184}
185
186#[derive(Debug, Error)]
188pub enum InputError {
189 #[error("Invalid input type: {0}")]
191 InvalidType(String),
192 #[error("Validation failed: {0}")]
194 ValidationFailed(String),
195 #[error("Input cancelled")]
197 Cancelled,
198 #[error("IO error: {0}")]
200 IoError(#[from] std::io::Error),
201 #[error("Input timeout")]
203 Timeout,
204 #[error("Other error: {0}")]
206 Other(String),
207}
208
209pub type InputResult<T> = Result<T, InputError>;
211
212#[derive(Debug, Clone)]
214pub struct InputContext {
215 pub server_name: Option<String>,
217 pub tool_name: Option<String>,
219 pub metadata: std::collections::HashMap<String, String>,
221}
222
223impl InputContext {
224 pub fn new() -> Self {
226 Self {
227 server_name: None,
228 tool_name: None,
229 metadata: std::collections::HashMap::new(),
230 }
231 }
232
233 pub fn with_server_name(mut self, name: String) -> Self {
235 self.server_name = Some(name);
236 self
237 }
238
239 pub fn with_tool_name(mut self, name: String) -> Self {
241 self.tool_name = Some(name);
242 self
243 }
244
245 pub fn with_metadata(mut self, key: String, value: String) -> Self {
247 self.metadata.insert(key, value);
248 self
249 }
250}
251
252impl Default for InputContext {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_input_value_display() {
264 assert_eq!(InputValue::String("test".to_string()).to_string(), "test");
265 assert_eq!(InputValue::Bool(true).to_string(), "true");
266 assert_eq!(InputValue::Number(42).to_string(), "42");
267 assert_eq!(InputValue::Float(3.15).to_string(), "3.15");
268 }
269
270 #[test]
271 fn test_input_value_conversions() {
272 let string_val: InputValue = "test".into();
274 assert_eq!(string_val, InputValue::String("test".to_string()));
275
276 let bool_val: InputValue = true.into();
277 assert_eq!(bool_val, InputValue::Bool(true));
278
279 let number_val: InputValue = 42i64.into();
280 assert_eq!(number_val, InputValue::Number(42));
281
282 let float_val: InputValue = std::f64::consts::PI.into();
283 assert_eq!(float_val, InputValue::Float(std::f64::consts::PI));
284 }
285
286 #[test]
287 fn test_input_context() {
288 let ctx = InputContext::new()
289 .with_server_name("test_server".to_string())
290 .with_tool_name("test_tool".to_string())
291 .with_metadata("key1".to_string(), "value1".to_string())
292 .with_metadata("key2".to_string(), "value2".to_string());
293
294 assert_eq!(ctx.server_name, Some("test_server".to_string()));
295 assert_eq!(ctx.tool_name, Some("test_tool".to_string()));
296 assert_eq!(ctx.metadata.len(), 2);
297 assert_eq!(ctx.metadata.get("key1"), Some(&"value1".to_string()));
298 assert_eq!(ctx.metadata.get("key2"), Some(&"value2".to_string()));
299 }
300}