tokio_openai/
lib.rs

1#![allow(clippy::multiple_crate_versions)]
2//! API for `OpenAI`
3
4extern crate core;
5
6use core::fmt;
7use std::{
8    fmt::{Display, Formatter},
9    future::Future,
10};
11
12use anyhow::{bail, Context};
13use derive_build::Build;
14use derive_more::Constructor;
15pub use ext::OpenAiStreamExt;
16use futures_util::{Stream, StreamExt, TryStreamExt};
17pub use reqwest;
18use reqwest::Response;
19use schemars::JsonSchema;
20use serde::{
21    de,
22    de::{DeserializeOwned, Visitor},
23    Deserialize, Deserializer, Serialize,
24};
25use serde_json::Value;
26use tokio::sync::mpsc;
27use tokio_stream::wrappers::ReceiverStream;
28
29use crate::util::schema;
30
31mod ext;
32mod util;
33struct StringOrStruct(Option<Value>);
34
35impl<'de> Visitor<'de> for StringOrStruct {
36    type Value = Option<Value>;
37
38    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
39        formatter.write_str("string or structure")
40    }
41
42    fn visit_str<E: de::Error>(self, value: &str) -> Result<Self::Value, E> {
43        match serde_json::from_str(value) {
44            Ok(val) => Ok(Some(val)),
45            Err(_) => Err(E::custom("expected valid json in string format")),
46        }
47    }
48
49    fn visit_map<M>(self, visitor: M) -> Result<Self::Value, M::Error>
50    where
51        M: de::MapAccess<'de>,
52    {
53        let val = Value::deserialize(de::value::MapAccessDeserializer::new(visitor))?;
54        Ok(Some(val))
55    }
56}
57
58fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Option<Value>, D::Error>
59where
60    D: Deserializer<'de>,
61{
62    deserializer.deserialize_any(StringOrStruct(None))
63}
64
65/// Grab the `OpenAI` key from the environment
66///
67/// # Errors
68/// Will return `Err` if the key `OPENAI_KEY` does not exist
69#[inline]
70pub fn openai_key() -> anyhow::Result<String> {
71    std::env::var("OPENAI_API_KEY")
72        .context("no OpenAI key specified. Set the variable OPENAI_API_KEY")
73}
74
75/// The `OpenAI` client
76#[derive(Clone)]
77pub struct Client {
78    client: reqwest::Client,
79    api_key: String,
80}
81
82impl Client {
83    /// Create a new [`Client`] client
84    #[must_use]
85    pub fn new(client: reqwest::Client, api_key: impl Into<String>) -> Self {
86        let api_key = api_key.into();
87        Self { client, api_key }
88    }
89
90    /// # Errors
91    /// Will return `Err` if no `OpenAI` key is defined
92    pub fn simple() -> anyhow::Result<Self> {
93        let key = openai_key()?;
94        Ok(Self::new(reqwest::Client::default(), key))
95    }
96}
97
98/// ```json
99/// {"model": "text-davinci-003", "prompt": "Say this is a test", "temperature": 0, "max_tokens": 7}
100/// ```
101#[derive(Clone, Serialize)]
102pub struct TextRequest<'a> {
103    pub model: Completions,
104    pub prompt: &'a str,
105    pub temperature: f64,
106
107    /// Up to 4 sequences where the API will stop generating further tokens. The returned text will
108    /// not contain the stop sequence.
109    #[serde(skip_serializing_if = "Vec::is_empty", default)]
110    pub stop: Vec<&'a str>,
111
112    /// number of completions
113    pub n: Option<usize>,
114    pub max_tokens: usize,
115}
116
117impl Default for TextRequest<'_> {
118    fn default() -> Self {
119        Self {
120            model: Completions::Davinci,
121            prompt: "",
122            temperature: 0.0,
123            stop: Vec::new(),
124            n: None,
125            max_tokens: 1_000,
126        }
127    }
128}
129
130/// ```json
131/// {"input": "Your text string goes here", "model":"text-embedding-ada-002"}
132/// ```
133#[derive(Copy, Clone, Serialize, Deserialize)]
134struct EmbedRequest<'a> {
135    input: &'a str,
136    model: &'a str,
137}
138
139#[derive(Clone, Serialize, Deserialize)]
140struct TextResponseChoice {
141    text: String,
142}
143
144#[derive(Clone, Serialize, Deserialize)]
145struct TextResponse {
146    choices: Vec<TextResponseChoice>,
147}
148
149#[derive(Clone, Serialize, Deserialize)]
150struct EmbedDataFrame {
151    embedding: Vec<f32>,
152}
153
154#[derive(Clone, Serialize, Deserialize)]
155struct EmbedResponse {
156    data: Vec<EmbedDataFrame>,
157}
158
159#[derive(Serialize, Deserialize)]
160struct DavinciiData<'a> {
161    model: &'a str,
162    prompt: &'a str,
163    temperature: f64,
164    max_tokens: usize,
165}
166
167/// The text model we are using. See <https://openai.com/api/pricing/>
168#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
169pub enum Model {
170    /// The Davinci model
171    #[default]
172    Davinci,
173    /// The Curie model
174    Curie,
175    /// The Babbage model
176    Babbage,
177    /// The Ada model
178    Ada,
179}
180
181#[derive(Serialize, Deserialize, Default, Debug, PartialEq, Eq, Copy, Clone)]
182pub enum ChatModel {
183    #[serde(rename = "gpt-4-turbo-preview")]
184    #[default]
185    Gpt4TurboPreview,
186
187    #[serde(rename = "gpt-4-1106-preview")]
188    Gpt4_1106,
189
190    #[serde(rename = "gpt-4-0613")]
191    Gpt4_0613,
192
193    #[serde(rename = "gpt-4")]
194    Gpt4,
195    #[serde(rename = "gpt-3.5-turbo")]
196    Turbo,
197
198    #[serde(rename = "gpt-3.5-turbo-0301")]
199    Turbo0301,
200}
201
202/// ```json
203/// {"role": "system", "content": "You are a helpful assistant."},
204/// {"role": "user", "content": "Who won the world series in 2020?"},
205/// {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
206/// {"role": "user", "content": "Where was it played?"}
207/// ```
208#[derive(
209    Serialize,
210    Deserialize,
211    Debug,
212    Copy,
213    Clone,
214    PartialOrd,
215    PartialEq,
216    Ord,
217    Eq
218)]
219#[serde(rename_all = "snake_case")]
220#[non_exhaustive]
221pub enum Role {
222    System,
223    User,
224    Assistant,
225    Function,
226}
227
228#[derive(Serialize, Deserialize, Debug, Clone, Constructor)]
229pub struct Msg {
230    /// Usually
231    pub role: Role,
232    pub content: Option<String>,
233
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub name: Option<String>,
236
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub function_call: Option<FunctionCall>,
239}
240
241#[derive(Serialize, Deserialize, Debug, Clone)]
242pub struct FunctionCall {
243    pub name: String,
244
245    #[serde(deserialize_with = "deserialize_arguments")]
246    pub arguments: Option<Value>,
247}
248
249impl FunctionCall {
250    pub fn into_struct<T: DeserializeOwned>(self) -> anyhow::Result<T> {
251        let args = self.arguments.context("no arguments")?;
252        let res = serde_json::from_value(args).context("failed to deserialize arguments")?;
253        Ok(res)
254    }
255}
256
257impl Default for Msg {
258    fn default() -> Self {
259        Self::system("")
260    }
261}
262
263impl Msg {
264    pub fn system(content: impl Into<String>) -> Self {
265        Self::new(Role::System, Some(content.into()), None, None)
266    }
267
268    pub fn user(content: impl Into<String>) -> Self {
269        Self::new(Role::User, Some(content.into()), None, None)
270    }
271
272    pub fn assistant(content: impl Into<String>) -> Self {
273        Self::new(Role::Assistant, Some(content.into()), None, None)
274    }
275
276    pub fn function(name: impl Into<String>, content: impl Serialize) -> anyhow::Result<Self> {
277        let name = name.into();
278        let content = serde_json::to_value(content)?;
279        let content = serde_json::to_string(&content)?;
280
281        Ok(Self::new(Role::Function, Some(content), Some(name), None))
282    }
283}
284
285#[derive(Serialize, Deserialize, Debug, Clone)]
286#[serde(rename_all = "snake_case")]
287pub enum Delta {
288    /// Usually
289    Role(Role),
290    Content(String),
291}
292
293impl Display for Msg {
294    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295        match &self.content {
296            None => f.write_str(""),
297            Some(content) => f.write_str(content),
298        }
299    }
300}
301
302#[allow(clippy::trivially_copy_pass_by_ref)]
303fn real_is_one(input: &f64) -> bool {
304    (*input - 1.0).abs() < f64::EPSILON
305}
306
307#[allow(clippy::trivially_copy_pass_by_ref)]
308const fn int_is_one(input: &u32) -> bool {
309    *input == 1
310}
311
312#[allow(clippy::trivially_copy_pass_by_ref)]
313const fn int_is_zero(input: &u32) -> bool {
314    *input == 0
315}
316
317const fn empty<T>(input: &[T]) -> bool {
318    input.is_empty()
319}
320
321#[derive(Debug, Build, Serialize, Clone)]
322pub struct ChatRequest {
323    pub model: ChatModel,
324    pub messages: Vec<Msg>,
325
326    /// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the
327    /// output more random, while lower values like 0.2 will make it more focused and
328    /// deterministic.
329    ///
330    /// OpenAI generally recommend altering this or top_p but not both.
331    #[serde(skip_serializing_if = "real_is_one")]
332    #[default = 1.0]
333    pub temperature: f64,
334
335    /// An alternative to sampling with temperature, called nucleus sampling, where the model
336    /// considers the results of the tokens with top_p probability mass. So 0.1 means only the
337    /// tokens comprising the top 10% probability mass are considered.
338    ///
339    /// OpenAI generally recommends altering this or temperature but not both.
340    #[serde(skip_serializing_if = "real_is_one")]
341    #[default = 1.0]
342    pub top_p: f64,
343
344    /// How many chat completion choices to generate for each input message.
345    #[serde(skip_serializing_if = "int_is_one")]
346    #[default = 1]
347    pub n: u32,
348
349    #[serde(skip_serializing_if = "empty", rename = "stop")]
350    pub stop_at: Vec<String>,
351
352    /// max tokens to generate
353    ///
354    /// if 0, then no limit
355    #[serde(skip_serializing_if = "int_is_zero")]
356    pub max_tokens: u32,
357
358    #[serde(skip_serializing_if = "empty")]
359    pub functions: Vec<Function>,
360}
361
362impl ChatRequest {
363    #[must_use]
364    pub fn sys_msg(mut self, msg: impl Into<String>) -> Self {
365        self.messages.push(Msg::system(msg));
366        self
367    }
368
369    #[must_use]
370    pub fn user_msg(mut self, msg: impl Into<String>) -> Self {
371        self.messages.push(Msg::user(msg));
372        self
373    }
374
375    #[must_use]
376    pub fn assistant_msg(mut self, msg: impl Into<String>) -> Self {
377        self.messages.push(Msg::assistant(msg));
378        self
379    }
380}
381
382impl Default for ChatRequest {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388impl<'a> From<&'a str> for ChatRequest {
389    fn from(input: &'a str) -> Self {
390        Self {
391            messages: vec![Msg::user(input)],
392            ..Self::default()
393        }
394    }
395}
396
397impl<'a> From<&'a String> for ChatRequest {
398    fn from(input: &'a String) -> Self {
399        Self::from(input.as_str())
400    }
401}
402
403// From for ChatRequest with &[ChatMessage]
404impl<'a> From<&'a [Msg]> for ChatRequest {
405    fn from(input: &'a [Msg]) -> Self {
406        Self {
407            messages: input.to_vec(),
408            ..Self::default()
409        }
410    }
411}
412
413// From for [ChatMessage; N]
414impl<const N: usize> From<[Msg; N]> for ChatRequest {
415    fn from(input: [Msg; N]) -> Self {
416        Self {
417            messages: input.to_vec(),
418            ..Self::default()
419        }
420    }
421}
422
423#[derive(Serialize, Deserialize, Debug)]
424pub struct ChatChoice {
425    pub message: Msg,
426}
427
428#[derive(Serialize, Deserialize, Debug, Clone)]
429pub struct Function {
430    pub name: String,
431    pub description: Option<String>,
432    pub parameters: Option<Value>,
433}
434
435impl Function {
436    pub fn new<Input: JsonSchema>(name: impl Into<String>, description: impl Into<String>) -> Self {
437        let schema = schema::<Input>();
438        Self {
439            name: name.into(),
440            description: Some(description.into()),
441            parameters: Some(schema),
442        }
443    }
444}
445
446#[derive(Serialize, Deserialize, Debug)]
447pub struct ChatResponse {
448    pub id: String,
449    pub object: String,
450    pub created: u64,
451    pub choices: Vec<ChatChoice>,
452}
453
454impl ChatResponse {
455    pub fn take_first(self) -> Option<ChatChoice> {
456        self.choices.into_iter().next()
457    }
458}
459
460/// The text model we are using. See <https://openai.com/api/pricing/>
461#[derive(Deserialize, Serialize, Copy, Clone, Default, Eq, PartialEq, Debug)]
462#[allow(unused)]
463pub enum Completions {
464    /// The Davinci model
465    #[serde(rename = "text-davinci-003")]
466    #[default]
467    Davinci,
468
469    /// The Curie model
470    #[serde(rename = "text-curie-001")]
471    Curie,
472    /// The Babbage model
473    #[serde(rename = "text-babbage-001")]
474    Babbage,
475    /// The Ada model
476    #[serde(rename = "text-ada-001")]
477    Ada,
478}
479
480impl Model {
481    const fn embed_repr(self) -> Option<&'static str> {
482        match self {
483            Self::Davinci | Self::Curie | Self::Babbage => None,
484            Self::Ada => Some("text-embedding-ada-002"),
485        }
486    }
487
488    #[allow(unused)]
489    const fn text_repr(self) -> &'static str {
490        match self {
491            Self::Davinci => "text-davinci-003",
492            Self::Curie => "text-curie-001",
493            Self::Babbage => "text-babbage-001",
494            Self::Ada => "text-ada-001",
495        }
496    }
497}
498
499impl Client {
500    fn request(
501        &self,
502        url: &str,
503        request: &impl Serialize,
504    ) -> impl Future<Output = reqwest::Result<Response>> {
505        self.client
506            .post(url)
507            .header("Authorization", format!("Bearer {}", self.api_key))
508            .json(request)
509            .send()
510    }
511
512    /// Calls the embedding API
513    ///
514    /// - turns an `input` [`str`] into a vector
515    ///
516    /// # Errors
517    /// Returns `Err` if there is a network error communicating to `OpenAI`
518    pub async fn embed(&self, input: &str) -> anyhow::Result<Vec<f32>> {
519        let request = EmbedRequest {
520            input,
521            model: unsafe { Model::Ada.embed_repr().unwrap_unchecked() },
522        };
523
524        let embed: EmbedResponse = self
525            .request("https://api.openai.com/v1/embeddings", &request)
526            .await
527            .context("could not complete embed request")?
528            .json()
529            .await?;
530
531        let result = embed
532            .data
533            .into_iter()
534            .next()
535            .context("no data for embedding")?
536            .embedding;
537
538        Ok(result)
539    }
540
541    /// # Errors
542    /// Returns `Err` if there is a network error communicating to `OpenAI`
543    pub async fn raw_chat(&self, req: &ChatRequest) -> anyhow::Result<ChatResponse> {
544        let response: String = self
545            .request("https://api.openai.com/v1/chat/completions", req)
546            .await
547            .context("could not complete chat request")?
548            .text()
549            .await?;
550
551        let response = match serde_json::from_str(&response) {
552            Ok(response) => response,
553            Err(e) => {
554                return Err(anyhow::anyhow!(
555                    "could not parse chat response {response}: {e}"
556                ));
557            }
558        };
559
560        Ok(response)
561    }
562
563    /// # Errors
564    /// Returns `Err` if there is a network error communicating to `OpenAI`
565    pub async fn chat(&self, req: impl Into<ChatRequest> + Send) -> anyhow::Result<String> {
566        let req = req.into();
567        let response = self.raw_chat(&req).await?;
568        let choice = response
569            .choices
570            .into_iter()
571            .next()
572            .context("no choices for chat")?;
573
574        choice.message.content.context("no content for chat")
575    }
576
577    /// # Errors
578    /// Returns `Err` if there is a network error communicating to `OpenAI`
579    pub async fn stream_text(
580        &self,
581        req: TextRequest<'_>,
582    ) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
583        #[derive(Clone, Serialize)]
584        pub struct TextStreamRequest<'a> {
585            stream: bool,
586
587            #[serde(flatten)]
588            req: TextRequest<'a>,
589        }
590
591        #[derive(Deserialize, Debug)]
592        pub struct TextStreamData {
593            pub text: Option<String>,
594        }
595
596        #[derive(Deserialize, Debug)]
597        pub struct TextStreamResponse {
598            pub choices: Vec<TextStreamData>,
599        }
600
601        let req = TextStreamRequest { stream: true, req };
602
603        let response = self
604            .request("https://api.openai.com/v1/completions", &req)
605            .await
606            .context("could not complete chat request")?;
607
608        let stream = response
609            .bytes_stream()
610            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
611            .into_async_read();
612
613        let mut messages = event_stream_processor::get_messages(stream);
614
615        let (tx, rx) = mpsc::channel(100);
616
617        fn message_to_data(
618            message: anyhow::Result<event_stream_processor::Message>,
619        ) -> anyhow::Result<Option<String>> {
620            let message = message?;
621            let data = message.data.context("no data")?;
622
623            if &data == "[DONE]" {
624                return Ok(None);
625            }
626
627            let Ok(data) = serde_json::from_str::<TextStreamResponse>(&data) else {
628                return Ok(None);
629            };
630
631            let choice = data.choices.into_iter().next().context("no choices")?;
632
633            let Some(content) = choice.text else {
634                return Ok(Some(String::new()));
635            };
636
637            Ok(Some(content))
638        }
639
640        tokio::spawn(async move {
641            while let Some(msg) = messages.next().await {
642                let msg = message_to_data(msg);
643                match msg {
644                    Ok(None) => {
645                        return;
646                    }
647                    Ok(Some(msg)) => {
648                        if tx.send(Ok(msg)).await.is_err() {
649                            return;
650                        }
651                    }
652                    Err(e) => {
653                        if tx.send(Err(e)).await.is_err() {
654                            return;
655                        }
656                    }
657                }
658            }
659        });
660
661        Ok(ReceiverStream::from(rx))
662    }
663
664    /// # Errors
665    /// Returns `Err` if there is a network error communicating to `OpenAI`
666    pub async fn stream_chat(
667        &self,
668        req: impl Into<ChatRequest> + Send,
669    ) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
670        #[derive(Serialize)]
671        struct ChatStreamRequest {
672            stream: bool,
673
674            #[serde(flatten)]
675            req: ChatRequest,
676        }
677
678        #[derive(Serialize, Deserialize, Debug, Clone)]
679        struct ChatStreamMessage {
680            pub delta: Delta,
681        }
682
683        #[derive(Serialize, Deserialize, Debug, Clone)]
684        struct ChatStreamResponse {
685            pub choices: Vec<ChatStreamMessage>,
686        }
687
688        let req = req.into();
689
690        let req = ChatStreamRequest { stream: true, req };
691
692        let response = self
693            .request("https://api.openai.com/v1/chat/completions", &req)
694            .await
695            .context("could not complete chat request")?;
696
697        let stream = response
698            .bytes_stream()
699            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
700            .into_async_read();
701
702        let mut messages = event_stream_processor::get_messages(stream);
703
704        let (tx, rx) = mpsc::channel(100);
705
706        fn message_to_data(
707            message: anyhow::Result<event_stream_processor::Message>,
708        ) -> anyhow::Result<Option<String>> {
709            let message = message?;
710            let data = message.data.context("no data")?;
711
712            if &data == "[DONE]" {
713                return Ok(None);
714            }
715
716            let Ok(data) = serde_json::from_str::<ChatStreamResponse>(&data) else {
717                return Ok(None);
718            };
719
720            let choice = data.choices.into_iter().next().context("no choices")?;
721
722            let Delta::Content(content) = choice.delta else {
723                return Ok(Some(String::new()));
724            };
725
726            Ok(Some(content))
727        }
728
729        tokio::spawn(async move {
730            while let Some(msg) = messages.next().await {
731                let msg = message_to_data(msg);
732                match msg {
733                    Ok(None) => {
734                        return;
735                    }
736                    Ok(Some(msg)) => {
737                        if tx.send(Ok(msg)).await.is_err() {
738                            return;
739                        }
740                    }
741                    Err(e) => {
742                        if tx.send(Err(e)).await.is_err() {
743                            return;
744                        }
745                    }
746                }
747            }
748        });
749
750        Ok(ReceiverStream::from(rx))
751    }
752
753    /// # Errors
754    /// Will return `Err` if cannot properly contact `OpenAI` API.
755    pub async fn text(&self, request: TextRequest<'_>) -> anyhow::Result<Vec<String>> {
756        let text = self
757            .request("https://api.openai.com/v1/completions", &request)
758            .await
759            .context("could not complete text request")?
760            .text()
761            .await
762            .context("could not convert into text")?;
763
764        let json: TextResponse = match serde_json::from_str(&text) {
765            Ok(res) => res,
766            Err(e) => bail!("error {e} parsing json {text}"),
767        };
768
769        let choices = json.choices.into_iter().map(|e| e.text).collect();
770        Ok(choices)
771    }
772}
773
774#[cfg(test)]
775mod tests {
776    use approx::relative_eq;
777    use once_cell::sync::Lazy;
778    use pretty_assertions::assert_eq;
779    use serde_json::json;
780
781    use crate::{ChatChoice, ChatModel, ChatRequest, Completions, Function, Model, Msg, Role};
782
783    static API: Lazy<crate::Client> =
784        Lazy::new(|| crate::Client::simple().expect("could not create client"));
785
786    #[tokio::test]
787    async fn test_chat_raw() {
788        let req = ChatRequest {
789            model: ChatModel::Turbo,
790            messages: vec![
791                Msg {
792                    role: Role::System,
793                    content: Some(
794                        "You are a helpful assistant that translates English to French."
795                            .to_string(),
796                    ),
797                    ..Msg::default()
798                },
799                Msg {
800                    role: Role::User,
801                    content: Some(
802                        "Translate the following English text to French: Hello".to_string(),
803                    ),
804                    ..Msg::default()
805                },
806            ],
807            ..ChatRequest::default()
808        };
809
810        let choices = API.raw_chat(&req).await.unwrap().choices;
811
812        let [ChatChoice { message }] = choices.as_slice() else {
813            panic!("no choices");
814        };
815
816        let message = message
817            // prune all non-alphanumeric characters
818            .content
819            .as_ref()
820            .unwrap()
821            .replace(|c: char| !c.is_ascii_alphanumeric(), "")
822            .to_ascii_lowercase();
823
824        assert!(message.contains("bonjour"));
825    }
826
827    #[tokio::test]
828    async fn test_chat() {
829        let request = ChatRequest {
830            model: ChatModel::Turbo,
831            messages: vec![
832                Msg {
833                    role: Role::System,
834                    content: Some(
835                        "You are a helpful assistant that translates English to French."
836                            .to_string(),
837                    ),
838                    ..Msg::default()
839                },
840                Msg {
841                    role: Role::User,
842                    content: Some(
843                        "Translate the following English text to French: Hello".to_string(),
844                    ),
845                    ..Msg::default()
846                },
847            ],
848            ..ChatRequest::default()
849        };
850
851        let res = API.chat(request).await.unwrap();
852
853        let choice = res
854            // prune all non-alphanumeric characters
855            .replace(|c: char| !c.is_ascii_alphanumeric(), "")
856            .to_ascii_lowercase();
857
858        assert!(choice.contains("bonjour"));
859    }
860
861    /// test no panic
862    #[test]
863    fn test_text_request() {
864        // test default does not panic
865        crate::TextRequest::default();
866    }
867
868    #[test]
869    fn test_message() {
870        {
871            let msg = Msg::system("hello");
872            assert_eq!("hello", format!("{msg}"));
873            let msg = serde_json::to_string(&msg).unwrap();
874            assert_eq!(msg, r#"{"role":"system","content":"hello"}"#);
875        }
876
877        {
878            let msg = Msg::user("hello");
879            assert_eq!("hello", format!("{msg}"));
880            let msg = serde_json::to_string(&msg).unwrap();
881            assert_eq!(msg, r#"{"role":"user","content":"hello"}"#);
882        }
883
884        {
885            let msg = Msg::assistant("hello");
886            assert_eq!("hello", format!("{msg}"));
887            let msg = serde_json::to_string(&msg).unwrap();
888            assert_eq!(msg, r#"{"role":"assistant","content":"hello"}"#);
889        }
890    }
891
892    #[test]
893    fn test_chat_builder() {
894        let req = ChatRequest::default()
895            .model(ChatModel::Turbo)
896            .temperature(1.2)
897            .message(Msg::system("hello"))
898            .message(Msg::user("hello"))
899            .top_p(1.0)
900            .n(3)
901            .stop_at("\n")
902            .stop_at("#####");
903
904        assert_eq!(req.model, ChatModel::Turbo);
905        assert!(relative_eq!(req.temperature, 1.2));
906        assert_eq!(req.messages.len(), 2);
907        assert!(relative_eq!(req.top_p, 1.0));
908        assert_eq!(req.n, 3);
909        assert_eq!(req.stop_at, vec!["\n", "#####"]);
910    }
911
912    #[test]
913    fn test_chat_from() {
914        let req = ChatRequest::from("hello");
915        assert_eq!(req.messages.len(), 1);
916        assert_eq!(req.messages[0].content, Some("hello".to_string()));
917        assert_eq!(req.messages[0].role, Role::User);
918        assert_eq!(req.n, 1);
919
920        let req = ChatRequest::from(&"hello".to_string());
921        assert_eq!(req.messages.len(), 1);
922        assert_eq!(req.messages[0].content, Some("hello".to_string()));
923        assert_eq!(req.messages[0].role, Role::User);
924        assert_eq!(req.n, 1);
925
926        let messages = [Msg::user("hello"), Msg::assistant("world")];
927        let req = ChatRequest::from(messages.as_slice());
928        assert_eq!(req.messages.len(), 2);
929        assert_eq!(req.messages[0].content, Some("hello".to_string()));
930        assert_eq!(req.messages[0].role, Role::User);
931        assert_eq!(req.messages[1].content, Some("world".to_string()));
932        assert_eq!(req.messages[1].role, Role::Assistant);
933        assert_eq!(req.n, 1);
934
935        let messages = [Msg::user("hello"), Msg::assistant("world")];
936        let req = ChatRequest::from(messages);
937        assert_eq!(req.messages.len(), 2);
938        assert_eq!(req.messages[0].content, Some("hello".to_string()));
939        assert_eq!(req.messages[0].role, Role::User);
940        assert_eq!(req.messages[1].content, Some("world".to_string()));
941        assert_eq!(req.messages[1].role, Role::Assistant);
942        assert_eq!(req.n, 1);
943    }
944
945    #[test]
946    fn test_completions() {
947        let completion = Completions::default();
948        assert_eq!(completion, Completions::Davinci);
949    }
950
951    #[test]
952    fn test_chat_model() {
953        let model = ChatModel::default();
954        assert_eq!(model, ChatModel::Gpt4);
955    }
956
957    #[test]
958    fn test_model() {
959        let model = Model::default();
960        assert_eq!(model, Model::Davinci);
961        assert_eq!(model.embed_repr(), None);
962        assert_eq!(model.text_repr(), "text-davinci-003");
963
964        let model = Model::Curie;
965        assert_eq!(model.embed_repr(), None);
966        assert_eq!(model.text_repr(), "text-curie-001");
967
968        let model = Model::Babbage;
969        assert_eq!(model.embed_repr(), None);
970        assert_eq!(model.text_repr(), "text-babbage-001");
971
972        let model = Model::Ada;
973        assert_eq!(model.embed_repr().unwrap(), "text-embedding-ada-002");
974        assert_eq!(model.text_repr(), "text-ada-001");
975    }
976
977    #[tokio::test]
978    async fn test_function() {
979        let request = ChatRequest::new();
980
981        let function = Function {
982            name: "weather".to_string(),
983            description: Some("Get the weather for a location".to_string()),
984            parameters: Some(json!({
985                "type": "object",
986                "properties": {
987                        "lat": {
988                            "type": "number",
989                        },
990                        "lon": {
991                            "type": "number",
992                        },
993                        "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
994                    },
995                    "required": ["lat", "lon"],
996            })),
997        };
998
999        let request = request
1000            .function(function)
1001            .user_msg("What's the weather like in Svalbard");
1002
1003        println!("{}", serde_json::to_string_pretty(&request).unwrap());
1004
1005        let response = API.raw_chat(&request).await.unwrap();
1006
1007        let first_choice = response.choices.into_iter().next().unwrap();
1008
1009        let msg = first_choice.message;
1010
1011        let call = serde_json::to_string_pretty(&msg.function_call).unwrap();
1012
1013        println!("call: {}", call);
1014    }
1015}