rustylms/
chat.rs

1use anyhow::Result;
2
3use crate::{
4    lmsserver::LMSServer,
5    models::{ChatCompletionsRequest, ChatCompletionsResponse, Message},
6};
7
8#[derive(Debug)]
9pub struct Chat {
10    model: String,
11    messages: Vec<Message>,
12    temperature: f32,
13    max_tokens: i32,
14}
15
16impl Chat {
17    /// Creates a new `Chat` with the selected model.
18    ///
19    /// # Example
20    ///
21    /// ```rust
22    /// use rustylms::{
23    ///     chat::Chat,
24    ///     lmsserver::LMSServer
25    /// };
26    ///
27    /// let server = LMSServer::new("http://localhost:1234");
28    /// let chat = Chat::new("model-name").system_prompt("You are a helpful assistant.").user_prompt("Why does iron rust?");
29    /// let completion = chat.get_completions(&server).await.expect("Could not get completions");
30    ///
31    /// println!("From assistant: {}", completion.get_message().unwrap().content);
32    /// ```
33    pub fn new<T>(model: T) -> Self
34    where
35        T: Into<String>,
36    {
37        Self {
38            model: model.into(),
39            messages: vec![],
40            temperature: 0.7,
41            max_tokens: -1,
42        }
43    }
44
45    /// Sets the temperature of the model. The default value for this is `0.7`.
46    pub fn temperature(mut self, temperature: f32) -> Self {
47        self.temperature = temperature;
48
49        self
50    }
51
52    /// Sets the maximum tokens a completion can generate. The default value is `-1` meaning no limit.
53    pub fn max_tokens(mut self, max_tokens: i32) -> Self {
54        self.max_tokens = max_tokens;
55
56        self
57    }
58
59    /// This function adds a system prompt to the messages.
60    ///
61    /// **NOTE:** This function doesn't clear the messages array before adding the system prompt!
62    pub fn system_prompt<T>(mut self, system_prompt: T) -> Self
63    where
64        T: Into<String>,
65    {
66        self.add_system_message(system_prompt);
67
68        self
69    }
70
71    /// This function adds a user prompt to the messages.
72    ///
73    /// **NOTE:** This function doesn't clear the messages array before adding the user prompt!
74    pub fn user_prompt<T>(mut self, user_prompt: T) -> Self
75    where
76        T: Into<String>,
77    {
78        self.add_user_message(user_prompt);
79
80        self
81    }
82
83    /// Clears all messages in the chat **including system prompts**.
84    pub fn clear_messages(&mut self) {
85        self.messages.clear()
86    }
87
88    /// Adds the provided `Message` to the chat.
89    pub fn add_message(&mut self, message: Message) {
90        self.messages.push(message)
91    }
92
93    /// Adds the provided message content as a system message.
94    pub fn add_system_message<T>(&mut self, message: T)
95    where
96        T: Into<String>,
97    {
98        self.add_message(Message::system(message))
99    }
100
101    /// Adds the provided message content as a message from the assistant.
102    pub fn add_assistant_message<T>(&mut self, message: T)
103    where
104        T: Into<String>,
105    {
106        self.add_message(Message::assistant(message))
107    }
108
109    /// Adds the provided message content as a message from the user.
110    pub fn add_user_message<T>(&mut self, message: T)
111    where
112        T: Into<String>,
113    {
114        self.add_message(Message::user(message))
115    }
116
117    /// Gets the completion from the server by sending the current `Chat` struct.
118    pub async fn get_completions(&self, server: &LMSServer) -> Result<ChatCompletionsResponse> {
119        let request = ChatCompletionsRequest {
120            max_tokens: self.max_tokens,
121            messages: self.messages.clone(),
122            model: self.model.clone(),
123            temperature: self.temperature,
124        };
125        let response = server.get_chat_completion(request).await?;
126
127        Ok(response)
128    }
129}