1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use std::panic::panic_any;
use std::sync::Arc;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum ProviderError {
    #[error("HTTP Request Error: {0}")]
    RequestError(#[from] reqwest::Error),
    #[error("JSON Parsing Error: {0}")]
    JsonError(#[from] serde_json::Error),
    #[error("Unexpected Response Structure")]
    UnexpectedResponse(String),
}

#[async_trait]
pub trait ProviderApi {
    async fn call(&self, role_prompt: &str, user_prompt: &str) -> Result<String, ProviderError>;
}

#[derive(Debug, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum ProviderConfig {
    OpenAI {
        api_key: String,
        api_url: String,
        model: String,
    },
    AzureOpenAI {
        api_key: String,
        api_url: String,
        model: String,
    },
    Ollama {
        api_key: String,
        api_url: String,
        model: String,
    },
}

#[derive(Clone)]
pub struct AzureOpenAI {
    pub client: Client,
    pub model: String,
    pub api_key: String,
    pub url_full: String,
}

#[derive(serde::Deserialize)]
pub struct CompletionResponse {
    choices: Vec<Choice>,
}

#[derive(serde::Deserialize)]
pub struct Choice {
    message: Message,
}

#[derive(serde::Deserialize)]
pub struct Message {
    content: String,
}

#[async_trait]
impl ProviderApi for AzureOpenAI {
    async fn call(&self, role_prompt: &str, user_prompt: &str) -> Result<String, ProviderError> {
        let messages = vec![
            json!({ "role": "system", "content": role_prompt }),
            json!({ "role": "user", "content": user_prompt }),
        ];

        let body = json!({
            "model": &self.model,
            "messages": messages,
        });

        let response = self
            .client
            .post(&self.url_full)
            .header("api-key", &self.api_key)
            .json(&body)
            .send()
            .await?
            .text()
            .await?;

        match serde_json::from_str::<CompletionResponse>(&response) {
            Ok(resp) => {
                if let Some(choice) = resp.choices.first() {
                    Ok(choice.message.content.clone())
                } else {
                    Err(ProviderError::UnexpectedResponse(response.to_string()))
                }
            }
            Err(_) => {
                Err(ProviderError::UnexpectedResponse(response.to_string()))
            }
        }
    }
}

pub fn new_provider(provider_type: &ProviderConfig) -> Arc<dyn ProviderApi + Send + Sync> {
    let client = Client::new();
    let provider: Arc<dyn ProviderApi + Send + Sync> = match provider_type {
        ProviderConfig::AzureOpenAI {
            api_key,
            api_url: base_url,
            model,
        } => Arc::new(AzureOpenAI {
            client,
            model: model.clone(),
            api_key: api_key.clone(),
            url_full: format!(
                "{}/openai/deployments/{}/chat/completions?api-version=2024-02-01",
                &base_url, &model,
            ),
        }),
        _ => panic_any(format!(
            "the provider not implemented yet: {:?}",
            provider_type
        )),
    };
    provider
}