Skip to main content

vibesort_rs/
lib.rs

1//! # Vibesort
2//!
3//! A Rust library for sorting arrays using Large Language Models (LLMs).
4//!
5//! This library provides a simple interface to sort arrays by leveraging LLM APIs
6//! such as OpenAI, Anthropic, or other compatible services. It sends the array to
7//! the LLM and parses the sorted result.
8//!
9//! ## Features
10//!
11//! - Sort arrays of any type that implements `Display`, `Serialize`, and `DeserializeOwned`
12//! - Support for any LLM API compatible with OpenAI's chat completion format
13//! - Comprehensive error handling with detailed error messages
14//! - Async/await support using Tokio
15//!
16//! ## Example
17//!
18//! ```no_run
19//! use vibesort_rs::Vibesort;
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
22//! let sorter = Vibesort::new(
23//!     "your-api-key",
24//!     "gpt-3.5-turbo",
25//!     "https://api.openai.com/v1",
26//! );
27//!
28//! let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
29//! let sorted = sorter.sort(&numbers).await?;
30//! println!("{:?}", sorted); // [1, 1, 2, 3, 4, 5, 6, 9]
31//! # Ok(())
32//! # }
33//! ```
34
35use serde::{Deserialize, Serialize, de::DeserializeOwned};
36use std::fmt::Display;
37use thiserror::Error;
38
39#[cfg(test)]
40mod tests {
41    use super::*;
42
43    #[test]
44    fn test_vibesort_config() {
45        let sorter = Vibesort::new("key", "model", "url");
46        assert_eq!(sorter.api_key, "key");
47        assert_eq!(sorter.model, "model");
48        assert_eq!(sorter.base_url, "url");
49    }
50
51    #[tokio::test]
52    async fn test_vibesort_with_mock() {
53        use wiremock::matchers::{method, path};
54        use wiremock::{Mock, MockServer, ResponseTemplate};
55
56        // Start a mock server
57        let mock_server = MockServer::start().await;
58        let base_url = mock_server.uri();
59
60        // Set up a mock response
61        Mock::given(method("POST"))
62            .and(path("/chat/completions"))
63            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
64                "choices": [{
65                    "message": {
66                        "content": "[1,1,2,3,4,5,6,9]"
67                    }
68                }]
69            })))
70            .mount(&mock_server)
71            .await;
72
73        // Create a Vibesort instance pointing to the mock server
74        let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
75
76        // Test the sorting
77        let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
78        let result = sorter.sort(&numbers).await;
79
80        assert!(result.is_ok());
81        let sorted = result.unwrap();
82        assert_eq!(sorted, vec![1, 1, 2, 3, 4, 5, 6, 9]);
83    }
84
85    #[tokio::test]
86    async fn test_vibesort_api_error() {
87        use wiremock::matchers::{method, path};
88        use wiremock::{Mock, MockServer, ResponseTemplate};
89
90        // Start a mock server
91        let mock_server = MockServer::start().await;
92        let base_url = mock_server.uri();
93
94        // Set up a mock error response
95        Mock::given(method("POST"))
96            .and(path("/chat/completions"))
97            .respond_with(ResponseTemplate::new(500).set_body_string("Internal Server Error"))
98            .mount(&mock_server)
99            .await;
100
101        let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
102
103        let numbers = vec![3, 1, 4];
104        let result = sorter.sort(&numbers).await;
105
106        assert!(result.is_err());
107        match result.unwrap_err() {
108            VibesortError::ApiError(_) => {}
109            _ => panic!("Expected ApiError"),
110        }
111    }
112
113    #[tokio::test]
114    async fn test_vibesort_parse_error() {
115        use wiremock::matchers::{method, path};
116        use wiremock::{Mock, MockServer, ResponseTemplate};
117
118        // Start a mock server
119        let mock_server = MockServer::start().await;
120        let base_url = mock_server.uri();
121
122        // Set up a mock response with invalid JSON (not a valid array)
123        Mock::given(method("POST"))
124            .and(path("/chat/completions"))
125            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
126                "choices": [{
127                    "message": {
128                        "content": "Here is the sorted array: 1, 2, 3"
129                    }
130                }]
131            })))
132            .mount(&mock_server)
133            .await;
134
135        let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
136
137        let numbers = vec![3, 1, 2];
138        let result = sorter.sort(&numbers).await;
139
140        assert!(result.is_err());
141        match result.unwrap_err() {
142            VibesortError::ParseError(msg) => {
143                // Verify that the error message contains the LLM's response
144                assert!(msg.contains("Here is the sorted array: 1, 2, 3"));
145            }
146            _ => panic!("Expected ParseError"),
147        }
148    }
149
150    #[tokio::test]
151    async fn test_vibesort_str_with_mock() {
152        use wiremock::matchers::{method, path};
153        use wiremock::{Mock, MockServer, ResponseTemplate};
154
155        // Start a mock server
156        let mock_server = MockServer::start().await;
157        let base_url = mock_server.uri();
158
159        // Set up a mock response for string sorting
160        Mock::given(method("POST"))
161            .and(path("/chat/completions"))
162            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
163                "choices": [{
164                    "message": {
165                        "content": "[\"apple\",\"banana\",\"cherry\"]"
166                    }
167                }]
168            })))
169            .mount(&mock_server)
170            .await;
171
172        // Create a Vibesort instance pointing to the mock server
173        let sorter = Vibesort::new("test-api-key", "test-model", base_url.as_str());
174
175        // Test the string sorting
176        let words = vec!["banana", "apple", "cherry"];
177        let result = sorter.sort_str(&words).await;
178
179        assert!(result.is_ok());
180        let sorted = result.unwrap();
181        assert_eq!(sorted, vec!["apple", "banana", "cherry"]);
182    }
183}
184
185/// Error types for vibesort operations.
186///
187/// This enum represents all possible errors that can occur during the sorting process.
188#[derive(Error, Debug)]
189pub enum VibesortError {
190    /// An error occurred while making the HTTP request to the LLM API.
191    #[error("HTTP request failed: {0}")]
192    HttpError(#[from] reqwest::Error),
193
194    /// An error occurred while parsing JSON (e.g., when serializing the input array
195    /// or deserializing the LLM response).
196    #[error("JSON parsing failed: {0}")]
197    JsonError(#[from] serde_json::Error),
198
199    /// The LLM API returned an error status code.
200    ///
201    /// This error includes the HTTP status code and the server's response body.
202    #[error("LLM API error: {0}")]
203    ApiError(String),
204
205    /// The LLM API response is missing required fields or has an invalid structure.
206    ///
207    /// This typically means the response doesn't contain a `choices` array or
208    /// the first choice doesn't have a `message` field.
209    #[error("Invalid response format from LLM")]
210    InvalidResponse,
211
212    /// The LLM returned content that cannot be parsed as a JSON array.
213    ///
214    /// This error includes the parsing error details and the actual content
215    /// returned by the LLM, which helps diagnose why the parsing failed.
216    #[error("Failed to parse LLM response as sorted array. LLM returned: {0}")]
217    ParseError(String),
218}
219
220/// OpenAI API request/response structures
221#[derive(Debug, Serialize)]
222struct ChatRequest<'a> {
223    model: &'a str,
224    messages: Vec<ChatMessage<'a>>,
225    temperature: f32,
226}
227
228#[derive(Debug, Serialize)]
229struct ChatMessage<'a> {
230    role: &'a str,
231    content: &'a str,
232}
233
234#[derive(Debug, Deserialize)]
235struct ChatMessageResponse {
236    content: String,
237}
238
239#[derive(Debug, Deserialize)]
240struct ChatResponse {
241    choices: Vec<Choice>,
242}
243
244#[derive(Debug, Deserialize)]
245struct Choice {
246    message: ChatMessageResponse,
247}
248
249/// Client for sorting arrays using LLM APIs.
250///
251/// This struct holds the configuration needed to communicate with an LLM API
252/// and provides methods to sort arrays.
253///
254/// # Example
255///
256/// ```no_run
257/// use vibesort_rs::Vibesort;
258///
259/// let sorter = Vibesort::new(
260///     "sk-...",
261///     "gpt-3.5-turbo",
262///     "https://api.openai.com/v1",
263/// );
264/// ```
265#[derive(Debug, Clone)]
266pub struct Vibesort<'a> {
267    /// The API key for authenticating with the LLM service.
268    pub api_key: &'a str,
269
270    /// The model identifier to use (e.g., "gpt-3.5-turbo", "gpt-4").
271    pub model: &'a str,
272
273    /// The base URL of the LLM API endpoint (e.g., "https://api.openai.com/v1").
274    pub base_url: &'a str,
275}
276
277impl<'a> Vibesort<'a> {
278    /// Creates a new `Vibesort` instance.
279    ///
280    /// # Arguments
281    ///
282    /// * `api_key` - The API key for authenticating with the LLM service
283    /// * `model` - The model identifier to use (e.g., "gpt-3.5-turbo", "gpt-4")
284    /// * `base_url` - The base URL of the LLM API endpoint
285    ///
286    /// # Example
287    ///
288    /// ```no_run
289    /// use vibesort_rs::Vibesort;
290    ///
291    /// let sorter = Vibesort::new(
292    ///     "sk-1234567890abcdef",
293    ///     "gpt-3.5-turbo",
294    ///     "https://api.openai.com/v1",
295    /// );
296    /// ```
297    pub fn new(api_key: &'a str, model: &'a str, base_url: &'a str) -> Self {
298        Self {
299            api_key,
300            model,
301            base_url,
302        }
303    }
304
305    /// Sorts an array using an LLM.
306    ///
307    /// This method sends the input array to the configured LLM API and requests
308    /// it to sort the elements. The LLM is instructed to return only a JSON array
309    /// with the sorted elements, which is then parsed and returned.
310    ///
311    /// # Arguments
312    ///
313    /// * `items` - A slice of items to sort. Each item must implement:
314    ///   - `Display` - For error messages
315    ///   - `Serialize` - For serializing to JSON
316    ///   - `DeserializeOwned` - For deserializing the sorted result
317    ///
318    /// # Returns
319    ///
320    /// Returns `Ok(Vec<T>)` with the sorted array if successful, or an error
321    /// if the API call fails, the response is invalid, or parsing fails.
322    ///
323    /// # Errors
324    ///
325    /// This method can return various errors:
326    /// - [`VibesortError::HttpError`] - Network or HTTP request errors
327    /// - [`VibesortError::ApiError`] - API returned an error status code
328    /// - [`VibesortError::InvalidResponse`] - Response format is invalid
329    /// - [`VibesortError::ParseError`] - LLM response cannot be parsed as a JSON array
330    /// - [`VibesortError::JsonError`] - JSON serialization/deserialization errors
331    ///
332    /// # Examples
333    ///
334    /// ## Sorting numbers
335    ///
336    /// ```no_run
337    /// use vibesort_rs::Vibesort;
338    ///
339    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
340    /// let sorter = Vibesort::new(
341    ///     "your-api-key",
342    ///     "gpt-3.5-turbo",
343    ///     "https://api.openai.com/v1",
344    /// );
345    ///
346    /// let numbers = vec![3, 1, 4, 1, 5, 9, 2, 6];
347    /// let sorted = sorter.sort(&numbers).await?;
348    /// assert_eq!(sorted, vec![1, 1, 2, 3, 4, 5, 6, 9]);
349    /// # Ok(())
350    /// # }
351    /// ```
352    ///
353    /// ## Sorting strings
354    ///
355    /// ```no_run
356    /// use vibesort_rs::Vibesort;
357    ///
358    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
359    /// let sorter = Vibesort::new(
360    ///     "your-api-key",
361    ///     "gpt-3.5-turbo",
362    ///     "https://api.openai.com/v1",
363    /// );
364    ///
365    /// let words: Vec<String> = vec!["banana", "apple", "cherry"]
366    ///     .into_iter()
367    ///     .map(|s| s.to_string())
368    ///     .collect();
369    /// let sorted = sorter.sort(&words).await?;
370    /// # Ok(())
371    /// # }
372    /// ```
373    ///
374    /// ## Error handling
375    ///
376    /// ```no_run
377    /// use vibesort_rs::{Vibesort, VibesortError};
378    ///
379    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
380    /// let sorter = Vibesort::new(
381    ///     "invalid-key",
382    ///     "gpt-3.5-turbo",
383    ///     "https://api.openai.com/v1",
384    /// );
385    ///
386    /// match sorter.sort(&vec![1, 2, 3]).await {
387    ///     Ok(sorted) => println!("Sorted: {:?}", sorted),
388    ///     Err(VibesortError::ApiError(msg)) => eprintln!("API error: {}", msg),
389    ///     Err(e) => eprintln!("Other error: {}", e),
390    /// }
391    /// # Ok(())
392    /// # }
393    /// ```
394    pub async fn sort<T>(&self, items: &[T]) -> Result<Vec<T>, VibesortError>
395    where
396        T: Display + Serialize + DeserializeOwned,
397    {
398        // Serialize the input array to JSON
399        let json_array = serde_json::to_string(items)?;
400
401        // Build the API URL
402        let url = format!("{}/chat/completions", self.base_url);
403
404        // Create the HTTP client
405        let client = reqwest::Client::new();
406
407        // Prepare the request with system prompt and user prompt
408        let system_prompt = "You are a helpful assistant that sorts arrays. Sort the following JSON array with ascending order and return ONLY the sorted JSON array, nothing else.";
409        let request = ChatRequest {
410            model: self.model,
411            messages: vec![
412                ChatMessage {
413                    role: "system",
414                    content: system_prompt,
415                },
416                ChatMessage {
417                    role: "user",
418                    content: &json_array,
419                },
420            ],
421            temperature: 0.0, // Use 0.0 for deterministic sorting
422        };
423
424        // Send the request
425        let response = client
426            .post(&url)
427            .header("Authorization", format!("Bearer {}", self.api_key))
428            .header("Content-Type", "application/json")
429            .json(&request)
430            .send()
431            .await?;
432
433        // Check if the request was successful
434        let status = response.status();
435        if !status.is_success() {
436            let error_text = response.text().await.unwrap_or_default();
437            return Err(VibesortError::ApiError(format!(
438                "API returned status {}\nServer response: {}",
439                status, error_text
440            )));
441        }
442
443        // Parse the response
444        let chat_response: ChatResponse = response.json().await?;
445
446        // Extract the sorted array from the LLM's response
447        let mut sorted_json = chat_response
448            .choices
449            .first()
450            .ok_or(VibesortError::InvalidResponse)?
451            .message
452            .content
453            .trim()
454            .to_string();
455
456        // Strip markdown code blocks if present (e.g., ```json ... ```)
457        if sorted_json.starts_with("```") {
458            // Remove the opening ``` and optional language identifier
459            if let Some(start_idx) = sorted_json.find('\n') {
460                sorted_json = sorted_json[start_idx + 1..].to_string();
461            } else {
462                // No newline, just remove the ```
463                sorted_json = sorted_json[3..].to_string();
464            }
465            // Remove the closing ```
466            if sorted_json.ends_with("```") {
467                sorted_json = sorted_json[..sorted_json.len() - 3].trim().to_string();
468            }
469        }
470
471        // Parse the JSON array back to Vec<T>
472        let sorted: Vec<T> = serde_json::from_str(&sorted_json).map_err(|e| {
473            VibesortError::ParseError(format!(
474                "Failed to parse as JSON array: {}\nLLM returned: {}",
475                e, sorted_json
476            ))
477        })?;
478
479        Ok(sorted)
480    }
481
482    /// Sorts an array of strings using an LLM.
483    ///
484    /// This is a convenience method specifically for sorting string arrays.
485    /// It accepts a slice of string references and returns a vector of owned strings.
486    ///
487    /// # Arguments
488    ///
489    /// * `items` - A slice of string references to sort
490    ///
491    /// # Returns
492    ///
493    /// Returns `Ok(Vec<String>)` with the sorted array if successful, or an error
494    /// if the API call fails, the response is invalid, or parsing fails.
495    ///
496    /// # Errors
497    ///
498    /// This method can return the same errors as [`sort`](Self::sort):
499    /// - [`VibesortError::HttpError`] - Network or HTTP request errors
500    /// - [`VibesortError::ApiError`] - API returned an error status code
501    /// - [`VibesortError::InvalidResponse`] - Response format is invalid
502    /// - [`VibesortError::ParseError`] - LLM response cannot be parsed as a JSON array
503    /// - [`VibesortError::JsonError`] - JSON serialization/deserialization errors
504    ///
505    /// # Examples
506    ///
507    /// ```no_run
508    /// use vibesort_rs::Vibesort;
509    ///
510    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
511    /// let sorter = Vibesort::new(
512    ///     "your-api-key",
513    ///     "gpt-3.5-turbo",
514    ///     "https://api.openai.com/v1",
515    /// );
516    ///
517    /// let words = vec!["banana", "apple", "cherry"];
518    /// let sorted = sorter.sort_str(&words).await?;
519    /// assert_eq!(sorted, vec!["apple", "banana", "cherry"]);
520    /// # Ok(())
521    /// # }
522    /// ```
523    pub async fn sort_str(&self, items: &[&str]) -> Result<Vec<String>, VibesortError> {
524        // Convert &[&str] to Vec<String> for serialization
525        let string_vec: Vec<String> = items.iter().map(|s| s.to_string()).collect();
526        self.sort(&string_vec).await
527    }
528}