semantic_search_cli/
config.rs1use anyhow::Result as AnyResult;
4use std::path::Path;
5
6use semantic_search::Model;
7use serde::Deserialize;
8
9#[derive(Deserialize, Debug)]
11pub struct Config {
12 #[serde(default)]
14 pub server: Server,
15 pub api: ApiConfig,
17 #[serde(default)]
19 pub bot: BotConfig,
20}
21
22#[derive(Deserialize, Debug)]
24#[serde(default)]
25pub struct Server {
26 pub port: u16,
28}
29
30impl Default for Server {
31 fn default() -> Self {
32 Self { port: 8080 }
33 }
34}
35
36#[derive(Deserialize, Debug)]
38pub struct ApiConfig {
39 pub key: String,
41 #[serde(default)]
43 pub model: Model,
44}
45
46#[derive(Deserialize, Debug)]
48#[serde(default)]
49pub struct BotConfig {
50 pub token: String,
52 pub owner: u64,
54 pub whitelist: Vec<u64>,
56 pub sticker_set: String,
58 pub num_results: usize,
60 pub postscript: String,
62}
63
64impl Default for BotConfig {
65 fn default() -> Self {
66 Self {
67 token: String::new(),
68 owner: 0,
69 whitelist: Vec::new(),
70 num_results: 8,
71 sticker_set: "meme".to_string(),
72 postscript: String::new(),
73 }
74 }
75}
76
77fn parse_config_from_str(content: &str) -> Result<Config, toml::de::Error> {
83 toml::from_str(content)
84}
85
86pub fn parse_config<T>(path: T) -> AnyResult<Config>
92where
93 T: AsRef<Path>,
94{
95 let content = std::fs::read_to_string(path)?;
96 Ok(parse_config_from_str(&content)?)
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 fn test(content: &str, port: u16, key: &str, model: Model, bot_token: &str) {
104 let config = parse_config_from_str(content).unwrap();
105 assert_eq!(config.server.port, port);
106 assert_eq!(config.api.key, key);
107 assert_eq!(config.api.model, model);
108 assert_eq!(config.bot.token, bot_token);
109 }
110
111 #[test]
112 fn parse_config_1() {
113 let content = r#"
114 [server]
115 port = 8081
116
117 [api]
118 key = "test_key"
119
120 [bot]
121 token = "test_token"
122 "#;
123 test(
124 content,
125 8081,
126 "test_key",
127 Model::BgeLargeZhV1_5,
128 "test_token",
129 );
130 }
131
132 #[test]
133 fn parse_config_2() {
134 let content = r#"
135 [server]
136 port = 8080
137
138 [api]
139 key = "test_key"
140 model = "BAAI/bge-large-zh-v1.5"
141 "#;
142 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
143 }
144
145 #[test]
146 fn parse_config_3() {
147 let content = r#"
148 [server]
149
150 [api]
151 key = "test_key"
152 model = "BAAI/bge-large-en-v1.5"
153 "#;
154 test(content, 8080, "test_key", Model::BgeLargeEnV1_5, "");
155 }
156
157 #[test]
158 fn parse_config_4() {
159 let content = r#"
160 [api]
161 key = "test_key"
162 "#;
163 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
164 }
165
166 #[test]
167 fn parse_config_5() {
168 let content = r#"
169 [server]
170 port = 8081
171
172 [api]
173 key = "test_key"
174
175 [bot]
176 "#;
177 test(content, 8081, "test_key", Model::BgeLargeZhV1_5, "");
178 }
179
180 #[test]
181 #[should_panic(expected = "missing field `api`")]
182 fn parse_config_fail_1() {
183 let content = r"
184 [server]
185 port = 8080
186 ";
187 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
188 }
189
190 #[test]
191 #[should_panic(expected = "missing field `key`")]
192 fn parse_config_fail_2() {
193 let content = r"
194 [api]
195 ";
196 test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
197 }
198}