shai/
lib.rs

1#![allow(clippy::future_not_send)]
2
3pub mod cli;
4mod context;
5mod model;
6mod openai;
7mod prompts;
8
9use context::Context;
10use futures::Stream;
11use model::Task;
12use openai::{OpenAIError, OpenAIGPTModel};
13use serde::Deserialize;
14use thiserror::Error;
15
16enum ConfigKind {
17    Ask(AskConfig),
18    Explain(ExplainConfig),
19}
20
21impl ConfigKind {
22    const fn model(&self) -> &ModelKind {
23        match self {
24            Self::Ask(config) => &config.model,
25            Self::Explain(config) => &config.model,
26        }
27    }
28}
29
30#[derive(Deserialize)]
31struct AskConfig {
32    operating_system: String,
33    shell: String,
34    environment: Option<Vec<String>>,
35    programs: Option<Vec<String>>,
36    cwd: Option<()>,
37    depth: Option<u32>,
38    model: ModelKind,
39}
40
41#[derive(Deserialize)]
42struct ExplainConfig {
43    operating_system: String,
44    shell: String,
45    environment: Option<Vec<String>>,
46    model: ModelKind,
47    cwd: Option<()>,
48    depth: Option<u32>,
49}
50
51impl Default for AskConfig {
52    fn default() -> Self {
53        Self {
54            operating_system: "Linux".to_string(),
55            shell: "Bash".to_string(),
56            environment: None,
57            programs: None,
58            cwd: None,
59            depth: None,
60            model: ModelKind::OpenAIGPT(OpenAIGPTModel::GPT35Turbo),
61        }
62    }
63}
64
65impl Default for ExplainConfig {
66    fn default() -> Self {
67        Self {
68            operating_system: "Linux".to_string(),
69            shell: "Bash".to_string(),
70            environment: None,
71            cwd: None,
72            depth: None,
73            model: ModelKind::OpenAIGPT(OpenAIGPTModel::GPT35Turbo),
74        }
75    }
76}
77
78#[derive(Deserialize, Clone)]
79enum ModelKind {
80    OpenAIGPT(OpenAIGPTModel),
81    // OpenAssistant // waiting for a minimal API, go guys :D
82    // Local // ?
83}
84
85#[derive(Debug, Error)]
86enum ModelError {
87    #[error("ModelError: {0}")]
88    Error(#[from] Box<dyn std::error::Error>),
89}
90
91#[allow(unused)]
92async fn model_request(
93    model: ModelKind,
94    request: String,
95    context: Context,
96    task: Task,
97) -> Result<String, ModelError> {
98    match model {
99        ModelKind::OpenAIGPT(model) => model
100            .send(request, context, task)
101            .await
102            .map_err(|err| ModelError::Error(Box::new(err))),
103    }
104}
105
106async fn model_stream_request(
107    model: ModelKind,
108    request: String,
109    context: Context,
110    task: Task,
111) -> Result<impl Stream<Item = Result<String, OpenAIError>>, OpenAIError> {
112    match model {
113        ModelKind::OpenAIGPT(model) => model.send_streaming(request, context, task).await,
114    }
115}
116
117fn build_context_request(request: &str, context: Context) -> String {
118    String::from(context) + &format!("Here is your <task>: \n <task>{request}</task>")
119}
120
121// #[cfg(test)]
122// mod tests {
123//     use crate::{
124//         context::Context, model::Task, model_stream_request, openai::OpenAIGPTModel::GPT35Turbo,
125//         AskConfig, ConfigKind, ModelKind::OpenAIGPT,
126//     };
127//     use futures_util::StreamExt;
128//
129//     #[tokio::test]
130//     async fn ssh_tunnel() {
131//         let mut  response_stream = model_stream_request(OpenAIGPT(GPT35Turbo), 
132//             "make an ssh tunnel between port 8080 in this machine and port 1243 in the machine with IP 192.168.0.42".to_string(), 
133//             Context::from(ConfigKind::Ask(AskConfig::default())),
134//             Task::GenerateCommand
135//             ).await.unwrap();
136//         while response_stream.next().await.is_some() {
137//         }
138//     }
139// }