Skip to main content

stakpak_shared/
utils.rs

1use crate::local_store::LocalStore;
2use async_trait::async_trait;
3use rand::Rng;
4use std::fs;
5use std::path::{Path, PathBuf};
6use walkdir::DirEntry;
7
8/// Read .gitignore patterns from the specified base directory
9pub fn read_gitignore_patterns(base_dir: &str) -> Vec<String> {
10    let mut patterns = vec![".git".to_string()]; // Always ignore .git directory
11
12    let gitignore_path = PathBuf::from(base_dir).join(".gitignore");
13    if let Ok(content) = std::fs::read_to_string(&gitignore_path) {
14        for line in content.lines() {
15            let line = line.trim();
16            // Skip empty lines and comments
17            if !line.is_empty() && !line.starts_with('#') {
18                patterns.push(line.to_string());
19            }
20        }
21    }
22
23    patterns
24}
25
26/// Check if a directory entry should be included based on gitignore patterns and file type support
27pub fn should_include_entry(entry: &DirEntry, base_dir: &str, ignore_patterns: &[String]) -> bool {
28    let path = entry.path();
29    let is_file = entry.file_type().is_file();
30
31    // Get relative path from base directory
32    let base_path = PathBuf::from(base_dir);
33    let relative_path = match path.strip_prefix(&base_path) {
34        Ok(rel_path) => rel_path,
35        Err(_) => path,
36    };
37
38    let path_str = relative_path.to_string_lossy();
39
40    // Check if path matches any ignore pattern
41    for pattern in ignore_patterns {
42        if matches_gitignore_pattern(pattern, &path_str) {
43            return false;
44        }
45    }
46
47    // For files, also check if they are supported file types
48    if is_file {
49        is_supported_file(entry.path())
50    } else {
51        true // Allow directories to be traversed
52    }
53}
54
55/// Check if a path matches a gitignore pattern
56#[allow(clippy::string_slice)] // pattern[1..len-1] guarded by starts_with('*')/ends_with('*'), '*' is ASCII
57pub fn matches_gitignore_pattern(pattern: &str, path: &str) -> bool {
58    // Basic gitignore pattern matching
59    let pattern = pattern.trim_end_matches('/'); // Remove trailing slash
60
61    if pattern.contains('*') {
62        if pattern == "*" {
63            true
64        } else if pattern.starts_with('*') && pattern.ends_with('*') {
65            let middle = &pattern[1..pattern.len() - 1];
66            path.contains(middle)
67        } else if let Some(suffix) = pattern.strip_prefix('*') {
68            path.ends_with(suffix)
69        } else if let Some(prefix) = pattern.strip_suffix('*') {
70            path.starts_with(prefix)
71        } else {
72            // Pattern contains * but not at start/end, do basic glob matching
73            pattern_matches_glob(pattern, path)
74        }
75    } else {
76        // Exact match or directory match
77        path == pattern || path.starts_with(&format!("{}/", pattern))
78    }
79}
80
81/// Simple glob pattern matching for basic cases
82#[allow(clippy::string_slice)] // text_pos accumulated from starts_with/find on same string, always valid boundaries
83pub fn pattern_matches_glob(pattern: &str, text: &str) -> bool {
84    let parts: Vec<&str> = pattern.split('*').collect();
85    if parts.len() == 1 {
86        return text == pattern;
87    }
88
89    let mut text_pos = 0;
90    for (i, part) in parts.iter().enumerate() {
91        if i == 0 {
92            // First part must match at the beginning
93            if !text[text_pos..].starts_with(part) {
94                return false;
95            }
96            text_pos += part.len();
97        } else if i == parts.len() - 1 {
98            // Last part must match at the end
99            return text[text_pos..].ends_with(part);
100        } else {
101            // Middle parts must be found in order
102            if let Some(pos) = text[text_pos..].find(part) {
103                text_pos += pos + part.len();
104            } else {
105                return false;
106            }
107        }
108    }
109    true
110}
111
112/// Check if a directory entry represents a supported file type
113pub fn is_supported_file(file_path: &Path) -> bool {
114    match file_path.file_name().and_then(|name| name.to_str()) {
115        Some(name) => {
116            // Only allow supported files
117            if file_path.is_file() {
118                name.ends_with(".tf")
119                    || name.ends_with(".tfvars")
120                    || name.ends_with(".yaml")
121                    || name.ends_with(".yml")
122                    || name.to_lowercase().contains("dockerfile")
123            } else {
124                true // Allow directories to be traversed
125            }
126        }
127        None => false,
128    }
129}
130
131/// Generate a secure password with alphanumeric characters and optional symbols
132pub fn generate_password(length: usize, no_symbols: bool) -> String {
133    let mut rng = rand::rng();
134
135    // Define character sets
136    let lowercase = "abcdefghijklmnopqrstuvwxyz";
137    let uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
138    let digits = "0123456789";
139    let symbols = "!@#$%^&*()_+-=[]{}|;:,.<>?";
140
141    // Build the character set based on options
142    let mut charset = String::new();
143    charset.push_str(lowercase);
144    charset.push_str(uppercase);
145    charset.push_str(digits);
146
147    if !no_symbols {
148        charset.push_str(symbols);
149    }
150
151    let charset_chars: Vec<char> = charset.chars().collect();
152
153    // Generate password ensuring at least one character from each required category
154    let mut password = String::new();
155
156    // Ensure at least one character from each category
157    password.push(
158        lowercase
159            .chars()
160            .nth(rng.random_range(0..lowercase.len()))
161            .unwrap(),
162    );
163    password.push(
164        uppercase
165            .chars()
166            .nth(rng.random_range(0..uppercase.len()))
167            .unwrap(),
168    );
169    password.push(
170        digits
171            .chars()
172            .nth(rng.random_range(0..digits.len()))
173            .unwrap(),
174    );
175
176    if !no_symbols {
177        password.push(
178            symbols
179                .chars()
180                .nth(rng.random_range(0..symbols.len()))
181                .unwrap(),
182        );
183    }
184
185    // Fill the rest with random characters from the full charset
186    let remaining_length = if length > password.len() {
187        length - password.len()
188    } else {
189        0
190    };
191
192    for _ in 0..remaining_length {
193        let random_char = charset_chars[rng.random_range(0..charset_chars.len())];
194        password.push(random_char);
195    }
196
197    // Shuffle the password to randomize the order
198    let mut password_chars: Vec<char> = password.chars().collect();
199    for i in 0..password_chars.len() {
200        let j = rng.random_range(0..password_chars.len());
201        password_chars.swap(i, j);
202    }
203
204    // Take only the requested length
205    password_chars.into_iter().take(length).collect()
206}
207
208/// Sanitize text output by removing control characters while preserving essential whitespace
209pub fn sanitize_text_output(text: &str) -> String {
210    text.chars()
211        .filter(|&c| {
212            // Drop replacement char
213            if c == '\u{FFFD}' {
214                return false;
215            }
216            // Allow essential whitespace even though they're "control"
217            if matches!(c, '\n' | '\t' | '\r' | ' ') {
218                return true;
219            }
220            // Keep everything else that's not a control character
221            !c.is_control()
222        })
223        .collect()
224}
225
226/// Handle large output: if the output has >= `max_lines`, save the full content to session
227/// storage and return a string showing only the first or last `max_lines` lines with a pointer
228/// to the saved file. Returns `Ok(final_string)` or `Err(error_string)` on failure.
229pub fn handle_large_output(
230    output: &str,
231    file_prefix: &str,
232    max_lines: usize,
233    show_head: bool,
234) -> Result<String, String> {
235    let output_lines = output.lines().collect::<Vec<_>>();
236    if output_lines.len() >= max_lines {
237        let mut __rng__ = rand::rng();
238        let output_file = format!(
239            "{}.{:06x}.txt",
240            file_prefix,
241            __rng__.random_range(0..=0xFFFFFF)
242        );
243        let output_file_path = match LocalStore::write_session_data(&output_file, output) {
244            Ok(path) => path,
245            Err(e) => {
246                return Err(format!("Failed to write session data: {}", e));
247            }
248        };
249
250        let excerpt = if show_head {
251            let head_lines: Vec<&str> = output_lines.iter().take(max_lines).copied().collect();
252            head_lines.join("\n")
253        } else {
254            let mut tail_lines: Vec<&str> =
255                output_lines.iter().rev().take(max_lines).copied().collect();
256            tail_lines.reverse();
257            tail_lines.join("\n")
258        };
259
260        let position = if show_head { "first" } else { "last" };
261        Ok(format!(
262            "Showing the {} {} / {} output lines. Full output saved to {}\n{}\n{}",
263            position,
264            max_lines,
265            output_lines.len(),
266            output_file_path,
267            if show_head { "" } else { "...\n" },
268            excerpt
269        ))
270    } else {
271        Ok(output.to_string())
272    }
273}
274
275#[cfg(test)]
276mod password_tests {
277    use super::*;
278
279    #[test]
280    fn test_generate_password_length() {
281        let password = generate_password(10, false);
282        assert_eq!(password.len(), 10);
283
284        let password = generate_password(20, true);
285        assert_eq!(password.len(), 20);
286    }
287
288    #[test]
289    fn test_generate_password_no_symbols() {
290        let password = generate_password(50, true);
291        let symbols = "!@#$%^&*()_+-=[]{}|;:,.<>?";
292
293        for symbol in symbols.chars() {
294            assert!(
295                !password.contains(symbol),
296                "Password should not contain symbol: {}",
297                symbol
298            );
299        }
300    }
301
302    #[test]
303    fn test_generate_password_with_symbols() {
304        let password = generate_password(50, false);
305        let symbols = "!@#$%^&*()_+-=[]{}|;:,.<>?";
306
307        // At least one symbol should be present (due to our algorithm)
308        let has_symbol = password.chars().any(|c| symbols.contains(c));
309        assert!(has_symbol, "Password should contain at least one symbol");
310    }
311
312    #[test]
313    fn test_generate_password_contains_required_chars() {
314        let password = generate_password(50, false);
315
316        let has_lowercase = password.chars().any(|c| c.is_ascii_lowercase());
317        let has_uppercase = password.chars().any(|c| c.is_ascii_uppercase());
318        let has_digit = password.chars().any(|c| c.is_ascii_digit());
319
320        assert!(has_lowercase, "Password should contain lowercase letters");
321        assert!(has_uppercase, "Password should contain uppercase letters");
322        assert!(has_digit, "Password should contain digits");
323    }
324
325    #[test]
326    fn test_generate_password_uniqueness() {
327        let password1 = generate_password(20, false);
328        let password2 = generate_password(20, false);
329
330        // Very unlikely to generate the same password twice
331        assert_ne!(password1, password2);
332    }
333}
334
335/// Directory entry information for tree generation
336#[derive(Debug, Clone)]
337pub struct DirectoryEntry {
338    pub name: String,
339    pub path: String,
340    pub is_directory: bool,
341}
342
343/// Trait for abstracting file system operations for tree generation
344#[async_trait]
345pub trait FileSystemProvider {
346    type Error: std::fmt::Display;
347
348    /// List directory contents
349    async fn list_directory(&self, path: &str) -> Result<Vec<DirectoryEntry>, Self::Error>;
350}
351
352/// Generate a tree view of a directory structure using a generic file system provider
353pub async fn generate_directory_tree<P: FileSystemProvider>(
354    provider: &P,
355    path: &str,
356    prefix: &str,
357    max_depth: usize,
358    current_depth: usize,
359) -> Result<String, P::Error> {
360    let mut result = String::new();
361
362    if current_depth >= max_depth || current_depth >= 10 {
363        return Ok(result);
364    }
365
366    let entries = provider.list_directory(path).await?;
367    let mut file_entries = Vec::new();
368    let mut dir_entries = Vec::new();
369    for entry in entries.iter() {
370        if entry.is_directory {
371            if entry.name == "."
372                || entry.name == ".."
373                || entry.name == ".git"
374                || entry.name == "node_modules"
375            {
376                continue;
377            }
378            dir_entries.push(entry.clone());
379        } else {
380            file_entries.push(entry.clone());
381        }
382    }
383
384    dir_entries.sort_by(|a, b| a.name.cmp(&b.name));
385    file_entries.sort_by(|a, b| a.name.cmp(&b.name));
386
387    const MAX_ITEMS: usize = 5;
388    let total_items = dir_entries.len() + file_entries.len();
389    let should_limit = current_depth > 0 && total_items > MAX_ITEMS;
390
391    if should_limit {
392        if dir_entries.len() > MAX_ITEMS {
393            dir_entries.truncate(MAX_ITEMS);
394            file_entries.clear();
395        } else {
396            let remaining_items = MAX_ITEMS - dir_entries.len();
397            file_entries.truncate(remaining_items);
398        }
399    }
400
401    let mut dir_headers = Vec::new();
402    let mut dir_futures = Vec::new();
403    for (i, entry) in dir_entries.iter().enumerate() {
404        let is_last_dir = i == dir_entries.len() - 1;
405        let is_last_overall = is_last_dir && file_entries.is_empty() && !should_limit;
406        let current_prefix = if is_last_overall {
407            "└── "
408        } else {
409            "├── "
410        };
411        let next_prefix = format!(
412            "{}{}",
413            prefix,
414            if is_last_overall { "    " } else { "│   " }
415        );
416
417        let header = format!("{}{}{}/\n", prefix, current_prefix, entry.name);
418        dir_headers.push(header);
419
420        let entry_path = entry.path.clone();
421        let next_prefix_clone = next_prefix.clone();
422        let future = async move {
423            generate_directory_tree(
424                provider,
425                &entry_path,
426                &next_prefix_clone,
427                max_depth,
428                current_depth + 1,
429            )
430            .await
431        };
432        dir_futures.push(future);
433    }
434    if !dir_futures.is_empty() {
435        let subtree_results = futures::future::join_all(dir_futures).await;
436
437        for (i, header) in dir_headers.iter().enumerate() {
438            result.push_str(header);
439            if let Some(Ok(subtree)) = subtree_results.get(i) {
440                result.push_str(subtree);
441            }
442        }
443    }
444
445    for (i, entry) in file_entries.iter().enumerate() {
446        let is_last_file = i == file_entries.len() - 1;
447        let is_last_overall = is_last_file && !should_limit;
448        let current_prefix = if is_last_overall {
449            "└── "
450        } else {
451            "├── "
452        };
453        result.push_str(&format!("{}{}{}\n", prefix, current_prefix, entry.name));
454    }
455
456    if should_limit {
457        let remaining_count = total_items - MAX_ITEMS;
458        result.push_str(&format!(
459            "{}└── ... {} more item{}\n",
460            prefix,
461            remaining_count,
462            if remaining_count == 1 { "" } else { "s" }
463        ));
464    }
465
466    Ok(result)
467}
468
469/// Strip the MCP server prefix and any trailing "()" from a tool name.
470/// Example: "stakpak__run_command" -> "run_command"
471/// Example: "run_command" -> "run_command"
472/// Example: "str_replace()" -> "str_replace"
473pub fn strip_tool_name(name: &str) -> &str {
474    let mut result = name;
475
476    // Strip the MCP server prefix (e.g., "stakpak__")
477    if let Some((_, suffix)) = result.split_once("__") {
478        result = suffix;
479    }
480
481    // Strip trailing "()" if present
482    if let Some(stripped) = result.strip_suffix("()") {
483        result = stripped;
484    }
485
486    backward_compatibility_mapping(result)
487}
488
489/// Map legacy tool names to their current counterparts.
490/// Currently handles mapping "read_rulebook" to "load_skill".
491pub fn backward_compatibility_mapping(name: &str) -> &str {
492    match name {
493        "read_rulebook" | "read_rulebooks" => "load_skill",
494        _ => name,
495    }
496}
497
498/// Local file system provider implementation
499pub struct LocalFileSystemProvider;
500
501#[async_trait]
502impl FileSystemProvider for LocalFileSystemProvider {
503    type Error = std::io::Error;
504
505    async fn list_directory(&self, path: &str) -> Result<Vec<DirectoryEntry>, Self::Error> {
506        let entries = fs::read_dir(path)?;
507        let mut result = Vec::new();
508
509        for entry in entries {
510            let entry = entry?;
511            let file_name = entry.file_name().to_string_lossy().to_string();
512            let file_path = entry.path().to_string_lossy().to_string();
513            let is_directory = entry.file_type()?.is_dir();
514
515            result.push(DirectoryEntry {
516                name: file_name,
517                path: file_path,
518                is_directory,
519            });
520        }
521
522        Ok(result)
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use std::fs;
530    use std::io::Write;
531    use tempfile::TempDir;
532
533    #[test]
534    fn test_matches_gitignore_pattern_exact() {
535        assert!(matches_gitignore_pattern("node_modules", "node_modules"));
536        assert!(matches_gitignore_pattern(
537            "node_modules",
538            "node_modules/package.json"
539        ));
540        assert!(!matches_gitignore_pattern(
541            "node_modules",
542            "src/node_modules"
543        ));
544    }
545
546    #[test]
547    fn test_matches_gitignore_pattern_wildcard_prefix() {
548        assert!(matches_gitignore_pattern("*.log", "debug.log"));
549        assert!(matches_gitignore_pattern("*.log", "error.log"));
550        assert!(!matches_gitignore_pattern("*.log", "log.txt"));
551    }
552
553    #[test]
554    fn test_matches_gitignore_pattern_wildcard_suffix() {
555        assert!(matches_gitignore_pattern("temp*", "temp"));
556        assert!(matches_gitignore_pattern("temp*", "temp.txt"));
557        assert!(matches_gitignore_pattern("temp*", "temporary"));
558        assert!(!matches_gitignore_pattern("temp*", "mytemp"));
559    }
560
561    #[test]
562    fn test_matches_gitignore_pattern_wildcard_middle() {
563        assert!(matches_gitignore_pattern("*temp*", "temp"));
564        assert!(matches_gitignore_pattern("*temp*", "mytemp"));
565        assert!(matches_gitignore_pattern("*temp*", "temporary"));
566        assert!(matches_gitignore_pattern("*temp*", "mytemporary"));
567        assert!(!matches_gitignore_pattern("*temp*", "example"));
568    }
569
570    #[test]
571    fn test_pattern_matches_glob() {
572        assert!(pattern_matches_glob("test*.txt", "test.txt"));
573        assert!(pattern_matches_glob("test*.txt", "test123.txt"));
574        assert!(pattern_matches_glob("*test*.txt", "mytest.txt"));
575        assert!(pattern_matches_glob("*test*.txt", "mytestfile.txt"));
576        assert!(!pattern_matches_glob("test*.txt", "test.log"));
577        assert!(!pattern_matches_glob("*test*.txt", "example.txt"));
578    }
579
580    #[test]
581    fn test_read_gitignore_patterns() -> Result<(), Box<dyn std::error::Error>> {
582        let temp_dir = TempDir::new()?;
583        let temp_path = temp_dir.path();
584
585        // Create a .gitignore file
586        let gitignore_content = r#"
587# This is a comment
588node_modules
589*.log
590dist/
591.env
592
593# Another comment
594temp*
595"#;
596
597        let gitignore_path = temp_path.join(".gitignore");
598        let mut file = fs::File::create(&gitignore_path)?;
599        file.write_all(gitignore_content.as_bytes())?;
600
601        let patterns = read_gitignore_patterns(temp_path.to_str().unwrap());
602
603        // Should include .git by default
604        assert!(patterns.contains(&".git".to_string()));
605        assert!(patterns.contains(&"node_modules".to_string()));
606        assert!(patterns.contains(&"*.log".to_string()));
607        assert!(patterns.contains(&"dist/".to_string()));
608        assert!(patterns.contains(&".env".to_string()));
609        assert!(patterns.contains(&"temp*".to_string()));
610
611        // Should not include comments or empty lines
612        assert!(!patterns.iter().any(|p| p.starts_with('#')));
613        assert!(!patterns.contains(&"".to_string()));
614
615        Ok(())
616    }
617
618    #[test]
619    fn test_read_gitignore_patterns_no_file() {
620        let temp_dir = TempDir::new().unwrap();
621        let temp_path = temp_dir.path();
622
623        let patterns = read_gitignore_patterns(temp_path.to_str().unwrap());
624
625        // Should only contain .git when no .gitignore exists
626        assert_eq!(patterns, vec![".git".to_string()]);
627    }
628
629    #[test]
630    fn test_strip_tool_name() {
631        assert_eq!(strip_tool_name("stakpak__run_command"), "run_command");
632        assert_eq!(strip_tool_name("run_command"), "run_command");
633        assert_eq!(strip_tool_name("str_replace()"), "str_replace");
634        assert_eq!(strip_tool_name("stakpak__read_rulebook"), "load_skill");
635        assert_eq!(strip_tool_name("read_rulebook()"), "load_skill");
636        assert_eq!(strip_tool_name("read_rulebooks"), "load_skill");
637        // Additional edge cases
638        assert_eq!(strip_tool_name("just_name"), "just_name");
639        assert_eq!(strip_tool_name("prefix__name()"), "name");
640        assert_eq!(strip_tool_name("nested__prefix__tool"), "prefix__tool");
641        assert_eq!(strip_tool_name("empty_suffix()"), "empty_suffix");
642    }
643
644    #[test]
645    fn test_backward_compatibility_mapping() {
646        assert_eq!(
647            backward_compatibility_mapping("read_rulebook"),
648            "load_skill"
649        );
650        assert_eq!(
651            backward_compatibility_mapping("read_rulebooks"),
652            "load_skill"
653        );
654        assert_eq!(backward_compatibility_mapping("run_command"), "run_command");
655    }
656
657    #[test]
658    fn test_gitignore_integration() -> Result<(), Box<dyn std::error::Error>> {
659        let temp_dir = TempDir::new()?;
660        let temp_path = temp_dir.path();
661
662        // Create a .gitignore file
663        let gitignore_content = "node_modules\n*.log\ndist/\n";
664        let gitignore_path = temp_path.join(".gitignore");
665        let mut file = fs::File::create(&gitignore_path)?;
666        file.write_all(gitignore_content.as_bytes())?;
667
668        let patterns = read_gitignore_patterns(temp_path.to_str().unwrap());
669
670        // Test various paths
671        assert!(
672            patterns
673                .iter()
674                .any(|p| matches_gitignore_pattern(p, "node_modules"))
675        );
676        assert!(
677            patterns
678                .iter()
679                .any(|p| matches_gitignore_pattern(p, "node_modules/package.json"))
680        );
681        assert!(
682            patterns
683                .iter()
684                .any(|p| matches_gitignore_pattern(p, "debug.log"))
685        );
686        assert!(
687            patterns
688                .iter()
689                .any(|p| matches_gitignore_pattern(p, "dist/bundle.js"))
690        );
691        assert!(
692            patterns
693                .iter()
694                .any(|p| matches_gitignore_pattern(p, ".git"))
695        );
696
697        // These should not match
698        assert!(
699            !patterns
700                .iter()
701                .any(|p| matches_gitignore_pattern(p, "src/main.js"))
702        );
703        assert!(
704            !patterns
705                .iter()
706                .any(|p| matches_gitignore_pattern(p, "README.md"))
707        );
708
709        Ok(())
710    }
711}