1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
mod backends;
mod error;
pub mod message;
mod traits;

use backends::create_llm_model;
pub use error::Error;
use message::PromptMessageBuilder;
use traits::{LLMBackend, MessageBuilder};

//
//
//
pub enum Backend {
    ChatGPT { api_key: String, model: String },
    // TODO Future beackends
    // Llama2Cpu { path: String },
}

impl Backend {}

// manage the messages?
#[derive(Debug)]
pub struct Model {
    backend: Box<dyn LLMBackend>,
    temperature: f64,
}

impl Model {
    pub fn new(config: Backend) -> Result<Model, Error> {
        let backend = create_llm_model(config)?;
        Ok(Self {
            backend,
            temperature: 0.9,
        })
    }

    pub async fn generate_response<T>(&self, context_message_group: T) -> Result<String, Error>
    where
        T: IntoIterator + Send,
        T::Item: MessageBuilder + Send,
    {
        self.backend
            .generate_response(
                self.temperature,
                PromptMessageBuilder::new(context_message_group)
                    .build()
                    .as_str(),
            )
            .await
    }

    pub fn set_temperature(&mut self, temperature: f64) {
        self.temperature = temperature;
    }

    pub fn temperature(&self) -> f64 {
        self.temperature
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::env;

    #[test]
    fn test_request() {
        dotenv::dotenv().ok();
        env_logger::init();

        assert!(Model::new(Backend::ChatGPT {
            api_key: env::var("OPEN_API_KEY").unwrap(),
            model: "gpt-3.5-turbo".into(),
        })
        .is_ok());
    }
}