1mod cached;
7mod config;
8
9pub use config::{CodeToolConfig, GitHubRepo, Source};
10
11use cached::CachedTarball;
12
13use async_trait::async_trait;
14use globset::{Glob, GlobSet, GlobSetBuilder};
15use regex::Regex;
16use serde::Deserialize;
17use signal_gateway_assistant::{Tool, ToolExecutor, ToolResult};
18use std::{error::Error, fmt::Write, future::Future, pin::Pin, sync::Arc};
19use tokio::sync::{Mutex, MutexGuard};
20use tracing::{error, info, warn};
21
22pub type ShaCallback = Arc<
26 dyn Fn() -> Pin<Box<dyn Future<Output = Result<String, Box<dyn Error + Send + Sync>>> + Send>>
27 + Send
28 + Sync,
29>;
30
31enum ResolvedSource {
33 GitHub {
34 owner: String,
35 repo: String,
36 token: Option<String>,
37 },
38 File {
39 path: std::path::PathBuf,
40 },
41}
42
43pub struct CodeTool {
47 config: CodeToolConfig,
48 source: ResolvedSource,
49 glob_filter: Option<GlobSet>,
50 summary: Option<Box<str>>,
51 get_sha: ShaCallback,
52 client: reqwest::Client,
53 cache: Mutex<Option<CachedTarball>>,
54}
55
56impl CodeTool {
57 pub fn new(config: CodeToolConfig, get_sha: ShaCallback) -> Result<Self, std::io::Error> {
63 let source = match &config.source {
64 Source::GitHub { repo, token_file } => {
65 let token = token_file
66 .as_ref()
67 .map(|path| std::fs::read_to_string(path).map(|s| s.trim().to_string()))
68 .transpose()?;
69 ResolvedSource::GitHub {
70 owner: repo.owner.clone(),
71 repo: repo.repo.clone(),
72 token,
73 }
74 }
75 Source::File { path } => ResolvedSource::File { path: path.clone() },
76 };
77
78 let glob_filter =
80 if config.glob.is_empty() {
81 None
82 } else {
83 let mut builder = GlobSetBuilder::new();
84 for pattern in &config.glob {
85 let glob = Glob::new(pattern).map_err(|e| {
86 std::io::Error::other(format!("invalid glob pattern '{}': {}", pattern, e))
87 })?;
88 builder.add(glob);
89 }
90 Some(builder.build().map_err(|e| {
91 std::io::Error::other(format!("failed to build glob set: {}", e))
92 })?)
93 };
94
95 let summary = config
97 .summary_file
98 .as_ref()
99 .map(|path| std::fs::read_to_string(path).map(|s| s.into_boxed_str()))
100 .transpose()?;
101
102 Ok(Self {
103 config,
104 source,
105 glob_filter,
106 summary,
107 get_sha,
108 client: reqwest::Client::new(),
109 cache: Mutex::new(None),
110 })
111 }
112
113 pub fn name(&self) -> &str {
115 &self.config.name
116 }
117
118 pub fn has_summary(&self) -> bool {
120 self.summary.is_some()
121 }
122
123 pub fn summary(&self) -> Option<&str> {
125 self.summary.as_deref()
126 }
127
128 async fn get_current_tarball(&self) -> MutexGuard<'_, Option<CachedTarball>> {
140 let current_sha = match (self.get_sha)().await {
141 Ok(sha) => sha,
142 Err(e) => {
143 warn!("Failed to get current SHA for {}: {e}", self.config.name);
144 return self.cache.lock().await;
145 }
146 };
147
148 let mut cache = self.cache.lock().await;
149
150 let needs_refresh = match (&*cache, &self.source) {
154 (None, _) => true,
155 (Some(_), ResolvedSource::File { .. }) => false,
156 (Some(cached), ResolvedSource::GitHub { .. }) => cached.sha != current_sha,
157 };
158
159 if needs_refresh {
160 info!(
161 "Loading tarball for {} at {}",
162 self.config.name, current_sha
163 );
164
165 let tarball = match self.load_tarball(¤t_sha).await {
166 Ok(t) => t,
167 Err(e) => {
168 error!("Failed to load tarball for {}: {e}", self.config.name);
169 return cache;
170 }
171 };
172
173 match CachedTarball::extract(
174 current_sha,
175 &tarball,
176 self.glob_filter.as_ref(),
177 self.config.include_non_utf8,
178 ) {
179 Ok(cached_tarball) => {
180 *cache = Some(cached_tarball);
181 }
182 Err(e) => {
183 error!("Failed to extract tarball for {}: {e}", self.config.name);
184 }
185 };
186 }
187
188 cache
189 }
190
191 async fn load_tarball(&self, sha: &str) -> Result<Vec<u8>, String> {
193 match &self.source {
194 ResolvedSource::GitHub { owner, repo, token } => {
195 self.download_tarball_from_github(owner, repo, token.as_deref(), sha)
196 .await
197 }
198 ResolvedSource::File { path } => std::fs::read(path)
199 .map_err(|e| format!("Failed to read tarball from {}: {e}", path.display())),
200 }
201 }
202
203 async fn download_tarball_from_github(
205 &self,
206 owner: &str,
207 repo: &str,
208 token: Option<&str>,
209 sha: &str,
210 ) -> Result<Vec<u8>, String> {
211 let url = format!(
212 "https://api.github.com/repos/{}/{}/tarball/{}",
213 owner, repo, sha
214 );
215
216 let mut request = self
217 .client
218 .get(&url)
219 .header("Accept", "application/vnd.github+json")
220 .header("User-Agent", "signal-gateway")
221 .header("X-GitHub-Api-Version", "2022-11-28");
222
223 if let Some(token) = token {
224 request = request.header("Authorization", format!("Bearer {}", token));
225 }
226
227 let response = request
228 .send()
229 .await
230 .map_err(|e| format!("HTTP request failed: {e}"))?;
231
232 if !response.status().is_success() {
233 return Err(format!(
234 "GitHub API error: {} {}",
235 response.status(),
236 response.text().await.unwrap_or_default()
237 ));
238 }
239
240 response
241 .bytes()
242 .await
243 .map(|b| b.to_vec())
244 .map_err(|e| format!("Failed to read response body: {e}"))
245 }
246
247 pub async fn ls(&self, path: Option<&str>) -> Result<String, String> {
251 let cache = self.get_current_tarball().await;
252 let cached = cache.as_ref().ok_or("source code not available")?;
253
254 let prefix = path.unwrap_or("").trim_start_matches('/');
255 let prefix = if prefix.is_empty() {
256 String::new()
257 } else if prefix.ends_with('/') {
258 prefix.to_string()
259 } else {
260 format!("{}/", prefix)
261 };
262
263 let mut entries = std::collections::BTreeSet::new();
264
265 for file_path in cached.files.keys() {
266 if prefix.is_empty() || file_path.starts_with(&prefix) {
267 let remainder = if prefix.is_empty() {
269 file_path.as_str()
270 } else {
271 &file_path[prefix.len()..]
272 };
273
274 if let Some(first) = remainder.split('/').next()
276 && !first.is_empty()
277 {
278 let is_dir = remainder.contains('/');
280 let entry = if is_dir {
281 format!("{}/", first)
282 } else {
283 first.to_string()
284 };
285 entries.insert(entry);
286 }
287 }
288 }
289
290 if entries.is_empty() {
291 Ok(format!(
292 "No files found in '{}'",
293 prefix.trim_end_matches('/')
294 ))
295 } else {
296 Ok(entries.into_iter().collect::<Vec<_>>().join("\n"))
297 }
298 }
299
300 pub async fn find(&self, pattern: Option<&str>) -> Result<String, String> {
304 let cache = self.get_current_tarball().await;
305 let cached = cache.as_ref().ok_or("source code not available")?;
306
307 let pattern = pattern.unwrap_or("*");
308
309 let glob = Glob::new(pattern)
310 .map_err(|e| format!("Invalid glob pattern: {e}"))?
311 .compile_matcher();
312
313 let matches: Vec<&str> = cached
314 .files
315 .keys()
316 .filter(|path| glob.is_match(path))
317 .map(|s| s.as_str())
318 .collect();
319
320 if matches.is_empty() {
321 Ok(format!("No files matching '{}'", pattern))
322 } else {
323 Ok(matches.join("\n"))
324 }
325 }
326
327 pub async fn read(
331 &self,
332 path: &str,
333 line_start: Option<usize>,
334 line_end: Option<usize>,
335 ) -> Result<String, String> {
336 let cache = self.get_current_tarball().await;
337 let cached = cache.as_ref().ok_or("source code not available")?;
338
339 let path = path.trim_start_matches('/');
340 let file = cached
341 .files
342 .get(path)
343 .ok_or_else(|| format!("File not found: {}", path))?;
344
345 let total_lines = file.line_count();
346 let start = line_start.unwrap_or(1) as u32;
347 let end = line_end.map(|e| e as u32);
348
349 if start > total_lines {
350 return Ok(format!(
351 "Line {} is past end of file ({} lines)",
352 start, total_lines
353 ));
354 }
355
356 let mut output = String::new();
357 for (i, line) in file.line_range(Some(start), end).enumerate() {
358 writeln!(&mut output, "{:>6}\t{}", start as usize + i, line)
359 .map_err(|e| format!("Format error: {e}"))?;
360 }
361
362 Ok(output)
363 }
364
365 pub async fn search(
371 &self,
372 pattern: &str,
373 context: u32,
374 path_prefix: Option<&str>,
375 ) -> Result<String, String> {
376 let regex = Regex::new(pattern).map_err(|e| format!("Invalid regex: {e}"))?;
377
378 let cache = self.get_current_tarball().await;
379 let cached = cache.as_ref().ok_or("source code not available")?;
380
381 let prefix = path_prefix.map(|p| p.trim_start_matches('/'));
382
383 let mut output = String::new();
384 let mut match_count = 0;
385 let mut file_count = 0;
386 const MAX_MATCHES: usize = 100;
387
388 'outer: for (path, file) in &cached.files {
389 if let Some(prefix) = prefix
391 && !path.starts_with(prefix)
392 {
393 continue;
394 }
395
396 if file.is_binary() {
398 continue;
399 }
400
401 let content = file.as_str();
403 let mut file_matches: Vec<u32> = Vec::new();
404
405 for m in regex.find_iter(content) {
406 let line_num = file.idx_to_line(m.start());
407 if file_matches.last() != Some(&line_num) {
409 file_matches.push(line_num);
410 match_count += 1;
411 if match_count >= MAX_MATCHES {
412 break 'outer;
413 }
414 }
415 }
416
417 if !file_matches.is_empty() {
418 file_count += 1;
419 let total_lines = file.line_count();
420
421 if context == 0 {
422 for &line_num in &file_matches {
424 let line = file
425 .line_range(Some(line_num), Some(line_num))
426 .next()
427 .unwrap_or("");
428 writeln!(&mut output, "{}:{}: {}", path, line_num, line)
429 .map_err(|e| format!("Format error: {e}"))?;
430 }
431 } else {
432 writeln!(&mut output, "=== {} ===", path)
434 .map_err(|e| format!("Format error: {e}"))?;
435
436 let mut printed = std::collections::BTreeSet::new();
437
438 for &match_line in &file_matches {
439 let start = match_line.saturating_sub(context).max(1);
440 let end = (match_line + context).min(total_lines);
441
442 if let Some(&last) = printed.iter().next_back()
444 && start > last + 1
445 {
446 writeln!(&mut output, "---")
447 .map_err(|e| format!("Format error: {e}"))?;
448 }
449
450 for (i, line) in file.line_range(Some(start), Some(end)).enumerate() {
451 let line_num = start + i as u32;
452 if printed.insert(line_num) {
453 let marker = if line_num == match_line { ">" } else { " " };
454 writeln!(&mut output, "{}{:>5}\t{}", marker, line_num, line)
455 .map_err(|e| format!("Format error: {e}"))?;
456 }
457 }
458 }
459 writeln!(&mut output).map_err(|e| format!("Format error: {e}"))?;
460 }
461 }
462 }
463
464 if match_count == 0 {
465 Ok(format!("No matches for '{}'", pattern))
466 } else {
467 let truncated = if match_count >= MAX_MATCHES {
468 format!(" (truncated at {} matches)", MAX_MATCHES)
469 } else {
470 String::new()
471 };
472 Ok(format!(
473 "{}\n[{} matches in {} files{}]",
474 output.trim_end(),
475 match_count,
476 file_count,
477 truncated
478 ))
479 }
480 }
481}
482
483pub struct CodeToolTools {
485 repos: Vec<CodeTool>,
486}
487
488impl CodeToolTools {
489 pub fn new(repos: Vec<CodeTool>) -> Self {
491 Self { repos }
492 }
493
494 fn find_repo(&self, name: &str) -> Option<&CodeTool> {
496 self.repos.iter().find(|repo| repo.name() == name)
497 }
498
499 fn repo_names(&self) -> String {
501 self.repos
502 .iter()
503 .map(|r| r.name())
504 .collect::<Vec<_>>()
505 .join(", ")
506 }
507
508 fn repos_with_summaries(&self) -> Vec<&str> {
510 self.repos
511 .iter()
512 .filter(|r| r.has_summary())
513 .map(|r| r.name())
514 .collect()
515 }
516}
517
518#[derive(Deserialize)]
519struct LsInput {
520 repo: String,
521 path: Option<String>,
522}
523
524#[derive(Deserialize)]
525struct FindInput {
526 repo: String,
527 pattern: Option<String>,
528}
529
530#[derive(Deserialize)]
531struct ReadInput {
532 repo: String,
533 path: String,
534 line_start: Option<usize>,
535 line_end: Option<usize>,
536}
537
538#[derive(Deserialize)]
539struct SearchInput {
540 repo: String,
541 pattern: String,
542 #[serde(default)]
543 context: u32,
544 path_prefix: Option<String>,
545}
546
547#[derive(Deserialize)]
548struct SummaryInput {
549 repo: String,
550}
551
552#[async_trait]
553impl ToolExecutor for CodeToolTools {
554 fn tools(&self) -> Vec<Tool> {
555 let mut tools = vec![
556 Tool {
557 name: "code_ls",
558 description: "List files in a directory of a repository's source code.",
559 input_schema: serde_json::json!({
560 "type": "object",
561 "properties": {
562 "repo": {
563 "type": "string",
564 "description": "Name of the repository"
565 },
566 "path": {
567 "type": "string",
568 "description": "Directory path to list (optional, defaults to root)"
569 }
570 },
571 "required": ["repo"]
572 }),
573 },
574 Tool {
575 name: "code_find",
576 description: "Find files matching a glob pattern in a repository's source code.",
577 input_schema: serde_json::json!({
578 "type": "object",
579 "properties": {
580 "repo": {
581 "type": "string",
582 "description": "Name of the repository"
583 },
584 "pattern": {
585 "type": "string",
586 "description": "Glob pattern to match (e.g., '*.rs', 'src/*.py')"
587 }
588 },
589 "required": ["repo"]
590 }),
591 },
592 Tool {
593 name: "code_read",
594 description: "Read a file from a repository's source code.",
595 input_schema: serde_json::json!({
596 "type": "object",
597 "properties": {
598 "repo": {
599 "type": "string",
600 "description": "Name of the repository"
601 },
602 "path": {
603 "type": "string",
604 "description": "Path to the file to read"
605 },
606 "line_start": {
607 "type": "integer",
608 "description": "Starting line number (1-indexed, optional)"
609 },
610 "line_end": {
611 "type": "integer",
612 "description": "Ending line number (inclusive, optional)"
613 }
614 },
615 "required": ["repo", "path"]
616 }),
617 },
618 Tool {
619 name: "code_search",
620 description: "Search for a regex pattern in a repository's source code (like grep).",
621 input_schema: serde_json::json!({
622 "type": "object",
623 "properties": {
624 "repo": {
625 "type": "string",
626 "description": "Name of the repository"
627 },
628 "pattern": {
629 "type": "string",
630 "description": "Regex pattern to search for"
631 },
632 "context": {
633 "type": "integer",
634 "description": "Number of context lines to show (like grep -C, default 0)"
635 },
636 "path_prefix": {
637 "type": "string",
638 "description": "Optional path prefix to limit search scope"
639 }
640 },
641 "required": ["repo", "pattern"]
642 }),
643 },
644 ];
645
646 let repos_with_summaries = self.repos_with_summaries();
648 if !repos_with_summaries.is_empty() {
649 tools.push(Tool {
650 name: "code_summary",
651 description: "Get a summary/overview of a repository's codebase.",
652 input_schema: serde_json::json!({
653 "type": "object",
654 "properties": {
655 "repo": {
656 "type": "string",
657 "description": format!("Name of the repository. Repos with summaries: {}", repos_with_summaries.join(", "))
658 }
659 },
660 "required": ["repo"]
661 }),
662 });
663 }
664
665 tools
666 }
667
668 fn has_tool(&self, name: &str) -> bool {
669 matches!(
670 name,
671 "code_ls" | "code_find" | "code_read" | "code_search" | "code_summary"
672 )
673 }
674
675 async fn execute(&self, name: &str, input: &serde_json::Value) -> Result<ToolResult, String> {
676 match name {
677 "code_ls" => {
678 let input: LsInput = serde_json::from_value(input.clone())
679 .map_err(|e| format!("Invalid input: {e}"))?;
680 let repo = self.find_repo(&input.repo).ok_or_else(|| {
681 format!(
682 "Unknown repo '{}'. Available: {}",
683 input.repo,
684 self.repo_names()
685 )
686 })?;
687 let result = repo.ls(input.path.as_deref()).await?;
688 Ok(ToolResult::new(result))
689 }
690 "code_find" => {
691 let input: FindInput = serde_json::from_value(input.clone())
692 .map_err(|e| format!("Invalid input: {e}"))?;
693 let repo = self.find_repo(&input.repo).ok_or_else(|| {
694 format!(
695 "Unknown repo '{}'. Available: {}",
696 input.repo,
697 self.repo_names()
698 )
699 })?;
700 let result = repo.find(input.pattern.as_deref()).await?;
701 Ok(ToolResult::new(result))
702 }
703 "code_read" => {
704 let input: ReadInput = serde_json::from_value(input.clone())
705 .map_err(|e| format!("Invalid input: {e}"))?;
706 let repo = self.find_repo(&input.repo).ok_or_else(|| {
707 format!(
708 "Unknown repo '{}'. Available: {}",
709 input.repo,
710 self.repo_names()
711 )
712 })?;
713 let result = repo
714 .read(&input.path, input.line_start, input.line_end)
715 .await?;
716 Ok(ToolResult::new(result))
717 }
718 "code_search" => {
719 let input: SearchInput = serde_json::from_value(input.clone())
720 .map_err(|e| format!("Invalid input: {e}"))?;
721 let repo = self.find_repo(&input.repo).ok_or_else(|| {
722 format!(
723 "Unknown repo '{}'. Available: {}",
724 input.repo,
725 self.repo_names()
726 )
727 })?;
728 let result = repo
729 .search(&input.pattern, input.context, input.path_prefix.as_deref())
730 .await?;
731 Ok(ToolResult::new(result))
732 }
733 "code_summary" => {
734 let input: SummaryInput = serde_json::from_value(input.clone())
735 .map_err(|e| format!("Invalid input: {e}"))?;
736 let repo = self.find_repo(&input.repo).ok_or_else(|| {
737 format!(
738 "Unknown repo '{}'. Available: {}",
739 input.repo,
740 self.repo_names()
741 )
742 })?;
743 let summary = repo.summary().ok_or_else(|| {
744 format!(
745 "No summary available for '{}'. Repos with summaries: {}",
746 input.repo,
747 self.repos_with_summaries().join(", ")
748 )
749 })?;
750 Ok(ToolResult::new(summary.to_string()))
751 }
752 _ => Err(format!("Unknown tool: {name}")),
753 }
754 }
755}