1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::fs;
use std::sync::Arc;
use actix_web::{Responder, web};
use clap::Parser;
use fancy_regex::Regex;
use serde::Deserialize;
use tracing::{error, info};
use crate::common::{HEADER_API_KEY, ShellRequest, ShellResponse};
use crate::prompts::Prompts;
use crate::providers::{ProviderApi, ProviderConfig};


lazy_static::lazy_static! {
    pub static ref CODE_BLOCK_RE: Regex = Regex::new(r"(?ms)```\w*(.*)```").unwrap();
}

#[derive(Debug, Deserialize, Clone)]
pub struct Config {
    pub provider: ProviderConfig,
}

impl Config {
    fn from_yaml(file_path: &str) -> Self {
        let config_content = fs::read_to_string(file_path).expect(
            &format!("Failed to read the configuration file: {}", file_path));
        serde_yaml::from_str(&config_content).expect(
            &format!("Failed to read the configuration file: {}", file_path))
    }
}

pub struct AppConfig {
    pub provider: Arc<dyn ProviderApi + Send + Sync>,
    pub prompts: Prompts,
}

pub async fn chat(
    request: web::Json<ShellRequest>,
    data: web::Data<Arc<AppConfig>>,
    key: web::Data<Arc<String>>,
    req: actix_web::HttpRequest,
) -> impl Responder {
    let api_key = req
        .headers()
        .get(HEADER_API_KEY)
        .and_then(|v| v.to_str().ok())
        .unwrap_or("");

    if api_key != key.as_str() {
        error!("Invalid API key: {}", api_key);
        return actix_web::HttpResponse::Unauthorized().body("Invalid API key");
    }

    let prompts = &data.prompts;
    let provider = &data.provider;

    let prompt = if !request.explain {
        &prompts.shell_prompt(&request.os, &request.shell)
    } else {
        &prompts.explain
    };

    let response = match provider.call(&prompt, &request.prompt).await {
        Ok(mut eval_str) => {
            if !request.explain {
                if let Ok(true) = CODE_BLOCK_RE.is_match(&eval_str) {
                    eval_str = extract_block(&eval_str);
                }
            }
            info!("{}/{}: {} => {}", &request.os, &request.shell, &&request.prompt, &eval_str);
            ShellResponse {
                result: eval_str,
                error: String::new(),
            }
        },
        Err(err) => {
            error!("Error calling provider: {}", err);

            ShellResponse {
                result: String::new(),
                error: err,
            }
        },
    };

    actix_web::HttpResponse::Ok().json(response)
}

pub fn extract_block(input: &str) -> String {
    let output: String = CODE_BLOCK_RE
        .captures_iter(input)
        .filter_map(|m| {
            m.ok()
                .and_then(|cap| cap.get(1))
                .map(|m| String::from(m.as_str()))
        })
        .collect();
    if output.is_empty() {
        input.trim().to_string()
    } else {
        output.trim().to_string()
    }
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct ServerCli {
    #[clap(short = 'c', long, env = "CONFIG", default_value = "config.yaml")]
    pub config: String,
    #[clap(short = 'a', long, env = "ADDRESS", default_value = "127.0.0.1")]
    pub address: String,
    #[clap(short = 'p', long, env = "PORT", default_value = "8080")]
    pub port: String,
    #[clap(short = 'k', long, env = "KEY")]
    pub key: Option<String>,
    #[clap(short = 'd', long, env = "LOGS_DIR")]
    pub logs_dir: Option<String>,
}

impl ServerCli {
    pub fn config(&self) -> Config {
        Config::from_yaml(&self.config)
    }

    pub fn endpoint(&self) -> String {
        format!("{}:{}", &self.address, &self.port)
    }
}