Skip to main content

stakpak_shared/models/tools/
ask_user.rs

1//! Ask User tool types — single source of truth for MCP schema, CLI, and TUI.
2//!
3//! These types carry both `serde` and `schemars` annotations so they can be
4//! used directly in MCP tool definitions (schema generation) **and** for
5//! runtime (de)serialization in the TUI / CLI.
6
7use rmcp::schemars;
8use serde::{Deserialize, Serialize};
9
10// ---------------------------------------------------------------------------
11// Request (LLM → tool)
12// ---------------------------------------------------------------------------
13
14/// Request payload for the `ask_user` tool.
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
16pub struct AskUserRequest {
17    #[schemars(
18        description = "List of questions to ask the user. Each question has a label, question text, and options."
19    )]
20    pub questions: Vec<AskUserQuestion>,
21}
22
23/// A single question presented to the user.
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
25pub struct AskUserQuestion {
26    #[schemars(description = "Short unique label for tab display (max ~15 chars recommended)")]
27    pub label: String,
28    #[schemars(description = "Full question text to display")]
29    pub question: String,
30    #[schemars(description = "Predefined answer options")]
31    pub options: Vec<AskUserOption>,
32    /// Whether to allow custom text input (default: true)
33    #[serde(default = "default_true")]
34    #[schemars(description = "Whether to allow custom text input (default: true)")]
35    pub allow_custom: bool,
36    /// Whether this question must be answered (default: true)
37    #[serde(default = "default_true")]
38    #[schemars(description = "Whether this question must be answered (default: true)")]
39    pub required: bool,
40}
41
42/// A predefined answer option for a question.
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
44pub struct AskUserOption {
45    #[schemars(description = "Value to return to LLM when selected")]
46    pub value: String,
47    #[schemars(description = "Display label for the option")]
48    pub label: String,
49    /// Optional description shown below the label.
50    #[serde(skip_serializing_if = "Option::is_none")]
51    #[schemars(description = "Optional description shown below the label")]
52    pub description: Option<String>,
53}
54
55// ---------------------------------------------------------------------------
56// Response (tool → LLM)
57// ---------------------------------------------------------------------------
58
59/// User's answer to a single question.
60#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
61pub struct AskUserAnswer {
62    /// Question label this answers.
63    pub question_label: String,
64    /// Selected option value OR custom text.
65    pub answer: String,
66    /// Whether this was a custom answer (typed by user).
67    pub is_custom: bool,
68}
69
70/// Aggregated result of the `ask_user` tool.
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
72pub struct AskUserResult {
73    /// All answers provided by the user.
74    pub answers: Vec<AskUserAnswer>,
75    /// Whether the user completed all questions (false if cancelled).
76    pub completed: bool,
77    /// Reason for incompletion (if cancelled).
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub reason: Option<String>,
80}
81
82// ---------------------------------------------------------------------------
83// Helpers
84// ---------------------------------------------------------------------------
85
86fn default_true() -> bool {
87    true
88}
89
90// ---------------------------------------------------------------------------
91// Tests
92// ---------------------------------------------------------------------------
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_question_serialization() {
100        let question = AskUserQuestion {
101            label: "Environment".to_string(),
102            question: "Which environment should I deploy to?".to_string(),
103            options: vec![
104                AskUserOption {
105                    value: "dev".to_string(),
106                    label: "Development".to_string(),
107                    description: Some("For testing changes".to_string()),
108                },
109                AskUserOption {
110                    value: "prod".to_string(),
111                    label: "Production".to_string(),
112                    description: None,
113                },
114            ],
115            allow_custom: true,
116            required: true,
117        };
118
119        let json = serde_json::to_string(&question).unwrap();
120        assert!(json.contains("\"label\":\"Environment\""));
121        assert!(json.contains("\"value\":\"dev\""));
122        assert!(json.contains("\"description\":\"For testing changes\""));
123        // description: None should be skipped
124        assert!(!json.contains("\"description\":null"));
125    }
126
127    #[test]
128    fn test_question_deserialization_with_defaults() {
129        let json = r#"{
130            "label": "Test",
131            "question": "Is this a test?",
132            "options": []
133        }"#;
134
135        let question: AskUserQuestion = serde_json::from_str(json).unwrap();
136        assert_eq!(question.label, "Test");
137        assert!(question.allow_custom, "allow_custom should default to true");
138        assert!(question.required, "required should default to true");
139    }
140
141    #[test]
142    fn test_question_deserialization_explicit_false() {
143        let json = r#"{
144            "label": "Test",
145            "question": "Is this a test?",
146            "options": [],
147            "allow_custom": false,
148            "required": false
149        }"#;
150
151        let question: AskUserQuestion = serde_json::from_str(json).unwrap();
152        assert!(!question.allow_custom);
153        assert!(!question.required);
154    }
155
156    #[test]
157    fn test_answer_serialization() {
158        let answer = AskUserAnswer {
159            question_label: "Environment".to_string(),
160            answer: "production".to_string(),
161            is_custom: false,
162        };
163
164        let json = serde_json::to_string(&answer).unwrap();
165        assert!(json.contains("\"question_label\":\"Environment\""));
166        assert!(json.contains("\"answer\":\"production\""));
167        assert!(json.contains("\"is_custom\":false"));
168    }
169
170    #[test]
171    fn test_answer_custom_input() {
172        let answer = AskUserAnswer {
173            question_label: "Feedback".to_string(),
174            answer: "User typed this custom response".to_string(),
175            is_custom: true,
176        };
177
178        let json = serde_json::to_string(&answer).unwrap();
179        assert!(json.contains("\"is_custom\":true"));
180        assert!(json.contains("User typed this custom response"));
181    }
182
183    #[test]
184    fn test_result_completed() {
185        let result = AskUserResult {
186            answers: vec![
187                AskUserAnswer {
188                    question_label: "q1".to_string(),
189                    answer: "a1".to_string(),
190                    is_custom: false,
191                },
192                AskUserAnswer {
193                    question_label: "q2".to_string(),
194                    answer: "custom answer".to_string(),
195                    is_custom: true,
196                },
197            ],
198            completed: true,
199            reason: None,
200        };
201
202        let json = serde_json::to_string(&result).unwrap();
203        assert!(json.contains("\"completed\":true"));
204        // reason: None should be skipped
205        assert!(!json.contains("\"reason\""));
206        assert!(json.contains("\"question_label\":\"q1\""));
207        assert!(json.contains("\"question_label\":\"q2\""));
208    }
209
210    #[test]
211    fn test_result_cancelled() {
212        let result = AskUserResult {
213            answers: vec![],
214            completed: false,
215            reason: Some("User cancelled the question prompt.".to_string()),
216        };
217
218        let json = serde_json::to_string(&result).unwrap();
219        assert!(json.contains("\"completed\":false"));
220        assert!(json.contains("\"reason\":\"User cancelled the question prompt.\""));
221        assert!(json.contains("\"answers\":[]"));
222    }
223
224    #[test]
225    fn test_result_deserialization() {
226        let json = r#"{
227            "answers": [
228                {"question_label": "env", "answer": "dev", "is_custom": false}
229            ],
230            "completed": true
231        }"#;
232
233        let result: AskUserResult = serde_json::from_str(json).unwrap();
234        assert!(result.completed);
235        assert!(result.reason.is_none());
236        assert_eq!(result.answers.len(), 1);
237        assert_eq!(result.answers[0].question_label, "env");
238        assert_eq!(result.answers[0].answer, "dev");
239        assert!(!result.answers[0].is_custom);
240    }
241
242    #[test]
243    fn test_option_without_description() {
244        let option = AskUserOption {
245            value: "yes".to_string(),
246            label: "Yes".to_string(),
247            description: None,
248        };
249
250        let json = serde_json::to_string(&option).unwrap();
251        // description should be omitted entirely when None
252        assert!(!json.contains("description"));
253        assert!(json.contains("\"value\":\"yes\""));
254        assert!(json.contains("\"label\":\"Yes\""));
255    }
256
257    #[test]
258    fn test_unicode_handling() {
259        let question = AskUserQuestion {
260            label: "言語".to_string(),
261            question: "どの言語を使用しますか?".to_string(),
262            options: vec![
263                AskUserOption {
264                    value: "ja".to_string(),
265                    label: "日本語".to_string(),
266                    description: Some("Japanese language".to_string()),
267                },
268                AskUserOption {
269                    value: "emoji".to_string(),
270                    label: "🚀 Rocket".to_string(),
271                    description: Some("With emoji 🎉".to_string()),
272                },
273            ],
274            allow_custom: true,
275            required: true,
276        };
277
278        let json = serde_json::to_string(&question).unwrap();
279        let parsed: AskUserQuestion = serde_json::from_str(&json).unwrap();
280
281        assert_eq!(parsed.label, "言語");
282        assert_eq!(parsed.question, "どの言語を使用しますか?");
283        assert_eq!(parsed.options[0].label, "日本語");
284        assert_eq!(parsed.options[1].label, "🚀 Rocket");
285    }
286
287    #[test]
288    fn test_types_equality() {
289        let q1 = AskUserQuestion {
290            label: "Test".to_string(),
291            question: "Question?".to_string(),
292            options: vec![],
293            allow_custom: true,
294            required: true,
295        };
296
297        let q2 = q1.clone();
298        assert_eq!(q1, q2);
299
300        let a1 = AskUserAnswer {
301            question_label: "Test".to_string(),
302            answer: "answer".to_string(),
303            is_custom: false,
304        };
305
306        let a2 = a1.clone();
307        assert_eq!(a1, a2);
308
309        let r1 = AskUserResult {
310            answers: vec![a1],
311            completed: true,
312            reason: None,
313        };
314
315        let r2 = r1.clone();
316        assert_eq!(r1, r2);
317    }
318
319    #[test]
320    fn test_request_round_trip() {
321        let request = AskUserRequest {
322            questions: vec![AskUserQuestion {
323                label: "Env".to_string(),
324                question: "Which env?".to_string(),
325                options: vec![AskUserOption {
326                    value: "dev".to_string(),
327                    label: "Dev".to_string(),
328                    description: None,
329                }],
330                allow_custom: false,
331                required: true,
332            }],
333        };
334
335        let json = serde_json::to_string(&request).unwrap();
336        let parsed: AskUserRequest = serde_json::from_str(&json).unwrap();
337        assert_eq!(request, parsed);
338    }
339}