Skip to main content

stakpak_shared/models/tools/
ask_user.rs

1//! Ask User tool types — single source of truth for MCP schema, CLI, and TUI.
2//!
3//! These types carry both `serde` and `schemars` annotations so they can be
4//! used directly in MCP tool definitions (schema generation) **and** for
5//! runtime (de)serialization in the TUI / CLI.
6
7use rmcp::schemars;
8use serde::{Deserialize, Serialize};
9
10// ---------------------------------------------------------------------------
11// Request (LLM → tool)
12// ---------------------------------------------------------------------------
13
14/// Request payload for the `ask_user` tool.
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
16pub struct AskUserRequest {
17    #[schemars(
18        description = "List of questions to ask the user. Each question has a label, question text, and options."
19    )]
20    pub questions: Vec<AskUserQuestion>,
21}
22
23/// A single question presented to the user.
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
25pub struct AskUserQuestion {
26    #[schemars(description = "Short unique label for tab display (max ~15 chars recommended)")]
27    pub label: String,
28    #[schemars(description = "Full question text to display")]
29    pub question: String,
30    #[schemars(description = "Predefined answer options")]
31    pub options: Vec<AskUserOption>,
32    /// Whether to allow custom text input (default: true)
33    #[serde(default = "default_true")]
34    #[schemars(description = "Whether to allow custom text input (default: true)")]
35    pub allow_custom: bool,
36    /// Whether this question must be answered (default: true)
37    #[serde(default = "default_true")]
38    #[schemars(description = "Whether this question must be answered (default: true)")]
39    pub required: bool,
40    /// When true, user can select multiple options (checkbox list). Default: false (single-select).
41    #[serde(default)]
42    #[schemars(
43        description = "When true, user can select/deselect multiple options (checkbox list). Default: false (single-select radio behavior)."
44    )]
45    pub multi_select: bool,
46}
47
48/// A predefined answer option for a question.
49#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
50pub struct AskUserOption {
51    #[schemars(description = "Value to return to LLM when selected")]
52    pub value: String,
53    #[schemars(description = "Display label for the option")]
54    pub label: String,
55    /// Optional description shown below the label.
56    #[serde(skip_serializing_if = "Option::is_none")]
57    #[schemars(description = "Optional description shown below the label")]
58    pub description: Option<String>,
59    /// Default selection state for multi_select questions. Ignored for single-select.
60    #[serde(default)]
61    #[schemars(
62        description = "Default selection state when multi_select is true. Pre-marks this option as selected. Ignored for single-select questions."
63    )]
64    pub selected: bool,
65}
66
67// ---------------------------------------------------------------------------
68// Response (tool → LLM)
69// ---------------------------------------------------------------------------
70
71/// User's answer to a single question.
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
73pub struct AskUserAnswer {
74    /// Question label this answers.
75    pub question_label: String,
76    /// Selected option value OR custom text (for single-select questions).
77    /// For multi-select questions this is a JSON array string of selected values.
78    pub answer: String,
79    /// Whether this was a custom answer (typed by user).
80    pub is_custom: bool,
81    /// Selected values for multi-select questions. Empty/absent for single-select.
82    #[serde(default, skip_serializing_if = "Vec::is_empty")]
83    pub selected_values: Vec<String>,
84}
85
86/// Aggregated result of the `ask_user` tool.
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, schemars::JsonSchema)]
88pub struct AskUserResult {
89    /// All answers provided by the user.
90    pub answers: Vec<AskUserAnswer>,
91    /// Whether the user completed all questions (false if cancelled).
92    pub completed: bool,
93    /// Reason for incompletion (if cancelled).
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub reason: Option<String>,
96}
97
98// ---------------------------------------------------------------------------
99// Helpers
100// ---------------------------------------------------------------------------
101
102fn default_true() -> bool {
103    true
104}
105
106// ---------------------------------------------------------------------------
107// Tests
108// ---------------------------------------------------------------------------
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn test_question_serialization() {
116        let question = AskUserQuestion {
117            label: "Environment".to_string(),
118            question: "Which environment should I deploy to?".to_string(),
119            options: vec![
120                AskUserOption {
121                    value: "dev".to_string(),
122                    label: "Development".to_string(),
123                    description: Some("For testing changes".to_string()),
124                    selected: false,
125                },
126                AskUserOption {
127                    value: "prod".to_string(),
128                    label: "Production".to_string(),
129                    description: None,
130                    selected: false,
131                },
132            ],
133            allow_custom: true,
134            required: true,
135            multi_select: false,
136        };
137
138        let json = serde_json::to_string(&question).unwrap();
139        assert!(json.contains("\"label\":\"Environment\""));
140        assert!(json.contains("\"value\":\"dev\""));
141        assert!(json.contains("\"description\":\"For testing changes\""));
142        // description: None should be skipped
143        assert!(!json.contains("\"description\":null"));
144    }
145
146    #[test]
147    fn test_question_deserialization_with_defaults() {
148        let json = r#"{
149            "label": "Test",
150            "question": "Is this a test?",
151            "options": []
152        }"#;
153
154        let question: AskUserQuestion = serde_json::from_str(json).unwrap();
155        assert_eq!(question.label, "Test");
156        assert!(question.allow_custom, "allow_custom should default to true");
157        assert!(question.required, "required should default to true");
158    }
159
160    #[test]
161    fn test_question_deserialization_explicit_false() {
162        let json = r#"{
163            "label": "Test",
164            "question": "Is this a test?",
165            "options": [],
166            "allow_custom": false,
167            "required": false
168        }"#;
169
170        let question: AskUserQuestion = serde_json::from_str(json).unwrap();
171        assert!(!question.allow_custom);
172        assert!(!question.required);
173    }
174
175    #[test]
176    fn test_answer_serialization() {
177        let answer = AskUserAnswer {
178            question_label: "Environment".to_string(),
179            answer: "production".to_string(),
180            is_custom: false,
181            selected_values: vec![],
182        };
183
184        let json = serde_json::to_string(&answer).unwrap();
185        assert!(json.contains("\"question_label\":\"Environment\""));
186        assert!(json.contains("\"answer\":\"production\""));
187        assert!(json.contains("\"is_custom\":false"));
188    }
189
190    #[test]
191    fn test_answer_custom_input() {
192        let answer = AskUserAnswer {
193            question_label: "Feedback".to_string(),
194            answer: "User typed this custom response".to_string(),
195            is_custom: true,
196            selected_values: vec![],
197        };
198
199        let json = serde_json::to_string(&answer).unwrap();
200        assert!(json.contains("\"is_custom\":true"));
201        assert!(json.contains("User typed this custom response"));
202    }
203
204    #[test]
205    fn test_result_completed() {
206        let result = AskUserResult {
207            answers: vec![
208                AskUserAnswer {
209                    question_label: "q1".to_string(),
210                    answer: "a1".to_string(),
211                    is_custom: false,
212                    selected_values: vec![],
213                },
214                AskUserAnswer {
215                    question_label: "q2".to_string(),
216                    answer: "custom answer".to_string(),
217                    is_custom: true,
218                    selected_values: vec![],
219                },
220            ],
221            completed: true,
222            reason: None,
223        };
224
225        let json = serde_json::to_string(&result).unwrap();
226        assert!(json.contains("\"completed\":true"));
227        // reason: None should be skipped
228        assert!(!json.contains("\"reason\""));
229        assert!(json.contains("\"question_label\":\"q1\""));
230        assert!(json.contains("\"question_label\":\"q2\""));
231    }
232
233    #[test]
234    fn test_result_cancelled() {
235        let result = AskUserResult {
236            answers: vec![],
237            completed: false,
238            reason: Some("User cancelled the question prompt.".to_string()),
239        };
240
241        let json = serde_json::to_string(&result).unwrap();
242        assert!(json.contains("\"completed\":false"));
243        assert!(json.contains("\"reason\":\"User cancelled the question prompt.\""));
244        assert!(json.contains("\"answers\":[]"));
245    }
246
247    #[test]
248    fn test_result_deserialization() {
249        let json = r#"{
250            "answers": [
251                {"question_label": "env", "answer": "dev", "is_custom": false}
252            ],
253            "completed": true
254        }"#;
255
256        let result: AskUserResult = serde_json::from_str(json).unwrap();
257        assert!(result.completed);
258        assert!(result.reason.is_none());
259        assert_eq!(result.answers.len(), 1);
260        assert_eq!(result.answers[0].question_label, "env");
261        assert_eq!(result.answers[0].answer, "dev");
262        assert!(!result.answers[0].is_custom);
263    }
264
265    #[test]
266    fn test_option_without_description() {
267        let option = AskUserOption {
268            value: "yes".to_string(),
269            label: "Yes".to_string(),
270            description: None,
271            selected: false,
272        };
273
274        let json = serde_json::to_string(&option).unwrap();
275        // description should be omitted entirely when None
276        assert!(!json.contains("description"));
277        assert!(json.contains("\"value\":\"yes\""));
278        assert!(json.contains("\"label\":\"Yes\""));
279    }
280
281    #[test]
282    fn test_unicode_handling() {
283        let question = AskUserQuestion {
284            label: "言語".to_string(),
285            question: "どの言語を使用しますか?".to_string(),
286            options: vec![
287                AskUserOption {
288                    value: "ja".to_string(),
289                    label: "日本語".to_string(),
290                    description: Some("Japanese language".to_string()),
291                    selected: false,
292                },
293                AskUserOption {
294                    value: "emoji".to_string(),
295                    label: "🚀 Rocket".to_string(),
296                    description: Some("With emoji 🎉".to_string()),
297                    selected: false,
298                },
299            ],
300            allow_custom: true,
301            required: true,
302            multi_select: false,
303        };
304
305        let json = serde_json::to_string(&question).unwrap();
306        let parsed: AskUserQuestion = serde_json::from_str(&json).unwrap();
307
308        assert_eq!(parsed.label, "言語");
309        assert_eq!(parsed.question, "どの言語を使用しますか?");
310        assert_eq!(parsed.options[0].label, "日本語");
311        assert_eq!(parsed.options[1].label, "🚀 Rocket");
312    }
313
314    #[test]
315    fn test_types_equality() {
316        let q1 = AskUserQuestion {
317            label: "Test".to_string(),
318            question: "Question?".to_string(),
319            options: vec![],
320            allow_custom: true,
321            required: true,
322            multi_select: false,
323        };
324
325        let q2 = q1.clone();
326        assert_eq!(q1, q2);
327
328        let a1 = AskUserAnswer {
329            question_label: "Test".to_string(),
330            answer: "answer".to_string(),
331            is_custom: false,
332            selected_values: vec![],
333        };
334
335        let a2 = a1.clone();
336        assert_eq!(a1, a2);
337
338        let r1 = AskUserResult {
339            answers: vec![a1],
340            completed: true,
341            reason: None,
342        };
343
344        let r2 = r1.clone();
345        assert_eq!(r1, r2);
346    }
347
348    #[test]
349    fn test_request_round_trip() {
350        let request = AskUserRequest {
351            questions: vec![AskUserQuestion {
352                label: "Env".to_string(),
353                question: "Which env?".to_string(),
354                options: vec![AskUserOption {
355                    value: "dev".to_string(),
356                    label: "Dev".to_string(),
357                    description: None,
358                    selected: false,
359                }],
360                allow_custom: false,
361                required: true,
362                multi_select: false,
363            }],
364        };
365
366        let json = serde_json::to_string(&request).unwrap();
367        let parsed: AskUserRequest = serde_json::from_str(&json).unwrap();
368        assert_eq!(request, parsed);
369    }
370
371    #[test]
372    fn test_multi_select_defaults() {
373        let json = r#"{
374            "label": "Scope",
375            "question": "Which repos?",
376            "options": [
377                {"value": "a", "label": "Repo A"},
378                {"value": "b", "label": "Repo B", "selected": true}
379            ]
380        }"#;
381
382        let question: AskUserQuestion = serde_json::from_str(json).unwrap();
383        assert!(
384            !question.multi_select,
385            "multi_select should default to false"
386        );
387        assert!(
388            !question.options[0].selected,
389            "selected should default to false"
390        );
391        assert!(
392            question.options[1].selected,
393            "selected should be true when set"
394        );
395    }
396
397    #[test]
398    fn test_multi_select_question_round_trip() {
399        let question = AskUserQuestion {
400            label: "Scope".to_string(),
401            question: "Which repos should I include?".to_string(),
402            options: vec![
403                AskUserOption {
404                    value: "repo:api".to_string(),
405                    label: "~/projects/api".to_string(),
406                    description: None,
407                    selected: true,
408                },
409                AskUserOption {
410                    value: "repo:web".to_string(),
411                    label: "~/projects/web".to_string(),
412                    description: None,
413                    selected: false,
414                },
415            ],
416            allow_custom: false,
417            required: true,
418            multi_select: true,
419        };
420
421        let json = serde_json::to_string(&question).unwrap();
422        assert!(json.contains("\"multi_select\":true"));
423        assert!(json.contains("\"selected\":true"));
424
425        let parsed: AskUserQuestion = serde_json::from_str(&json).unwrap();
426        assert_eq!(question, parsed);
427    }
428
429    #[test]
430    fn test_multi_select_answer_with_selected_values() {
431        let answer = AskUserAnswer {
432            question_label: "Scope".to_string(),
433            answer: "[\"repo:api\",\"repo:web\"]".to_string(),
434            is_custom: false,
435            selected_values: vec!["repo:api".to_string(), "repo:web".to_string()],
436        };
437
438        let json = serde_json::to_string(&answer).unwrap();
439        assert!(json.contains("\"selected_values\""));
440        assert!(json.contains("repo:api"));
441        assert!(json.contains("repo:web"));
442
443        let parsed: AskUserAnswer = serde_json::from_str(&json).unwrap();
444        assert_eq!(parsed.selected_values.len(), 2);
445    }
446
447    #[test]
448    fn test_selected_values_omitted_when_empty() {
449        let answer = AskUserAnswer {
450            question_label: "Env".to_string(),
451            answer: "dev".to_string(),
452            is_custom: false,
453            selected_values: vec![],
454        };
455
456        let json = serde_json::to_string(&answer).unwrap();
457        assert!(
458            !json.contains("selected_values"),
459            "selected_values should be omitted when empty"
460        );
461    }
462
463    #[test]
464    fn test_answer_deserialization_without_selected_values() {
465        // Backward compatibility: old answers without selected_values should still parse
466        let json = r#"{"question_label": "env", "answer": "dev", "is_custom": false}"#;
467        let answer: AskUserAnswer = serde_json::from_str(json).unwrap();
468        assert!(answer.selected_values.is_empty());
469    }
470}