1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::panic::panic_any;
use std::sync::Arc;
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;

#[async_trait]
pub trait ProviderApi {
    async fn call(&self, role_prompt: &str, user_prompt: &str) -> Result<String, String>;
}

#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum ProviderConfig {
    OpenAI { api_key: String, base_url: String, model: String },
    AzureOpenAI { api_key: String, base_url: String, model: String },
    Ollama { api_key: String, base_url: String, model: String },
}

#[derive(Clone)]
pub struct AzureOpenAI {
    pub client: Client,
    pub model: String,
    pub api_key: String,
    pub url_full: String,
}

#[derive(serde::Deserialize)]
pub struct CompletionResponse {
    choices: Vec<Choice>,
}

#[derive(serde::Deserialize)]
pub struct Choice {
    message: Message,
}

#[derive(serde::Deserialize)]
pub struct Message {
    content: String,
}

#[async_trait]
impl ProviderApi for AzureOpenAI {
    async fn call(&self, role_prompt: &str, user_prompt: &str) -> Result<String, String> {
        let messages = vec![
            json!({ "role": "system", "content": role_prompt }),
            json!({ "role": "user", "content": user_prompt })
        ];

        let body = json!({
            "model": &self.model,
            "messages": messages,
        });

        let response = self.client.post(&self.url_full)
            .header("api-key", &self.api_key)
            .json(&body)
            .send()
            .await.map_err(|e| e.to_string())?
            .text()
            .await.map_err(|e| e.to_string())?;

        let resp: CompletionResponse = serde_json::from_str(&response).map_err(|e| e.to_string())?;

        if let Some(choice) = resp.choices.first() {
            Ok(choice.message.content.clone())
        } else {
            Ok("".to_string())
        }
    }
}

pub fn new_provider(provider_type: &ProviderConfig) -> Arc<dyn ProviderApi + Send + Sync> {
    let client = Client::new();
    let provider: Arc<dyn ProviderApi + Send + Sync> = match provider_type {
        ProviderConfig::AzureOpenAI { api_key, base_url, model } => Arc::new(AzureOpenAI {
            client,
            model: model.clone(),
            api_key: api_key.clone(),
            url_full: format!(
                "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
                &base_url,
                &model,
            )
        }),
        _ => {
            panic_any(format!("the provider not implemented yet: {:?}", provider_type))
        }
    };
    provider
}