signal_gateway_code_tool/
lib.rs

1//! Repository source code browsing via GitHub tarball downloads.
2//!
3//! This crate provides tools for browsing repository source code by downloading
4//! tarballs from GitHub and caching them in memory.
5
6mod cached;
7mod config;
8
9pub use config::{CodeToolConfig, GitHubRepo, Source};
10
11use cached::CachedTarball;
12
13use async_trait::async_trait;
14use globset::{Glob, GlobSet, GlobSetBuilder};
15use regex::Regex;
16use serde::Deserialize;
17use signal_gateway_assistant::{Tool, ToolExecutor, ToolResult};
18use std::{error::Error, fmt::Write, future::Future, pin::Pin, sync::Arc};
19use tokio::sync::{Mutex, MutexGuard};
20use tracing::{error, info, warn};
21
22/// Callback type for getting the current deployed git SHA.
23///
24/// This is an async callback that returns a future resolving to the SHA.
25pub type ShaCallback = Arc<
26    dyn Fn() -> Pin<Box<dyn Future<Output = Result<String, Box<dyn Error + Send + Sync>>> + Send>>
27        + Send
28        + Sync,
29>;
30
31/// Internal representation of the resolved source.
32enum ResolvedSource {
33    GitHub {
34        owner: String,
35        repo: String,
36        token: Option<String>,
37    },
38    File {
39        path: std::path::PathBuf,
40    },
41}
42
43/// Repository source code browser.
44///
45/// Downloads and caches GitHub tarballs for browsing repository source code.
46pub struct CodeTool {
47    config: CodeToolConfig,
48    source: ResolvedSource,
49    glob_filter: Option<GlobSet>,
50    summary: Option<Box<str>>,
51    get_sha: ShaCallback,
52    client: reqwest::Client,
53    cache: Mutex<Option<CachedTarball>>,
54}
55
56impl CodeTool {
57    /// Create a new CodeTool instance from configuration.
58    ///
59    /// The `get_sha` callback is called to determine which git SHA to download/load.
60    /// For GitHub sources, this is the commit SHA. For file sources, this can be
61    /// used to track file modification (e.g., mtime or a version string).
62    pub fn new(config: CodeToolConfig, get_sha: ShaCallback) -> Result<Self, std::io::Error> {
63        let source = match &config.source {
64            Source::GitHub { repo, token_file } => {
65                let token = token_file
66                    .as_ref()
67                    .map(|path| std::fs::read_to_string(path).map(|s| s.trim().to_string()))
68                    .transpose()?;
69                ResolvedSource::GitHub {
70                    owner: repo.owner.clone(),
71                    repo: repo.repo.clone(),
72                    token,
73                }
74            }
75            Source::File { path } => ResolvedSource::File { path: path.clone() },
76        };
77
78        // Compile glob patterns if any are specified
79        let glob_filter =
80            if config.glob.is_empty() {
81                None
82            } else {
83                let mut builder = GlobSetBuilder::new();
84                for pattern in &config.glob {
85                    let glob = Glob::new(pattern).map_err(|e| {
86                        std::io::Error::other(format!("invalid glob pattern '{}': {}", pattern, e))
87                    })?;
88                    builder.add(glob);
89                }
90                Some(builder.build().map_err(|e| {
91                    std::io::Error::other(format!("failed to build glob set: {}", e))
92                })?)
93            };
94
95        // Load summary from file if configured
96        let summary = config
97            .summary_file
98            .as_ref()
99            .map(|path| std::fs::read_to_string(path).map(|s| s.into_boxed_str()))
100            .transpose()?;
101
102        Ok(Self {
103            config,
104            source,
105            glob_filter,
106            summary,
107            get_sha,
108            client: reqwest::Client::new(),
109            cache: Mutex::new(None),
110        })
111    }
112
113    /// Get the repository name.
114    pub fn name(&self) -> &str {
115        &self.config.name
116    }
117
118    /// Check if this code tool has a summary available.
119    pub fn has_summary(&self) -> bool {
120        self.summary.is_some()
121    }
122
123    /// Get the summary of the codebase, if available.
124    pub fn summary(&self) -> Option<&str> {
125        self.summary.as_deref()
126    }
127
128    /// Get the current tarball, downloading or reading from file as needed.
129    ///
130    /// Returns a mutex guard containing the cached tarball. This method never fails;
131    /// instead it logs warnings and returns whatever is currently cached:
132    ///
133    /// - If the SHA callback fails, returns the existing cache (possibly stale or None)
134    /// - If the tarball download/read fails, returns the existing cache
135    /// - If tarball extraction fails, returns the existing cache
136    ///
137    /// This design assumes that stale code is better than no code, since most of the
138    /// codebase is likely unchanged between versions.
139    async fn get_current_tarball(&self) -> MutexGuard<'_, Option<CachedTarball>> {
140        let current_sha = match (self.get_sha)().await {
141            Ok(sha) => sha,
142            Err(e) => {
143                warn!("Failed to get current SHA for {}: {e}", self.config.name);
144                return self.cache.lock().await;
145            }
146        };
147
148        let mut cache = self.cache.lock().await;
149
150        // Check if we need to load/reload the tarball.
151        // For file sources, we only load once (no refresh after initial load).
152        // For GitHub sources, we reload when the SHA changes.
153        let needs_refresh = match (&*cache, &self.source) {
154            (None, _) => true,
155            (Some(_), ResolvedSource::File { .. }) => false,
156            (Some(cached), ResolvedSource::GitHub { .. }) => cached.sha != current_sha,
157        };
158
159        if needs_refresh {
160            info!(
161                "Loading tarball for {} at {}",
162                self.config.name, current_sha
163            );
164
165            let tarball = match self.load_tarball(&current_sha).await {
166                Ok(t) => t,
167                Err(e) => {
168                    error!("Failed to load tarball for {}: {e}", self.config.name);
169                    return cache;
170                }
171            };
172
173            match CachedTarball::extract(
174                current_sha,
175                &tarball,
176                self.glob_filter.as_ref(),
177                self.config.include_non_utf8,
178            ) {
179                Ok(cached_tarball) => {
180                    *cache = Some(cached_tarball);
181                }
182                Err(e) => {
183                    error!("Failed to extract tarball for {}: {e}", self.config.name);
184                }
185            };
186        }
187
188        cache
189    }
190
191    /// Load a tarball either from GitHub or from a local file.
192    async fn load_tarball(&self, sha: &str) -> Result<Vec<u8>, String> {
193        match &self.source {
194            ResolvedSource::GitHub { owner, repo, token } => {
195                self.download_tarball_from_github(owner, repo, token.as_deref(), sha)
196                    .await
197            }
198            ResolvedSource::File { path } => std::fs::read(path)
199                .map_err(|e| format!("Failed to read tarball from {}: {e}", path.display())),
200        }
201    }
202
203    /// Download a tarball from GitHub for the given SHA.
204    async fn download_tarball_from_github(
205        &self,
206        owner: &str,
207        repo: &str,
208        token: Option<&str>,
209        sha: &str,
210    ) -> Result<Vec<u8>, String> {
211        let url = format!(
212            "https://api.github.com/repos/{}/{}/tarball/{}",
213            owner, repo, sha
214        );
215
216        let mut request = self
217            .client
218            .get(&url)
219            .header("Accept", "application/vnd.github+json")
220            .header("User-Agent", "signal-gateway")
221            .header("X-GitHub-Api-Version", "2022-11-28");
222
223        if let Some(token) = token {
224            request = request.header("Authorization", format!("Bearer {}", token));
225        }
226
227        let response = request
228            .send()
229            .await
230            .map_err(|e| format!("HTTP request failed: {e}"))?;
231
232        if !response.status().is_success() {
233            return Err(format!(
234                "GitHub API error: {} {}",
235                response.status(),
236                response.text().await.unwrap_or_default()
237            ));
238        }
239
240        response
241            .bytes()
242            .await
243            .map(|b| b.to_vec())
244            .map_err(|e| format!("Failed to read response body: {e}"))
245    }
246
247    /// List files in a directory (like `ls`).
248    ///
249    /// If `path` is None or empty, lists the root directory.
250    pub async fn ls(&self, path: Option<&str>) -> Result<String, String> {
251        let cache = self.get_current_tarball().await;
252        let cached = cache.as_ref().ok_or("source code not available")?;
253
254        let prefix = path.unwrap_or("").trim_start_matches('/');
255        let prefix = if prefix.is_empty() {
256            String::new()
257        } else if prefix.ends_with('/') {
258            prefix.to_string()
259        } else {
260            format!("{}/", prefix)
261        };
262
263        let mut entries = std::collections::BTreeSet::new();
264
265        for file_path in cached.files.keys() {
266            if prefix.is_empty() || file_path.starts_with(&prefix) {
267                // Get the part after the prefix
268                let remainder = if prefix.is_empty() {
269                    file_path.as_str()
270                } else {
271                    &file_path[prefix.len()..]
272                };
273
274                // Get just the first component (file or directory name)
275                if let Some(first) = remainder.split('/').next()
276                    && !first.is_empty()
277                {
278                    // Check if it's a directory (has more components)
279                    let is_dir = remainder.contains('/');
280                    let entry = if is_dir {
281                        format!("{}/", first)
282                    } else {
283                        first.to_string()
284                    };
285                    entries.insert(entry);
286                }
287            }
288        }
289
290        if entries.is_empty() {
291            Ok(format!(
292                "No files found in '{}'",
293                prefix.trim_end_matches('/')
294            ))
295        } else {
296            Ok(entries.into_iter().collect::<Vec<_>>().join("\n"))
297        }
298    }
299
300    /// Find files matching a glob pattern (like `find`).
301    ///
302    /// Supports glob patterns using the `globset` crate syntax.
303    pub async fn find(&self, pattern: Option<&str>) -> Result<String, String> {
304        let cache = self.get_current_tarball().await;
305        let cached = cache.as_ref().ok_or("source code not available")?;
306
307        let pattern = pattern.unwrap_or("*");
308
309        let glob = Glob::new(pattern)
310            .map_err(|e| format!("Invalid glob pattern: {e}"))?
311            .compile_matcher();
312
313        let matches: Vec<&str> = cached
314            .files
315            .keys()
316            .filter(|path| glob.is_match(path))
317            .map(|s| s.as_str())
318            .collect();
319
320        if matches.is_empty() {
321            Ok(format!("No files matching '{}'", pattern))
322        } else {
323            Ok(matches.join("\n"))
324        }
325    }
326
327    /// Read a file's contents.
328    ///
329    /// If `line_range` is provided, only returns those lines (1-indexed, inclusive).
330    pub async fn read(
331        &self,
332        path: &str,
333        line_start: Option<usize>,
334        line_end: Option<usize>,
335    ) -> Result<String, String> {
336        let cache = self.get_current_tarball().await;
337        let cached = cache.as_ref().ok_or("source code not available")?;
338
339        let path = path.trim_start_matches('/');
340        let file = cached
341            .files
342            .get(path)
343            .ok_or_else(|| format!("File not found: {}", path))?;
344
345        let total_lines = file.line_count();
346        let start = line_start.unwrap_or(1) as u32;
347        let end = line_end.map(|e| e as u32);
348
349        if start > total_lines {
350            return Ok(format!(
351                "Line {} is past end of file ({} lines)",
352                start, total_lines
353            ));
354        }
355
356        let mut output = String::new();
357        for (i, line) in file.line_range(Some(start), end).enumerate() {
358            writeln!(&mut output, "{:>6}\t{}", start as usize + i, line)
359                .map_err(|e| format!("Format error: {e}"))?;
360        }
361
362        Ok(output)
363    }
364
365    /// Search for a regex pattern in all files.
366    ///
367    /// - `pattern`: The regex pattern to search for.
368    /// - `context`: Number of context lines to show (like `grep -C`).
369    /// - `path_prefix`: Optional path prefix to limit search scope.
370    pub async fn search(
371        &self,
372        pattern: &str,
373        context: u32,
374        path_prefix: Option<&str>,
375    ) -> Result<String, String> {
376        let regex = Regex::new(pattern).map_err(|e| format!("Invalid regex: {e}"))?;
377
378        let cache = self.get_current_tarball().await;
379        let cached = cache.as_ref().ok_or("source code not available")?;
380
381        let prefix = path_prefix.map(|p| p.trim_start_matches('/'));
382
383        let mut output = String::new();
384        let mut match_count = 0;
385        let mut file_count = 0;
386        const MAX_MATCHES: usize = 100;
387
388        'outer: for (path, file) in &cached.files {
389            // Skip if path doesn't match prefix
390            if let Some(prefix) = prefix
391                && !path.starts_with(prefix)
392            {
393                continue;
394            }
395
396            // Skip binary-looking files
397            if file.is_binary() {
398                continue;
399            }
400
401            // Find all matches and map byte positions to line numbers
402            let content = file.as_str();
403            let mut file_matches: Vec<u32> = Vec::new();
404
405            for m in regex.find_iter(content) {
406                let line_num = file.idx_to_line(m.start());
407                // Deduplicate: only add if this line isn't already recorded
408                if file_matches.last() != Some(&line_num) {
409                    file_matches.push(line_num);
410                    match_count += 1;
411                    if match_count >= MAX_MATCHES {
412                        break 'outer;
413                    }
414                }
415            }
416
417            if !file_matches.is_empty() {
418                file_count += 1;
419                let total_lines = file.line_count();
420
421                if context == 0 {
422                    // No context, just print matches
423                    for &line_num in &file_matches {
424                        let line = file
425                            .line_range(Some(line_num), Some(line_num))
426                            .next()
427                            .unwrap_or("");
428                        writeln!(&mut output, "{}:{}: {}", path, line_num, line)
429                            .map_err(|e| format!("Format error: {e}"))?;
430                    }
431                } else {
432                    // Print with context
433                    writeln!(&mut output, "=== {} ===", path)
434                        .map_err(|e| format!("Format error: {e}"))?;
435
436                    let mut printed = std::collections::BTreeSet::new();
437
438                    for &match_line in &file_matches {
439                        let start = match_line.saturating_sub(context).max(1);
440                        let end = (match_line + context).min(total_lines);
441
442                        // Add separator if there's a gap
443                        if let Some(&last) = printed.iter().next_back()
444                            && start > last + 1
445                        {
446                            writeln!(&mut output, "---")
447                                .map_err(|e| format!("Format error: {e}"))?;
448                        }
449
450                        for (i, line) in file.line_range(Some(start), Some(end)).enumerate() {
451                            let line_num = start + i as u32;
452                            if printed.insert(line_num) {
453                                let marker = if line_num == match_line { ">" } else { " " };
454                                writeln!(&mut output, "{}{:>5}\t{}", marker, line_num, line)
455                                    .map_err(|e| format!("Format error: {e}"))?;
456                            }
457                        }
458                    }
459                    writeln!(&mut output).map_err(|e| format!("Format error: {e}"))?;
460                }
461            }
462        }
463
464        if match_count == 0 {
465            Ok(format!("No matches for '{}'", pattern))
466        } else {
467            let truncated = if match_count >= MAX_MATCHES {
468                format!(" (truncated at {} matches)", MAX_MATCHES)
469            } else {
470                String::new()
471            };
472            Ok(format!(
473                "{}\n[{} matches in {} files{}]",
474                output.trim_end(),
475                match_count,
476                file_count,
477                truncated
478            ))
479        }
480    }
481}
482
483/// Tool executor for multiple repository code browsers.
484pub struct CodeToolTools {
485    repos: Vec<CodeTool>,
486}
487
488impl CodeToolTools {
489    /// Create a new CodeToolTools instance.
490    pub fn new(repos: Vec<CodeTool>) -> Self {
491        Self { repos }
492    }
493
494    /// Find a repo by name.
495    fn find_repo(&self, name: &str) -> Option<&CodeTool> {
496        self.repos.iter().find(|repo| repo.name() == name)
497    }
498
499    /// Get list of repo names for error messages.
500    fn repo_names(&self) -> String {
501        self.repos
502            .iter()
503            .map(|r| r.name())
504            .collect::<Vec<_>>()
505            .join(", ")
506    }
507
508    /// Get list of repo names that have summaries available.
509    fn repos_with_summaries(&self) -> Vec<&str> {
510        self.repos
511            .iter()
512            .filter(|r| r.has_summary())
513            .map(|r| r.name())
514            .collect()
515    }
516}
517
518#[derive(Deserialize)]
519struct LsInput {
520    repo: String,
521    path: Option<String>,
522}
523
524#[derive(Deserialize)]
525struct FindInput {
526    repo: String,
527    pattern: Option<String>,
528}
529
530#[derive(Deserialize)]
531struct ReadInput {
532    repo: String,
533    path: String,
534    line_start: Option<usize>,
535    line_end: Option<usize>,
536}
537
538#[derive(Deserialize)]
539struct SearchInput {
540    repo: String,
541    pattern: String,
542    #[serde(default)]
543    context: u32,
544    path_prefix: Option<String>,
545}
546
547#[derive(Deserialize)]
548struct SummaryInput {
549    repo: String,
550}
551
552#[async_trait]
553impl ToolExecutor for CodeToolTools {
554    fn tools(&self) -> Vec<Tool> {
555        let mut tools = vec![
556            Tool {
557                name: "code_ls",
558                description: "List files in a directory of a repository's source code.",
559                input_schema: serde_json::json!({
560                    "type": "object",
561                    "properties": {
562                        "repo": {
563                            "type": "string",
564                            "description": "Name of the repository"
565                        },
566                        "path": {
567                            "type": "string",
568                            "description": "Directory path to list (optional, defaults to root)"
569                        }
570                    },
571                    "required": ["repo"]
572                }),
573            },
574            Tool {
575                name: "code_find",
576                description: "Find files matching a glob pattern in a repository's source code.",
577                input_schema: serde_json::json!({
578                    "type": "object",
579                    "properties": {
580                        "repo": {
581                            "type": "string",
582                            "description": "Name of the repository"
583                        },
584                        "pattern": {
585                            "type": "string",
586                            "description": "Glob pattern to match (e.g., '*.rs', 'src/*.py')"
587                        }
588                    },
589                    "required": ["repo"]
590                }),
591            },
592            Tool {
593                name: "code_read",
594                description: "Read a file from a repository's source code.",
595                input_schema: serde_json::json!({
596                    "type": "object",
597                    "properties": {
598                        "repo": {
599                            "type": "string",
600                            "description": "Name of the repository"
601                        },
602                        "path": {
603                            "type": "string",
604                            "description": "Path to the file to read"
605                        },
606                        "line_start": {
607                            "type": "integer",
608                            "description": "Starting line number (1-indexed, optional)"
609                        },
610                        "line_end": {
611                            "type": "integer",
612                            "description": "Ending line number (inclusive, optional)"
613                        }
614                    },
615                    "required": ["repo", "path"]
616                }),
617            },
618            Tool {
619                name: "code_search",
620                description: "Search for a regex pattern in a repository's source code (like grep).",
621                input_schema: serde_json::json!({
622                    "type": "object",
623                    "properties": {
624                        "repo": {
625                            "type": "string",
626                            "description": "Name of the repository"
627                        },
628                        "pattern": {
629                            "type": "string",
630                            "description": "Regex pattern to search for"
631                        },
632                        "context": {
633                            "type": "integer",
634                            "description": "Number of context lines to show (like grep -C, default 0)"
635                        },
636                        "path_prefix": {
637                            "type": "string",
638                            "description": "Optional path prefix to limit search scope"
639                        }
640                    },
641                    "required": ["repo", "pattern"]
642                }),
643            },
644        ];
645
646        // Only include summary tool if at least one repo has a summary
647        let repos_with_summaries = self.repos_with_summaries();
648        if !repos_with_summaries.is_empty() {
649            tools.push(Tool {
650                name: "code_summary",
651                description: "Get a summary/overview of a repository's codebase.",
652                input_schema: serde_json::json!({
653                    "type": "object",
654                    "properties": {
655                        "repo": {
656                            "type": "string",
657                            "description": format!("Name of the repository. Repos with summaries: {}", repos_with_summaries.join(", "))
658                        }
659                    },
660                    "required": ["repo"]
661                }),
662            });
663        }
664
665        tools
666    }
667
668    fn has_tool(&self, name: &str) -> bool {
669        matches!(
670            name,
671            "code_ls" | "code_find" | "code_read" | "code_search" | "code_summary"
672        )
673    }
674
675    async fn execute(&self, name: &str, input: &serde_json::Value) -> Result<ToolResult, String> {
676        match name {
677            "code_ls" => {
678                let input: LsInput = serde_json::from_value(input.clone())
679                    .map_err(|e| format!("Invalid input: {e}"))?;
680                let repo = self.find_repo(&input.repo).ok_or_else(|| {
681                    format!(
682                        "Unknown repo '{}'. Available: {}",
683                        input.repo,
684                        self.repo_names()
685                    )
686                })?;
687                let result = repo.ls(input.path.as_deref()).await?;
688                Ok(ToolResult::new(result))
689            }
690            "code_find" => {
691                let input: FindInput = serde_json::from_value(input.clone())
692                    .map_err(|e| format!("Invalid input: {e}"))?;
693                let repo = self.find_repo(&input.repo).ok_or_else(|| {
694                    format!(
695                        "Unknown repo '{}'. Available: {}",
696                        input.repo,
697                        self.repo_names()
698                    )
699                })?;
700                let result = repo.find(input.pattern.as_deref()).await?;
701                Ok(ToolResult::new(result))
702            }
703            "code_read" => {
704                let input: ReadInput = serde_json::from_value(input.clone())
705                    .map_err(|e| format!("Invalid input: {e}"))?;
706                let repo = self.find_repo(&input.repo).ok_or_else(|| {
707                    format!(
708                        "Unknown repo '{}'. Available: {}",
709                        input.repo,
710                        self.repo_names()
711                    )
712                })?;
713                let result = repo
714                    .read(&input.path, input.line_start, input.line_end)
715                    .await?;
716                Ok(ToolResult::new(result))
717            }
718            "code_search" => {
719                let input: SearchInput = serde_json::from_value(input.clone())
720                    .map_err(|e| format!("Invalid input: {e}"))?;
721                let repo = self.find_repo(&input.repo).ok_or_else(|| {
722                    format!(
723                        "Unknown repo '{}'. Available: {}",
724                        input.repo,
725                        self.repo_names()
726                    )
727                })?;
728                let result = repo
729                    .search(&input.pattern, input.context, input.path_prefix.as_deref())
730                    .await?;
731                Ok(ToolResult::new(result))
732            }
733            "code_summary" => {
734                let input: SummaryInput = serde_json::from_value(input.clone())
735                    .map_err(|e| format!("Invalid input: {e}"))?;
736                let repo = self.find_repo(&input.repo).ok_or_else(|| {
737                    format!(
738                        "Unknown repo '{}'. Available: {}",
739                        input.repo,
740                        self.repo_names()
741                    )
742                })?;
743                let summary = repo.summary().ok_or_else(|| {
744                    format!(
745                        "No summary available for '{}'. Repos with summaries: {}",
746                        input.repo,
747                        self.repos_with_summaries().join(", ")
748                    )
749                })?;
750                Ok(ToolResult::new(summary.to_string()))
751            }
752            _ => Err(format!("Unknown tool: {name}")),
753        }
754    }
755}