Skip to main content

stakpak_shared/
utils.rs

1use crate::local_store::LocalStore;
2use async_trait::async_trait;
3use rand::Rng;
4use std::fs;
5use std::path::{Path, PathBuf};
6use walkdir::DirEntry;
7
8/// Read .gitignore patterns from the specified base directory
9pub fn read_gitignore_patterns(base_dir: &str) -> Vec<String> {
10    let mut patterns = vec![".git".to_string()]; // Always ignore .git directory
11
12    let gitignore_path = PathBuf::from(base_dir).join(".gitignore");
13    if let Ok(content) = std::fs::read_to_string(&gitignore_path) {
14        for line in content.lines() {
15            let line = line.trim();
16            // Skip empty lines and comments
17            if !line.is_empty() && !line.starts_with('#') {
18                patterns.push(line.to_string());
19            }
20        }
21    }
22
23    patterns
24}
25
26/// Check if a directory entry should be included based on gitignore patterns and file type support
27pub fn should_include_entry(entry: &DirEntry, base_dir: &str, ignore_patterns: &[String]) -> bool {
28    let path = entry.path();
29    let is_file = entry.file_type().is_file();
30
31    // Get relative path from base directory
32    let base_path = PathBuf::from(base_dir);
33    let relative_path = match path.strip_prefix(&base_path) {
34        Ok(rel_path) => rel_path,
35        Err(_) => path,
36    };
37
38    let path_str = relative_path.to_string_lossy();
39
40    // Check if path matches any ignore pattern
41    for pattern in ignore_patterns {
42        if matches_gitignore_pattern(pattern, &path_str) {
43            return false;
44        }
45    }
46
47    // For files, also check if they are supported file types
48    if is_file {
49        is_supported_file(entry.path())
50    } else {
51        true // Allow directories to be traversed
52    }
53}
54
55/// Check if a path matches a gitignore pattern
56#[allow(clippy::string_slice)] // pattern[1..len-1] guarded by starts_with('*')/ends_with('*'), '*' is ASCII
57pub fn matches_gitignore_pattern(pattern: &str, path: &str) -> bool {
58    // Basic gitignore pattern matching
59    let pattern = pattern.trim_end_matches('/'); // Remove trailing slash
60
61    if pattern.contains('*') {
62        if pattern == "*" {
63            true
64        } else if pattern.starts_with('*') && pattern.ends_with('*') {
65            let middle = &pattern[1..pattern.len() - 1];
66            path.contains(middle)
67        } else if let Some(suffix) = pattern.strip_prefix('*') {
68            path.ends_with(suffix)
69        } else if let Some(prefix) = pattern.strip_suffix('*') {
70            path.starts_with(prefix)
71        } else {
72            // Pattern contains * but not at start/end, do basic glob matching
73            pattern_matches_glob(pattern, path)
74        }
75    } else {
76        // Exact match or directory match
77        path == pattern || path.starts_with(&format!("{}/", pattern))
78    }
79}
80
81/// Simple glob pattern matching for basic cases
82#[allow(clippy::string_slice)] // text_pos accumulated from starts_with/find on same string, always valid boundaries
83pub fn pattern_matches_glob(pattern: &str, text: &str) -> bool {
84    let parts: Vec<&str> = pattern.split('*').collect();
85    if parts.len() == 1 {
86        return text == pattern;
87    }
88
89    let mut text_pos = 0;
90    for (i, part) in parts.iter().enumerate() {
91        if i == 0 {
92            // First part must match at the beginning
93            if !text[text_pos..].starts_with(part) {
94                return false;
95            }
96            text_pos += part.len();
97        } else if i == parts.len() - 1 {
98            // Last part must match at the end
99            return text[text_pos..].ends_with(part);
100        } else {
101            // Middle parts must be found in order
102            if let Some(pos) = text[text_pos..].find(part) {
103                text_pos += pos + part.len();
104            } else {
105                return false;
106            }
107        }
108    }
109    true
110}
111
112/// Check if a directory entry represents a supported file type
113pub fn is_supported_file(file_path: &Path) -> bool {
114    match file_path.file_name().and_then(|name| name.to_str()) {
115        Some(name) => {
116            // Only allow supported files
117            if file_path.is_file() {
118                name.ends_with(".tf")
119                    || name.ends_with(".tfvars")
120                    || name.ends_with(".yaml")
121                    || name.ends_with(".yml")
122                    || name.to_lowercase().contains("dockerfile")
123            } else {
124                true // Allow directories to be traversed
125            }
126        }
127        None => false,
128    }
129}
130
131/// Generate a secure password with alphanumeric characters and optional symbols
132pub fn generate_password(length: usize, no_symbols: bool) -> String {
133    let mut rng = rand::rng();
134
135    // Define character sets
136    let lowercase = "abcdefghijklmnopqrstuvwxyz";
137    let uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ";
138    let digits = "0123456789";
139    let symbols = "!@#$%^&*()_+-=[]{}|;:,.<>?";
140
141    // Build the character set based on options
142    let mut charset = String::new();
143    charset.push_str(lowercase);
144    charset.push_str(uppercase);
145    charset.push_str(digits);
146
147    if !no_symbols {
148        charset.push_str(symbols);
149    }
150
151    let charset_chars: Vec<char> = charset.chars().collect();
152
153    // Generate password ensuring at least one character from each required category
154    let mut password = String::new();
155
156    // Ensure at least one character from each category
157    password.push(
158        lowercase
159            .chars()
160            .nth(rng.random_range(0..lowercase.len()))
161            .unwrap(),
162    );
163    password.push(
164        uppercase
165            .chars()
166            .nth(rng.random_range(0..uppercase.len()))
167            .unwrap(),
168    );
169    password.push(
170        digits
171            .chars()
172            .nth(rng.random_range(0..digits.len()))
173            .unwrap(),
174    );
175
176    if !no_symbols {
177        password.push(
178            symbols
179                .chars()
180                .nth(rng.random_range(0..symbols.len()))
181                .unwrap(),
182        );
183    }
184
185    // Fill the rest with random characters from the full charset
186    let remaining_length = if length > password.len() {
187        length - password.len()
188    } else {
189        0
190    };
191
192    for _ in 0..remaining_length {
193        let random_char = charset_chars[rng.random_range(0..charset_chars.len())];
194        password.push(random_char);
195    }
196
197    // Shuffle the password to randomize the order
198    let mut password_chars: Vec<char> = password.chars().collect();
199    for i in 0..password_chars.len() {
200        let j = rng.random_range(0..password_chars.len());
201        password_chars.swap(i, j);
202    }
203
204    // Take only the requested length
205    password_chars.into_iter().take(length).collect()
206}
207
208/// Sanitize text output by removing control characters while preserving essential whitespace
209pub fn sanitize_text_output(text: &str) -> String {
210    text.chars()
211        .filter(|&c| {
212            // Drop replacement char
213            if c == '\u{FFFD}' {
214                return false;
215            }
216            // Allow essential whitespace even though they're "control"
217            if matches!(c, '\n' | '\t' | '\r' | ' ') {
218                return true;
219            }
220            // Keep everything else that's not a control character
221            !c.is_control()
222        })
223        .collect()
224}
225
226/// Truncate a string by character count and append `...` when truncated.
227///
228/// Uses char iteration (not byte slicing) so it is UTF-8 safe.
229pub fn truncate_chars_with_ellipsis(text: &str, max_chars: usize) -> String {
230    if text.chars().count() <= max_chars {
231        return text.to_string();
232    }
233
234    let mut truncated: String = text.chars().take(max_chars).collect();
235    truncated.push_str("...");
236    truncated
237}
238
239/// Handle large output: if the output has >= `max_lines`, save the full content to session
240/// storage and return a string showing only the first or last `max_lines` lines with a pointer
241/// to the saved file. Returns `Ok(final_string)` or `Err(error_string)` on failure.
242pub fn handle_large_output(
243    output: &str,
244    file_prefix: &str,
245    max_lines: usize,
246    show_head: bool,
247) -> Result<String, String> {
248    let output_lines = output.lines().collect::<Vec<_>>();
249    if output_lines.len() >= max_lines {
250        let mut __rng__ = rand::rng();
251        let output_file = format!(
252            "{}.{:06x}.txt",
253            file_prefix,
254            __rng__.random_range(0..=0xFFFFFF)
255        );
256        let output_file_path = match LocalStore::write_session_data(&output_file, output) {
257            Ok(path) => path,
258            Err(e) => {
259                return Err(format!("Failed to write session data: {}", e));
260            }
261        };
262
263        let excerpt = if show_head {
264            let head_lines: Vec<&str> = output_lines.iter().take(max_lines).copied().collect();
265            head_lines.join("\n")
266        } else {
267            let mut tail_lines: Vec<&str> =
268                output_lines.iter().rev().take(max_lines).copied().collect();
269            tail_lines.reverse();
270            tail_lines.join("\n")
271        };
272
273        let position = if show_head { "first" } else { "last" };
274        Ok(format!(
275            "Showing the {} {} / {} output lines. Full output saved to {}\n{}\n{}",
276            position,
277            max_lines,
278            output_lines.len(),
279            output_file_path,
280            if show_head { "" } else { "...\n" },
281            excerpt
282        ))
283    } else {
284        Ok(output.to_string())
285    }
286}
287
288#[cfg(test)]
289mod password_tests {
290    use super::*;
291
292    #[test]
293    fn test_generate_password_length() {
294        let password = generate_password(10, false);
295        assert_eq!(password.len(), 10);
296
297        let password = generate_password(20, true);
298        assert_eq!(password.len(), 20);
299    }
300
301    #[test]
302    fn test_generate_password_no_symbols() {
303        let password = generate_password(50, true);
304        let symbols = "!@#$%^&*()_+-=[]{}|;:,.<>?";
305
306        for symbol in symbols.chars() {
307            assert!(
308                !password.contains(symbol),
309                "Password should not contain symbol: {}",
310                symbol
311            );
312        }
313    }
314
315    #[test]
316    fn test_generate_password_with_symbols() {
317        let password = generate_password(50, false);
318        let symbols = "!@#$%^&*()_+-=[]{}|;:,.<>?";
319
320        // At least one symbol should be present (due to our algorithm)
321        let has_symbol = password.chars().any(|c| symbols.contains(c));
322        assert!(has_symbol, "Password should contain at least one symbol");
323    }
324
325    #[test]
326    fn test_generate_password_contains_required_chars() {
327        let password = generate_password(50, false);
328
329        let has_lowercase = password.chars().any(|c| c.is_ascii_lowercase());
330        let has_uppercase = password.chars().any(|c| c.is_ascii_uppercase());
331        let has_digit = password.chars().any(|c| c.is_ascii_digit());
332
333        assert!(has_lowercase, "Password should contain lowercase letters");
334        assert!(has_uppercase, "Password should contain uppercase letters");
335        assert!(has_digit, "Password should contain digits");
336    }
337
338    #[test]
339    fn test_generate_password_uniqueness() {
340        let password1 = generate_password(20, false);
341        let password2 = generate_password(20, false);
342
343        // Very unlikely to generate the same password twice
344        assert_ne!(password1, password2);
345    }
346}
347
348#[cfg(test)]
349mod truncate_tests {
350    use super::*;
351
352    #[test]
353    fn truncate_chars_with_ellipsis_exact_boundary_keeps_value() {
354        let value = "a".repeat(20);
355        let truncated = truncate_chars_with_ellipsis(&value, 20);
356        assert_eq!(truncated, value);
357    }
358
359    #[test]
360    fn truncate_chars_with_ellipsis_appends_suffix_when_truncated() {
361        let value = "é".repeat(10);
362        let truncated = truncate_chars_with_ellipsis(&value, 5);
363        assert_eq!(truncated, "ééééé...");
364    }
365}
366
367/// Directory entry information for tree generation
368#[derive(Debug, Clone)]
369pub struct DirectoryEntry {
370    pub name: String,
371    pub path: String,
372    pub is_directory: bool,
373}
374
375/// Trait for abstracting file system operations for tree generation
376#[async_trait]
377pub trait FileSystemProvider {
378    type Error: std::fmt::Display;
379
380    /// List directory contents
381    async fn list_directory(&self, path: &str) -> Result<Vec<DirectoryEntry>, Self::Error>;
382}
383
384/// Generate a tree view of a directory structure using a generic file system provider
385pub async fn generate_directory_tree<P: FileSystemProvider>(
386    provider: &P,
387    path: &str,
388    prefix: &str,
389    max_depth: usize,
390    current_depth: usize,
391) -> Result<String, P::Error> {
392    let mut result = String::new();
393
394    if current_depth >= max_depth || current_depth >= 10 {
395        return Ok(result);
396    }
397
398    let entries = provider.list_directory(path).await?;
399    let mut file_entries = Vec::new();
400    let mut dir_entries = Vec::new();
401    for entry in entries.iter() {
402        if entry.is_directory {
403            if entry.name == "."
404                || entry.name == ".."
405                || entry.name == ".git"
406                || entry.name == "node_modules"
407            {
408                continue;
409            }
410            dir_entries.push(entry.clone());
411        } else {
412            file_entries.push(entry.clone());
413        }
414    }
415
416    dir_entries.sort_by(|a, b| a.name.cmp(&b.name));
417    file_entries.sort_by(|a, b| a.name.cmp(&b.name));
418
419    const MAX_ITEMS: usize = 5;
420    let total_items = dir_entries.len() + file_entries.len();
421    let should_limit = current_depth > 0 && total_items > MAX_ITEMS;
422
423    if should_limit {
424        if dir_entries.len() > MAX_ITEMS {
425            dir_entries.truncate(MAX_ITEMS);
426            file_entries.clear();
427        } else {
428            let remaining_items = MAX_ITEMS - dir_entries.len();
429            file_entries.truncate(remaining_items);
430        }
431    }
432
433    let mut dir_headers = Vec::new();
434    let mut dir_futures = Vec::new();
435    for (i, entry) in dir_entries.iter().enumerate() {
436        let is_last_dir = i == dir_entries.len() - 1;
437        let is_last_overall = is_last_dir && file_entries.is_empty() && !should_limit;
438        let current_prefix = if is_last_overall {
439            "└── "
440        } else {
441            "├── "
442        };
443        let next_prefix = format!(
444            "{}{}",
445            prefix,
446            if is_last_overall { "    " } else { "│   " }
447        );
448
449        let header = format!("{}{}{}/\n", prefix, current_prefix, entry.name);
450        dir_headers.push(header);
451
452        let entry_path = entry.path.clone();
453        let next_prefix_clone = next_prefix.clone();
454        let future = async move {
455            generate_directory_tree(
456                provider,
457                &entry_path,
458                &next_prefix_clone,
459                max_depth,
460                current_depth + 1,
461            )
462            .await
463        };
464        dir_futures.push(future);
465    }
466    if !dir_futures.is_empty() {
467        let subtree_results = futures::future::join_all(dir_futures).await;
468
469        for (i, header) in dir_headers.iter().enumerate() {
470            result.push_str(header);
471            if let Some(Ok(subtree)) = subtree_results.get(i) {
472                result.push_str(subtree);
473            }
474        }
475    }
476
477    for (i, entry) in file_entries.iter().enumerate() {
478        let is_last_file = i == file_entries.len() - 1;
479        let is_last_overall = is_last_file && !should_limit;
480        let current_prefix = if is_last_overall {
481            "└── "
482        } else {
483            "├── "
484        };
485        result.push_str(&format!("{}{}{}\n", prefix, current_prefix, entry.name));
486    }
487
488    if should_limit {
489        let remaining_count = total_items - MAX_ITEMS;
490        result.push_str(&format!(
491            "{}└── ... {} more item{}\n",
492            prefix,
493            remaining_count,
494            if remaining_count == 1 { "" } else { "s" }
495        ));
496    }
497
498    Ok(result)
499}
500
501/// Strip the MCP server prefix and any trailing "()" from a tool name.
502/// Example: "stakpak__run_command" -> "run_command"
503/// Example: "run_command" -> "run_command"
504/// Example: "str_replace()" -> "str_replace"
505pub fn strip_tool_name(name: &str) -> &str {
506    let mut result = name;
507
508    // Strip the MCP server prefix (e.g., "stakpak__")
509    if let Some((_, suffix)) = result.split_once("__") {
510        result = suffix;
511    }
512
513    // Strip trailing "()" if present
514    if let Some(stripped) = result.strip_suffix("()") {
515        result = stripped;
516    }
517
518    backward_compatibility_mapping(result)
519}
520
521/// Map legacy tool names to their current counterparts.
522/// Currently handles mapping "read_rulebook" to "load_skill".
523pub fn backward_compatibility_mapping(name: &str) -> &str {
524    match name {
525        "read_rulebook" | "read_rulebooks" => "load_skill",
526        _ => name,
527    }
528}
529
530/// Local file system provider implementation
531pub struct LocalFileSystemProvider;
532
533#[async_trait]
534impl FileSystemProvider for LocalFileSystemProvider {
535    type Error = std::io::Error;
536
537    async fn list_directory(&self, path: &str) -> Result<Vec<DirectoryEntry>, Self::Error> {
538        let entries = fs::read_dir(path)?;
539        let mut result = Vec::new();
540
541        for entry in entries {
542            let entry = entry?;
543            let file_name = entry.file_name().to_string_lossy().to_string();
544            let file_path = entry.path().to_string_lossy().to_string();
545            let is_directory = entry.file_type()?.is_dir();
546
547            result.push(DirectoryEntry {
548                name: file_name,
549                path: file_path,
550                is_directory,
551            });
552        }
553
554        Ok(result)
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use std::fs;
562    use std::io::Write;
563    use tempfile::TempDir;
564
565    #[test]
566    fn test_matches_gitignore_pattern_exact() {
567        assert!(matches_gitignore_pattern("node_modules", "node_modules"));
568        assert!(matches_gitignore_pattern(
569            "node_modules",
570            "node_modules/package.json"
571        ));
572        assert!(!matches_gitignore_pattern(
573            "node_modules",
574            "src/node_modules"
575        ));
576    }
577
578    #[test]
579    fn test_matches_gitignore_pattern_wildcard_prefix() {
580        assert!(matches_gitignore_pattern("*.log", "debug.log"));
581        assert!(matches_gitignore_pattern("*.log", "error.log"));
582        assert!(!matches_gitignore_pattern("*.log", "log.txt"));
583    }
584
585    #[test]
586    fn test_matches_gitignore_pattern_wildcard_suffix() {
587        assert!(matches_gitignore_pattern("temp*", "temp"));
588        assert!(matches_gitignore_pattern("temp*", "temp.txt"));
589        assert!(matches_gitignore_pattern("temp*", "temporary"));
590        assert!(!matches_gitignore_pattern("temp*", "mytemp"));
591    }
592
593    #[test]
594    fn test_matches_gitignore_pattern_wildcard_middle() {
595        assert!(matches_gitignore_pattern("*temp*", "temp"));
596        assert!(matches_gitignore_pattern("*temp*", "mytemp"));
597        assert!(matches_gitignore_pattern("*temp*", "temporary"));
598        assert!(matches_gitignore_pattern("*temp*", "mytemporary"));
599        assert!(!matches_gitignore_pattern("*temp*", "example"));
600    }
601
602    #[test]
603    fn test_pattern_matches_glob() {
604        assert!(pattern_matches_glob("test*.txt", "test.txt"));
605        assert!(pattern_matches_glob("test*.txt", "test123.txt"));
606        assert!(pattern_matches_glob("*test*.txt", "mytest.txt"));
607        assert!(pattern_matches_glob("*test*.txt", "mytestfile.txt"));
608        assert!(!pattern_matches_glob("test*.txt", "test.log"));
609        assert!(!pattern_matches_glob("*test*.txt", "example.txt"));
610    }
611
612    #[test]
613    fn test_read_gitignore_patterns() -> Result<(), Box<dyn std::error::Error>> {
614        let temp_dir = TempDir::new()?;
615        let temp_path = temp_dir.path();
616
617        // Create a .gitignore file
618        let gitignore_content = r#"
619# This is a comment
620node_modules
621*.log
622dist/
623.env
624
625# Another comment
626temp*
627"#;
628
629        let gitignore_path = temp_path.join(".gitignore");
630        let mut file = fs::File::create(&gitignore_path)?;
631        file.write_all(gitignore_content.as_bytes())?;
632
633        let patterns = read_gitignore_patterns(temp_path.to_str().unwrap());
634
635        // Should include .git by default
636        assert!(patterns.contains(&".git".to_string()));
637        assert!(patterns.contains(&"node_modules".to_string()));
638        assert!(patterns.contains(&"*.log".to_string()));
639        assert!(patterns.contains(&"dist/".to_string()));
640        assert!(patterns.contains(&".env".to_string()));
641        assert!(patterns.contains(&"temp*".to_string()));
642
643        // Should not include comments or empty lines
644        assert!(!patterns.iter().any(|p| p.starts_with('#')));
645        assert!(!patterns.contains(&"".to_string()));
646
647        Ok(())
648    }
649
650    #[test]
651    fn test_read_gitignore_patterns_no_file() {
652        let temp_dir = TempDir::new().unwrap();
653        let temp_path = temp_dir.path();
654
655        let patterns = read_gitignore_patterns(temp_path.to_str().unwrap());
656
657        // Should only contain .git when no .gitignore exists
658        assert_eq!(patterns, vec![".git".to_string()]);
659    }
660
661    #[test]
662    fn test_strip_tool_name() {
663        assert_eq!(strip_tool_name("stakpak__run_command"), "run_command");
664        assert_eq!(strip_tool_name("run_command"), "run_command");
665        assert_eq!(strip_tool_name("str_replace()"), "str_replace");
666        assert_eq!(strip_tool_name("stakpak__read_rulebook"), "load_skill");
667        assert_eq!(strip_tool_name("read_rulebook()"), "load_skill");
668        assert_eq!(strip_tool_name("read_rulebooks"), "load_skill");
669        // Additional edge cases
670        assert_eq!(strip_tool_name("just_name"), "just_name");
671        assert_eq!(strip_tool_name("prefix__name()"), "name");
672        assert_eq!(strip_tool_name("nested__prefix__tool"), "prefix__tool");
673        assert_eq!(strip_tool_name("empty_suffix()"), "empty_suffix");
674    }
675
676    #[test]
677    fn test_backward_compatibility_mapping() {
678        assert_eq!(
679            backward_compatibility_mapping("read_rulebook"),
680            "load_skill"
681        );
682        assert_eq!(
683            backward_compatibility_mapping("read_rulebooks"),
684            "load_skill"
685        );
686        assert_eq!(backward_compatibility_mapping("run_command"), "run_command");
687    }
688
689    #[test]
690    fn test_gitignore_integration() -> Result<(), Box<dyn std::error::Error>> {
691        let temp_dir = TempDir::new()?;
692        let temp_path = temp_dir.path();
693
694        // Create a .gitignore file
695        let gitignore_content = "node_modules\n*.log\ndist/\n";
696        let gitignore_path = temp_path.join(".gitignore");
697        let mut file = fs::File::create(&gitignore_path)?;
698        file.write_all(gitignore_content.as_bytes())?;
699
700        let patterns = read_gitignore_patterns(temp_path.to_str().unwrap());
701
702        // Test various paths
703        assert!(
704            patterns
705                .iter()
706                .any(|p| matches_gitignore_pattern(p, "node_modules"))
707        );
708        assert!(
709            patterns
710                .iter()
711                .any(|p| matches_gitignore_pattern(p, "node_modules/package.json"))
712        );
713        assert!(
714            patterns
715                .iter()
716                .any(|p| matches_gitignore_pattern(p, "debug.log"))
717        );
718        assert!(
719            patterns
720                .iter()
721                .any(|p| matches_gitignore_pattern(p, "dist/bundle.js"))
722        );
723        assert!(
724            patterns
725                .iter()
726                .any(|p| matches_gitignore_pattern(p, ".git"))
727        );
728
729        // These should not match
730        assert!(
731            !patterns
732                .iter()
733                .any(|p| matches_gitignore_pattern(p, "src/main.js"))
734        );
735        assert!(
736            !patterns
737                .iter()
738                .any(|p| matches_gitignore_pattern(p, "README.md"))
739        );
740
741        Ok(())
742    }
743}