simple_llm_client/perplexity/
mod.rs

1use dotenv::dotenv;
2use futures_util::StreamExt;
3
4use reqwest::Client;
5use std::{env, fs, path::Path};
6use tokio::fs::File;
7use tokio::io::{stdout, AsyncWriteExt};
8
9pub mod models;
10
11use models::*;
12
13// Function to format text as Markdown
14#[allow(dead_code)]
15fn format_as_markdown(text: &str, citations: &Option<Vec<Citation>>) -> String {
16    // Just return the raw text as it already contains markdown formatting
17    let mut formatted_text = text.to_string();
18
19    // If citations exist, add them to the bottom of the document
20    if let Some(citation_vec) = citations {
21        if !citation_vec.is_empty() {
22            formatted_text.push_str("\n\n## Sources\n\n");
23
24            for (i, citation) in citation_vec.iter().enumerate() {
25                if let Some(url) = citation.url.strip_prefix("http") {
26                    formatted_text.push_str(&format!(
27                        "{}. [{}](http{})\n",
28                        i + 1,
29                        citation.url,
30                        url
31                    ));
32                } else {
33                    formatted_text.push_str(&format!("{}. {}\n", i + 1, citation.url));
34                }
35            }
36        }
37    }
38
39    formatted_text
40}
41
42#[allow(dead_code)]
43pub async fn chat_completion(
44    model: &str,
45    messages: Vec<ChatMessage>,
46) -> Result<(), Box<dyn std::error::Error>> {
47    dotenv().ok();
48
49    let api_key = env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set in .env");
50
51    let client = Client::new();
52    let request = ChatCompletionRequest {
53        model: model.to_string(),
54        messages,
55        stream: true,
56    };
57
58    let response = client
59        .post("https://api.perplexity.ai/chat/completions")
60        .header("Authorization", format!("Bearer {}", api_key))
61        .header("Content-Type", "application/json")
62        .json(&request)
63        .send()
64        .await?;
65
66    if !response.status().is_success() {
67        println!("Response status code: {}", response.status());
68        let response_body = response.text().await?;
69        println!("Response body: {}", response_body);
70        return Err("API request failed".into());
71    }
72
73    let mut stream = response.bytes_stream();
74    let mut stdout = stdout();
75
76    while let Some(chunk) = stream.next().await {
77        let chunk = chunk?;
78        let chunk_str = String::from_utf8(chunk.to_vec())?;
79
80        let events: Vec<&str> = chunk_str.split("\n\n").collect();
81
82        for event in events {
83            if event.starts_with("data:") {
84                let data = event.trim_start_matches("data:").trim();
85
86                if data == "[DONE]" {
87                    break;
88                }
89
90                if let Ok(completion_response) =
91                    serde_json::from_str::<ChatCompletionResponse>(data)
92                {
93                    if let Some(choice) = completion_response.choices.get(0) {
94                        if let Some(message) = &choice.message {
95                            stdout.write_all(message.content.as_bytes()).await?;
96                            stdout.flush().await?;
97                        }
98                    }
99                } else {
100                    eprintln!("Failed to parse SSE data: {}", data);
101                }
102            }
103        }
104    }
105
106    Ok(())
107}
108
109#[allow(dead_code)]
110pub async fn chat_completion_markdown(
111    model: &str,
112    messages: Vec<ChatMessage>,
113    path: Option<&Path>,
114    filename: &str,
115) -> Result<(), Box<dyn std::error::Error>> {
116    dotenv().ok();
117
118    let api_key = env::var("PERPLEXITY_API_KEY").expect("PERPLEXITY_API_KEY not set in .env");
119
120    // Use a non-streaming request to get the complete response at once
121    let client = Client::new();
122    let request = ChatCompletionRequest {
123        model: model.to_string(),
124        messages,
125        stream: false,
126    };
127
128    let response = client
129        .post("https://api.perplexity.ai/chat/completions")
130        .header("Authorization", format!("Bearer {}", api_key))
131        .header("Content-Type", "application/json")
132        .json(&request)
133        .send()
134        .await?;
135
136    if !response.status().is_success() {
137        println!("Response status code: {}", response.status());
138        let response_body = response.text().await?;
139        println!("Response body: {}", response_body);
140        return Err("API request failed".into());
141    }
142
143    // Get the complete JSON response
144    let response_text = response.text().await?;
145
146    // For debugging
147    let _ = fs::write("raw_response.json", &response_text);
148
149    // Parse the response
150    let completion_response: ChatCompletionResponse = serde_json::from_str(&response_text)?;
151
152    // Extract the content and citations from the response
153    let (content, citations) = if let Some(choice) = completion_response.choices.get(0) {
154        let content = choice
155            .message
156            .as_ref()
157            .map_or(String::new(), |m| m.content.clone());
158        (content, choice.citations.clone())
159    } else {
160        (String::new(), None)
161    };
162
163    // Format the content with proper citations
164    let formatted_content = format_as_markdown(&content, &citations);
165
166    // For debugging
167    let _ = fs::write("extracted_content.txt", &content);
168    let _ = fs::write("formatted_content.txt", &formatted_content);
169
170    // Write to file
171    let file_path = match path {
172        Some(p) => p.join(filename),
173        None => Path::new(filename).to_path_buf(),
174    };
175
176    let mut file = File::create(file_path).await?;
177    file.write_all(formatted_content.as_bytes()).await?;
178
179    Ok(())
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use std::{env, path::Path};
186    use tokio;
187
188    #[tokio::test]
189    async fn test_chat_completion() {
190        dotenv::dotenv().ok();
191
192        // This test requires the PERPLEXITY_API_KEY environment variable to be set.
193        if env::var("PERPLEXITY_API_KEY").is_ok() {
194            let system_message = "Be precise and concise.".to_string();
195            let user_message = "How many stars are there in our galaxy?".to_string();
196
197            let messages = vec![
198                ChatMessage {
199                    role: "system".to_string(),
200                    content: system_message.clone(),
201                },
202                ChatMessage {
203                    role: "user".to_string(),
204                    content: user_message.clone(),
205                },
206            ];
207
208            let result = super::chat_completion("sonar-pro", messages).await;
209            assert!(result.is_ok());
210        } else {
211            println!("Skipping test because PERPLEXITY_API_KEY is not set");
212        }
213    }
214
215    #[tokio::test]
216    async fn test_chat_completion_markdown() {
217        dotenv::dotenv().ok();
218
219        // This test requires the PERPLEXITY_API_KEY environment variable to be set.
220        if env::var("PERPLEXITY_API_KEY").is_ok() {
221            let system_message =
222                "Be precise and concise. Return the response as a markdown, not a JSON."
223                    .to_string();
224            let user_message = "How many stars are there in our galaxy?".to_string();
225
226            let messages = vec![
227                ChatMessage {
228                    role: "system".to_string(),
229                    content: system_message.clone(),
230                },
231                ChatMessage {
232                    role: "user".to_string(),
233                    content: user_message.clone(),
234                },
235            ];
236
237            let path = Path::new(".");
238            let filename = "test_output.md";
239
240            let result =
241                super::chat_completion_markdown("sonar-pro", messages, Some(path), filename).await;
242            assert!(result.is_ok());
243        } else {
244            println!("Skipping test because PERPLEXITY_API_KEY is not set");
245        }
246    }
247}