shell_gpt/
openai.rs

1use crate::errors::ShellGptError;
2use regex::Regex;
3use reqwest::blocking::Client as HttpClient;
4use serde::{Deserialize, Serialize};
5use std::env;
6
7const OPENAI_MODEL: &'static str = "gpt-3.5-turbo";
8
9const PRE_PROMPT_SHELL_SCRIPT: &'static str = "You are an expert at creating bash scripts. \
10                    I want you to generate a valid bash script following a specific request. \
11                    You must only answer with the script that will be run on the target system. \
12                    Do not write something like \"this is the script you asked:\", just print the script ONLY.
13                    Do not write a warning message, only print the script itself.";
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq)]
16pub enum PrePrompt {
17    NoPrePrompt,
18    ShellScript,
19}
20
21pub fn ask_chatgpt(input: &str, pre_prompt: PrePrompt, api_key: &str) -> anyhow::Result<String> {
22    if let Ok(response) = env::var("OPENAI_API_RESPONSE_MOCK") {
23        return Ok(response);
24    }
25    let pre_prompt = get_pre_prompt(pre_prompt);
26    let messages = vec![Message::User(pre_prompt.as_str()), Message::User(input)];
27    request_chatgpt_api(messages, api_key)
28}
29
30/// Even with a pre-prompt indicating to not use code blocks and not give explanations,
31/// the model can output some. Extract it, if no code block found, just return the input string.
32pub fn extract_code_block_if_needed(str: &str) -> String {
33    let regex = Regex::new(r"```\w?\n([\s\S]*?)\n```").unwrap();
34    match regex.captures(str) {
35        Some(captures) if captures.len() > 0 => captures.get(0).unwrap().as_str().to_string(),
36        _ => str.to_string(),
37    }
38}
39
40fn get_pre_prompt(pre_prompt: PrePrompt) -> String {
41    match pre_prompt {
42        PrePrompt::NoPrePrompt => String::new(),
43        PrePrompt::ShellScript => PRE_PROMPT_SHELL_SCRIPT.to_string(),
44    }
45}
46
47fn request_chatgpt_api(messages: Vec<Message>, api_key: &str) -> anyhow::Result<String> {
48    let body = ChatRequestInput {
49        model: OPENAI_MODEL.to_string(),
50        messages,
51    };
52
53    let client = HttpClient::new();
54    let resp = client
55        .post("https://api.openai.com/v1/chat/completions")
56        .header("Authorization", format!("Bearer {api_key}"))
57        .json(&body)
58        .send()?;
59
60    if resp.status().is_success() {
61        let res: ChatResponse = resp.json()?;
62        Ok(res.choices.get(0).unwrap().message.content.clone())
63    } else {
64        let err = format!(
65            "Error when calling the OpenAI chat completion API - Status: {} - Body: {}",
66            resp.status(),
67            resp.text().unwrap()
68        );
69        Err(ShellGptError::ApiError(err))?
70    }
71}
72
73#[derive(Debug, Serialize)]
74#[serde(tag = "role", content = "content", rename_all = "lowercase")]
75pub enum Message<'a> {
76    System(&'a str),
77    Assistant(&'a str),
78    User(&'a str),
79}
80
81#[derive(Debug, Serialize)]
82struct ChatRequestInput<'a> {
83    model: String,
84    messages: Vec<Message<'a>>,
85}
86
87#[derive(Deserialize)]
88#[allow(dead_code)]
89struct ChatResponseUsage {
90    pub prompt_tokens: i64,
91    pub completion_tokens: i64,
92    pub total_tokens: i64,
93}
94
95#[derive(Deserialize)]
96#[allow(dead_code)]
97struct ChatResponseMessage {
98    pub role: String,
99    pub content: String,
100}
101
102#[derive(Deserialize)]
103#[allow(dead_code)]
104struct ChatResponseChoice {
105    pub index: i64,
106    pub message: ChatResponseMessage,
107    pub finish_reason: String,
108}
109
110#[derive(Deserialize)]
111#[allow(dead_code)]
112struct ChatResponse {
113    pub id: String,
114    pub object: String,
115    pub created: i64,
116    pub choices: Vec<ChatResponseChoice>,
117    pub usage: ChatResponseUsage,
118}