Skip to main content

zag_agent/
search.rs

1use crate::session_log::{
2    AgentLogEvent, LogEventKind, SessionLogIndex, SessionLogIndexEntry, ToolKind,
3};
4use anyhow::{Context, Result, bail};
5use chrono::{DateTime, Duration, NaiveDate, Utc};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::io::{BufRead, BufReader};
9use std::path::{Path, PathBuf};
10
11/// Query parameters for searching session logs.
12#[derive(Debug, Default)]
13pub struct SearchQuery {
14    /// Full-text pattern (literal substring or regex). None matches all events.
15    pub text: Option<String>,
16    /// Case-insensitive match (default: true).
17    pub case_insensitive: bool,
18    /// Treat `text` as a regular expression (default: false → literal substring).
19    pub use_regex: bool,
20    /// Filter by provider name (case-insensitive).
21    pub provider: Option<String>,
22    /// Filter by message role — only applies to `UserMessage` events.
23    pub role: Option<String>,
24    /// Filter by tool name (case-insensitive substring) — only applies to tool events.
25    pub tool: Option<String>,
26    /// Filter by tool kind — only applies to `ToolCall`/`ToolResult` events.
27    pub tool_kind: Option<ToolKind>,
28    /// Show only events at or after this timestamp.
29    pub from: Option<DateTime<Utc>>,
30    /// Show only events at or before this timestamp.
31    pub to: Option<DateTime<Utc>>,
32    /// Restrict search to a specific session ID (prefix match).
33    pub session_id: Option<String>,
34    /// Filter by session tag (exact match, case-insensitive).
35    pub tag: Option<String>,
36    /// Search all sessions across all projects (default: current project and sub-projects).
37    pub global: bool,
38    /// Stop after collecting this many matches.
39    pub limit: Option<usize>,
40}
41
42impl SearchQuery {
43    pub fn new() -> Self {
44        Self {
45            case_insensitive: true,
46            ..Default::default()
47        }
48    }
49}
50
51/// A single event that matched the search query.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct SearchMatch {
54    pub session_id: String,
55    pub provider: String,
56    pub started_at: String,
57    pub ended_at: Option<String>,
58    pub workspace_path: Option<String>,
59    pub command: Option<String>,
60    pub event: AgentLogEvent,
61    /// Short excerpt (~200 chars) of the matched text.
62    pub snippet: String,
63}
64
65/// Aggregate results from a search.
66#[derive(Debug, Default)]
67pub struct SearchResults {
68    pub total_sessions_scanned: usize,
69    pub total_events_scanned: usize,
70    pub total_files_missing: usize,
71    pub matches: Vec<SearchMatch>,
72}
73
74// ---------------------------------------------------------------------------
75// Date parsing
76// ---------------------------------------------------------------------------
77
78/// Parse a date/time string for `--from` / `--to` filters.
79///
80/// Accepted formats:
81/// - RFC 3339 (e.g. `2024-01-15T10:30:00Z`)
82/// - Date only (e.g. `2024-01-15`) — interpreted as start of day UTC
83/// - Relative offset from now: `1h`, `2d`, `3w`, `1m` (hours/days/weeks/months)
84pub fn parse_date_arg(s: &str) -> Result<DateTime<Utc>> {
85    // Try RFC 3339 first.
86    if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
87        return Ok(dt.with_timezone(&Utc));
88    }
89
90    // Try date-only (YYYY-MM-DD).
91    if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
92        let dt = date
93            .and_hms_opt(0, 0, 0)
94            .expect("midnight is always valid")
95            .and_utc();
96        return Ok(dt);
97    }
98
99    // Try relative offset: leading digits followed by a unit character.
100    let s_trimmed = s.trim();
101    if !s_trimmed.is_empty() {
102        let unit = s_trimmed.chars().last().unwrap();
103        let digits = &s_trimmed[..s_trimmed.len() - unit.len_utf8()];
104        if let Ok(n) = digits.parse::<i64>() {
105            let delta = match unit {
106                'h' => Duration::hours(n),
107                'd' => Duration::days(n),
108                'w' => Duration::weeks(n),
109                'm' => Duration::days(n * 30),
110                _ => bail!(
111                    "Unknown time unit '{unit}'. Use h (hours), d (days), w (weeks), or m (months)."
112                ),
113            };
114            return Ok(Utc::now() - delta);
115        }
116    }
117
118    bail!(
119        "Cannot parse date '{s}'. Use RFC 3339 (2024-01-15T10:30:00Z), date only (2024-01-15), or relative (1h, 2d, 3w, 1m)."
120    )
121}
122
123// ---------------------------------------------------------------------------
124// Text matcher
125// ---------------------------------------------------------------------------
126
127enum TextMatcher {
128    /// No text filter — everything matches.
129    None,
130    /// Case-insensitive literal substring.
131    Literal(String),
132    /// Compiled regex.
133    Pattern(Regex),
134}
135
136impl TextMatcher {
137    fn build(query: &SearchQuery) -> Result<Self> {
138        let Some(ref text) = query.text else {
139            return Ok(Self::None);
140        };
141        if query.use_regex {
142            let pattern = if query.case_insensitive {
143                format!("(?i){text}")
144            } else {
145                text.clone()
146            };
147            let re =
148                Regex::new(&pattern).with_context(|| format!("Invalid regex pattern: '{text}'"))?;
149            Ok(Self::Pattern(re))
150        } else if query.case_insensitive {
151            Ok(Self::Literal(text.to_lowercase()))
152        } else {
153            Ok(Self::Literal(text.clone()))
154        }
155    }
156
157    fn is_match(&self, haystack: &str) -> bool {
158        match self {
159            Self::None => true,
160            Self::Literal(needle) => haystack.to_lowercase().contains(needle.as_str()),
161            Self::Pattern(re) => re.is_match(haystack),
162        }
163    }
164
165    fn find_offset(&self, haystack: &str) -> Option<usize> {
166        match self {
167            Self::None => Some(0),
168            Self::Literal(needle) => haystack.to_lowercase().find(needle.as_str()),
169            Self::Pattern(re) => re.find(haystack).map(|m| m.start()),
170        }
171    }
172
173    fn has_filter(&self) -> bool {
174        !matches!(self, Self::None)
175    }
176}
177
178// ---------------------------------------------------------------------------
179// Content extraction
180// ---------------------------------------------------------------------------
181
182fn extract_searchable_text(event: &AgentLogEvent) -> String {
183    let mut parts: Vec<String> = Vec::new();
184
185    match &event.kind {
186        LogEventKind::SessionStarted {
187            command,
188            model,
189            cwd,
190            ..
191        } => {
192            parts.push(command.clone());
193            if let Some(m) = model {
194                parts.push(m.clone());
195            }
196            if let Some(c) = cwd {
197                parts.push(c.clone());
198            }
199        }
200        LogEventKind::UserMessage { role, content, .. } => {
201            parts.push(role.clone());
202            parts.push(content.clone());
203        }
204        LogEventKind::AssistantMessage { content, .. } => {
205            parts.push(content.clone());
206        }
207        LogEventKind::Reasoning { content, .. } => {
208            parts.push(content.clone());
209        }
210        LogEventKind::ToolCall {
211            tool_name, input, ..
212        } => {
213            parts.push(tool_name.clone());
214            if let Some(v) = input {
215                parts.push(v.to_string());
216            }
217        }
218        LogEventKind::ToolResult {
219            tool_name,
220            output,
221            error,
222            data,
223            ..
224        } => {
225            if let Some(n) = tool_name {
226                parts.push(n.clone());
227            }
228            if let Some(o) = output {
229                parts.push(o.clone());
230            }
231            if let Some(e) = error {
232                parts.push(e.clone());
233            }
234            if let Some(d) = data {
235                parts.push(d.to_string());
236            }
237        }
238        LogEventKind::Permission {
239            tool_name,
240            description,
241            ..
242        } => {
243            parts.push(tool_name.clone());
244            parts.push(description.clone());
245        }
246        LogEventKind::ProviderStatus { message, .. } => {
247            parts.push(message.clone());
248        }
249        LogEventKind::Stderr { message } => {
250            parts.push(message.clone());
251        }
252        LogEventKind::ParseWarning { message, raw } => {
253            parts.push(message.clone());
254            if let Some(r) = raw {
255                parts.push(r.clone());
256            }
257        }
258        LogEventKind::SessionEnded { error, .. } => {
259            if let Some(e) = error {
260                parts.push(e.clone());
261            }
262        }
263        LogEventKind::SessionCleared { .. } => {}
264        LogEventKind::Heartbeat { .. } => {}
265        LogEventKind::Usage { .. } => {}
266        LogEventKind::UserEvent { message, .. } => {
267            parts.push(message.clone());
268        }
269    }
270
271    parts.join(" ")
272}
273
274// ---------------------------------------------------------------------------
275// Snippet builder
276// ---------------------------------------------------------------------------
277
278fn make_snippet(text: &str, matcher: &TextMatcher, max_len: usize) -> String {
279    let offset = matcher.find_offset(text).unwrap_or(0);
280
281    let start = offset.saturating_sub(max_len / 4);
282    let end = (start + max_len).min(text.len());
283
284    // Clamp to char boundaries.
285    let start = text
286        .char_indices()
287        .map(|(i, _)| i)
288        .rfind(|&i| i <= start)
289        .unwrap_or(0);
290    let end = text
291        .char_indices()
292        .map(|(i, _)| i)
293        .find(|&i| i >= end)
294        .unwrap_or(text.len());
295
296    let mut snippet = String::new();
297    if start > 0 {
298        snippet.push_str("[...] ");
299    }
300    snippet.push_str(&text[start..end]);
301    if end < text.len() {
302        snippet.push_str(" [...]");
303    }
304    snippet
305}
306
307// ---------------------------------------------------------------------------
308// Metadata pre-filter
309// ---------------------------------------------------------------------------
310
311fn session_matches_query(entry: &SessionLogIndexEntry, query: &SearchQuery) -> bool {
312    // Provider filter
313    if let Some(ref p) = query.provider
314        && !entry.provider.eq_ignore_ascii_case(p)
315    {
316        return false;
317    }
318
319    // Session ID prefix filter
320    if let Some(ref sid) = query.session_id
321        && !entry.wrapper_session_id.starts_with(sid.as_str())
322    {
323        return false;
324    }
325
326    // Date range: skip sessions that definitely ended before `from`
327    if let Some(from) = query.from
328        && let Some(ref ended) = entry.ended_at
329        && let Ok(ended_dt) = DateTime::parse_from_rfc3339(ended)
330        && ended_dt.with_timezone(&Utc) < from
331    {
332        return false;
333    }
334
335    // Date range: skip sessions that started after `to`
336    if let Some(to) = query.to
337        && let Ok(started_dt) = DateTime::parse_from_rfc3339(&entry.started_at)
338        && started_dt.with_timezone(&Utc) > to
339    {
340        return false;
341    }
342
343    true
344}
345
346// ---------------------------------------------------------------------------
347// Event filter
348// ---------------------------------------------------------------------------
349
350fn event_matches_query(event: &AgentLogEvent, query: &SearchQuery, matcher: &TextMatcher) -> bool {
351    // Provider filter at event level
352    if let Some(ref p) = query.provider
353        && !event.provider.eq_ignore_ascii_case(p)
354    {
355        return false;
356    }
357
358    // Date range filters
359    if (query.from.is_some() || query.to.is_some())
360        && let Ok(event_dt) = DateTime::parse_from_rfc3339(&event.ts)
361    {
362        let event_utc = event_dt.with_timezone(&Utc);
363        if let Some(from) = query.from
364            && event_utc < from
365        {
366            return false;
367        }
368        if let Some(to) = query.to
369            && event_utc > to
370        {
371            return false;
372        }
373    }
374
375    // Tool kind / tool name / role filters
376    let has_tool_filter = query.tool.is_some() || query.tool_kind.is_some();
377    let has_role_filter = query.role.is_some();
378
379    if has_tool_filter {
380        match &event.kind {
381            LogEventKind::ToolCall {
382                tool_name,
383                tool_kind,
384                ..
385            } => {
386                if let Some(ref t) = query.tool
387                    && !tool_name.to_lowercase().contains(&t.to_lowercase())
388                {
389                    return false;
390                }
391                if let Some(ref tk) = query.tool_kind {
392                    let kind = tool_kind.unwrap_or_else(|| ToolKind::infer(tool_name));
393                    if kind != *tk {
394                        return false;
395                    }
396                }
397            }
398            LogEventKind::ToolResult {
399                tool_name,
400                tool_kind,
401                ..
402            } => {
403                if let Some(ref t) = query.tool {
404                    let name = tool_name.as_deref().unwrap_or("");
405                    if !name.to_lowercase().contains(&t.to_lowercase()) {
406                        return false;
407                    }
408                }
409                if let Some(ref tk) = query.tool_kind {
410                    let kind = tool_kind.unwrap_or_else(|| {
411                        tool_name
412                            .as_deref()
413                            .map(ToolKind::infer)
414                            .unwrap_or(ToolKind::Other)
415                    });
416                    if kind != *tk {
417                        return false;
418                    }
419                }
420            }
421            // Non-tool events are excluded when a tool filter is active
422            _ => return false,
423        }
424    }
425
426    if has_role_filter {
427        match &event.kind {
428            LogEventKind::UserMessage { role, .. } => {
429                if let Some(ref r) = query.role
430                    && !role.eq_ignore_ascii_case(r)
431                {
432                    return false;
433                }
434            }
435            // Non-message events are excluded when a role filter is active
436            // (unless combined with a tool filter, which we already handled above)
437            _ if !has_tool_filter => return false,
438            _ => {}
439        }
440    }
441
442    // Text filter
443    if matcher.has_filter() {
444        let text = extract_searchable_text(event);
445        if !matcher.is_match(&text) {
446            return false;
447        }
448    }
449
450    true
451}
452
453// ---------------------------------------------------------------------------
454// JSONL scanner
455// ---------------------------------------------------------------------------
456
457struct ScanResult {
458    events_scanned: usize,
459    matching_events: Vec<AgentLogEvent>,
460}
461
462fn scan_session(log_path: &Path, query: &SearchQuery, matcher: &TextMatcher) -> Result<ScanResult> {
463    let file = std::fs::File::open(log_path)
464        .with_context(|| format!("Failed to open log file: {}", log_path.display()))?;
465    let reader = BufReader::new(file);
466
467    let mut result = ScanResult {
468        events_scanned: 0,
469        matching_events: Vec::new(),
470    };
471
472    for line in reader.lines() {
473        let line =
474            line.with_context(|| format!("Failed to read line in {}", log_path.display()))?;
475        let line = line.trim();
476        if line.is_empty() {
477            continue;
478        }
479
480        let event: AgentLogEvent = match serde_json::from_str(line) {
481            Ok(e) => e,
482            Err(e) => {
483                log::debug!(
484                    "Skipping malformed JSONL line in {}: {}",
485                    log_path.display(),
486                    e
487                );
488                continue;
489            }
490        };
491
492        result.events_scanned += 1;
493
494        if event_matches_query(&event, query, matcher) {
495            result.matching_events.push(event);
496        }
497    }
498
499    Ok(result)
500}
501
502// ---------------------------------------------------------------------------
503// Session discovery
504// ---------------------------------------------------------------------------
505
506fn collect_candidate_sessions(
507    query: &SearchQuery,
508    zag_home: &Path,
509    cwd: &Path,
510) -> Result<Vec<(SessionLogIndexEntry, PathBuf)>> {
511    let projects_dir = zag_home.join("projects");
512    if !projects_dir.exists() {
513        return Ok(Vec::new());
514    }
515
516    // If tag filter is set, collect matching session IDs from session stores.
517    let tag_session_ids: Option<std::collections::HashSet<String>> = if query.tag.is_some() {
518        let store = if query.global {
519            crate::session::SessionStore::load_all().unwrap_or_default()
520        } else {
521            crate::session::SessionStore::load(Some(&cwd.to_string_lossy())).unwrap_or_default()
522        };
523        let tag = query.tag.as_deref().unwrap();
524        let matching = store.find_by_tag(tag);
525        Some(matching.into_iter().map(|e| e.session_id.clone()).collect())
526    } else {
527        None
528    };
529
530    let cwd_str = cwd.to_string_lossy().to_string();
531    let mut candidates: Vec<(SessionLogIndexEntry, PathBuf)> = Vec::new();
532    let mut seen_ids = std::collections::HashSet::new();
533
534    let read_dir = std::fs::read_dir(&projects_dir)
535        .with_context(|| format!("Failed to read {}", projects_dir.display()))?;
536
537    for entry in read_dir {
538        let project_dir = match entry {
539            Ok(e) => e.path(),
540            Err(_) => continue,
541        };
542        if !project_dir.is_dir() {
543            continue;
544        }
545
546        let index_path = project_dir.join("logs").join("index.json");
547        if !index_path.exists() {
548            continue;
549        }
550
551        let content = match std::fs::read_to_string(&index_path) {
552            Ok(c) => c,
553            Err(e) => {
554                log::warn!("Failed to read index {}: {}", index_path.display(), e);
555                continue;
556            }
557        };
558
559        let index: SessionLogIndex = match serde_json::from_str(&content) {
560            Ok(i) => i,
561            Err(e) => {
562                log::warn!("Malformed index {}: {}", index_path.display(), e);
563                continue;
564            }
565        };
566
567        for session_entry in index.sessions {
568            // Scope filter: in non-global mode, only include sessions whose workspace_path
569            // is within the current directory tree.
570            if !query.global {
571                let in_scope = match &session_entry.workspace_path {
572                    Some(wp) => {
573                        // Match if workspace is the cwd or a subdirectory of cwd
574                        wp == &cwd_str
575                            || wp.starts_with(&format!("{cwd_str}/"))
576                            || wp.starts_with(&format!("{cwd_str}\\"))
577                    }
578                    None => false,
579                };
580                if !in_scope {
581                    continue;
582                }
583            }
584
585            // Metadata pre-filter (provider, session ID, dates)
586            if !session_matches_query(&session_entry, query) {
587                continue;
588            }
589
590            // Tag filter: only include sessions matching the tag
591            if let Some(ref allowed) = tag_session_ids {
592                if !allowed.contains(&session_entry.wrapper_session_id) {
593                    continue;
594                }
595            }
596
597            // Deduplicate by session ID
598            if !seen_ids.insert(session_entry.wrapper_session_id.clone()) {
599                continue;
600            }
601
602            let log_path = PathBuf::from(&session_entry.log_path);
603            candidates.push((session_entry, log_path));
604        }
605    }
606
607    // Sort by started_at so results are in chronological order
608    candidates.sort_by(|a, b| a.0.started_at.cmp(&b.0.started_at));
609
610    Ok(candidates)
611}
612
613// ---------------------------------------------------------------------------
614// Main entry point
615// ---------------------------------------------------------------------------
616
617/// Search through session logs matching the given query.
618pub fn search(query: &SearchQuery, zag_home: &Path, cwd: &Path) -> Result<SearchResults> {
619    let matcher = TextMatcher::build(query)?;
620
621    let candidates = collect_candidate_sessions(query, zag_home, cwd)?;
622
623    let mut results = SearchResults::default();
624
625    'outer: for (entry, log_path) in candidates {
626        results.total_sessions_scanned += 1;
627
628        if !log_path.exists() {
629            results.total_files_missing += 1;
630            log::debug!("Log file missing: {}", log_path.display());
631            continue;
632        }
633
634        let scan = match scan_session(&log_path, query, &matcher) {
635            Ok(s) => s,
636            Err(e) => {
637                log::warn!("Failed to scan {}: {}", log_path.display(), e);
638                continue;
639            }
640        };
641
642        results.total_events_scanned += scan.events_scanned;
643
644        for event in scan.matching_events {
645            let text = extract_searchable_text(&event);
646            let snippet = make_snippet(&text, &matcher, 200);
647
648            results.matches.push(SearchMatch {
649                session_id: entry.wrapper_session_id.clone(),
650                provider: entry.provider.clone(),
651                started_at: entry.started_at.clone(),
652                ended_at: entry.ended_at.clone(),
653                workspace_path: entry.workspace_path.clone(),
654                command: entry.command.clone(),
655                event,
656                snippet,
657            });
658
659            if let Some(limit) = query.limit
660                && results.matches.len() >= limit
661            {
662                break 'outer;
663            }
664        }
665    }
666
667    Ok(results)
668}
669
670#[cfg(test)]
671#[path = "search_tests.rs"]
672mod tests;