simple_llm_client/openai/
mod.rs

1use dotenv::dotenv;
2use futures_util::StreamExt;
3use reqwest::Client;
4use std::{env, fs, path::Path};
5use tokio::fs::File;
6use tokio::io::{stdout, AsyncWriteExt};
7
8pub mod models;
9use models::*;
10
11// Function to format text as Markdown
12#[allow(dead_code)]
13fn format_as_markdown(text: &str, citations: &Option<Vec<Citation>>) -> String {
14    let mut formatted = text.to_string();
15    if let Some(citations_vec) = citations {
16        if !citations_vec.is_empty() {
17            formatted.push_str("\n\n## Sources\n\n");
18            for (i, citation) in citations_vec.iter().enumerate() {
19                if let Some(url) = citation.url.strip_prefix("http") {
20                    formatted.push_str(&format!("{}. [{}](http{})\n", i + 1, citation.url, url));
21                } else {
22                    formatted.push_str(&format!("{}. {}\n", i + 1, citation.url));
23                }
24            }
25        }
26    }
27    formatted
28}
29
30#[allow(dead_code)]
31pub async fn chat_completion(
32    model: &str,
33    messages: Vec<ChatMessage>,
34) -> Result<(), Box<dyn std::error::Error>> {
35    dotenv().ok();
36    let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set in .env");
37    let client = Client::new();
38    let request = ChatCompletionRequest {
39        model: model.to_string(),
40        messages,
41        stream: true,
42        temperature: None,
43        max_tokens: None,
44    };
45    let response = client
46        .post("https://api.openai.com/v1/chat/completions")
47        .header("Authorization", format!("Bearer {}", api_key))
48        .header("Content-Type", "application/json")
49        .json(&request)
50        .send()
51        .await?;
52    if !response.status().is_success() {
53        println!("Response status code: {}", response.status());
54        let response_body = response.text().await?;
55        println!("Response body: {}", response_body);
56        return Err("API request failed".into());
57    }
58    let mut stream = response.bytes_stream();
59    let mut stdout = stdout();
60    while let Some(chunk) = stream.next().await {
61        let chunk = chunk?;
62        let chunk_str = String::from_utf8(chunk.to_vec())?;
63        let events: Vec<&str> = chunk_str.split("\n\n").collect();
64        for event in events {
65            if event.starts_with("data:") {
66                let data = event.trim_start_matches("data:").trim();
67                if data == "[DONE]" {
68                    break;
69                }
70                if let Ok(completion_response) =
71                    serde_json::from_str::<ChatCompletionResponse>(data)
72                {
73                    if let Some(choice) = completion_response.choices.get(0) {
74                        if let Some(message) = &choice.message {
75                            stdout.write_all(message.content.as_bytes()).await?;
76                            stdout.flush().await?;
77                        }
78                    }
79                } else {
80                    eprintln!("Failed to parse SSE data: {}", data);
81                }
82            }
83        }
84    }
85    Ok(())
86}
87
88#[allow(dead_code)]
89pub async fn chat_completion_markdown(
90    model: &str,
91    messages: Vec<ChatMessage>,
92    path: Option<&Path>,
93    filename: &str,
94) -> Result<(), Box<dyn std::error::Error>> {
95    dotenv().ok();
96    let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set in .env");
97    let client = Client::new();
98    let request = ChatCompletionRequest {
99        model: model.to_string(),
100        messages,
101        stream: false,
102        temperature: None,
103        max_tokens: None,
104    };
105    let response = client
106        .post("https://api.openai.com/v1/chat/completions")
107        .header("Authorization", format!("Bearer {}", api_key))
108        .header("Content-Type", "application/json")
109        .json(&request)
110        .send()
111        .await?;
112    if !response.status().is_success() {
113        println!("Response status code: {}", response.status());
114        let response_body = response.text().await?;
115        println!("Response body: {}", response_body);
116        return Err("API request failed".into());
117    }
118    let response_text = response.text().await?;
119    let _ = fs::write("raw_openai_response.json", &response_text);
120    let completion_response: ChatCompletionResponse =
121        serde_json::from_str(&response_text)?;
122    let (content, citations) = if let Some(choice) = completion_response.choices.get(0) {
123        let content = choice
124            .message
125            .as_ref()
126            .map_or(String::new(), |m| m.content.clone());
127        (content, choice.citations.clone())
128    } else {
129        (String::new(), None)
130    };
131    let formatted_content = format_as_markdown(&content, &citations);
132    let file_path = match path {
133        Some(p) => p.join(filename),
134        None => Path::new(filename).to_path_buf(),
135    };
136    let mut file = File::create(file_path).await?;
137    file.write_all(formatted_content.as_bytes()).await?;
138    Ok(())
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use std::{env, path::Path};
145    use tokio;
146
147    #[tokio::test]
148    async fn test_chat_completion() {
149        dotenv::dotenv().ok();
150        if env::var("OPENAI_API_KEY").is_ok() {
151            let system_message = "Be precise and concise.".to_string();
152            let user_message = "How many stars are there in our galaxy?".to_string();
153            let messages = vec![
154                ChatMessage {
155                    role: "system".to_string(),
156                    content: system_message.clone(),
157                },
158                ChatMessage {
159                    role: "user".to_string(),
160                    content: user_message.clone(),
161                },
162            ];
163            let result = chat_completion("gpt-3.5-turbo", messages).await;
164            assert!(result.is_ok());
165        } else {
166            println!("Skipping test because OPENAI_API_KEY is not set");
167        }
168    }
169
170    #[tokio::test]
171    async fn test_chat_completion_markdown() {
172        dotenv::dotenv().ok();
173        if env::var("OPENAI_API_KEY").is_ok() {
174            let system_message = "Be precise and concise. Return the response as markdown.".to_string();
175            let user_message = "How many stars are there in our galaxy?".to_string();
176            let messages = vec![
177                ChatMessage {
178                    role: "system".to_string(),
179                    content: system_message.clone(),
180                },
181                ChatMessage {
182                    role: "user".to_string(),
183                    content: user_message.clone(),
184                },
185            ];
186            let path = Path::new(".");
187            let filename = "test_openai_output.md";
188            let result = chat_completion_markdown("gpt-3.5-turbo", messages, Some(path), filename).await;
189            assert!(result.is_ok());
190        } else {
191            println!("Skipping test because OPENAI_API_KEY is not set");
192        }
193    }
194}