tofu/
tofu.rs

1//! Tofu - A command-line tool for interacting with LLMs
2//!
3//! This library provides the core functionality for the Tofu CLI tool.
4
5#![forbid(unsafe_code)]
6#![warn(missing_docs)]
7
8use colored::Colorize;
9use dialoguer::Editor;
10use dialoguer::Input;
11use home::home_dir;
12use indicatif::{ProgressBar, ProgressStyle};
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use std::error::Error;
16use std::fs;
17use std::io::Write;
18use std::path::PathBuf;
19
20mod theme;
21use theme::TofuTheme;
22
23/// Configuration loaded from config file
24#[derive(Debug, Serialize, Deserialize)]
25pub struct ConfigFile {
26    /// The LLM provider to use
27    pub provider: String,
28    /// The model to use
29    pub model: String,
30    /// Whether to JSON stream the response
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub stream: Option<bool>,
33    /// The system prompt to use
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub system_prompt: Option<String>,
36}
37
38/// Message in the conversation history
39#[derive(Debug, Clone, Serialize, Deserialize)]
40struct Message {
41    role: String,
42    content: String,
43}
44
45/// Configuration loaded from keys file
46#[derive(Debug, Serialize, Deserialize)]
47pub struct KeysFile {
48    /// The Google API key
49    pub google: Option<String>,
50    /// The OpenAI API key
51    pub openai: Option<String>,
52    /// The Anthropic API key
53    pub anthropic: Option<String>,
54}
55
56/// The main configuration for the Tofu application.
57#[derive(Debug)]
58pub struct Config {
59    /// Whether to enable verbose output
60    pub verbose: bool,
61    /// Whether to enable interactive mode
62    pub interactive: bool,
63    /// Optional message string
64    pub message: Option<String>,
65    /// Whether to JSON stream the response
66    pub stream: Option<bool>,
67    /// Configuration loaded from file
68    pub file: Option<ConfigFile>,
69}
70
71/// Runs the Tofu application with the given configuration.
72/// # Arguments
73/// * `config` - The configuration for the application
74/// # Returns
75/// Returns `Ok(())` on success, or an error if something went wrong.
76/// Gets the path to the config file.
77fn get_config_path() -> Result<PathBuf, Box<dyn Error>> {
78    let config_dir = if cfg!(windows) {
79        dirs::config_dir()
80            .ok_or("Could not determine config directory")?
81            .join("tofu")
82    } else {
83        home_dir()
84            .ok_or("Could not determine home directory")?
85            .join(".tofu")
86    };
87
88    // Create config directory if it doesn't exist
89    if !config_dir.exists() {
90        std::fs::create_dir_all(&config_dir)?;
91    }
92
93    Ok(config_dir.join("config.json"))
94}
95
96fn get_keys_path() -> Result<PathBuf, Box<dyn Error>> {
97    let config_dir = if cfg!(windows) {
98        dirs::config_dir()
99            .ok_or("Could not determine config directory")?
100            .join("tofu")
101    } else {
102        home_dir()
103            .ok_or("Could not determine home directory")?
104            .join(".tofu")
105    };
106
107    // Create config directory if it doesn't exist
108    if !config_dir.exists() {
109        std::fs::create_dir_all(&config_dir)?;
110    }
111
112    Ok(config_dir.join("keys.json"))
113}
114
115/// Loads the configuration from file.
116pub fn load_config(profile: Option<&str>) -> Result<ConfigFile, Box<dyn Error>> {
117    let config_path = get_config_path()?;
118
119    if !config_path.exists() {
120        // If config file doesn't exist, create it with default values as a multi-profile map
121        let default_config = ConfigFile {
122            provider: String::from("pollinations"),
123            model: String::from("openai"),
124            stream: Some(true),
125            system_prompt: Some(String::from("You are a helpful assistant named Tofu.")),
126        };
127        let gemini_config = ConfigFile {
128            provider: String::from("google"),
129            model: String::from("gemini-2.5-flash"),
130            stream: None,
131            system_prompt: None,
132        };
133        let openai_config = ConfigFile {
134            provider: String::from("openai"),
135            model: String::from("gpt-5-mini"),
136            stream: None,
137            system_prompt: None,
138        };
139        let profiles_json = serde_json::json!({
140            "default": &default_config,
141            "gemini": &gemini_config,
142            "openai": &openai_config
143        });
144        let config_json = serde_json::to_string_pretty(&profiles_json)?;
145        std::fs::write(&config_path, config_json)?;
146        return Ok(default_config);
147    }
148
149    let config_content = fs::read_to_string(&config_path)?;
150
151    // Try to parse as either legacy single-profile or multi-profile config
152    let root_value: serde_json::Value = serde_json::from_str(&config_content)
153        .map_err(|e| format!("Failed to parse config file: {}", e))?;
154
155    if let Some(obj) = root_value.as_object() {
156        let looks_like_legacy = obj.contains_key("provider")
157            || obj.contains_key("model")
158            || obj.contains_key("stream")
159            || obj.contains_key("system_prompt");
160
161        if looks_like_legacy {
162            // Legacy single-profile config
163            let cfg: ConfigFile = serde_json::from_value(root_value)
164                .map_err(|e| format!("Failed to parse legacy config: {}", e))?;
165            if cfg.provider.is_empty() || cfg.model.is_empty() {
166                return Err("Invalid config: provider and model must not be empty".into());
167            }
168            return Ok(cfg);
169        }
170
171        // Multi-profile config
172        let (selected_name, selected_value) = if let Some(name) = profile {
173            match obj.get(name) {
174                Some(v) => (name.to_string(), v.clone()),
175                None => {
176                    let available = obj.keys().cloned().collect::<Vec<_>>().join(", ");
177                    return Err(
178                        format!("Profile '{}' not found. Available: {}", name, available).into(),
179                    );
180                }
181            }
182        } else {
183            if let Some(v) = obj.get("default") {
184                (String::from("default"), v.clone())
185            } else {
186                match obj.iter().next() {
187                    Some((k, v)) => (k.clone(), v.clone()),
188                    None => return Err("Config file contains no profiles".into()),
189                }
190            }
191        };
192
193        let mut cfg: ConfigFile = serde_json::from_value(selected_value).map_err(|e| {
194            format!(
195                "Failed to parse selected profile '{}': {}",
196                selected_name, e
197            )
198        })?;
199
200        // If selected profile is not the default, fall back to default for None values
201        if selected_name != "default" {
202            if let Some(default_value) = obj.get("default") {
203                let default_cfg: ConfigFile = serde_json::from_value(default_value.clone())
204                    .map_err(|e| format!("Failed to parse default profile: {}", e))?;
205                // Fall back to default for None values
206                if cfg.stream.is_none() {
207                    cfg.stream = default_cfg.stream;
208                }
209                if cfg.system_prompt.is_none() {
210                    cfg.system_prompt = default_cfg.system_prompt;
211                }
212            }
213        }
214
215        if cfg.provider.is_empty() || cfg.model.is_empty() {
216            return Err("Invalid config: provider and model must not be empty".into());
217        }
218        return Ok(cfg);
219    }
220
221    Err("Invalid config: root must be a JSON object".into())
222}
223
224/// Loads the keys file from the default location.
225pub fn load_keys() -> Result<KeysFile, Box<dyn Error>> {
226    let keys_path = get_keys_path()?;
227
228    if !keys_path.exists() {
229        // If keys file doesn't exist, create it with default values as a multi-profile map
230        let default_keys = serde_json::json!({
231            "google": "",
232            "openai": "",
233            "anthropic": ""
234        });
235
236        let keys_json = serde_json::to_string_pretty(&default_keys)?;
237        std::fs::write(&keys_path, keys_json)?;
238        return Ok(KeysFile {
239            google: None,
240            openai: None,
241            anthropic: None,
242        });
243    }
244
245    let keys_content = fs::read_to_string(&keys_path)?;
246    let keys_json: serde_json::Value = serde_json::from_str(&keys_content)?;
247    let keys: KeysFile = serde_json::from_value(keys_json)?;
248
249    return Ok(keys);
250}
251
252/// Opens the config file in the default editor.
253///
254/// # Returns
255/// Returns `Ok(())` on success, or an error if something went wrong.
256pub fn open_config() -> Result<(), Box<dyn Error>> {
257    println!("Opening config file...");
258    let config_path = get_config_path()?;
259
260    // Ensure config file exists by trying to load it or create a default one
261    if let Err(e) = load_config(None) {
262        eprintln!("Warning: {}", e);
263        eprintln!("Opening editor to fix the config file...");
264    }
265
266    // Open the config file in the default editor
267    let editor = std::env::var("EDITOR").unwrap_or_else(|_| {
268        if cfg!(windows) {
269            String::from("notepad")
270        } else {
271            String::from("nano")
272        }
273    });
274
275    let status = std::process::Command::new(editor)
276        .arg(&config_path)
277        .status()?;
278
279    if !status.success() {
280        return Err(format!("Editor exited with status: {}", status).into());
281    }
282
283    // Try to load the config after editing, but don't fail if it's still invalid
284    if let Err(e) = load_config(None) {
285        eprintln!("Warning: The config file is still invalid: {}", e);
286        eprintln!("Please fix the config file and try again.");
287    }
288
289    Ok(())
290}
291
292/// Opens the keys file in the default editor.
293pub fn open_keys() -> Result<(), Box<dyn Error>> {
294    println!("Opening keys file...");
295    let config_path = get_keys_path()?;
296
297    // Ensure keys file exists by trying to load it or create a default one
298    if let Err(e) = load_keys() {
299        eprintln!("Warning: {}", e);
300        eprintln!("Opening editor to fix the keys file...");
301    }
302
303    // Open the file in the default editor
304    let editor = std::env::var("EDITOR").unwrap_or_else(|_| {
305        if cfg!(windows) {
306            String::from("notepad")
307        } else {
308            String::from("nano")
309        }
310    });
311
312    let status = std::process::Command::new(editor)
313        .arg(&config_path)
314        .status()?;
315
316    if !status.success() {
317        return Err(format!("Editor exited with status: {}", status).into());
318    }
319
320    Ok(())
321}
322
323/// Runs the Tofu application with the given configuration.
324/// # Arguments
325/// * `config` - The configuration for the application
326/// # Returns
327/// Returns `Ok(())` on success, or an error if something went wrong.
328/// Runs the Tofu application with the given configuration asynchronously.
329/// # Arguments
330/// * `config` - The configuration for the application
331/// # Returns
332/// Returns `Ok(())` on success, or an error if something went wrong.
333pub async fn run(config: Config) -> Result<(), Box<dyn Error>> {
334    if config.verbose {
335        println!(
336            "Tofu v{} initialized (verbose mode)",
337            env!("CARGO_PKG_VERSION")
338        );
339        println!("{:#?}", config);
340    }
341
342    if config.interactive {
343        run_interactive(config).await
344    } else {
345        let message = config.message.as_ref().unwrap_or(&String::new()).clone();
346        send_message(&message, &config, vec![]).await?;
347        Ok(())
348    }
349}
350
351async fn run_interactive(mut config: Config) -> Result<(), Box<dyn Error>> {
352    let mut conversation_history = vec![];
353
354    println!(
355        "{}",
356        format!("Tofu {}", env!("CARGO_PKG_VERSION")).bold().blue()
357    );
358    println!("{}", "Ctrl+C or /exit to exit".italic().dimmed());
359
360    loop {
361        let input: Result<String, _> = Input::with_theme(&TofuTheme::default()).interact_text();
362
363        match input {
364            Ok(mut line) => {
365                line = line.trim().to_string();
366                if line.is_empty() {
367                    continue;
368                }
369
370                // Check for commands starting with /
371                if line.starts_with('/') {
372                    let (should_exit, new_config, message_to_send) =
373                        handle_command(line.as_str(), &mut conversation_history)?;
374                    if let Some(new_file_config) = new_config {
375                        config.file = Some(new_file_config);
376                    }
377                    if should_exit {
378                        break; // Exit the loop if command returns true
379                    }
380                    if let Some(message) = message_to_send {
381                        // Process the multiline message like a regular input
382                        line = message;
383                    } else {
384                        continue; // Skip sending to model for commands that don't return a message
385                    }
386                }
387
388                // Add user message to conversation history
389                conversation_history.push(Message {
390                    role: "user".to_string(),
391                    content: line.to_string(),
392                });
393
394                // If length > 100 messages, remove the oldest message (keep system prompt)
395                if conversation_history.len() > 100 {
396                    conversation_history.remove(1);
397                }
398
399                // Send message and get response
400                match send_message(line.as_str(), &config, conversation_history.clone()).await {
401                    Ok(response_content) => {
402                        // Add assistant response to conversation history
403                        conversation_history.push(Message {
404                            role: "assistant".to_string(),
405                            content: response_content.clone(),
406                        });
407                    }
408                    Err(e) => {
409                        eprintln!("{}", format!("Error: {}", e).red());
410                        // Remove the failed message from history
411                        if !conversation_history.is_empty() {
412                            conversation_history.pop();
413                        }
414                        continue;
415                    }
416                }
417            }
418            Err(e) => {
419                eprintln!("{}", format!("Error reading input: {}", e).red());
420                break;
421            }
422        }
423    }
424
425    Ok(())
426}
427
428/// Handles special commands starting with /
429/// Returns a tuple: (should_exit, new_config_option, message_to_send)
430fn handle_command(
431    command: &str,
432    conversation_history: &mut Vec<Message>,
433) -> Result<(bool, Option<ConfigFile>, Option<String>), Box<dyn Error>> {
434    match command {
435        "/exit" | "/quit" | "/q" => Ok((true, None, None)),
436        "/help" | "/h" | "/?" | "/commands" | "/cmds" => {
437            println!("{}", "Available commands:".bold());
438            println!("  /help               - Show this help message");
439            println!("  /exit, /quit, /q    - Exit the program");
440            println!("  /profile <name>     - Switch to a different config profile");
441            println!("  /clear              - Clear conversation history");
442            println!("  /keys               - Open the API keys file");
443            println!("  /listprofiles, /lsp - List all available profiles");
444            println!("  /multiline, /ml, // - Enter multiline input mode");
445            Ok((false, None, None))
446        }
447        "/clear" => {
448            conversation_history.clear();
449            println!("{}", "Conversation history cleared.".blue());
450            Ok((false, None, None))
451        }
452        "/keys" | "/key" | "/apikeys" | "/apikey" => {
453            open_keys()?;
454            Ok((false, None, None))
455        }
456        cmd if cmd.starts_with("/profile") || cmd.starts_with("/p") => {
457            let parts: Vec<&str> = command.split_whitespace().collect();
458            if parts.len() != 2 {
459                println!("Usage: /profile <profile_name>");
460                return Ok((false, None, None));
461            }
462
463            let profile_name = parts[1];
464            match load_config(Some(profile_name)) {
465                Ok(new_config) => {
466                    println!(
467                        "{}",
468                        format!("Switched to profile '{}'", profile_name).green()
469                    );
470                    Ok((false, Some(new_config), None))
471                }
472                Err(e) => {
473                    eprintln!(
474                        "{}",
475                        format!("Failed to switch to profile '{}': {}", profile_name, e).red()
476                    );
477                    Ok((false, None, None))
478                }
479            }
480        }
481        "/listprofiles" | "/lsp" => {
482            let path = get_config_path()?;
483            let config = fs::read_to_string(&path)?;
484
485            // Parse the JSON config to get root keys (profile names)
486            let root_value: serde_json::Value = serde_json::from_str(&config)
487                .map_err(|e| format!("Failed to parse config file: {}", e))?;
488
489            println!("{}", "Available profiles:".bold());
490            if let Some(obj) = root_value.as_object() {
491                if obj.is_empty() {
492                    println!("  No profiles found");
493                } else {
494                    for key in obj.keys() {
495                        println!("  {}", key);
496                    }
497                }
498            } else {
499                eprintln!("  Invalid config format - expected JSON object");
500            }
501            Ok((false, None, None))
502        }
503        "/multiline" | "/ml" | "//" => {
504            if let Some(multiline_input) = Editor::new().edit("").unwrap() {
505                if !multiline_input.trim().is_empty() {
506                    println!("{}\n", multiline_input);
507                    // Return the multiline input as a message to be processed
508                    return Ok((false, None, Some(multiline_input)));
509                } else {
510                    println!("{}", "Empty input - cancelled".yellow());
511                }
512            } else {
513                eprintln!("{}", "Cancelled".red());
514            }
515            Ok((false, None, None))
516        }
517        _ => {
518            eprintln!(
519                "{}",
520                format!(
521                    "Unknown command: {}. Type /help for available commands.",
522                    command
523                )
524                .red()
525            );
526            Ok((false, None, None))
527        }
528    }
529}
530
531async fn send_message(
532    _message: &str,
533    config: &Config,
534    history: Vec<Message>,
535) -> Result<String, Box<dyn Error>> {
536    let spinner = ProgressBar::new_spinner();
537    spinner.enable_steady_tick(std::time::Duration::from_millis(100));
538    spinner.set_style(
539        ProgressStyle::with_template("{spinner:.blue} {msg} {elapsed:.bold}")
540            .unwrap()
541            .tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"),
542    );
543    spinner.set_message("Thinking...");
544
545    // Build messages array with history
546    let mut messages = vec![];
547
548    if let Some(file) = &config.file {
549        if let Some(system_prompt) = &file.system_prompt {
550            messages.push(serde_json::json!({ "role": "system", "content": system_prompt }));
551        }
552    }
553
554    for msg in history {
555        messages.push(serde_json::json!({ "role": msg.role, "content": msg.content }));
556    }
557
558    let body = if let Some(file) = &config.file {
559        serde_json::json!({
560            "model": file.model,
561            "messages": messages,
562            "stream": config.stream,
563        })
564    } else {
565        return Err("No configuration file found".to_string().into());
566    };
567
568    // Send the message
569    let client = Client::new();
570
571    let mut response;
572
573    if config.verbose {
574        dbg!(&body);
575    }
576
577    let (url, auth_header) = if let Some(file) = &config.file {
578        match file.provider.as_str() {
579            "pollinations" => ("https://text.pollinations.ai/openai", None),
580            "google" => (
581                "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
582                Some(format!("Bearer {}", load_keys().unwrap().google.unwrap())),
583            ),
584            "openai" => (
585                "https://api.openai.com/v1/chat/completions",
586                Some(format!("Bearer {}", load_keys().unwrap().openai.unwrap())),
587            ),
588            provider => {
589                return Err(format!("Unsupported provider: {}", provider).into());
590            }
591        }
592    } else {
593        return Err("No configuration file found".to_string().into());
594    };
595
596    let mut request = client
597        .post(url)
598        .header("Content-Type", "application/json")
599        .body(serde_json::to_string(&body)?);
600
601    if let Some(auth) = auth_header {
602        request = request.header("Authorization", auth);
603    }
604
605    response = request.send().await?;
606
607    if !response.status().is_success() {
608        return Err(format!("Request failed with status: {}", response.status()).into());
609    }
610
611    spinner.finish_and_clear();
612
613    if config.stream == Some(true) {
614        spinner.finish_and_clear();
615        let mut buffer = String::new();
616        let mut response_content = String::new();
617        let mut done = false;
618        while let Some(chunk) = response.chunk().await? {
619            let chunk_str = String::from_utf8_lossy(&chunk);
620            buffer.push_str(&chunk_str);
621
622            loop {
623                if let Some(newline_idx) = buffer.find('\n') {
624                    let line = buffer[..newline_idx].trim_end_matches('\r').to_string();
625                    buffer.drain(..=newline_idx);
626                    if line.is_empty() {
627                        continue;
628                    }
629
630                    if line.starts_with("data: ") {
631                        let payload = line[6..].trim();
632                        if payload == "[DONE]" {
633                            done = true;
634                            println!(); // Fixes issue on Linux where % sign shows at end
635                            break;
636                        } else {
637                            if let Ok(v) = serde_json::from_str::<serde_json::Value>(payload) {
638                                if let Some(choices) = v.get("choices").and_then(|c| c.as_array()) {
639                                    for choice in choices {
640                                        if let Some(delta) = choice.get("delta") {
641                                            if let Some(content) =
642                                                delta.get("content").and_then(|c| c.as_str())
643                                            {
644                                                print!("{}", content);
645                                                let _ = std::io::stdout().flush();
646                                                response_content.push_str(content);
647                                            }
648                                        } else if let Some(content) = choice
649                                            .get("message")
650                                            .and_then(|m| m.get("content"))
651                                            .and_then(|c| c.as_str())
652                                        {
653                                            print!("{}", content);
654                                            let _ = std::io::stdout().flush();
655                                            response_content.push_str(content);
656                                        }
657                                    }
658                                }
659                            }
660                        }
661                    } else if config.verbose {
662                        eprintln!("{}", line);
663                    }
664                } else {
665                    break;
666                }
667            }
668
669            if done {
670                break;
671            }
672        }
673        Ok(response_content)
674    } else {
675        let response_text = response.text().await?;
676        let json: serde_json::Value = serde_json::from_str(&response_text)?;
677        let content = json["choices"][0]["message"]["content"]
678            .as_str()
679            .unwrap_or("")
680            .replace("\\n", "\n")
681            .trim_matches('"')
682            .to_string();
683        spinner.finish_and_clear();
684        println!("\n{}\n", content);
685        Ok(content)
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[tokio::test]
694    async fn test_run() {
695        let config = Config {
696            verbose: false,
697            interactive: false,
698            message: Some(String::from("Hello, world!")),
699            stream: Some(false),
700            file: Some(ConfigFile {
701                provider: String::from("pollinations"),
702                model: String::from("openai"),
703                stream: Some(false),
704                system_prompt: Some(String::from("You are a helpful assistant named Tofu.")),
705            }),
706        };
707        let result = run(config).await;
708        assert!(result.is_ok());
709    }
710}