1use crate::errors::ShellGptError;
2use regex::Regex;
3use reqwest::blocking::Client as HttpClient;
4use serde::{Deserialize, Serialize};
5use std::env;
6
7const OPENAI_MODEL: &'static str = "gpt-3.5-turbo";
8
9const PRE_PROMPT_SHELL_SCRIPT: &'static str = "You are an expert at creating bash scripts. \
10 I want you to generate a valid bash script following a specific request. \
11 You must only answer with the script that will be run on the target system. \
12 Do not write something like \"this is the script you asked:\", just print the script ONLY.
13 Do not write a warning message, only print the script itself.";
14
15#[derive(Debug, Copy, Clone, Eq, PartialEq)]
16pub enum PrePrompt {
17 NoPrePrompt,
18 ShellScript,
19}
20
21pub fn ask_chatgpt(input: &str, pre_prompt: PrePrompt, api_key: &str) -> anyhow::Result<String> {
22 if let Ok(response) = env::var("OPENAI_API_RESPONSE_MOCK") {
23 return Ok(response);
24 }
25 let pre_prompt = get_pre_prompt(pre_prompt);
26 let messages = vec![Message::User(pre_prompt.as_str()), Message::User(input)];
27 request_chatgpt_api(messages, api_key)
28}
29
30pub fn extract_code_block_if_needed(str: &str) -> String {
33 let regex = Regex::new(r"```\w?\n([\s\S]*?)\n```").unwrap();
34 match regex.captures(str) {
35 Some(captures) if captures.len() > 0 => captures.get(0).unwrap().as_str().to_string(),
36 _ => str.to_string(),
37 }
38}
39
40fn get_pre_prompt(pre_prompt: PrePrompt) -> String {
41 match pre_prompt {
42 PrePrompt::NoPrePrompt => String::new(),
43 PrePrompt::ShellScript => PRE_PROMPT_SHELL_SCRIPT.to_string(),
44 }
45}
46
47fn request_chatgpt_api(messages: Vec<Message>, api_key: &str) -> anyhow::Result<String> {
48 let body = ChatRequestInput {
49 model: OPENAI_MODEL.to_string(),
50 messages,
51 };
52
53 let client = HttpClient::new();
54 let resp = client
55 .post("https://api.openai.com/v1/chat/completions")
56 .header("Authorization", format!("Bearer {api_key}"))
57 .json(&body)
58 .send()?;
59
60 if resp.status().is_success() {
61 let res: ChatResponse = resp.json()?;
62 Ok(res.choices.get(0).unwrap().message.content.clone())
63 } else {
64 let err = format!(
65 "Error when calling the OpenAI chat completion API - Status: {} - Body: {}",
66 resp.status(),
67 resp.text().unwrap()
68 );
69 Err(ShellGptError::ApiError(err))?
70 }
71}
72
73#[derive(Debug, Serialize)]
74#[serde(tag = "role", content = "content", rename_all = "lowercase")]
75pub enum Message<'a> {
76 System(&'a str),
77 Assistant(&'a str),
78 User(&'a str),
79}
80
81#[derive(Debug, Serialize)]
82struct ChatRequestInput<'a> {
83 model: String,
84 messages: Vec<Message<'a>>,
85}
86
87#[derive(Deserialize)]
88#[allow(dead_code)]
89struct ChatResponseUsage {
90 pub prompt_tokens: i64,
91 pub completion_tokens: i64,
92 pub total_tokens: i64,
93}
94
95#[derive(Deserialize)]
96#[allow(dead_code)]
97struct ChatResponseMessage {
98 pub role: String,
99 pub content: String,
100}
101
102#[derive(Deserialize)]
103#[allow(dead_code)]
104struct ChatResponseChoice {
105 pub index: i64,
106 pub message: ChatResponseMessage,
107 pub finish_reason: String,
108}
109
110#[derive(Deserialize)]
111#[allow(dead_code)]
112struct ChatResponse {
113 pub id: String,
114 pub object: String,
115 pub created: i64,
116 pub choices: Vec<ChatResponseChoice>,
117 pub usage: ChatResponseUsage,
118}