semantic_search_cli/
config.rs1use anyhow::Result as AnyResult;
4use std::path::Path;
5
6use semantic_search::Model;
7use serde::Deserialize;
8
9#[derive(Deserialize, Debug)]
11pub struct Config {
12 #[serde(default)]
14 pub server: Server,
15 pub api: ApiConfig,
17 #[serde(default)]
19 pub bot: BotConfig,
20}
21
22#[derive(Deserialize, Debug)]
24pub struct Server {
25 #[serde(default = "defaults::server_port")]
27 pub port: u16,
28}
29
30impl Default for Server {
31 fn default() -> Self {
32 Self {
33 port: defaults::server_port(),
34 }
35 }
36}
37
38#[derive(Deserialize, Debug)]
40pub struct ApiConfig {
41 pub key: String,
43 #[serde(default)]
45 pub model: Model,
46}
47
48#[derive(Deserialize, Debug)]
50pub struct BotConfig {
51 #[serde(default)]
53 pub token: String,
54 #[serde(default)]
56 pub owner: u64,
57 #[serde(default)]
59 pub whitelist: Vec<u64>,
60 #[serde(default = "defaults::sticker_set")]
62 pub sticker_set: String,
63 #[serde(default = "defaults::num_results")]
65 pub num_results: usize,
66}
67
68impl Default for BotConfig {
69 fn default() -> Self {
70 Self {
71 token: String::new(),
72 owner: 0,
73 whitelist: Vec::new(),
74 num_results: defaults::num_results(),
75 sticker_set: defaults::sticker_set(),
76 }
77 }
78}
79
80fn parse_config_from_str(content: &str) -> Result<Config, toml::de::Error> {
86 toml::from_str(content)
87}
88
89pub fn parse_config<T>(path: T) -> AnyResult<Config>
95where
96 T: AsRef<Path>,
97{
98 let content = std::fs::read_to_string(path)?;
99 Ok(parse_config_from_str(&content)?)
100}
101
102mod defaults {
104 pub const fn server_port() -> u16 {
106 8080
107 }
108 pub const fn num_results() -> usize {
110 8
111 }
112 pub fn sticker_set() -> String {
114 "meme".to_string()
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 fn test(content: &str, port: u16, key: &str, model: Model, bot_token: &str) {
123 let config = parse_config_from_str(content).unwrap();
124 assert_eq!(config.server.port, port);
125 assert_eq!(config.api.key, key);
126 assert_eq!(config.api.model, model);
127 assert_eq!(config.bot.token, bot_token);
128 }
129
130 #[test]
131 fn parse_config_1() {
132 let content = r#"
133 [server]
134 port = 8081
135
136 [api]
137 key = "test_key"
138
139 [bot]
140 token = "test_token"
141 "#;
142 test(
143 content,
144 8081,
145 "test_key",
146 Model::BgeLargeZhV1_5,
147 "test_token",
148 );
149 }
150
151 #[test]
152 fn parse_config_2() {
153 let content = r#"
154 [server]
155 port = 8080
156
157 [api]
158 key = "test_key"
159 model = "BAAI/bge-large-zh-v1.5"
160 "#;
161 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
162 }
163
164 #[test]
165 fn parse_config_3() {
166 let content = r#"
167 [server]
168
169 [api]
170 key = "test_key"
171 model = "BAAI/bge-large-en-v1.5"
172 "#;
173 test(content, 8080, "test_key", Model::BgeLargeEnV1_5, "");
174 }
175
176 #[test]
177 fn parse_config_4() {
178 let content = r#"
179 [api]
180 key = "test_key"
181 "#;
182 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
183 }
184
185 #[test]
186 fn parse_config_5() {
187 let content = r#"
188 [server]
189 port = 8081
190
191 [api]
192 key = "test_key"
193
194 [bot]
195 "#;
196 test(content, 8081, "test_key", Model::BgeLargeZhV1_5, "");
197 }
198
199 #[test]
200 #[should_panic(expected = "missing field `api`")]
201 fn parse_config_fail_1() {
202 let content = r"
203 [server]
204 port = 8080
205 ";
206 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
207 }
208
209 #[test]
210 #[should_panic(expected = "missing field `key`")]
211 fn parse_config_fail_2() {
212 let content = r"
213 [api]
214 ";
215 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
216 }
217}