toast_api/
unified_cli.rs

1use crate::utils::extract_org_id_from_cookie;
2use anyhow::{anyhow, Context, Result};
3use ctrlc;
4use std::fs;
5use std::future::Future;
6use std::io::{self, Write};
7use std::path::Path;
8use std::pin::Pin;
9use std::process::{self, Command, Stdio};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::Arc;
12
13use crate::api::{Attachment, Claude, Session as ClaudeSession};
14use crate::config::{HAIKU_MODEL, MAX_INTERNAL_ITERS, OPUS_MODEL, SONNET_MODEL, SYSTEM_PROMPT};
15use crate::deepseek::{DeepSeek, Session as DeepSeekSession};
16use crate::utils::{extract_commands, prettify};
17use log::debug;
18
19/// Unified CLI arguments for both Claude and DeepSeek
20#[derive(Debug)]
21pub struct UnifiedArgs {
22    pub use_deepseek: bool,
23    pub use_opus: bool,
24    pub use_sonnet: bool,
25    pub use_haiku: bool,
26}
27
28/// Execute a shell command and capture its output
29fn execute_command(command: &str) -> Result<String> {
30    let result = Command::new("sh")
31        .arg("-c")
32        .arg(command)
33        .stdout(Stdio::piped())
34        .stderr(Stdio::piped())
35        .spawn()?;
36
37    let output = result.wait_with_output()?;
38    let mut msg = String::new();
39
40    if !output.stdout.is_empty() {
41        msg.push_str("=== STDOUT ===\n");
42        msg.push_str(&String::from_utf8_lossy(&output.stdout));
43        msg.push('\n');
44    }
45
46    if !output.stderr.is_empty() {
47        msg.push_str("=== STDERR ===\n");
48        msg.push_str(&String::from_utf8_lossy(&output.stderr));
49        msg.push('\n');
50    }
51
52    msg.push_str(&format!(
53        "Exit code: {}",
54        output.status.code().unwrap_or(-1)
55    ));
56
57    Ok(msg)
58}
59
60/// Read files into attachments for Claude
61fn collect_claude_attachments(paths: &[&str]) -> Result<Vec<Attachment>> {
62    const LIMIT: usize = 5;
63    const SIZE_LIMIT: u64 = 10 * 1024 * 1024;
64
65    if paths.len() > LIMIT {
66        return Err(anyhow!("cannot attach more than {LIMIT} files"));
67    }
68
69    let mut atts = Vec::new();
70
71    for p in paths {
72        if let Ok(meta) = fs::metadata(p) {
73            if meta.len() > SIZE_LIMIT {
74                eprintln!("Warning: file {p} is larger than 10 MB, skipping");
75                continue;
76            }
77
78            if let Ok(content) = fs::read_to_string(p) {
79                atts.push(Attachment {
80                    file_name: Path::new(p)
81                        .file_name()
82                        .unwrap_or_default()
83                        .to_string_lossy()
84                        .into(),
85                    size: meta.len(),
86                    content,
87                });
88            } else {
89                eprintln!("Warning: couldn't read file {p}");
90            }
91        } else {
92            eprintln!("Warning: couldn't access file {p}");
93        }
94    }
95
96    Ok(atts)
97}
98
99/// Run the unified CLI application with provider selection
100pub async fn run(args: UnifiedArgs) -> Result<()> {
101    // Set up Ctrl-C handler
102    let running = Arc::new(AtomicBool::new(true));
103    {
104        let running = running.clone();
105        ctrlc::set_handler(move || {
106            running.store(false, Ordering::SeqCst);
107            println!("\nGoodbye!");
108            process::exit(0);
109        })?;
110    }
111
112    if args.use_deepseek {
113        run_deepseek(args, running).await
114    } else {
115        run_claude(args, running).await
116    }
117}
118
119/// Run the CLI with DeepSeek provider
120async fn run_deepseek(args: UnifiedArgs, running: Arc<AtomicBool>) -> Result<()> {
121    // Load session values from config files
122    let config_dir = dirs::config_dir()
123        .ok_or_else(|| anyhow!("Could not determine config directory"))?
124        .join("toast")
125        .join("deepseek");
126
127    // Create config directory if it doesn't exist
128    if !config_dir.exists() {
129        fs::create_dir_all(&config_dir)?;
130    }
131
132    let auth_token_path = config_dir.join("auth_token");
133    let cookies_path = config_dir.join("cookies.json");
134
135    // Check auth token
136    let auth_token = if auth_token_path.exists() {
137        fs::read_to_string(&auth_token_path)
138            .context(format!(
139                "Failed to read auth token from {auth_token_path:?}"
140            ))?
141            .trim()
142            .to_string()
143    } else {
144        return Err(anyhow!(
145            "Auth token file not found at {:?}\n\nTo get your DeepSeek auth token:\n1. Go to chat.deepseek.com in your browser\n2. Open Developer Tools (F12)\n3. Go to Network tab\n4. Look for Authorization header in any request\n5. Save the token part (without 'Bearer ') to this file",
146            auth_token_path
147        ));
148    };
149
150    // Check cookies
151    let cookies = if cookies_path.exists() {
152        serde_json::from_str(
153            &fs::read_to_string(&cookies_path)
154                .context(format!("Failed to read cookies from {cookies_path:?}"))?,
155        )?
156    } else {
157        return Err(anyhow!(
158            "Cookies file not found at {:?}\n\nDeepSeek requires Cloudflare cookies.\nUse the deepseek4free library to generate them.",
159            cookies_path
160        ));
161    };
162
163    let session = DeepSeekSession {
164        auth_token,
165        cookies,
166    };
167
168    // Determine model based on flags - R1 is the reasoning model with thinking enabled
169    let model = if args.use_opus {
170        "deepseek-r1" // Use R1 reasoning model for opus
171    } else if args.use_haiku {
172        "deepseek-lite"
173    } else {
174        "deepseek-r1" // Default to R1 reasoning model
175    };
176
177    let mut deepseek = DeepSeek::new(session)?;
178
179    let stdin = io::stdin();
180    let mut stdout = io::stdout();
181
182    // Track if system prompt has been sent
183    let mut system_prompt_sent = false;
184
185    // Create a new chat session
186    println!("Starting new DeepSeek chat session...");
187    let chat_id = match deepseek.create_chat_session().await {
188        Ok(id) => {
189            println!("Session started with DeepSeek!\n");
190            id
191        }
192        Err(e) => {
193            return Err(anyhow!("Failed to create DeepSeek chat session: {}", e));
194        }
195    };
196
197    // Enable detailed thinking for reasoning model
198    let thinking_mode = if model == "deepseek-r1" {
199        crate::deepseek::ThinkingMode::Detailed
200    } else {
201        crate::deepseek::ThinkingMode::Simple
202    };
203    let search_mode = crate::deepseek::SearchMode::Disabled;
204
205    // Main chat loop - simplified like working deepseek_cli.rs
206    while running.load(Ordering::SeqCst) {
207        print!("You: ");
208        stdout.flush()?;
209
210        let mut buf = String::new();
211        match stdin.read_line(&mut buf) {
212            Ok(0) => {
213                // EOF reached, exit gracefully
214                println!("\nGoodbye!");
215                break;
216            }
217            Ok(_) => {
218                let input = buf.trim_end();
219
220                // Check for empty input or exit commands
221                if input.is_empty() {
222                    continue;
223                }
224
225                if input.eq_ignore_ascii_case("/exit")
226                    || input.eq_ignore_ascii_case("exit")
227                    || input == "x"
228                {
229                    break;
230                }
231
232                // Send message to API - let DeepSeek handle all commands in its response
233                print!("DeepSeek: ");
234                stdout.flush()?;
235
236                debug!("Sending to DeepSeek API...");
237
238                // Include system prompt only on first message
239                let system_prompt = if !system_prompt_sent {
240                    system_prompt_sent = true;
241                    Some(SYSTEM_PROMPT)
242                } else {
243                    None
244                };
245
246                match deepseek
247                    .chat_completion(
248                        &chat_id,
249                        input,
250                        None,
251                        thinking_mode,
252                        search_mode,
253                        system_prompt,
254                    )
255                    .await
256                {
257                    Ok(response) => {
258                        debug!("Got response, length: {}", response.len());
259                        println!("{}", prettify(&response));
260
261                        // Process commands in the response
262                        process_deepseek_commands(
263                            &mut deepseek,
264                            &chat_id,
265                            &response,
266                            thinking_mode,
267                            search_mode,
268                        )
269                        .await?;
270                    }
271                    Err(e) => {
272                        debug!("DeepSeek API error: {e}");
273                        eprintln!("\nError: {e}");
274                    }
275                }
276                println!();
277            }
278            Err(e) => {
279                eprintln!("Failed to read input: {e}");
280                break;
281            }
282        }
283    }
284
285    Ok(())
286}
287
288/// Process commands in DeepSeek's response
289async fn process_deepseek_commands(
290    deepseek: &mut DeepSeek,
291    chat_id: &str,
292    response: &str,
293    thinking_mode: crate::deepseek::ThinkingMode,
294    search_mode: crate::deepseek::SearchMode,
295) -> Result<()> {
296    process_deepseek_commands_internal(deepseek, chat_id, response, thinking_mode, search_mode, 0)
297        .await
298}
299
300fn process_deepseek_commands_internal<'a>(
301    deepseek: &'a mut DeepSeek,
302    chat_id: &'a str,
303    response: &'a str,
304    thinking_mode: crate::deepseek::ThinkingMode,
305    search_mode: crate::deepseek::SearchMode,
306    depth: usize,
307) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
308    Box::pin(async move {
309        // Limit recursion depth
310        const MAX_DEPTH: usize = 20;
311        if depth >= MAX_DEPTH {
312            println!("Maximum command processing depth reached ({MAX_DEPTH}). Returning to user.");
313            return Ok(());
314        }
315
316        // Extract read_file and exec commands
317        let (reads, execs) = extract_commands(response);
318
319        if reads.is_empty() && execs.is_empty() {
320            return Ok(());
321        }
322
323        // Short pause before processing commands
324        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
325
326        // Process file reads
327        if !reads.is_empty() {
328            let mut file_contents = Vec::new();
329
330            for path in &reads {
331                match fs::read_to_string(path) {
332                    Ok(content) => {
333                        file_contents.push(format!("=== File: {path} ===\n{content}"));
334                    }
335                    Err(e) => {
336                        file_contents.push(format!("Error reading file {path}: {e}"));
337                    }
338                }
339            }
340
341            let file_message = format!(
342                "Here are the contents of the files you requested:\n\n{}",
343                file_contents.join("\n\n")
344            );
345
346            // print!("Sending file contents... ");
347            io::stdout().flush()?;
348
349            match deepseek
350                .chat_completion(
351                    chat_id,
352                    &file_message,
353                    None,
354                    thinking_mode,
355                    search_mode,
356                    None,
357                )
358                .await
359            {
360                Ok(response) => {
361                    println!("Done!");
362                    println!("DeepSeek: {}", prettify(&response));
363
364                    // Process next level of commands
365                    process_deepseek_commands_internal(
366                        deepseek,
367                        chat_id,
368                        &response,
369                        thinking_mode,
370                        search_mode,
371                        depth + 1,
372                    )
373                    .await?;
374                }
375                Err(e) => {
376                    println!("Error: {e}");
377                }
378            }
379        }
380
381        // Process exec commands
382        if !execs.is_empty() {
383            for cmd in &execs {
384                println!("\nExecuting: {cmd}");
385
386                match execute_command(cmd) {
387                    Ok(output) => {
388                        println!("{output}");
389
390                        print!("Sending command results... ");
391                        io::stdout().flush()?;
392
393                        let cmd_message = format!("Command executed: {cmd}\n\nOutput:\n{output}");
394
395                        match deepseek
396                            .chat_completion(
397                                chat_id,
398                                &cmd_message,
399                                None,
400                                thinking_mode,
401                                search_mode,
402                                None,
403                            )
404                            .await
405                        {
406                            Ok(response) => {
407                                println!("Done!");
408                                println!("DeepSeek: {}", prettify(&response));
409
410                                // Process next level of commands
411                                process_deepseek_commands_internal(
412                                    deepseek,
413                                    chat_id,
414                                    &response,
415                                    thinking_mode,
416                                    search_mode,
417                                    depth + 1,
418                                )
419                                .await?;
420                            }
421                            Err(e) => {
422                                println!("Error: {e}");
423                            }
424                        }
425                    }
426                    Err(e) => {
427                        println!("Error executing command: {e}");
428                    }
429                }
430            }
431        }
432
433        Ok(())
434    })
435}
436
437/// Run the CLI with Claude provider
438async fn run_claude(args: UnifiedArgs, running: Arc<AtomicBool>) -> Result<()> {
439    // Load session values from config files
440    let config_dir = dirs::config_dir()
441        .ok_or_else(|| anyhow!("Could not determine config directory"))?
442        .join("toast");
443
444    let cookie_path = config_dir.join("cookie");
445    let org_id_path = config_dir.join("org_id");
446
447    // Check if config directory exists, if not create it and provide instructions
448    if !config_dir.exists() {
449        fs::create_dir_all(&config_dir).context(format!(
450            "Failed to create config directory at {config_dir:?}"
451        ))?;
452        return Err(anyhow!(
453            "Configuration directory created at {:?}\n\nPlease create a cookie file with your Claude cookie", 
454            config_dir,
455        ));
456    }
457
458    // Check and load cookie
459    let cookie = if cookie_path.exists() {
460        fs::read_to_string(&cookie_path)
461            .context(format!("Failed to read cookie from {cookie_path:?}"))?
462            .trim()
463            .to_string()
464    } else {
465        return Err(anyhow!("Cookie file not found at {:?}", cookie_path,));
466    };
467
468    // Check and load org_id, or extract from cookie if file doesn't exist
469    let org_id = if org_id_path.exists() {
470        fs::read_to_string(&org_id_path)
471            .context(format!(
472                "Failed to read organization ID from {org_id_path:?}"
473            ))?
474            .trim()
475            .to_string()
476    } else {
477        // Try to extract org_id from cookie
478        if let Some(extracted_org_id) = extract_org_id_from_cookie(&cookie) {
479            // Save the extracted org_id to the file for future use
480            fs::write(&org_id_path, &extracted_org_id).context(format!(
481                "Failed to write organization ID to {org_id_path:?}"
482            ))?;
483            println!("Extracted organization ID from cookie and saved to {org_id_path:?}");
484            extracted_org_id
485        } else {
486            return Err(anyhow!(
487                "Organization ID file not found at {:?} and couldn't extract it from cookie.",
488                org_id_path,
489            ));
490        }
491    };
492
493    let user_agent =
494        "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:137.0) Gecko/20100101 Firefox/137.0"
495            .to_string();
496
497    let session = ClaudeSession {
498        cookie,
499        user_agent,
500        organization_id: org_id,
501    };
502
503    // Determine model based on flags
504    let model: &str = if args.use_opus {
505        OPUS_MODEL
506    } else if args.use_haiku {
507        HAIKU_MODEL
508    } else {
509        HAIKU_MODEL
510        // SONNET_MODEL
511    };
512
513    let claude = Claude::new(session.clone(), model)?;
514    println!("Starting new Claude chat session using model: {model}");
515
516    let stdin = io::stdin();
517    let mut stdout = io::stdout();
518    let mut chat_id = String::new();
519    let mut system_prompt_sent = false;
520
521    while running.load(Ordering::SeqCst) {
522        print!("You: ");
523        stdout.flush()?;
524        let mut buf = String::new();
525        stdin.read_line(&mut buf)?;
526        let input = buf.trim_end();
527        if input.is_empty() {
528            continue;
529        }
530        if input.eq_ignore_ascii_case("/exit") || input.eq_ignore_ascii_case("exit") || input == "x"
531        {
532            if !chat_id.is_empty() {
533                claude.delete_chat(&chat_id).await.ok();
534            }
535            break;
536        }
537
538        // Initialize chat
539        if chat_id.is_empty() {
540            chat_id = claude.create_chat().await.context("creating chat")?;
541        }
542
543        // Handle exec commands
544        if let Some(caps) = crate::utils::EXEC_RE.captures(input) {
545            let cmd = caps[1].to_string();
546            if !system_prompt_sent {
547                claude
548                    .send_message(&chat_id, SYSTEM_PROMPT, &[])
549                    .await
550                    .context("sending system prompt")?;
551                system_prompt_sent = true;
552            }
553
554            match execute_command(&cmd) {
555                Ok(output) => {
556                    let msg = format!("Command executed: {cmd}\n\n{output}");
557                    let ans = claude.send_message(&chat_id, &msg, &[]).await?;
558                    println!("Claude:\n{}", prettify(&ans));
559                    process_claude_commands(&claude, &chat_id, &ans).await?;
560                }
561                Err(e) => {
562                    eprintln!("Warning: command execution failed: {e}");
563                    let msg = format!("Command execution failed: {e}");
564                    let ans = claude.send_message(&chat_id, &msg, &[]).await?;
565                    println!("Claude:\n{}", prettify(&ans));
566                }
567            }
568            continue;
569        }
570
571        // Handle read_file commands
572        if let Some(caps) = crate::utils::READ_RE.captures(input) {
573            let paths: Vec<String> = caps[1].split_whitespace().map(String::from).collect();
574            let path_refs: Vec<&str> = paths.iter().map(String::as_str).collect();
575            if !system_prompt_sent {
576                claude
577                    .send_message(&chat_id, SYSTEM_PROMPT, &[])
578                    .await
579                    .context("sending system prompt")?;
580                system_prompt_sent = true;
581            }
582
583            let rest = input.strip_prefix(&caps[0]).unwrap_or("").trim();
584            let attachments = collect_claude_attachments(&path_refs).unwrap_or_default();
585            let ans = claude
586                .send_message(&chat_id, rest, &attachments)
587                .await
588                .context("sending user message")?;
589
590            println!("Claude:\n{}", prettify(&ans));
591            process_claude_commands(&claude, &chat_id, &ans).await?;
592        } else {
593            // Regular message
594            if !system_prompt_sent {
595                claude
596                    .send_message(&chat_id, SYSTEM_PROMPT, &[])
597                    .await
598                    .context("sending system prompt")?;
599                system_prompt_sent = true;
600            }
601
602            let ans = claude
603                .send_message(&chat_id, input, &[])
604                .await
605                .context("sending user message")?;
606
607            println!("Claude:\n{}", prettify(&ans));
608            process_claude_commands(&claude, &chat_id, &ans).await?;
609        }
610    }
611
612    Ok(())
613}
614
615/// Process Claude's responses for internal tool commands
616async fn process_claude_commands(claude: &Claude, chat_id: &str, response: &str) -> Result<()> {
617    process_claude_commands_internal(claude, chat_id, response, 0).await
618}
619
620fn process_claude_commands_internal<'a>(
621    claude: &'a Claude,
622    chat_id: &'a str,
623    response: &'a str,
624    depth: usize,
625) -> Pin<Box<dyn Future<Output = Result<()>> + 'a>> {
626    Box::pin(async move {
627        // Avoid infinite recursion
628        if depth >= MAX_INTERNAL_ITERS {
629            println!("Max internal iterations reached, returning to user.");
630            return Ok(());
631        }
632
633        let (reads, execs) = extract_commands(response);
634        if reads.is_empty() && execs.is_empty() {
635            return Ok(());
636        }
637
638        if !reads.is_empty() {
639            let atts =
640                collect_claude_attachments(&reads.iter().map(String::as_str).collect::<Vec<_>>())
641                    .unwrap_or_default();
642
643            match claude
644                .send_message(chat_id, "read_file response:", &atts)
645                .await
646            {
647                Ok(resp) => {
648                    println!("Claude:\n{}", prettify(&resp));
649                    return process_claude_commands_internal(claude, chat_id, &resp, depth + 1)
650                        .await;
651                }
652                Err(e) => {
653                    return Err(e);
654                }
655            }
656        }
657
658        if !execs.is_empty() {
659            let mut outputs = String::new();
660
661            for cmd in &execs {
662                match execute_command(cmd) {
663                    Ok(output) => outputs.push_str(&output),
664                    Err(e) => outputs.push_str(&format!("Command execution failed: {e}")),
665                }
666                outputs.push_str("\n\n---\n\n");
667            }
668
669            match claude.send_message(chat_id, &outputs, &[]).await {
670                Ok(resp) => {
671                    println!("Claude:\n{}", prettify(&resp));
672                    return process_claude_commands_internal(claude, chat_id, &resp, depth + 1)
673                        .await;
674                }
675                Err(e) => {
676                    return Err(e);
677                }
678            }
679        }
680
681        Ok(())
682    })
683}