rust_ai/utils/config.rs
1//!
2//! # Configuration
3//!
4//! Provide configuration related types and functions/methods.
5//!
6//! Note: `config.yml` must locate in your current working directory.
7//!
8//! ## Example
9//!
10//! ```yaml
11//! openai:
12//! api_key: sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
13//! azure:
14//! speech:
15//! key: 4c7eXXXXXXXXXXXXXXXXXXXXXXX54c32
16//! region: westus
17//! ```
18
19////////////////////////////////////////////////////////////////////////////////
20
21use serde_yaml;
22use std::{fs::read_to_string, path::PathBuf};
23
24/// Configurations from `config.yml`
25///
26/// Example contents:
27/// ```yaml
28/// openai:
29/// api_key: sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
30/// base_endpoint: ""
31/// azure:
32/// speech:
33/// key: 4c7eXXXXXXXXXXXXXXXXXXXXXXX54c32
34/// region: westus
35/// ```
36///
37/// # Examples
38///
39/// ```rust
40/// use rust_ai::utils::config::Config;
41///
42/// let config = Config::load();
43/// ```
44#[derive(Debug, serde::Serialize, serde::Deserialize, Default)]
45pub struct Config {
46 /// OpenAI config mappping
47 pub openai: OpenAi,
48
49 /// Azure config mapping
50 pub azure: Azure,
51}
52
53impl Config {
54 /// Load contents from local config file.
55 pub fn load() -> Result<Self, Box<dyn std::error::Error>> {
56 if let Ok(config_contents) = std::env::var("RUST_AI_CONFIG") {
57 return match serde_yaml::from_str(&config_contents) {
58 Ok(config) => Ok(config),
59 Err(e) => {
60 log::error!(target: "global", "Unable to parse config: {:?}", e);
61 Err(e.into())
62 }
63 };
64 } else {
65 let config_path = PathBuf::from("config.yml");
66 if !config_path.exists() {
67 return Err("`config.yml` doesn't exist!".into());
68 }
69
70 return if let Ok(config_contents) = read_to_string(config_path) {
71 match serde_yaml::from_str(&config_contents) {
72 Ok(config) => Ok(config),
73 Err(e) => {
74 log::error!(target: "global", "Unable to parse config: {:?}", e);
75 Err(e.into())
76 }
77 }
78 } else {
79 Err("Unable to read `config.yml`".into())
80 };
81 }
82 }
83}
84
85/// A mapping for OpenAI configuration contents
86#[derive(Debug, serde::Serialize, serde::Deserialize, Default)]
87pub struct OpenAi {
88 /// API key obtained from <https://openai.com>.
89 pub api_key: String,
90
91 /// OpenAI Organization ID
92 pub org_id: Option<String>,
93
94 /// Alternative base endpoint for OpenAI.
95 pub base_endpoint: Option<String>,
96}
97
98impl OpenAi {
99 pub fn base_endpoint(&self) -> String {
100 self.base_endpoint
101 .clone()
102 .unwrap_or("https://api.openai.com".to_string())
103 }
104}
105
106/// A mapping for Azure (Global) configuration contents
107#[derive(Debug, serde::Serialize, serde::Deserialize, Default)]
108pub struct Azure {
109 /// Configuration content for cognitive/speech.
110 pub speech: AzureSpeech,
111}
112
113/// Service key for use in multiple Azure services.
114#[derive(Debug, serde::Serialize, serde::Deserialize, Default)]
115pub struct AzureSpeech {
116 /// Key content from <https://portal.azure.com/>
117 pub key: String,
118
119 /// Region name from <https://portal.azure.com/>
120 pub region: String,
121}