1use crate::config::{get_config_path, load_or_create_config, save_config, Config};
2use prettytable::{Table, Row, Cell};
3use reqwest::Client;
4use serde_json::Value;
5use std::io::{self, Write};
6use std::process::Command as ShellCommand;
7use std::fs;
8use async_trait::async_trait;
9use anyhow::Result;
10use futures_util::StreamExt;
11use crate::db::init_db;
12
13#[async_trait]
14pub trait AIProvider {
15 async fn ask_openai(&self, _messages: Vec<serde_json::Value>) -> Result<String> {
16 Err(anyhow::anyhow!("Not implemented"))
17 }
18 async fn ask_ollama(&self, _prompt: &str) -> Result<String> {
19 Err(anyhow::anyhow!("Not implemented"))
20 }
21}
22
23pub struct OpenAIProvider {
24 pub model: String,
25 pub api_key: String,
26}
27
28#[async_trait]
29impl AIProvider for OpenAIProvider {
30 async fn ask_openai(&self, messages: Vec<serde_json::Value>) -> Result<String> {
31 let client = reqwest::Client::new();
32 let body = serde_json::json!({
33 "model": self.model,
34 "messages": messages,
35 "stream": true
36 });
37 let res = client
38 .post("https://api.openai.com/v1/chat/completions")
39 .bearer_auth(&self.api_key)
40 .json(&body)
41 .send()
42 .await?;
43 let status = res.status();
44 let mut full = String::new();
45 if !status.is_success() {
46 let err_text = res.text().await.unwrap_or_default();
47 eprintln!("OpenAI API error: {}\n{}", status, err_text);
48 return Ok(String::new());
49 }
50 let mut stream = res.bytes_stream();
51 let mut got_content = false;
52 while let Some(chunk) = stream.next().await {
53 let chunk = chunk?;
54 for line in chunk.split(|&b| b == b'\n') {
55 if line.starts_with(b"data: ") {
56 let json = &line[6..];
57 if json == b"[DONE]" { continue; }
58 if let Ok(val) = serde_json::from_slice::<serde_json::Value>(json) {
59 if let Some(content) = val["choices"][0]["delta"]["content"].as_str() {
60 print!("{}", content);
61 std::io::stdout().flush().ok();
62 full.push_str(content);
63 got_content = true;
64 }
65 }
66 } else if !line.is_empty() {
67 }
69 }
70 }
71 if !got_content {
72 eprintln!("No response from OpenAI. Check your API key, model, or network.");
73 }
74 println!();
75 Ok(full)
76 }
77}
78
79pub struct OllamaProvider {
80 pub model: String,
81}
82
83#[async_trait]
84impl AIProvider for OllamaProvider {
85 async fn ask_ollama(&self, prompt: &str) -> Result<String> {
86 use std::process::Command;
87 let output = Command::new("ollama")
88 .arg("run")
89 .arg(&self.model)
90 .arg(prompt)
91 .output()?;
92 let response = String::from_utf8_lossy(&output.stdout).to_string();
93 Ok(response)
94 }
95}
96
97pub enum Provider {
98 OpenAI(OpenAIProvider),
99 Ollama(OllamaProvider),
100}
101
102#[async_trait]
103impl AIProvider for Provider {
104 async fn ask_openai(&self, messages: Vec<serde_json::Value>) -> Result<String> {
105 match self {
106 Provider::OpenAI(p) => p.ask_openai(messages).await,
107 _ => Err(anyhow::anyhow!("Not OpenAI provider")),
108 }
109 }
110 async fn ask_ollama(&self, prompt: &str) -> Result<String> {
111 match self {
112 Provider::Ollama(p) => p.ask_ollama(prompt).await,
113 _ => Err(anyhow::anyhow!("Not Ollama provider")),
114 }
115 }
116}
117
118async fn fetch_openai_models(api_key: &str) -> Vec<String> {
119 let client = Client::new();
120 let res = client
121 .get("https://api.openai.com/v1/models")
122 .bearer_auth(api_key)
123 .send()
124 .await
125 .expect("failed to fetch models");
126 let j: Value = res.json().await.expect("invalid JSON");
127 j["data"]
128 .as_array()
129 .unwrap()
130 .iter()
131 .map(|m| m["id"].as_str().unwrap().to_string())
132 .collect()
133}
134
135fn fetch_ollama_local() -> Vec<String> {
136 let out = ShellCommand::new("ollama").arg("list").output().expect("ollama list failed");
137 let lines: Vec<String> = String::from_utf8_lossy(&out.stdout)
138 .lines()
139 .map(str::trim)
140 .filter(|l| !l.is_empty())
141 .map(String::from)
142 .collect();
143
144 if lines.len() > 1 {
145 lines.iter()
146 .skip(1)
147 .filter_map(|line| {
148 let parts: Vec<&str> = line.split_whitespace().collect();
149 if !parts.is_empty() {
150 Some(parts[0].to_string())
151 } else {
152 None
153 }
154 })
155 .collect()
156 } else {
157 Vec::new()
158 }
159}
160
161pub fn clear_history() {
163 let chat_id = match get_current_chat_id() {
164 Some(id) => id,
165 None => { eprintln!("No current chat selected. Start or switch to a chat first."); return; }
166 };
167 let conn = match init_db() {
168 Ok(c) => c,
169 Err(e) => { eprintln!("DB error: {}", e); return; }
170 };
171 if let Err(e) = conn.execute("DELETE FROM messages WHERE chat_id = ?1", [chat_id]) {
172 eprintln!("Failed to clear history: {}", e);
173 } else {
174 println!("✅ Cleared history for current chat");
175 }
176}
177
178pub fn setup() {
180 print!("Choose backend (1) ollama (2) openai: ");
182 io::stdout().flush().unwrap();
183 let mut c = String::new(); io::stdin().read_line(&mut c).unwrap();
184 let src = match c.trim() {
185 "1" => "ollama",
186 "2" => "openai",
187 _ => { eprintln!("invalid"); return; }
188 }.to_string();
189
190 let mut key = None;
191 if src == "openai" {
192 print!("Enter OpenAI API key: ");
193 io::stdout().flush().unwrap();
194 let mut k2 = String::new(); io::stdin().read_line(&mut k2).unwrap();
195 key = Some(k2.trim().to_string());
196 }
197
198 let default_model = if src == "openai" { "gpt-4".to_string() } else {
200 let loc = fetch_ollama_local();
201 loc.get(0).cloned().unwrap_or_else(|| {
202 eprintln!("no local ollama model installed");
203 std::process::exit(1);
204 })
205 };
206
207 let cfg = Config { source: src.clone(), model: default_model, openai_api_key: key };
208 save_config(&cfg);
209 println!("✅ setup complete");
210 println!("⚙️ config saved at {}", get_config_path().display());
211}
212
213pub fn show_config_path() {
215 println!("{}", get_config_path().display());
216}
217
218pub async fn set_gpt(gpt_model: &str) {
220 let mut cfg = load_or_create_config();
221 cfg.source = "openai".into();
222
223 if cfg.openai_api_key.is_none() {
225 print!("Enter OpenAI API key: ");
226 io::stdout().flush().unwrap();
227 let mut k = String::new(); io::stdin().read_line(&mut k).unwrap();
228 cfg.openai_api_key = Some(k.trim().to_string());
229 }
230
231 let models = fetch_openai_models(cfg.openai_api_key.as_ref().unwrap()).await;
233 if models.iter().any(|m| m == gpt_model) {
234 cfg.model = gpt_model.to_string();
235 save_config(&cfg);
236 println!("Switched to OpenAI model: {}", gpt_model);
237 println!("⚙️ config saved at {}", get_config_path().display());
238 } else {
239 println!("Model '{}' not found in available OpenAI models.", gpt_model);
240 println!("Available models:");
241 for model in models.iter().take(10) {
242 println!(" {}", model);
243 }
244 if models.len() > 10 {
245 println!(" ... and {} more", models.len() - 10);
246 println!("Run 'yo list' to see all available models");
247 }
248 }
249}
250
251pub async fn switch(model: &str) {
253 let mut cfg = load_or_create_config();
254 if model == "openai" {
255 cfg.source = "openai".into();
256 if cfg.openai_api_key.is_none() {
257 print!("Enter OpenAI API key: ");
258 io::stdout().flush().unwrap();
259 let mut k = String::new(); io::stdin().read_line(&mut k).unwrap();
260 cfg.openai_api_key = Some(k.trim().to_string());
261 }
262
263 if !cfg.model.starts_with("gpt-") &&
266 !["o1", "o3", "o4", "dall-e"].iter().any(|prefix| cfg.model.starts_with(prefix)) {
267 cfg.model = "gpt-4o".to_string();
269 }
270
271 println!("Switched to OpenAI model: {}", cfg.model);
272 save_config(&cfg);
273 println!("⚙️ config saved at {}", get_config_path().display());
274 return;
275 } else if model == "ollama" {
276 cfg.source = "ollama".into();
277 let loc = fetch_ollama_local();
278 if loc.is_empty() {
279 eprintln!("❌ No local Ollama models found. Please install one first with:");
280 eprintln!(" ollama pull llama3");
281 eprintln!("\nVisit https://ollama.com/search to discover available models.");
282 return;
283 }
284
285 if !cfg.model.starts_with("gpt-") &&
287 !["o1", "o3", "o4", "dall-e"].iter().any(|prefix| cfg.model.starts_with(prefix)) {
288 if loc.iter().any(|m| m == &cfg.model) {
291 println!("Using previously selected Ollama model: {}", cfg.model);
292 save_config(&cfg);
293 println!("⚙️ config saved at {}", get_config_path().display());
294 return;
295 }
296 }
297
298 cfg.model = loc[0].clone();
300 println!("Switched to Ollama model: {}", cfg.model);
301 } else {
302 eprintln!("usage: yo switch <ollama|openai>");
303 return;
304 }
305 save_config(&cfg);
306 println!("switched to {}:{}", cfg.source, cfg.model);
307 println!("⚙️ config saved at {}", get_config_path().display());
308}
309
310pub async fn list_models() {
312 let mut table = Table::new();
313 table.add_row(Row::new(vec![Cell::new("Src"), Cell::new("Model"), Cell::new("You")]));
314
315 let cfg = load_or_create_config();
316 if let Some(key) = cfg.openai_api_key.as_deref() {
317 for m in fetch_openai_models(key).await {
318 let you = if cfg.source=="openai" && cfg.model==m { "✔" } else { "" };
319 table.add_row(Row::new(vec![Cell::new("OpenAI"), Cell::new(&m), Cell::new(you)]));
320 }
321 }
322 for m in fetch_ollama_local() {
323 let you = if cfg.source=="ollama" && cfg.model==m { "✔" } else { "" };
324 table.add_row(Row::new(vec![Cell::new("Ollama"), Cell::new(&m), Cell::new(you)]));
325 }
326 table.printstd();
327}
328
329pub async fn ask(question: &[String]) {
331 let chat_id = match get_current_chat_id() {
332 Some(id) => id,
333 None => { eprintln!("No current chat selected. Start or switch to a chat first."); return; }
334 };
335 let conn = match init_db() {
336 Ok(c) => c,
337 Err(e) => { eprintln!("DB error: {}", e); return; }
338 };
339 let prompt = question.join(" ");
340 let _ = conn.execute(
342 "INSERT INTO messages (chat_id, role, content) VALUES (?1, 'user', ?2)",
343 (&chat_id, &prompt),
344 );
345 let cfg = load_or_create_config();
346 let mut stmt = conn.prepare("SELECT role, content FROM messages WHERE chat_id = ?1 ORDER BY created_at ASC").unwrap();
348 let history: Vec<(String, String)> = stmt
349 .query_map([chat_id], |row| Ok((row.get(0)?, row.get(1)?)))
350 .unwrap()
351 .flatten()
352 .collect();
353 match cfg.source.as_str() {
354 "openai" => {
355 let mut messages = vec![serde_json::json!({
356 "role": "system",
357 "content": "You are a helpful AI assistant."
358 })];
359 for (role, content) in &history {
360 messages.push(serde_json::json!({"role": role, "content": content}));
361 }
362 messages.push(serde_json::json!({"role": "user", "content": &prompt}));
363 let provider = Provider::OpenAI(OpenAIProvider {
364 model: cfg.model.clone(),
365 api_key: cfg.openai_api_key.clone().unwrap(),
366 });
367 match provider.ask_openai(messages).await {
368 Ok(response) => {
369 let _ = conn.execute(
370 "INSERT INTO messages (chat_id, role, content) VALUES (?1, 'assistant', ?2)",
371 (&chat_id, &response),
372 );
373 }
374 Err(e) => {
375 eprintln!("Error during AI call: {}", e);
376 }
377 }
378 }
379 "ollama" => {
380 let mut full_prompt = String::new();
381 for (role, content) in &history {
382 let who = match role.as_str() {
383 "user" => "User",
384 "assistant" => "AI",
385 _ => role.as_str(),
386 };
387 full_prompt.push_str(&format!("{}: {}\n", who, content));
388 }
389 full_prompt.push_str(&format!("User: {}\n", &prompt));
390 let provider = Provider::Ollama(OllamaProvider {
391 model: cfg.model.clone(),
392 });
393 match provider.ask_ollama(&full_prompt).await {
394 Ok(response) => {
395 println!("{}", response);
396 let _ = conn.execute(
397 "INSERT INTO messages (chat_id, role, content) VALUES (?1, 'assistant', ?2)",
398 (&chat_id, &response),
399 );
400 }
401 Err(e) => {
402 eprintln!("Error during AI call: {}", e);
403 }
404 }
405 }
406 _ => eprintln!("Unknown backend: {}", cfg.source),
407 }
408}
409
410pub fn show_current() {
412 let cfg = load_or_create_config();
413
414 println!("📋 Current AI Configuration");
415 println!("---------------------------");
416 println!("Backend: {}", cfg.source);
417 println!("Model: {}", cfg.model);
418
419 if cfg.source == "ollama" {
420 let output = ShellCommand::new("ollama")
421 .args(["show", &cfg.model])
422 .output();
423
424 if let Ok(out) = output {
425 let info = String::from_utf8_lossy(&out.stdout);
426 if !info.is_empty() {
427 let lines: Vec<&str> = info.lines().take(5).collect();
428 if !lines.is_empty() {
429 println!("\nModel Details:");
430 for line in lines {
431 println!(" {}", line);
432 }
433 }
434 }
435 }
436 } else if cfg.source == "openai" {
437 if let Some(api_key) = cfg.openai_api_key.as_deref() {
438 if api_key.len() > 7 {
439 let visible_part = &api_key[..7];
440 let masked_part = "*".repeat(api_key.len() / 4);
441 println!("\nAPI Key: {}{}", visible_part, masked_part);
442 } else {
443 println!("\nAPI Key: {}", "*".repeat(api_key.len()));
444 }
445 } else {
446 println!("\nAPI Key: [not set]");
447 }
448 }
449
450 println!("\n💡 Use 'yo list' to see all available models");
451}
452
453const CURRENT_CHAT_FILE: &str = "current_chat";
454
455fn set_current_chat_id(chat_id: i64) {
456 let config_dir = dirs::home_dir().unwrap().join(".config").join("yo");
457 let file_path = config_dir.join(CURRENT_CHAT_FILE);
458 let _ = fs::write(file_path, chat_id.to_string());
459}
460
461fn get_current_chat_id() -> Option<i64> {
462 let config_dir = dirs::home_dir().unwrap().join(".config").join("yo");
463 let file_path = config_dir.join(CURRENT_CHAT_FILE);
464 if let Ok(s) = fs::read_to_string(file_path) {
465 s.trim().parse().ok()
466 } else {
467 None
468 }
469}
470
471pub fn new_chat(title: Option<String>) {
472 let conn = match init_db() {
473 Ok(c) => c,
474 Err(e) => { eprintln!("DB error: {}", e); return; }
475 };
476 let title = title.unwrap_or_else(|| "New Chat".to_string());
477 let res = conn.execute(
478 "INSERT INTO chats (title) VALUES (?1)",
479 [&title],
480 );
481 match res {
482 Ok(_) => {
483 let chat_id = conn.last_insert_rowid();
484 set_current_chat_id(chat_id);
485 println!("✅ Started new chat '{}' (id: {})", title, chat_id);
486 },
487 Err(e) => eprintln!("Failed to create chat: {}", e),
488 }
489}
490
491pub fn list_chats() {
492 let conn = match init_db() {
493 Ok(c) => c,
494 Err(e) => { eprintln!("DB error: {}", e); return; }
495 };
496 let mut stmt = match conn.prepare("SELECT id, title, created_at FROM chats ORDER BY created_at DESC") {
497 Ok(s) => s,
498 Err(e) => { eprintln!("Query error: {}", e); return; }
499 };
500 let rows = stmt.query_map([], |row| {
501 Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?))
502 });
503 match rows {
504 Ok(rows) => {
505 println!("\nChats:");
506 for row in rows.flatten() {
507 println!(" [{}] {} (created: {})", row.0, row.1, row.2);
508 }
509 },
510 Err(e) => eprintln!("Failed to list chats: {}", e),
511 }
512}
513
514pub fn switch_chat(chat_id: i64) {
515 let conn = match init_db() {
516 Ok(c) => c,
517 Err(e) => { eprintln!("DB error: {}", e); return; }
518 };
519 let mut stmt = match conn.prepare("SELECT id, title FROM chats WHERE id = ?1") {
520 Ok(s) => s,
521 Err(e) => { eprintln!("Query error: {}", e); return; }
522 };
523 let result = stmt.query_row([chat_id], |row| {
524 Ok(row.get::<_, String>(1)?)
525 });
526 match result {
527 Ok(title) => {
528 set_current_chat_id(chat_id);
529 println!("✅ Switched to chat [{}] {}", chat_id, title);
530 }
531 Err(rusqlite::Error::QueryReturnedNoRows) => {
532 eprintln!("Chat ID {} not found.", chat_id);
533 }
534 Err(e) => eprintln!("Failed to switch chat: {}", e),
535 }
536}
537
538pub fn set_profile(pair: &str) {
539 let parts: Vec<&str> = pair.splitn(2, '=').collect();
540 if parts.len() != 2 {
541 eprintln!("Invalid format. Use key=value");
542 return;
543 }
544 let key = parts[0].trim();
545 let value = parts[1].trim();
546 let conn = match init_db() {
547 Ok(c) => c,
548 Err(e) => { eprintln!("DB error: {}", e); return; }
549 };
550 let res = conn.execute(
551 "INSERT INTO user_profile (key, value) VALUES (?1, ?2) ON CONFLICT(key) DO UPDATE SET value=excluded.value",
552 [key, value],
553 );
554 match res {
555 Ok(_) => println!("✅ Set profile: {} = {}", key, value),
556 Err(e) => eprintln!("Failed to set profile: {}", e),
557 }
558}
559
560pub fn summarize_chat(chat_id: i64) {
561 let conn = match init_db() {
562 Ok(c) => c,
563 Err(e) => { eprintln!("DB error: {}", e); return; }
564 };
565 let mut stmt = match conn.prepare("SELECT role, content FROM messages WHERE chat_id = ?1 ORDER BY created_at ASC") {
566 Ok(s) => s,
567 Err(e) => { eprintln!("Query error: {}", e); return; }
568 };
569 let rows = stmt.query_map([chat_id], |row| {
570 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
571 });
572 match rows {
573 Ok(rows) => {
574 let mut full_chat = String::new();
575 for row in rows.flatten() {
576 let (role, content) = row;
577 let who = match role.as_str() { "user" => "You", "assistant" => "AI", _ => &role };
578 full_chat.push_str(&format!("{}: {}\n", who, content));
579 }
580 println!("\n--- Chat #{} Summary (stub) ---\n{}\n-------------------------------\n", chat_id, full_chat);
581 },
583 Err(e) => eprintln!("Failed to summarize chat: {}", e),
584 }
585}
586
587pub fn search_chats(query: &str) {
588 let conn = match init_db() {
589 Ok(c) => c,
590 Err(e) => { eprintln!("DB error: {}", e); return; }
591 };
592 let sql = "SELECT chat_id, created_at, role, content FROM messages WHERE content LIKE ?1 ORDER BY chat_id, created_at";
593 let pattern = format!("%{}%", query);
594 let mut stmt = match conn.prepare(sql) {
595 Ok(s) => s,
596 Err(e) => { eprintln!("Query error: {}", e); return; }
597 };
598 let rows = stmt.query_map([pattern], |row| {
599 Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?, row.get::<_, String>(3)?))
600 });
601 match rows {
602 Ok(rows) => {
603 println!("\n--- Search Results for '{}' ---", query);
604 for row in rows.flatten() {
605 let (chat_id, ts, role, content) = row;
606 let who = match role.as_str() { "user" => "You", "assistant" => "AI", _ => &role };
607 println!("[chat {}] [{}] {}: {}", chat_id, ts, who, content);
608 }
609 println!("-------------------------------\n");
610 },
611 Err(e) => eprintln!("Failed to search chats: {}", e),
612 }
613}
614
615pub fn view_chat() {
616 let chat_id = match get_current_chat_id() {
617 Some(id) => id,
618 None => { eprintln!("No current chat selected. Start or switch to a chat first."); return; }
619 };
620 let conn = match init_db() {
621 Ok(c) => c,
622 Err(e) => { eprintln!("DB error: {}", e); return; }
623 };
624 let mut stmt = match conn.prepare("SELECT created_at, role, content FROM messages WHERE chat_id = ?1 ORDER BY created_at ASC") {
625 Ok(s) => s,
626 Err(e) => { eprintln!("Query error: {}", e); return; }
627 };
628 let rows = stmt.query_map([chat_id], |row| {
629 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?))
630 });
631 match rows {
632 Ok(rows) => {
633 println!("\n--- Chat History (chat id: {}) ---", chat_id);
634 for row in rows.flatten() {
635 let (ts, role, content) = row;
636 let who = match role.as_str() { "user" => "You", "assistant" => "AI", _ => &role };
637 println!("[{}] {}: {}", ts, who, content);
638 }
639 println!("-------------------------------\n");
640 },
641 Err(e) => eprintln!("Failed to view chat: {}", e),
642 }
643}
644
645pub fn delete_chat(chat_id: i64) {
646 println!("Are you sure you want to delete chat {}? This cannot be undone! (y/N): ", chat_id);
647 io::stdout().flush().unwrap();
648 let mut input = String::new();
649 io::stdin().read_line(&mut input).unwrap();
650 if input.trim().to_lowercase() == "y" {
651 let conn = match init_db() {
652 Ok(c) => c,
653 Err(e) => { eprintln!("DB error: {}", e); return; }
654 };
655 if let Err(e) = conn.execute("DELETE FROM messages WHERE chat_id = ?1", [chat_id]) {
656 eprintln!("Failed to delete chat messages: {}", e);
657 }
658 if let Err(e) = conn.execute("DELETE FROM chats WHERE id = ?1", [chat_id]) {
659 eprintln!("Failed to delete chat: {}", e);
660 }
661 println!("✅ Deleted chat {}", chat_id);
662 } else {
663 println!("Aborted.");
664 }
665}
666
667pub fn clear_all_chats() {
668 println!("Are you sure you want to delete ALL chats and messages? This cannot be undone! (y/N): ");
669 io::stdout().flush().unwrap();
670 let mut input = String::new();
671 io::stdin().read_line(&mut input).unwrap();
672 if input.trim().to_lowercase() == "y" {
673 let conn = match init_db() {
674 Ok(c) => c,
675 Err(e) => { eprintln!("DB error: {}", e); return; }
676 };
677 if let Err(e) = conn.execute("DELETE FROM messages", []) {
678 eprintln!("Failed to clear messages: {}", e);
679 }
680 if let Err(e) = conn.execute("DELETE FROM chats", []) {
681 eprintln!("Failed to clear chats: {}", e);
682 }
683 println!("✅ All chats and messages deleted");
684 } else {
685 println!("Aborted.");
686 }
687}
688