1pub mod chat;
4pub mod models;
5pub mod text_to_image;
6pub mod text_to_speech;
7
8use base64::prelude::BASE64_STANDARD;
9use base64::Engine;
10use reqwest::header::HeaderMap;
11use reqwest::header::HeaderValue;
12use serde::Deserialize;
13use serde::Serialize;
14use std::collections::HashMap;
15use std::error::Error;
16use std::fs::File;
17use std::io::Read;
18use std::str::FromStr;
19
20pub(crate) fn request_headers(key: &Key) -> Result<HeaderMap, Box<dyn Error + Send + Sync>> {
21 let mut headers = HeaderMap::new();
22 headers.insert(
23 "Authorization",
24 HeaderValue::from_str(&format!("Bearer {}", key.key))?,
25 );
26 headers.insert("Content-Type", HeaderValue::from_str("application/json")?);
27 Ok(headers)
28}
29
30pub(crate) fn openai_base_url(provider: &Provider) -> String {
31 match provider {
32 Provider::Google => format!("{}/v1beta/openai", provider.domain()),
33 Provider::Groq => format!("{}/openai/v1", provider.domain()),
34 Provider::Hyperbolic => format!("{}/v1", provider.domain()),
35 Provider::Mistral => format!("{}/v1", provider.domain()),
36 Provider::OpenAI => format!("{}/v1", provider.domain()),
37 Provider::OpenAICompatible(domain) => domain.clone(),
38 Provider::SambaNova => format!("{}/v1", provider.domain()),
39 Provider::TogetherAI => format!("{}/v1", provider.domain()),
40 _ => format!("{}/v1/openai", provider.domain()),
41 }
42}
43
44#[allow(rustdoc::bare_urls)]
45#[derive(Clone, Debug, Serialize, PartialEq)]
46pub enum Provider {
47 Amazon,
48 Azure,
49 Cerebras,
50 DeepInfra,
51 ElevenLabs,
52 Fireworks,
53 FriendliAI,
54 Google,
55 Groq,
56 Hyperbolic,
57 Mistral,
58 Nebius,
59 Novita,
60 OpenAI,
61 OpenAICompatible(String),
65 SambaNova,
66 TogetherAI,
67}
68
69impl std::fmt::Display for Provider {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{:?}", self)
72 }
73}
74
75impl Provider {
76 pub fn domain(&self) -> String {
77 match self {
78 Provider::Amazon => "https://api.amazon.com",
79 Provider::Azure => "https://api.azure.com",
80 Provider::Cerebras => "https://api.cerebras.ai",
81 Provider::DeepInfra => "https://api.deepinfra.com",
82 Provider::ElevenLabs => "https://api.elevenlabs.io",
83 Provider::Fireworks => "https://api.fireworks.ai",
84 Provider::FriendliAI => "https://api.friendli.ai",
85 Provider::Google => "https://generativelanguage.googleapis.com",
86 Provider::Groq => "https://api.groq.com",
87 Provider::Hyperbolic => "https://api.hyperbolic.xyz",
88 Provider::Mistral => "https://api.mistral.ai",
89 Provider::Nebius => "https://api.nebi.us",
90 Provider::Novita => "https://api.novita.ai",
91 Provider::OpenAI => "https://api.openai.com",
92 Provider::OpenAICompatible(base_url) => base_url,
93 Provider::SambaNova => "https://api.sambanova.ai",
94 Provider::TogetherAI => "https://api.together.xyz",
95 }
96 .to_string()
97 }
98 pub fn key_name(&self) -> String {
99 match self {
100 Provider::OpenAICompatible(_) => "OPENAI_COMPATIBLE_KEY".to_string(),
101 _ => self.to_string().to_uppercase() + "_KEY",
102 }
103 }
104}
105
106impl FromStr for Provider {
107 type Err = Box<dyn Error + Send + Sync>;
108
109 fn from_str(s: &str) -> Result<Self, Self::Err> {
110 let s = s.to_lowercase();
111 if s.starts_with("openai-compatible(") {
112 let s = s.strip_prefix("openai-compatible(").unwrap();
113 let s = s.strip_suffix(")").unwrap();
114 let mut domain = s.to_string();
115 if !domain.starts_with("https") {
116 if domain.contains("localhost") {
117 domain = format!("http://{}", domain);
118 } else {
119 domain = format!("https://{}", domain);
120 }
121 }
122 return Ok(Provider::OpenAICompatible(domain));
123 }
124 match s.as_str() {
125 "amazon" => Ok(Provider::Amazon),
126 "azure" => Ok(Provider::Azure),
127 "cerebras" => Ok(Provider::Cerebras),
128 "deepinfra" => Ok(Provider::DeepInfra),
129 "elevenlabs" => Ok(Provider::ElevenLabs),
130 "fireworks" => Ok(Provider::Fireworks),
131 "friendliai" => Ok(Provider::FriendliAI),
132 "google" => Ok(Provider::Google),
133 "groq" => Ok(Provider::Groq),
134 "hyperbolic" => Ok(Provider::Hyperbolic),
135 "mistral" => Ok(Provider::Mistral),
136 "nebi" => Ok(Provider::Nebius),
137 "novita" => Ok(Provider::Novita),
138 "openai" => Ok(Provider::OpenAI),
139 "sambanova" => Ok(Provider::SambaNova),
140 "togetherai" => Ok(Provider::TogetherAI),
141 _ => Err(format!("Unsupported provider: {s}.").into()),
142 }
143 }
144}
145
146#[derive(Clone, Debug, Deserialize)]
147pub enum SubContent {
148 TextContent { text: String },
149 ImageUrlContent { image_url: String },
150}
151
152impl SubContent {
153 pub fn new(r#type: &str, text: &str) -> Self {
154 match r#type {
155 "text" => Self::TextContent {
156 text: text.to_string(),
157 },
158 "image_url" => Self::ImageUrlContent {
159 image_url: text.to_string(),
160 },
161 _ => panic!("Invalid subcontent type: {}", r#type),
162 }
163 }
164}
165
166impl Serialize for SubContent {
167 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
168 where
169 S: serde::Serializer,
170 {
171 match self {
172 SubContent::TextContent { text } => serializer.serialize_str(text),
173 SubContent::ImageUrlContent { image_url } => {
174 let json = serde_json::json!({
175 "type": "image_url",
176 "image_url": {
177 "url": image_url
178 }
179 });
180 json.serialize(serializer)
181 }
182 }
183 }
184}
185
186#[derive(Clone, Debug)]
187pub enum Content {
188 Text(String),
189 Collection(Vec<SubContent>),
190}
191
192impl std::fmt::Display for Content {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 Content::Text(text) => write!(f, "{}", text),
196 Content::Collection(items) => {
197 write!(f, "{items:?}")
198 }
199 }
200 }
201}
202
203impl Serialize for Content {
204 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
205 where
206 S: serde::Serializer,
207 {
208 match self {
209 Content::Text(text) => serializer.serialize_str(text),
210 Content::Collection(items) => items.serialize(serializer),
211 }
212 }
213}
214
215impl<'de> Deserialize<'de> for Content {
216 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217 where
218 D: serde::Deserializer<'de>,
219 {
220 let value = serde_json::Value::deserialize(deserializer)?;
221 if let serde_json::Value::String(text) = value {
222 Ok(Content::Text(text))
223 } else if let serde_json::Value::Array(items) = value {
224 let subcontent = items
225 .into_iter()
226 .map(SubContent::deserialize)
227 .collect::<Result<Vec<_>, _>>()
228 .unwrap();
229 Ok(Content::Collection(subcontent))
230 } else {
231 Err(serde::de::Error::custom("Invalid content format"))
232 }
233 }
234}
235
236#[derive(Clone, Debug, Deserialize, Serialize)]
237pub struct Message {
238 pub role: String,
239 pub content: Content,
240}
241
242impl Message {
243 pub fn from_str(role: &str, text: &str) -> Self {
244 Self {
245 role: role.to_string(),
246 content: Content::Text(text.to_string()),
247 }
248 }
249 pub fn from_image_url(role: &str, image_url: &str) -> Self {
250 Self {
251 role: role.to_string(),
252 content: Content::Collection(vec![SubContent::ImageUrlContent {
253 image_url: image_url.to_string(),
254 }]),
255 }
256 }
257 pub fn from_image_bytes(role: &str, image_type: &str, image: &[u8]) -> Self {
258 let base64 = BASE64_STANDARD.encode(image);
259 let image_url = format!("data:image/{image_type};base64,{base64}");
260 Self::from_image_url(role, &image_url)
261 }
262}
263
264#[derive(Clone, Debug)]
265pub struct Key {
266 pub provider: Provider,
267 pub key: String,
268}
269
270#[derive(Clone, Debug)]
271pub struct Keys {
272 pub keys: Vec<Key>,
273}
274
275impl Keys {
276 pub fn for_provider(&self, provider: &Provider) -> Option<Key> {
277 fn finder(provider: &Provider, key: &Key) -> bool {
278 match provider {
279 Provider::OpenAICompatible(_) => {
280 matches!(&key.provider, Provider::OpenAICompatible(_))
281 }
282 _ => key.provider == *provider,
283 }
284 }
285
286 self.keys.iter().find(|key| finder(provider, key)).cloned()
287 }
288}
289
290fn load_env_file(path: &str) -> HashMap<String, String> {
291 let mut env_content = String::new();
292 if let Ok(mut file) = File::open(path) {
293 file.read_to_string(&mut env_content)
294 .expect("Failed to read .env file");
295 }
296 env_content
297 .lines()
298 .filter_map(|line| {
299 let mut parts = line.split('=');
300 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
301 Some((key.to_string(), value.to_string()))
302 } else {
303 None
304 }
305 })
306 .collect()
307}
308
309pub fn load_keys(path: &str) -> Keys {
311 let env_map = load_env_file(path);
312
313 let mut keys = vec![];
314
315 let providers = [
316 Provider::Amazon,
317 Provider::Azure,
318 Provider::Cerebras,
319 Provider::DeepInfra,
320 Provider::ElevenLabs,
321 Provider::Fireworks,
322 Provider::FriendliAI,
323 Provider::Google,
324 Provider::Groq,
325 Provider::Hyperbolic,
326 Provider::Mistral,
327 Provider::Nebius,
328 Provider::Novita,
329 Provider::OpenAI,
330 Provider::OpenAICompatible("".to_string()),
331 Provider::SambaNova,
332 Provider::TogetherAI,
333 ];
334 for provider in providers {
335 if let Ok(key_value) = std::env::var(provider.key_name()) {
336 keys.push(Key {
337 provider: provider.clone(),
338 key: key_value,
339 });
340 } else if let Some(key_value) = env_map.get(&provider.key_name()) {
341 keys.push(Key {
342 provider: provider.clone(),
343 key: key_value.to_string(),
344 });
345 }
346 }
347 Keys { keys }
348}