simple_llm_client/openai/
mod.rs1use dotenv::dotenv;
2use futures_util::StreamExt;
3use reqwest::Client;
4use std::{env, fs, path::Path};
5use tokio::fs::File;
6use tokio::io::{stdout, AsyncWriteExt};
7
8pub mod models;
9use models::*;
10
11#[allow(dead_code)]
13fn format_as_markdown(text: &str, citations: &Option<Vec<Citation>>) -> String {
14 let mut formatted = text.to_string();
15 if let Some(citations_vec) = citations {
16 if !citations_vec.is_empty() {
17 formatted.push_str("\n\n## Sources\n\n");
18 for (i, citation) in citations_vec.iter().enumerate() {
19 if let Some(url) = citation.url.strip_prefix("http") {
20 formatted.push_str(&format!("{}. [{}](http{})\n", i + 1, citation.url, url));
21 } else {
22 formatted.push_str(&format!("{}. {}\n", i + 1, citation.url));
23 }
24 }
25 }
26 }
27 formatted
28}
29
30#[allow(dead_code)]
31pub async fn chat_completion(
32 model: &str,
33 messages: Vec<ChatMessage>,
34) -> Result<(), Box<dyn std::error::Error>> {
35 dotenv().ok();
36 let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set in .env");
37 let client = Client::new();
38 let request = ChatCompletionRequest {
39 model: model.to_string(),
40 messages,
41 stream: true,
42 temperature: None,
43 max_tokens: None,
44 };
45 let response = client
46 .post("https://api.openai.com/v1/chat/completions")
47 .header("Authorization", format!("Bearer {}", api_key))
48 .header("Content-Type", "application/json")
49 .json(&request)
50 .send()
51 .await?;
52 if !response.status().is_success() {
53 println!("Response status code: {}", response.status());
54 let response_body = response.text().await?;
55 println!("Response body: {}", response_body);
56 return Err("API request failed".into());
57 }
58 let mut stream = response.bytes_stream();
59 let mut stdout = stdout();
60 while let Some(chunk) = stream.next().await {
61 let chunk = chunk?;
62 let chunk_str = String::from_utf8(chunk.to_vec())?;
63 let events: Vec<&str> = chunk_str.split("\n\n").collect();
64 for event in events {
65 if event.starts_with("data:") {
66 let data = event.trim_start_matches("data:").trim();
67 if data == "[DONE]" {
68 break;
69 }
70 if let Ok(completion_response) =
71 serde_json::from_str::<ChatCompletionResponse>(data)
72 {
73 if let Some(choice) = completion_response.choices.get(0) {
74 if let Some(message) = &choice.message {
75 stdout.write_all(message.content.as_bytes()).await?;
76 stdout.flush().await?;
77 }
78 }
79 } else {
80 eprintln!("Failed to parse SSE data: {}", data);
81 }
82 }
83 }
84 }
85 Ok(())
86}
87
88#[allow(dead_code)]
89pub async fn chat_completion_markdown(
90 model: &str,
91 messages: Vec<ChatMessage>,
92 path: Option<&Path>,
93 filename: &str,
94) -> Result<(), Box<dyn std::error::Error>> {
95 dotenv().ok();
96 let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set in .env");
97 let client = Client::new();
98 let request = ChatCompletionRequest {
99 model: model.to_string(),
100 messages,
101 stream: false,
102 temperature: None,
103 max_tokens: None,
104 };
105 let response = client
106 .post("https://api.openai.com/v1/chat/completions")
107 .header("Authorization", format!("Bearer {}", api_key))
108 .header("Content-Type", "application/json")
109 .json(&request)
110 .send()
111 .await?;
112 if !response.status().is_success() {
113 println!("Response status code: {}", response.status());
114 let response_body = response.text().await?;
115 println!("Response body: {}", response_body);
116 return Err("API request failed".into());
117 }
118 let response_text = response.text().await?;
119 let _ = fs::write("raw_openai_response.json", &response_text);
120 let completion_response: ChatCompletionResponse =
121 serde_json::from_str(&response_text)?;
122 let (content, citations) = if let Some(choice) = completion_response.choices.get(0) {
123 let content = choice
124 .message
125 .as_ref()
126 .map_or(String::new(), |m| m.content.clone());
127 (content, choice.citations.clone())
128 } else {
129 (String::new(), None)
130 };
131 let formatted_content = format_as_markdown(&content, &citations);
132 let file_path = match path {
133 Some(p) => p.join(filename),
134 None => Path::new(filename).to_path_buf(),
135 };
136 let mut file = File::create(file_path).await?;
137 file.write_all(formatted_content.as_bytes()).await?;
138 Ok(())
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use std::{env, path::Path};
145 use tokio;
146
147 #[tokio::test]
148 async fn test_chat_completion() {
149 dotenv::dotenv().ok();
150 if env::var("OPENAI_API_KEY").is_ok() {
151 let system_message = "Be precise and concise.".to_string();
152 let user_message = "How many stars are there in our galaxy?".to_string();
153 let messages = vec![
154 ChatMessage {
155 role: "system".to_string(),
156 content: system_message.clone(),
157 },
158 ChatMessage {
159 role: "user".to_string(),
160 content: user_message.clone(),
161 },
162 ];
163 let result = chat_completion("gpt-3.5-turbo", messages).await;
164 assert!(result.is_ok());
165 } else {
166 println!("Skipping test because OPENAI_API_KEY is not set");
167 }
168 }
169
170 #[tokio::test]
171 async fn test_chat_completion_markdown() {
172 dotenv::dotenv().ok();
173 if env::var("OPENAI_API_KEY").is_ok() {
174 let system_message = "Be precise and concise. Return the response as markdown.".to_string();
175 let user_message = "How many stars are there in our galaxy?".to_string();
176 let messages = vec![
177 ChatMessage {
178 role: "system".to_string(),
179 content: system_message.clone(),
180 },
181 ChatMessage {
182 role: "user".to_string(),
183 content: user_message.clone(),
184 },
185 ];
186 let path = Path::new(".");
187 let filename = "test_openai_output.md";
188 let result = chat_completion_markdown("gpt-3.5-turbo", messages, Some(path), filename).await;
189 assert!(result.is_ok());
190 } else {
191 println!("Skipping test because OPENAI_API_KEY is not set");
192 }
193 }
194}