1#![allow(clippy::multiple_crate_versions)]
2extern crate core;
5
6use core::fmt;
7use std::{
8 fmt::{Display, Formatter},
9 future::Future,
10};
11
12use anyhow::{bail, Context};
13use derive_build::Build;
14use derive_more::Constructor;
15pub use ext::OpenAiStreamExt;
16use futures_util::{Stream, StreamExt, TryStreamExt};
17pub use reqwest;
18use reqwest::Response;
19use schemars::JsonSchema;
20use serde::{
21 de,
22 de::{DeserializeOwned, Visitor},
23 Deserialize, Deserializer, Serialize,
24};
25use serde_json::Value;
26use tokio::sync::mpsc;
27use tokio_stream::wrappers::ReceiverStream;
28
29use crate::util::schema;
30
31mod ext;
32mod util;
33struct StringOrStruct(Option<Value>);
34
35impl<'de> Visitor<'de> for StringOrStruct {
36 type Value = Option<Value>;
37
38 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
39 formatter.write_str("string or structure")
40 }
41
42 fn visit_str<E: de::Error>(self, value: &str) -> Result<Self::Value, E> {
43 match serde_json::from_str(value) {
44 Ok(val) => Ok(Some(val)),
45 Err(_) => Err(E::custom("expected valid json in string format")),
46 }
47 }
48
49 fn visit_map<M>(self, visitor: M) -> Result<Self::Value, M::Error>
50 where
51 M: de::MapAccess<'de>,
52 {
53 let val = Value::deserialize(de::value::MapAccessDeserializer::new(visitor))?;
54 Ok(Some(val))
55 }
56}
57
58fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Option<Value>, D::Error>
59where
60 D: Deserializer<'de>,
61{
62 deserializer.deserialize_any(StringOrStruct(None))
63}
64
65#[inline]
70pub fn openai_key() -> anyhow::Result<String> {
71 std::env::var("OPENAI_API_KEY")
72 .context("no OpenAI key specified. Set the variable OPENAI_API_KEY")
73}
74
75#[derive(Clone)]
77pub struct Client {
78 client: reqwest::Client,
79 api_key: String,
80}
81
82impl Client {
83 #[must_use]
85 pub fn new(client: reqwest::Client, api_key: impl Into<String>) -> Self {
86 let api_key = api_key.into();
87 Self { client, api_key }
88 }
89
90 pub fn simple() -> anyhow::Result<Self> {
93 let key = openai_key()?;
94 Ok(Self::new(reqwest::Client::default(), key))
95 }
96}
97
98#[derive(Clone, Serialize)]
102pub struct TextRequest<'a> {
103 pub model: Completions,
104 pub prompt: &'a str,
105 pub temperature: f64,
106
107 #[serde(skip_serializing_if = "Vec::is_empty", default)]
110 pub stop: Vec<&'a str>,
111
112 pub n: Option<usize>,
114 pub max_tokens: usize,
115}
116
117impl Default for TextRequest<'_> {
118 fn default() -> Self {
119 Self {
120 model: Completions::Davinci,
121 prompt: "",
122 temperature: 0.0,
123 stop: Vec::new(),
124 n: None,
125 max_tokens: 1_000,
126 }
127 }
128}
129
130#[derive(Copy, Clone, Serialize, Deserialize)]
134struct EmbedRequest<'a> {
135 input: &'a str,
136 model: &'a str,
137}
138
139#[derive(Clone, Serialize, Deserialize)]
140struct TextResponseChoice {
141 text: String,
142}
143
144#[derive(Clone, Serialize, Deserialize)]
145struct TextResponse {
146 choices: Vec<TextResponseChoice>,
147}
148
149#[derive(Clone, Serialize, Deserialize)]
150struct EmbedDataFrame {
151 embedding: Vec<f32>,
152}
153
154#[derive(Clone, Serialize, Deserialize)]
155struct EmbedResponse {
156 data: Vec<EmbedDataFrame>,
157}
158
159#[derive(Serialize, Deserialize)]
160struct DavinciiData<'a> {
161 model: &'a str,
162 prompt: &'a str,
163 temperature: f64,
164 max_tokens: usize,
165}
166
167#[derive(Copy, Clone, Default, PartialEq, Eq, Debug)]
169pub enum Model {
170 #[default]
172 Davinci,
173 Curie,
175 Babbage,
177 Ada,
179}
180
181#[derive(Serialize, Deserialize, Default, Debug, PartialEq, Eq, Copy, Clone)]
182pub enum ChatModel {
183 #[serde(rename = "gpt-4-turbo-preview")]
184 #[default]
185 Gpt4TurboPreview,
186
187 #[serde(rename = "gpt-4-1106-preview")]
188 Gpt4_1106,
189
190 #[serde(rename = "gpt-4-0613")]
191 Gpt4_0613,
192
193 #[serde(rename = "gpt-4")]
194 Gpt4,
195 #[serde(rename = "gpt-3.5-turbo")]
196 Turbo,
197
198 #[serde(rename = "gpt-3.5-turbo-0301")]
199 Turbo0301,
200}
201
202#[derive(
209 Serialize,
210 Deserialize,
211 Debug,
212 Copy,
213 Clone,
214 PartialOrd,
215 PartialEq,
216 Ord,
217 Eq
218)]
219#[serde(rename_all = "snake_case")]
220#[non_exhaustive]
221pub enum Role {
222 System,
223 User,
224 Assistant,
225 Function,
226}
227
228#[derive(Serialize, Deserialize, Debug, Clone, Constructor)]
229pub struct Msg {
230 pub role: Role,
232 pub content: Option<String>,
233
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub name: Option<String>,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
238 pub function_call: Option<FunctionCall>,
239}
240
241#[derive(Serialize, Deserialize, Debug, Clone)]
242pub struct FunctionCall {
243 pub name: String,
244
245 #[serde(deserialize_with = "deserialize_arguments")]
246 pub arguments: Option<Value>,
247}
248
249impl FunctionCall {
250 pub fn into_struct<T: DeserializeOwned>(self) -> anyhow::Result<T> {
251 let args = self.arguments.context("no arguments")?;
252 let res = serde_json::from_value(args).context("failed to deserialize arguments")?;
253 Ok(res)
254 }
255}
256
257impl Default for Msg {
258 fn default() -> Self {
259 Self::system("")
260 }
261}
262
263impl Msg {
264 pub fn system(content: impl Into<String>) -> Self {
265 Self::new(Role::System, Some(content.into()), None, None)
266 }
267
268 pub fn user(content: impl Into<String>) -> Self {
269 Self::new(Role::User, Some(content.into()), None, None)
270 }
271
272 pub fn assistant(content: impl Into<String>) -> Self {
273 Self::new(Role::Assistant, Some(content.into()), None, None)
274 }
275
276 pub fn function(name: impl Into<String>, content: impl Serialize) -> anyhow::Result<Self> {
277 let name = name.into();
278 let content = serde_json::to_value(content)?;
279 let content = serde_json::to_string(&content)?;
280
281 Ok(Self::new(Role::Function, Some(content), Some(name), None))
282 }
283}
284
285#[derive(Serialize, Deserialize, Debug, Clone)]
286#[serde(rename_all = "snake_case")]
287pub enum Delta {
288 Role(Role),
290 Content(String),
291}
292
293impl Display for Msg {
294 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
295 match &self.content {
296 None => f.write_str(""),
297 Some(content) => f.write_str(content),
298 }
299 }
300}
301
302#[allow(clippy::trivially_copy_pass_by_ref)]
303fn real_is_one(input: &f64) -> bool {
304 (*input - 1.0).abs() < f64::EPSILON
305}
306
307#[allow(clippy::trivially_copy_pass_by_ref)]
308const fn int_is_one(input: &u32) -> bool {
309 *input == 1
310}
311
312#[allow(clippy::trivially_copy_pass_by_ref)]
313const fn int_is_zero(input: &u32) -> bool {
314 *input == 0
315}
316
317const fn empty<T>(input: &[T]) -> bool {
318 input.is_empty()
319}
320
321#[derive(Debug, Build, Serialize, Clone)]
322pub struct ChatRequest {
323 pub model: ChatModel,
324 pub messages: Vec<Msg>,
325
326 #[serde(skip_serializing_if = "real_is_one")]
332 #[default = 1.0]
333 pub temperature: f64,
334
335 #[serde(skip_serializing_if = "real_is_one")]
341 #[default = 1.0]
342 pub top_p: f64,
343
344 #[serde(skip_serializing_if = "int_is_one")]
346 #[default = 1]
347 pub n: u32,
348
349 #[serde(skip_serializing_if = "empty", rename = "stop")]
350 pub stop_at: Vec<String>,
351
352 #[serde(skip_serializing_if = "int_is_zero")]
356 pub max_tokens: u32,
357
358 #[serde(skip_serializing_if = "empty")]
359 pub functions: Vec<Function>,
360}
361
362impl ChatRequest {
363 #[must_use]
364 pub fn sys_msg(mut self, msg: impl Into<String>) -> Self {
365 self.messages.push(Msg::system(msg));
366 self
367 }
368
369 #[must_use]
370 pub fn user_msg(mut self, msg: impl Into<String>) -> Self {
371 self.messages.push(Msg::user(msg));
372 self
373 }
374
375 #[must_use]
376 pub fn assistant_msg(mut self, msg: impl Into<String>) -> Self {
377 self.messages.push(Msg::assistant(msg));
378 self
379 }
380}
381
382impl Default for ChatRequest {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388impl<'a> From<&'a str> for ChatRequest {
389 fn from(input: &'a str) -> Self {
390 Self {
391 messages: vec![Msg::user(input)],
392 ..Self::default()
393 }
394 }
395}
396
397impl<'a> From<&'a String> for ChatRequest {
398 fn from(input: &'a String) -> Self {
399 Self::from(input.as_str())
400 }
401}
402
403impl<'a> From<&'a [Msg]> for ChatRequest {
405 fn from(input: &'a [Msg]) -> Self {
406 Self {
407 messages: input.to_vec(),
408 ..Self::default()
409 }
410 }
411}
412
413impl<const N: usize> From<[Msg; N]> for ChatRequest {
415 fn from(input: [Msg; N]) -> Self {
416 Self {
417 messages: input.to_vec(),
418 ..Self::default()
419 }
420 }
421}
422
423#[derive(Serialize, Deserialize, Debug)]
424pub struct ChatChoice {
425 pub message: Msg,
426}
427
428#[derive(Serialize, Deserialize, Debug, Clone)]
429pub struct Function {
430 pub name: String,
431 pub description: Option<String>,
432 pub parameters: Option<Value>,
433}
434
435impl Function {
436 pub fn new<Input: JsonSchema>(name: impl Into<String>, description: impl Into<String>) -> Self {
437 let schema = schema::<Input>();
438 Self {
439 name: name.into(),
440 description: Some(description.into()),
441 parameters: Some(schema),
442 }
443 }
444}
445
446#[derive(Serialize, Deserialize, Debug)]
447pub struct ChatResponse {
448 pub id: String,
449 pub object: String,
450 pub created: u64,
451 pub choices: Vec<ChatChoice>,
452}
453
454impl ChatResponse {
455 pub fn take_first(self) -> Option<ChatChoice> {
456 self.choices.into_iter().next()
457 }
458}
459
460#[derive(Deserialize, Serialize, Copy, Clone, Default, Eq, PartialEq, Debug)]
462#[allow(unused)]
463pub enum Completions {
464 #[serde(rename = "text-davinci-003")]
466 #[default]
467 Davinci,
468
469 #[serde(rename = "text-curie-001")]
471 Curie,
472 #[serde(rename = "text-babbage-001")]
474 Babbage,
475 #[serde(rename = "text-ada-001")]
477 Ada,
478}
479
480impl Model {
481 const fn embed_repr(self) -> Option<&'static str> {
482 match self {
483 Self::Davinci | Self::Curie | Self::Babbage => None,
484 Self::Ada => Some("text-embedding-ada-002"),
485 }
486 }
487
488 #[allow(unused)]
489 const fn text_repr(self) -> &'static str {
490 match self {
491 Self::Davinci => "text-davinci-003",
492 Self::Curie => "text-curie-001",
493 Self::Babbage => "text-babbage-001",
494 Self::Ada => "text-ada-001",
495 }
496 }
497}
498
499impl Client {
500 fn request(
501 &self,
502 url: &str,
503 request: &impl Serialize,
504 ) -> impl Future<Output = reqwest::Result<Response>> {
505 self.client
506 .post(url)
507 .header("Authorization", format!("Bearer {}", self.api_key))
508 .json(request)
509 .send()
510 }
511
512 pub async fn embed(&self, input: &str) -> anyhow::Result<Vec<f32>> {
519 let request = EmbedRequest {
520 input,
521 model: unsafe { Model::Ada.embed_repr().unwrap_unchecked() },
522 };
523
524 let embed: EmbedResponse = self
525 .request("https://api.openai.com/v1/embeddings", &request)
526 .await
527 .context("could not complete embed request")?
528 .json()
529 .await?;
530
531 let result = embed
532 .data
533 .into_iter()
534 .next()
535 .context("no data for embedding")?
536 .embedding;
537
538 Ok(result)
539 }
540
541 pub async fn raw_chat(&self, req: &ChatRequest) -> anyhow::Result<ChatResponse> {
544 let response: String = self
545 .request("https://api.openai.com/v1/chat/completions", req)
546 .await
547 .context("could not complete chat request")?
548 .text()
549 .await?;
550
551 let response = match serde_json::from_str(&response) {
552 Ok(response) => response,
553 Err(e) => {
554 return Err(anyhow::anyhow!(
555 "could not parse chat response {response}: {e}"
556 ));
557 }
558 };
559
560 Ok(response)
561 }
562
563 pub async fn chat(&self, req: impl Into<ChatRequest> + Send) -> anyhow::Result<String> {
566 let req = req.into();
567 let response = self.raw_chat(&req).await?;
568 let choice = response
569 .choices
570 .into_iter()
571 .next()
572 .context("no choices for chat")?;
573
574 choice.message.content.context("no content for chat")
575 }
576
577 pub async fn stream_text(
580 &self,
581 req: TextRequest<'_>,
582 ) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
583 #[derive(Clone, Serialize)]
584 pub struct TextStreamRequest<'a> {
585 stream: bool,
586
587 #[serde(flatten)]
588 req: TextRequest<'a>,
589 }
590
591 #[derive(Deserialize, Debug)]
592 pub struct TextStreamData {
593 pub text: Option<String>,
594 }
595
596 #[derive(Deserialize, Debug)]
597 pub struct TextStreamResponse {
598 pub choices: Vec<TextStreamData>,
599 }
600
601 let req = TextStreamRequest { stream: true, req };
602
603 let response = self
604 .request("https://api.openai.com/v1/completions", &req)
605 .await
606 .context("could not complete chat request")?;
607
608 let stream = response
609 .bytes_stream()
610 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
611 .into_async_read();
612
613 let mut messages = event_stream_processor::get_messages(stream);
614
615 let (tx, rx) = mpsc::channel(100);
616
617 fn message_to_data(
618 message: anyhow::Result<event_stream_processor::Message>,
619 ) -> anyhow::Result<Option<String>> {
620 let message = message?;
621 let data = message.data.context("no data")?;
622
623 if &data == "[DONE]" {
624 return Ok(None);
625 }
626
627 let Ok(data) = serde_json::from_str::<TextStreamResponse>(&data) else {
628 return Ok(None);
629 };
630
631 let choice = data.choices.into_iter().next().context("no choices")?;
632
633 let Some(content) = choice.text else {
634 return Ok(Some(String::new()));
635 };
636
637 Ok(Some(content))
638 }
639
640 tokio::spawn(async move {
641 while let Some(msg) = messages.next().await {
642 let msg = message_to_data(msg);
643 match msg {
644 Ok(None) => {
645 return;
646 }
647 Ok(Some(msg)) => {
648 if tx.send(Ok(msg)).await.is_err() {
649 return;
650 }
651 }
652 Err(e) => {
653 if tx.send(Err(e)).await.is_err() {
654 return;
655 }
656 }
657 }
658 }
659 });
660
661 Ok(ReceiverStream::from(rx))
662 }
663
664 pub async fn stream_chat(
667 &self,
668 req: impl Into<ChatRequest> + Send,
669 ) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
670 #[derive(Serialize)]
671 struct ChatStreamRequest {
672 stream: bool,
673
674 #[serde(flatten)]
675 req: ChatRequest,
676 }
677
678 #[derive(Serialize, Deserialize, Debug, Clone)]
679 struct ChatStreamMessage {
680 pub delta: Delta,
681 }
682
683 #[derive(Serialize, Deserialize, Debug, Clone)]
684 struct ChatStreamResponse {
685 pub choices: Vec<ChatStreamMessage>,
686 }
687
688 let req = req.into();
689
690 let req = ChatStreamRequest { stream: true, req };
691
692 let response = self
693 .request("https://api.openai.com/v1/chat/completions", &req)
694 .await
695 .context("could not complete chat request")?;
696
697 let stream = response
698 .bytes_stream()
699 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
700 .into_async_read();
701
702 let mut messages = event_stream_processor::get_messages(stream);
703
704 let (tx, rx) = mpsc::channel(100);
705
706 fn message_to_data(
707 message: anyhow::Result<event_stream_processor::Message>,
708 ) -> anyhow::Result<Option<String>> {
709 let message = message?;
710 let data = message.data.context("no data")?;
711
712 if &data == "[DONE]" {
713 return Ok(None);
714 }
715
716 let Ok(data) = serde_json::from_str::<ChatStreamResponse>(&data) else {
717 return Ok(None);
718 };
719
720 let choice = data.choices.into_iter().next().context("no choices")?;
721
722 let Delta::Content(content) = choice.delta else {
723 return Ok(Some(String::new()));
724 };
725
726 Ok(Some(content))
727 }
728
729 tokio::spawn(async move {
730 while let Some(msg) = messages.next().await {
731 let msg = message_to_data(msg);
732 match msg {
733 Ok(None) => {
734 return;
735 }
736 Ok(Some(msg)) => {
737 if tx.send(Ok(msg)).await.is_err() {
738 return;
739 }
740 }
741 Err(e) => {
742 if tx.send(Err(e)).await.is_err() {
743 return;
744 }
745 }
746 }
747 }
748 });
749
750 Ok(ReceiverStream::from(rx))
751 }
752
753 pub async fn text(&self, request: TextRequest<'_>) -> anyhow::Result<Vec<String>> {
756 let text = self
757 .request("https://api.openai.com/v1/completions", &request)
758 .await
759 .context("could not complete text request")?
760 .text()
761 .await
762 .context("could not convert into text")?;
763
764 let json: TextResponse = match serde_json::from_str(&text) {
765 Ok(res) => res,
766 Err(e) => bail!("error {e} parsing json {text}"),
767 };
768
769 let choices = json.choices.into_iter().map(|e| e.text).collect();
770 Ok(choices)
771 }
772}
773
774#[cfg(test)]
775mod tests {
776 use approx::relative_eq;
777 use once_cell::sync::Lazy;
778 use pretty_assertions::assert_eq;
779 use serde_json::json;
780
781 use crate::{ChatChoice, ChatModel, ChatRequest, Completions, Function, Model, Msg, Role};
782
783 static API: Lazy<crate::Client> =
784 Lazy::new(|| crate::Client::simple().expect("could not create client"));
785
786 #[tokio::test]
787 async fn test_chat_raw() {
788 let req = ChatRequest {
789 model: ChatModel::Turbo,
790 messages: vec![
791 Msg {
792 role: Role::System,
793 content: Some(
794 "You are a helpful assistant that translates English to French."
795 .to_string(),
796 ),
797 ..Msg::default()
798 },
799 Msg {
800 role: Role::User,
801 content: Some(
802 "Translate the following English text to French: Hello".to_string(),
803 ),
804 ..Msg::default()
805 },
806 ],
807 ..ChatRequest::default()
808 };
809
810 let choices = API.raw_chat(&req).await.unwrap().choices;
811
812 let [ChatChoice { message }] = choices.as_slice() else {
813 panic!("no choices");
814 };
815
816 let message = message
817 .content
819 .as_ref()
820 .unwrap()
821 .replace(|c: char| !c.is_ascii_alphanumeric(), "")
822 .to_ascii_lowercase();
823
824 assert!(message.contains("bonjour"));
825 }
826
827 #[tokio::test]
828 async fn test_chat() {
829 let request = ChatRequest {
830 model: ChatModel::Turbo,
831 messages: vec![
832 Msg {
833 role: Role::System,
834 content: Some(
835 "You are a helpful assistant that translates English to French."
836 .to_string(),
837 ),
838 ..Msg::default()
839 },
840 Msg {
841 role: Role::User,
842 content: Some(
843 "Translate the following English text to French: Hello".to_string(),
844 ),
845 ..Msg::default()
846 },
847 ],
848 ..ChatRequest::default()
849 };
850
851 let res = API.chat(request).await.unwrap();
852
853 let choice = res
854 .replace(|c: char| !c.is_ascii_alphanumeric(), "")
856 .to_ascii_lowercase();
857
858 assert!(choice.contains("bonjour"));
859 }
860
861 #[test]
863 fn test_text_request() {
864 crate::TextRequest::default();
866 }
867
868 #[test]
869 fn test_message() {
870 {
871 let msg = Msg::system("hello");
872 assert_eq!("hello", format!("{msg}"));
873 let msg = serde_json::to_string(&msg).unwrap();
874 assert_eq!(msg, r#"{"role":"system","content":"hello"}"#);
875 }
876
877 {
878 let msg = Msg::user("hello");
879 assert_eq!("hello", format!("{msg}"));
880 let msg = serde_json::to_string(&msg).unwrap();
881 assert_eq!(msg, r#"{"role":"user","content":"hello"}"#);
882 }
883
884 {
885 let msg = Msg::assistant("hello");
886 assert_eq!("hello", format!("{msg}"));
887 let msg = serde_json::to_string(&msg).unwrap();
888 assert_eq!(msg, r#"{"role":"assistant","content":"hello"}"#);
889 }
890 }
891
892 #[test]
893 fn test_chat_builder() {
894 let req = ChatRequest::default()
895 .model(ChatModel::Turbo)
896 .temperature(1.2)
897 .message(Msg::system("hello"))
898 .message(Msg::user("hello"))
899 .top_p(1.0)
900 .n(3)
901 .stop_at("\n")
902 .stop_at("#####");
903
904 assert_eq!(req.model, ChatModel::Turbo);
905 assert!(relative_eq!(req.temperature, 1.2));
906 assert_eq!(req.messages.len(), 2);
907 assert!(relative_eq!(req.top_p, 1.0));
908 assert_eq!(req.n, 3);
909 assert_eq!(req.stop_at, vec!["\n", "#####"]);
910 }
911
912 #[test]
913 fn test_chat_from() {
914 let req = ChatRequest::from("hello");
915 assert_eq!(req.messages.len(), 1);
916 assert_eq!(req.messages[0].content, Some("hello".to_string()));
917 assert_eq!(req.messages[0].role, Role::User);
918 assert_eq!(req.n, 1);
919
920 let req = ChatRequest::from(&"hello".to_string());
921 assert_eq!(req.messages.len(), 1);
922 assert_eq!(req.messages[0].content, Some("hello".to_string()));
923 assert_eq!(req.messages[0].role, Role::User);
924 assert_eq!(req.n, 1);
925
926 let messages = [Msg::user("hello"), Msg::assistant("world")];
927 let req = ChatRequest::from(messages.as_slice());
928 assert_eq!(req.messages.len(), 2);
929 assert_eq!(req.messages[0].content, Some("hello".to_string()));
930 assert_eq!(req.messages[0].role, Role::User);
931 assert_eq!(req.messages[1].content, Some("world".to_string()));
932 assert_eq!(req.messages[1].role, Role::Assistant);
933 assert_eq!(req.n, 1);
934
935 let messages = [Msg::user("hello"), Msg::assistant("world")];
936 let req = ChatRequest::from(messages);
937 assert_eq!(req.messages.len(), 2);
938 assert_eq!(req.messages[0].content, Some("hello".to_string()));
939 assert_eq!(req.messages[0].role, Role::User);
940 assert_eq!(req.messages[1].content, Some("world".to_string()));
941 assert_eq!(req.messages[1].role, Role::Assistant);
942 assert_eq!(req.n, 1);
943 }
944
945 #[test]
946 fn test_completions() {
947 let completion = Completions::default();
948 assert_eq!(completion, Completions::Davinci);
949 }
950
951 #[test]
952 fn test_chat_model() {
953 let model = ChatModel::default();
954 assert_eq!(model, ChatModel::Gpt4);
955 }
956
957 #[test]
958 fn test_model() {
959 let model = Model::default();
960 assert_eq!(model, Model::Davinci);
961 assert_eq!(model.embed_repr(), None);
962 assert_eq!(model.text_repr(), "text-davinci-003");
963
964 let model = Model::Curie;
965 assert_eq!(model.embed_repr(), None);
966 assert_eq!(model.text_repr(), "text-curie-001");
967
968 let model = Model::Babbage;
969 assert_eq!(model.embed_repr(), None);
970 assert_eq!(model.text_repr(), "text-babbage-001");
971
972 let model = Model::Ada;
973 assert_eq!(model.embed_repr().unwrap(), "text-embedding-ada-002");
974 assert_eq!(model.text_repr(), "text-ada-001");
975 }
976
977 #[tokio::test]
978 async fn test_function() {
979 let request = ChatRequest::new();
980
981 let function = Function {
982 name: "weather".to_string(),
983 description: Some("Get the weather for a location".to_string()),
984 parameters: Some(json!({
985 "type": "object",
986 "properties": {
987 "lat": {
988 "type": "number",
989 },
990 "lon": {
991 "type": "number",
992 },
993 "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
994 },
995 "required": ["lat", "lon"],
996 })),
997 };
998
999 let request = request
1000 .function(function)
1001 .user_msg("What's the weather like in Svalbard");
1002
1003 println!("{}", serde_json::to_string_pretty(&request).unwrap());
1004
1005 let response = API.raw_chat(&request).await.unwrap();
1006
1007 let first_choice = response.choices.into_iter().next().unwrap();
1008
1009 let msg = first_choice.message;
1010
1011 let call = serde_json::to_string_pretty(&msg.function_call).unwrap();
1012
1013 println!("call: {}", call);
1014 }
1015}