1use rusqlite::{params, Connection};
10use std::path::PathBuf;
11
12pub struct CompletionEngine {
13 conn: Connection,
14}
15
16#[derive(Debug, Clone)]
17pub struct Completion {
18 pub name: String,
19 pub kind: CompletionKind,
20 pub description: Option<String>,
21 pub frequency: u32,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum CompletionKind {
26 Command,
27 Builtin,
28 Function,
29 Alias,
30 File,
31 Directory,
32 Variable,
33 Option,
34}
35
36impl CompletionKind {
37 pub fn as_str(&self) -> &'static str {
38 match self {
39 Self::Command => "command",
40 Self::Builtin => "builtin",
41 Self::Function => "function",
42 Self::Alias => "alias",
43 Self::File => "file",
44 Self::Directory => "directory",
45 Self::Variable => "variable",
46 Self::Option => "option",
47 }
48 }
49
50 fn from_str(s: &str) -> Self {
51 match s {
52 "command" => Self::Command,
53 "builtin" => Self::Builtin,
54 "function" => Self::Function,
55 "alias" => Self::Alias,
56 "file" => Self::File,
57 "directory" => Self::Directory,
58 "variable" => Self::Variable,
59 "option" => Self::Option,
60 _ => Self::Command,
61 }
62 }
63}
64
65impl CompletionEngine {
66 pub fn new() -> rusqlite::Result<Self> {
67 let db_path = Self::db_path();
68 std::fs::create_dir_all(db_path.parent().unwrap()).ok();
69 let conn = Connection::open(&db_path)?;
70
71 let engine = Self { conn };
72 engine.init_schema()?;
73 Ok(engine)
74 }
75
76 pub fn in_memory() -> rusqlite::Result<Self> {
77 let conn = Connection::open_in_memory()?;
78 let engine = Self { conn };
79 engine.init_schema()?;
80 Ok(engine)
81 }
82
83 fn db_path() -> PathBuf {
84 dirs::cache_dir()
85 .unwrap_or_else(|| PathBuf::from("."))
86 .join("zshrs")
87 .join("completions.db")
88 }
89
90 fn init_schema(&self) -> rusqlite::Result<()> {
91 self.conn.execute_batch(
92 r#"
93 CREATE TABLE IF NOT EXISTS completions (
94 id INTEGER PRIMARY KEY,
95 name TEXT NOT NULL,
96 kind TEXT NOT NULL,
97 description TEXT,
98 frequency INTEGER DEFAULT 0,
99 UNIQUE(name, kind)
100 );
101
102 CREATE VIRTUAL TABLE IF NOT EXISTS completions_fts USING fts5(
103 name,
104 description,
105 content='completions',
106 content_rowid='id'
107 );
108
109 CREATE TRIGGER IF NOT EXISTS completions_ai AFTER INSERT ON completions BEGIN
110 INSERT INTO completions_fts(rowid, name, description)
111 VALUES (new.id, new.name, new.description);
112 END;
113
114 CREATE TRIGGER IF NOT EXISTS completions_ad AFTER DELETE ON completions BEGIN
115 INSERT INTO completions_fts(completions_fts, rowid, name, description)
116 VALUES ('delete', old.id, old.name, old.description);
117 END;
118
119 CREATE TRIGGER IF NOT EXISTS completions_au AFTER UPDATE ON completions BEGIN
120 INSERT INTO completions_fts(completions_fts, rowid, name, description)
121 VALUES ('delete', old.id, old.name, old.description);
122 INSERT INTO completions_fts(rowid, name, description)
123 VALUES (new.id, new.name, new.description);
124 END;
125
126 CREATE INDEX IF NOT EXISTS idx_completions_name ON completions(name);
127 CREATE INDEX IF NOT EXISTS idx_completions_kind ON completions(kind);
128 CREATE INDEX IF NOT EXISTS idx_completions_frequency ON completions(frequency DESC);
129 "#,
130 )?;
131 Ok(())
132 }
133
134 pub fn add_completion(
135 &self,
136 name: &str,
137 kind: CompletionKind,
138 description: Option<&str>,
139 ) -> rusqlite::Result<()> {
140 self.conn.execute(
141 "INSERT OR IGNORE INTO completions (name, kind, description) VALUES (?1, ?2, ?3)",
142 params![name, kind.as_str(), description],
143 )?;
144 Ok(())
145 }
146
147 pub fn add_completions(
148 &self,
149 completions: &[(String, CompletionKind, Option<String>)],
150 ) -> rusqlite::Result<()> {
151 let tx = self.conn.unchecked_transaction()?;
152 {
153 let mut stmt = self.conn.prepare(
154 "INSERT OR IGNORE INTO completions (name, kind, description) VALUES (?1, ?2, ?3)",
155 )?;
156 for (name, kind, desc) in completions {
157 stmt.execute(params![name, kind.as_str(), desc.as_deref()])?;
158 }
159 }
160 tx.commit()?;
161 Ok(())
162 }
163
164 pub fn increment_frequency(&self, name: &str) -> rusqlite::Result<()> {
165 self.conn.execute(
166 "UPDATE completions SET frequency = frequency + 1 WHERE name = ?1",
167 params![name],
168 )?;
169 Ok(())
170 }
171
172 pub fn search(&self, query: &str, limit: usize) -> rusqlite::Result<Vec<Completion>> {
173 if query.is_empty() {
174 return self.get_top_by_frequency(limit);
175 }
176
177 let prefix_results = self.search_prefix(query, limit)?;
179 if prefix_results.len() >= limit {
180 return Ok(prefix_results);
181 }
182
183 self.search_fts(query, limit)
185 }
186
187 fn search_prefix(&self, prefix: &str, limit: usize) -> rusqlite::Result<Vec<Completion>> {
188 let mut stmt = self.conn.prepare(
189 "SELECT name, kind, description, frequency FROM completions
190 WHERE name LIKE ?1 || '%'
191 ORDER BY frequency DESC, name ASC
192 LIMIT ?2",
193 )?;
194
195 let rows = stmt.query_map(params![prefix, limit as i64], |row| {
196 Ok(Completion {
197 name: row.get(0)?,
198 kind: CompletionKind::from_str(&row.get::<_, String>(1)?),
199 description: row.get(2)?,
200 frequency: row.get(3)?,
201 })
202 })?;
203
204 rows.collect()
205 }
206
207 fn search_fts(&self, query: &str, limit: usize) -> rusqlite::Result<Vec<Completion>> {
208 let fts_query = format!("{}*", query);
209 let mut stmt = self.conn.prepare(
210 "SELECT c.name, c.kind, c.description, c.frequency
211 FROM completions c
212 JOIN completions_fts fts ON c.id = fts.rowid
213 WHERE completions_fts MATCH ?1
214 ORDER BY c.frequency DESC, rank
215 LIMIT ?2",
216 )?;
217
218 let rows = stmt.query_map(params![fts_query, limit as i64], |row| {
219 Ok(Completion {
220 name: row.get(0)?,
221 kind: CompletionKind::from_str(&row.get::<_, String>(1)?),
222 description: row.get(2)?,
223 frequency: row.get(3)?,
224 })
225 })?;
226
227 rows.collect()
228 }
229
230 fn get_top_by_frequency(&self, limit: usize) -> rusqlite::Result<Vec<Completion>> {
231 let mut stmt = self.conn.prepare(
232 "SELECT name, kind, description, frequency FROM completions
233 ORDER BY frequency DESC, name ASC
234 LIMIT ?1",
235 )?;
236
237 let rows = stmt.query_map(params![limit as i64], |row| {
238 Ok(Completion {
239 name: row.get(0)?,
240 kind: CompletionKind::from_str(&row.get::<_, String>(1)?),
241 description: row.get(2)?,
242 frequency: row.get(3)?,
243 })
244 })?;
245
246 rows.collect()
247 }
248
249 pub fn count(&self) -> rusqlite::Result<usize> {
250 self.conn
251 .query_row("SELECT COUNT(*) FROM completions", [], |row| row.get(0))
252 }
253
254 pub fn index_system_commands(&self) -> rusqlite::Result<usize> {
255 let path = std::env::var("PATH").unwrap_or_default();
256 let mut completions = Vec::new();
257
258 for dir in path.split(':') {
259 if let Ok(entries) = std::fs::read_dir(dir) {
260 for entry in entries.flatten() {
261 if let Ok(ft) = entry.file_type() {
262 if ft.is_file() || ft.is_symlink() {
263 if let Some(name) = entry.file_name().to_str() {
264 completions.push((name.to_string(), CompletionKind::Command, None));
265 }
266 }
267 }
268 }
269 }
270 }
271
272 let count = completions.len();
273 self.add_completions(&completions)?;
274 Ok(count)
275 }
276
277 pub fn index_shell_builtins(&self) -> rusqlite::Result<usize> {
278 let builtins = [
279 ("cd", "Change directory"),
280 ("pwd", "Print working directory"),
281 ("echo", "Print arguments"),
282 ("export", "Set environment variable"),
283 ("unset", "Unset environment variable"),
284 ("alias", "Define alias"),
285 ("unalias", "Remove alias"),
286 ("source", "Execute file in current shell"),
287 ("exit", "Exit the shell"),
288 ("jobs", "List background jobs"),
289 ("fg", "Bring job to foreground"),
290 ("bg", "Continue job in background"),
291 ("history", "Show command history"),
292 ("set", "Set shell options"),
293 ("unset", "Unset shell options"),
294 ("type", "Show command type"),
295 ("which", "Show command path"),
296 ("builtin", "Execute builtin command"),
297 ("command", "Execute external command"),
298 ("exec", "Replace shell with command"),
299 ("eval", "Evaluate arguments as command"),
300 ("read", "Read input"),
301 ("printf", "Formatted print"),
302 ("test", "Evaluate conditional expression"),
303 ("true", "Return success"),
304 ("false", "Return failure"),
305 (":", "Null command"),
306 ("return", "Return from function"),
307 ("break", "Break from loop"),
308 ("continue", "Continue loop"),
309 ("shift", "Shift positional parameters"),
310 ("wait", "Wait for background jobs"),
311 ("trap", "Set signal handler"),
312 ("umask", "Set file creation mask"),
313 ("ulimit", "Set resource limits"),
314 ("times", "Show shell times"),
315 ("kill", "Send signal to process"),
316 ("let", "Evaluate arithmetic expression"),
317 ("declare", "Declare variable"),
318 ("local", "Declare local variable"),
319 ("readonly", "Make variable readonly"),
320 ("typeset", "Declare variable type"),
321 ("hash", "Remember command path"),
322 ("dirs", "Show directory stack"),
323 ("pushd", "Push directory"),
324 ("popd", "Pop directory"),
325 ("getopts", "Parse options"),
326 ("enable", "Enable/disable builtins"),
327 ("logout", "Exit login shell"),
328 ("suspend", "Suspend shell"),
329 ("disown", "Remove job from table"),
330 ];
331
332 let completions: Vec<_> = builtins
333 .iter()
334 .map(|(name, desc)| {
335 (
336 name.to_string(),
337 CompletionKind::Builtin,
338 Some(desc.to_string()),
339 )
340 })
341 .collect();
342
343 let count = completions.len();
344 self.add_completions(&completions)?;
345 Ok(count)
346 }
347}
348
349impl Default for CompletionEngine {
350 fn default() -> Self {
351 Self::new().expect("Failed to create completion engine")
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_completion_engine() {
361 let engine = CompletionEngine::in_memory().unwrap();
362
363 engine
364 .add_completion("git", CompletionKind::Command, Some("Version control"))
365 .unwrap();
366 engine
367 .add_completion("grep", CompletionKind::Command, Some("Search text"))
368 .unwrap();
369 engine
370 .add_completion("gzip", CompletionKind::Command, Some("Compress files"))
371 .unwrap();
372
373 let results = engine.search("g", 10).unwrap();
374 assert_eq!(results.len(), 3);
375
376 let results = engine.search("gi", 10).unwrap();
377 assert_eq!(results.len(), 1);
378 assert_eq!(results[0].name, "git");
379 }
380
381 #[test]
382 fn test_frequency_ranking() {
383 let engine = CompletionEngine::in_memory().unwrap();
384
385 engine
386 .add_completion("aaa", CompletionKind::Command, None)
387 .unwrap();
388 engine
389 .add_completion("aab", CompletionKind::Command, None)
390 .unwrap();
391
392 for _ in 0..5 {
394 engine.increment_frequency("aab").unwrap();
395 }
396
397 let results = engine.search("aa", 10).unwrap();
398 assert_eq!(results[0].name, "aab"); }
400}