semantic_search_cli/
config.rs

1//! Configuration file parser.
2
3use anyhow::Result as AnyResult;
4use std::path::Path;
5
6use semantic_search::Model;
7use serde::Deserialize;
8
9/// Structure of the configuration file.
10#[derive(Deserialize, Debug)]
11pub struct Config {
12    /// Server configuration.
13    #[serde(default)]
14    pub server: Server,
15    /// API configuration.
16    pub api: ApiConfig,
17    /// Telegram bot configuration.
18    #[serde(default)]
19    pub bot: BotConfig,
20}
21
22/// Server configuration.
23#[derive(Deserialize, Debug)]
24#[serde(default)]
25pub struct Server {
26    /// Port for the server. Default is 8080.
27    pub port: u16,
28}
29
30impl Default for Server {
31    fn default() -> Self {
32        Self { port: 8080 }
33    }
34}
35
36/// API configuration.
37#[derive(Deserialize, Debug)]
38pub struct ApiConfig {
39    /// API key for Silicon Cloud.
40    pub key: String,
41    /// Model to use for embedding.
42    #[serde(default)]
43    pub model: Model,
44}
45
46/// Telegram bot configuration.
47#[derive(Deserialize, Debug)]
48#[serde(default)]
49pub struct BotConfig {
50    /// Telegram bot token.
51    pub token: String,
52    /// Telegram user ID of the bot owner.
53    pub owner: u64,
54    /// Whitelisted user IDs.
55    pub whitelist: Vec<u64>,
56    /// Sticker set id prefix for the bot.
57    pub sticker_set: String,
58    /// Number of results to return.
59    pub num_results: usize,
60    /// Postscript to be appended after the help message.
61    pub postscript: String,
62}
63
64impl Default for BotConfig {
65    fn default() -> Self {
66        Self {
67            token: String::new(),
68            owner: 0,
69            whitelist: Vec::new(),
70            num_results: 8,
71            sticker_set: "meme".to_string(),
72            postscript: String::new(),
73        }
74    }
75}
76
77/// Parse the configuration into a `Config` structure.
78///
79/// # Errors
80///
81/// Returns an [`Error`](toml::de::Error) if the configuration file is not valid, like missing fields.
82fn parse_config_from_str(content: &str) -> Result<Config, toml::de::Error> {
83    toml::from_str(content)
84}
85
86/// Parse the configuration file into a `Config` structure.
87///
88/// # Errors
89///
90/// Returns an [IO error](std::io::Error) if reading fails, or a [TOML error](toml::de::Error) if parsing fails.
91pub fn parse_config<T>(path: T) -> AnyResult<Config>
92where
93    T: AsRef<Path>,
94{
95    let content = std::fs::read_to_string(path)?;
96    Ok(parse_config_from_str(&content)?)
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    fn test(content: &str, port: u16, key: &str, model: Model, bot_token: &str) {
104        let config = parse_config_from_str(content).unwrap();
105        assert_eq!(config.server.port, port);
106        assert_eq!(config.api.key, key);
107        assert_eq!(config.api.model, model);
108        assert_eq!(config.bot.token, bot_token);
109    }
110
111    #[test]
112    fn parse_config_1() {
113        let content = r#"
114            [server]
115            port = 8081
116
117            [api]
118            key = "test_key"
119
120            [bot]
121            token = "test_token"
122        "#;
123        test(
124            content,
125            8081,
126            "test_key",
127            Model::BgeLargeZhV1_5,
128            "test_token",
129        );
130    }
131
132    #[test]
133    fn parse_config_2() {
134        let content = r#"
135            [server]
136            port = 8080
137
138            [api]
139            key = "test_key"
140            model = "BAAI/bge-large-zh-v1.5"
141        "#;
142        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
143    }
144
145    #[test]
146    fn parse_config_3() {
147        let content = r#"
148            [server]
149
150            [api]
151            key = "test_key"
152            model = "BAAI/bge-large-en-v1.5"
153        "#;
154        test(content, 8080, "test_key", Model::BgeLargeEnV1_5, "");
155    }
156
157    #[test]
158    fn parse_config_4() {
159        let content = r#"
160            [api]
161            key = "test_key"
162        "#;
163        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
164    }
165
166    #[test]
167    fn parse_config_5() {
168        let content = r#"
169            [server]
170            port = 8081
171
172            [api]
173            key = "test_key"
174
175            [bot]
176        "#;
177        test(content, 8081, "test_key", Model::BgeLargeZhV1_5, "");
178    }
179
180    #[test]
181    #[should_panic(expected = "missing field `api`")]
182    fn parse_config_fail_1() {
183        let content = r"
184            [server]
185            port = 8080
186        ";
187        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
188    }
189
190    #[test]
191    #[should_panic(expected = "missing field `key`")]
192    fn parse_config_fail_2() {
193        let content = r"
194            [api]
195        ";
196        test(content, 8080, "test_key", Model::BgeLargeZhV1_5, "");
197    }
198}