vibesort_rs/lib.rs
1//! # Vibesort
2//!
3//! A Rust library for sorting arrays using Large Language Models (LLMs).
4//!
5//! This library provides a simple interface to sort arrays by leveraging LLM APIs
6//! such as OpenAI, Anthropic, or other compatible services. It sends the array to
7//! the LLM and parses the sorted result.
8//!
9//! ## Features
10//!
11//! - Sort arrays of any type that implements `Display`, `Serialize`, and `DeserializeOwned`
12//! - Support for any LLM API compatible with OpenAI's chat completion format
13//! - Comprehensive error handling with detailed error messages
14//! - Async/await support using Tokio
15//!
16//! ## Example
17//!
18//! ```no_run
19//! use vibesort_rs::Vibesort;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! let sorter = Vibesort::new(
23//! "your-api-key",
24//! "gpt-3.5-turbo",
25//! "https://api.openai.com/v1",
26//! );
27//!
28//! let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
29//! let sorted = sorter.sort(&numbers).await?;
30//! println!("{:?}", sorted); // [1, 1, 2, 3, 4, 5, 6, 9]
31//! # Ok(())
32//! # }
33//! ```
34
35use serde::{Deserialize, Serialize, de::DeserializeOwned};
36use std::fmt::Display;
37use thiserror::Error;
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42
43 #[test]
44 fn test_vibesort_config() {
45 let sorter = Vibesort::new("key", "model", "url");
46 assert_eq!(sorter.api_key, "key");
47 assert_eq!(sorter.model, "model");
48 assert_eq!(sorter.base_url, "url");
49 }
50
51 #[tokio::test]
52 async fn test_vibesort_with_mock() {
53 use wiremock::matchers::{method, path};
54 use wiremock::{Mock, MockServer, ResponseTemplate};
55
56 // Start a mock server
57 let mock_server = MockServer::start().await;
58 let base_url = mock_server.uri();
59
60 // Set up a mock response
61 Mock::given(method("POST"))
62 .and(path("/chat/completions"))
63 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
64 "choices": [{
65 "message": {
66 "content": "[1,1,2,3,4,5,6,9]"
67 }
68 }]
69 })))
70 .mount(&mock_server)
71 .await;
72
73 // Create a Vibesort instance pointing to the mock server
74 let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
75
76 // Test the sorting
77 let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
78 let result = sorter.sort(&numbers).await;
79
80 assert!(result.is_ok());
81 let sorted = result.unwrap();
82 assert_eq!(sorted, vec![1, 1, 2, 3, 4, 5, 6, 9]);
83 }
84
85 #[tokio::test]
86 async fn test_vibesort_api_error() {
87 use wiremock::matchers::{method, path};
88 use wiremock::{Mock, MockServer, ResponseTemplate};
89
90 // Start a mock server
91 let mock_server = MockServer::start().await;
92 let base_url = mock_server.uri();
93
94 // Set up a mock error response
95 Mock::given(method("POST"))
96 .and(path("/chat/completions"))
97 .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
98 .mount(&mock_server)
99 .await;
100
101 let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
102
103 let numbers = vec![3, 1, 4];
104 let result = sorter.sort(&numbers).await;
105
106 assert!(result.is_err());
107 match result.unwrap_err() {
108 VibesortError::ApiError(_) => {}
109 _ => panic!("Expected ApiError"),
110 }
111 }
112
113 #[tokio::test]
114 async fn test_vibesort_parse_error() {
115 use wiremock::matchers::{method, path};
116 use wiremock::{Mock, MockServer, ResponseTemplate};
117
118 // Start a mock server
119 let mock_server = MockServer::start().await;
120 let base_url = mock_server.uri();
121
122 // Set up a mock response with invalid JSON (not a valid array)
123 Mock::given(method("POST"))
124 .and(path("/chat/completions"))
125 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
126 "choices": [{
127 "message": {
128 "content": "Here is the sorted array: 1, 2, 3"
129 }
130 }]
131 })))
132 .mount(&mock_server)
133 .await;
134
135 let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
136
137 let numbers = vec![3, 1, 2];
138 let result = sorter.sort(&numbers).await;
139
140 assert!(result.is_err());
141 match result.unwrap_err() {
142 VibesortError::ParseError(msg) => {
143 // Verify that the error message contains the LLM's response
144 assert!(msg.contains("Here is the sorted array: 1, 2, 3"));
145 }
146 _ => panic!("Expected ParseError"),
147 }
148 }
149
150 #[tokio::test]
151 async fn test_vibesort_str_with_mock() {
152 use wiremock::matchers::{method, path};
153 use wiremock::{Mock, MockServer, ResponseTemplate};
154
155 // Start a mock server
156 let mock_server = MockServer::start().await;
157 let base_url = mock_server.uri();
158
159 // Set up a mock response for string sorting
160 Mock::given(method("POST"))
161 .and(path("/chat/completions"))
162 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
163 "choices": [{
164 "message": {
165 "content": "[\"apple\",\"banana\",\"cherry\"]"
166 }
167 }]
168 })))
169 .mount(&mock_server)
170 .await;
171
172 // Create a Vibesort instance pointing to the mock server
173 let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
174
175 // Test the string sorting
176 let words = vec!["banana", "apple", "cherry"];
177 let result = sorter.sort_str(&words).await;
178
179 assert!(result.is_ok());
180 let sorted = result.unwrap();
181 assert_eq!(sorted, vec!["apple", "banana", "cherry"]);
182 }
183}
184
185/// Error types for vibesort operations.
186///
187/// This enum represents all possible errors that can occur during the sorting process.
188#[derive(Error, Debug)]
189pub enum VibesortError {
190 /// An error occurred while making the HTTP request to the LLM API.
191 #[error("HTTP request failed: {0}")]
192 HttpError(#[from] reqwest::Error),
193
194 /// An error occurred while parsing JSON (e.g., when serializing the input array
195 /// or deserializing the LLM response).
196 #[error("JSON parsing failed: {0}")]
197 JsonError(#[from] serde_json::Error),
198
199 /// The LLM API returned an error status code.
200 ///
201 /// This error includes the HTTP status code and the server's response body.
202 #[error("LLM API error: {0}")]
203 ApiError(String),
204
205 /// The LLM API response is missing required fields or has an invalid structure.
206 ///
207 /// This typically means the response doesn't contain a `choices` array or
208 /// the first choice doesn't have a `message` field.
209 #[error("Invalid response format from LLM")]
210 InvalidResponse,
211
212 /// The LLM returned content that cannot be parsed as a JSON array.
213 ///
214 /// This error includes the parsing error details and the actual content
215 /// returned by the LLM, which helps diagnose why the parsing failed.
216 #[error("Failed to parse LLM response as sorted array. LLM returned: {0}")]
217 ParseError(String),
218}
219
220/// OpenAI API request/response structures
221#[derive(Debug, Serialize)]
222struct ChatRequest<'a> {
223 model: &'a str,
224 messages: Vec<ChatMessage<'a>>,
225 temperature: f32,
226}
227
228#[derive(Debug, Serialize)]
229struct ChatMessage<'a> {
230 role: &'a str,
231 content: &'a str,
232}
233
234#[derive(Debug, Deserialize)]
235struct ChatMessageResponse {
236 content: String,
237}
238
239#[derive(Debug, Deserialize)]
240struct ChatResponse {
241 choices: Vec<Choice>,
242}
243
244#[derive(Debug, Deserialize)]
245struct Choice {
246 message: ChatMessageResponse,
247}
248
249/// Client for sorting arrays using LLM APIs.
250///
251/// This struct holds the configuration needed to communicate with an LLM API
252/// and provides methods to sort arrays.
253///
254/// # Example
255///
256/// ```no_run
257/// use vibesort_rs::Vibesort;
258///
259/// let sorter = Vibesort::new(
260/// "sk-...",
261/// "gpt-3.5-turbo",
262/// "https://api.openai.com/v1",
263/// );
264/// ```
265#[derive(Debug, Clone)]
266pub struct Vibesort<'a> {
267 /// The API key for authenticating with the LLM service.
268 pub api_key: &'a str,
269
270 /// The model identifier to use (e.g., "gpt-3.5-turbo", "gpt-4").
271 pub model: &'a str,
272
273 /// The base URL of the LLM API endpoint (e.g., "https://api.openai.com/v1").
274 pub base_url: &'a str,
275}
276
277impl<'a> Vibesort<'a> {
278 /// Creates a new `Vibesort` instance.
279 ///
280 /// # Arguments
281 ///
282 /// * `api_key` - The API key for authenticating with the LLM service
283 /// * `model` - The model identifier to use (e.g., "gpt-3.5-turbo", "gpt-4")
284 /// * `base_url` - The base URL of the LLM API endpoint
285 ///
286 /// # Example
287 ///
288 /// ```no_run
289 /// use vibesort_rs::Vibesort;
290 ///
291 /// let sorter = Vibesort::new(
292 /// "sk-1234567890abcdef",
293 /// "gpt-3.5-turbo",
294 /// "https://api.openai.com/v1",
295 /// );
296 /// ```
297 pub fn new(api_key: &'a str, model: &'a str, base_url: &'a str) -> Self {
298 Self {
299 api_key,
300 model,
301 base_url,
302 }
303 }
304
305 /// Sorts an array using an LLM.
306 ///
307 /// This method sends the input array to the configured LLM API and requests
308 /// it to sort the elements. The LLM is instructed to return only a JSON array
309 /// with the sorted elements, which is then parsed and returned.
310 ///
311 /// # Arguments
312 ///
313 /// * `items` - A slice of items to sort. Each item must implement:
314 /// - `Display` - For error messages
315 /// - `Serialize` - For serializing to JSON
316 /// - `DeserializeOwned` - For deserializing the sorted result
317 ///
318 /// # Returns
319 ///
320 /// Returns `Ok(Vec<T>)` with the sorted array if successful, or an error
321 /// if the API call fails, the response is invalid, or parsing fails.
322 ///
323 /// # Errors
324 ///
325 /// This method can return various errors:
326 /// - [`VibesortError::HttpError`] - Network or HTTP request errors
327 /// - [`VibesortError::ApiError`] - API returned an error status code
328 /// - [`VibesortError::InvalidResponse`] - Response format is invalid
329 /// - [`VibesortError::ParseError`] - LLM response cannot be parsed as a JSON array
330 /// - [`VibesortError::JsonError`] - JSON serialization/deserialization errors
331 ///
332 /// # Examples
333 ///
334 /// ## Sorting numbers
335 ///
336 /// ```no_run
337 /// use vibesort_rs::Vibesort;
338 ///
339 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
340 /// let sorter = Vibesort::new(
341 /// "your-api-key",
342 /// "gpt-3.5-turbo",
343 /// "https://api.openai.com/v1",
344 /// );
345 ///
346 /// let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
347 /// let sorted = sorter.sort(&numbers).await?;
348 /// assert_eq!(sorted, vec![1, 1, 2, 3, 4, 5, 6, 9]);
349 /// # Ok(())
350 /// # }
351 /// ```
352 ///
353 /// ## Sorting strings
354 ///
355 /// ```no_run
356 /// use vibesort_rs::Vibesort;
357 ///
358 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
359 /// let sorter = Vibesort::new(
360 /// "your-api-key",
361 /// "gpt-3.5-turbo",
362 /// "https://api.openai.com/v1",
363 /// );
364 ///
365 /// let words: Vec<String> = vec!["banana", "apple", "cherry"]
366 /// .into_iter()
367 /// .map(|s| s.to_string())
368 /// .collect();
369 /// let sorted = sorter.sort(&words).await?;
370 /// # Ok(())
371 /// # }
372 /// ```
373 ///
374 /// ## Error handling
375 ///
376 /// ```no_run
377 /// use vibesort_rs::{Vibesort, VibesortError};
378 ///
379 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
380 /// let sorter = Vibesort::new(
381 /// "invalid-key",
382 /// "gpt-3.5-turbo",
383 /// "https://api.openai.com/v1",
384 /// );
385 ///
386 /// match sorter.sort(&vec![1, 2, 3]).await {
387 /// Ok(sorted) => println!("Sorted: {:?}", sorted),
388 /// Err(VibesortError::ApiError(msg)) => eprintln!("API error: {}", msg),
389 /// Err(e) => eprintln!("Other error: {}", e),
390 /// }
391 /// # Ok(())
392 /// # }
393 /// ```
394 pub async fn sort<T>(&self, items: &[T]) -> Result<Vec<T>, VibesortError>
395 where
396 T: Display + Serialize + DeserializeOwned,
397 {
398 // Serialize the input array to JSON
399 let json_array = serde_json::to_string(items)?;
400
401 // Build the API URL
402 let url = format!("{}/chat/completions", self.base_url);
403
404 // Create the HTTP client
405 let client = reqwest::Client::new();
406
407 // Prepare the request with system prompt and user prompt
408 let system_prompt = "You are a helpful assistant that sorts arrays. Sort the following JSON array with ascending order and return ONLY the sorted JSON array, nothing else.";
409 let request = ChatRequest {
410 model: self.model,
411 messages: vec![
412 ChatMessage {
413 role: "system",
414 content: system_prompt,
415 },
416 ChatMessage {
417 role: "user",
418 content: &json_array,
419 },
420 ],
421 temperature: 0.0, // Use 0.0 for deterministic sorting
422 };
423
424 // Send the request
425 let response = client
426 .post(&url)
427 .header("Authorization", format!("Bearer {}", self.api_key))
428 .header("Content-Type", "application/json")
429 .json(&request)
430 .send()
431 .await?;
432
433 // Check if the request was successful
434 let status = response.status();
435 if !status.is_success() {
436 let error_text = response.text().await.unwrap_or_default();
437 return Err(VibesortError::ApiError(format!(
438 "API returned status {}\nServer response: {}",
439 status, error_text
440 )));
441 }
442
443 // Parse the response
444 let chat_response: ChatResponse = response.json().await?;
445
446 // Extract the sorted array from the LLM's response
447 let mut sorted_json = chat_response
448 .choices
449 .first()
450 .ok_or(VibesortError::InvalidResponse)?
451 .message
452 .content
453 .trim()
454 .to_string();
455
456 // Strip markdown code blocks if present (e.g., ```json ... ```)
457 if sorted_json.starts_with("```") {
458 // Remove the opening ``` and optional language identifier
459 if let Some(start_idx) = sorted_json.find('\n') {
460 sorted_json = sorted_json[start_idx + 1..].to_string();
461 } else {
462 // No newline, just remove the ```
463 sorted_json = sorted_json[3..].to_string();
464 }
465 // Remove the closing ```
466 if sorted_json.ends_with("```") {
467 sorted_json = sorted_json[..sorted_json.len() - 3].trim().to_string();
468 }
469 }
470
471 // Parse the JSON array back to Vec<T>
472 let sorted: Vec<T> = serde_json::from_str(&sorted_json).map_err(|e| {
473 VibesortError::ParseError(format!(
474 "Failed to parse as JSON array: {}\nLLM returned: {}",
475 e, sorted_json
476 ))
477 })?;
478
479 Ok(sorted)
480 }
481
482 /// Sorts an array of strings using an LLM.
483 ///
484 /// This is a convenience method specifically for sorting string arrays.
485 /// It accepts a slice of string references and returns a vector of owned strings.
486 ///
487 /// # Arguments
488 ///
489 /// * `items` - A slice of string references to sort
490 ///
491 /// # Returns
492 ///
493 /// Returns `Ok(Vec<String>)` with the sorted array if successful, or an error
494 /// if the API call fails, the response is invalid, or parsing fails.
495 ///
496 /// # Errors
497 ///
498 /// This method can return the same errors as [`sort`](Self::sort):
499 /// - [`VibesortError::HttpError`] - Network or HTTP request errors
500 /// - [`VibesortError::ApiError`] - API returned an error status code
501 /// - [`VibesortError::InvalidResponse`] - Response format is invalid
502 /// - [`VibesortError::ParseError`] - LLM response cannot be parsed as a JSON array
503 /// - [`VibesortError::JsonError`] - JSON serialization/deserialization errors
504 ///
505 /// # Examples
506 ///
507 /// ```no_run
508 /// use vibesort_rs::Vibesort;
509 ///
510 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
511 /// let sorter = Vibesort::new(
512 /// "your-api-key",
513 /// "gpt-3.5-turbo",
514 /// "https://api.openai.com/v1",
515 /// );
516 ///
517 /// let words = vec!["banana", "apple", "cherry"];
518 /// let sorted = sorter.sort_str(&words).await?;
519 /// assert_eq!(sorted, vec!["apple", "banana", "cherry"]);
520 /// # Ok(())
521 /// # }
522 /// ```
523 pub async fn sort_str(&self, items: &[&str]) -> Result<Vec<String>, VibesortError> {
524 // Convert &[&str] to Vec<String> for serialization
525 let string_vec: Vec<String> = items.iter().map(|s| s.to_string()).collect();
526 self.sort(&string_vec).await
527 }
528}