libsubconverter/settings/
toml_deserializer.rs

1use serde::de::{MapAccess, SeqAccess, Visitor};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt;
5
6use crate::models::{
7    cron::CronTaskConfig, BalanceStrategy, ProxyGroupConfig, ProxyGroupType, RegexMatchConfig,
8    RulesetConfig,
9};
10use crate::settings::settings::toml_settings::TemplateSettings;
11
12pub trait ImportableInToml: serde::de::DeserializeOwned + Clone {
13    fn is_import_node(&self) -> bool;
14    fn get_import_path(&self) -> Option<String>;
15    fn try_from_toml_value(value: &toml::Value) -> Result<Self, Box<dyn std::error::Error>> {
16        Ok(value.clone().try_into()?)
17    }
18}
19
20/// Stream rule configuration
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22#[serde(default)]
23pub struct RegexMatchRuleInToml {
24    #[serde(rename = "match")]
25    pub match_str: Option<String>,
26
27    #[serde(alias = "emoji")]
28    pub replace: Option<String>,
29    pub script: Option<String>,
30    pub import: Option<String>,
31}
32
33impl Into<RegexMatchConfig> for RegexMatchRuleInToml {
34    fn into(self) -> RegexMatchConfig {
35        let mut config = RegexMatchConfig::new(
36            self.match_str.unwrap_or_default(),
37            self.replace.unwrap_or_default(),
38        );
39        config.compile();
40        config
41    }
42}
43
44impl ImportableInToml for RegexMatchRuleInToml {
45    fn is_import_node(&self) -> bool {
46        self.import.is_some()
47    }
48
49    fn get_import_path(&self) -> Option<String> {
50        self.import.clone()
51    }
52}
53
54/// Ruleset configuration
55#[derive(Debug, Clone, Serialize, Deserialize, Default)]
56#[serde(default)]
57pub struct RulesetConfigInToml {
58    pub group: String,
59    pub ruleset: Option<String>,
60    #[serde(rename = "type")]
61    pub ruleset_type: Option<String>,
62    pub interval: Option<u32>,
63    pub import: Option<String>,
64}
65
66impl ImportableInToml for RulesetConfigInToml {
67    fn is_import_node(&self) -> bool {
68        self.import.is_some()
69    }
70
71    fn get_import_path(&self) -> Option<String> {
72        self.import.clone()
73    }
74}
75
76impl Into<RulesetConfig> for RulesetConfigInToml {
77    fn into(self) -> RulesetConfig {
78        RulesetConfig {
79            url: self.ruleset.unwrap_or_default(),
80            group: self.group,
81            interval: self.interval.unwrap_or(300),
82        }
83    }
84}
85
86fn default_test_url() -> Option<String> {
87    Some("http://www.gstatic.com/generate_204".to_string())
88}
89
90fn default_interval() -> Option<u32> {
91    Some(300)
92}
93
94/// Proxy group configuration
95#[derive(Debug, Clone, Serialize, Deserialize, Default)]
96#[serde(default)]
97pub struct ProxyGroupConfigInToml {
98    pub name: String,
99    #[serde(rename = "type")]
100    pub group_type: String,
101    pub strategy: Option<String>,
102    pub rule: Vec<String>,
103    #[serde(default = "default_test_url")]
104    pub url: Option<String>,
105    #[serde(default = "default_interval")]
106    pub interval: Option<u32>,
107    pub lazy: Option<bool>,
108    pub tolerance: Option<u32>,
109    pub timeout: Option<u32>,
110    pub disable_udp: Option<bool>,
111    pub import: Option<String>,
112}
113
114impl ImportableInToml for ProxyGroupConfigInToml {
115    fn is_import_node(&self) -> bool {
116        self.import.is_some()
117    }
118
119    fn get_import_path(&self) -> Option<String> {
120        self.import.clone()
121    }
122}
123
124impl Into<ProxyGroupConfig> for ProxyGroupConfigInToml {
125    fn into(self) -> ProxyGroupConfig {
126        let group_type = match self.group_type.as_str() {
127            "select" => ProxyGroupType::Select,
128            "url-test" => ProxyGroupType::URLTest,
129            "load-balance" => ProxyGroupType::LoadBalance,
130            "fallback" => ProxyGroupType::Fallback,
131            "relay" => ProxyGroupType::Relay,
132            "ssid" => ProxyGroupType::SSID,
133            "smart" => ProxyGroupType::Smart,
134            _ => ProxyGroupType::Select, // 默认为 Select
135        };
136
137        // 处理 strategy 字段
138        let strategy = match self.strategy.as_deref() {
139            Some("consistent-hashing") => BalanceStrategy::ConsistentHashing,
140            Some("round-robin") => BalanceStrategy::RoundRobin,
141            _ => BalanceStrategy::ConsistentHashing,
142        };
143
144        // 创建基本的 ProxyGroupConfig
145        let mut config = ProxyGroupConfig {
146            name: self.name,
147            group_type,
148            proxies: self.rule,
149            url: self.url.unwrap_or_default(),
150            interval: self.interval.unwrap_or(300),
151            tolerance: self.tolerance.unwrap_or(0),
152            timeout: self.timeout.unwrap_or(5),
153            lazy: self.lazy.unwrap_or(false),
154            disable_udp: self.disable_udp.unwrap_or(false),
155            strategy,
156            // 添加缺失的字段
157            persistent: false,
158            evaluate_before_use: false,
159            using_provider: Vec::new(),
160        };
161
162        // 根据不同的代理组类型设置特定属性
163        match config.group_type {
164            ProxyGroupType::URLTest | ProxyGroupType::Smart => {
165                // 这些类型需要 URL 和 interval
166                if config.url.is_empty() {
167                    config.url = "http://www.gstatic.com/generate_204".to_string();
168                }
169            }
170            ProxyGroupType::LoadBalance => {
171                // 负载均衡需要 URL、interval 和 strategy
172                if config.url.is_empty() {
173                    config.url = "http://www.gstatic.com/generate_204".to_string();
174                }
175            }
176            ProxyGroupType::Fallback => {
177                // 故障转移需要 URL 和 interval
178                if config.url.is_empty() {
179                    config.url = "http://www.gstatic.com/generate_204".to_string();
180                }
181            }
182            _ => {}
183        }
184
185        config
186    }
187}
188
189/// Task configuration
190#[derive(Debug, Clone, Serialize, Deserialize, Default)]
191#[serde(default)]
192pub struct TaskConfigInToml {
193    pub name: String,
194    pub cronexp: String,
195    pub path: String,
196    pub timeout: u32,
197    pub import: Option<String>,
198}
199
200impl ImportableInToml for TaskConfigInToml {
201    fn is_import_node(&self) -> bool {
202        self.import.is_some()
203    }
204
205    fn get_import_path(&self) -> Option<String> {
206        self.import.clone()
207    }
208}
209
210impl Into<CronTaskConfig> for TaskConfigInToml {
211    fn into(self) -> CronTaskConfig {
212        CronTaskConfig {
213            name: self.name,
214            cron_exp: self.cronexp,
215            path: self.path,
216            timeout: self.timeout,
217        }
218    }
219}
220
221pub fn deserialize_template_as_template_settings<'de, D>(
222    deserializer: D,
223) -> Result<TemplateSettings, D::Error>
224where
225    D: serde::Deserializer<'de>,
226{
227    struct TemplateSettingsVisitor;
228
229    impl<'de> Visitor<'de> for TemplateSettingsVisitor {
230        type Value = TemplateSettings;
231
232        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
233            formatter.write_str("a TemplateSettings struct")
234        }
235
236        fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
237        where
238            V: MapAccess<'de>,
239        {
240            let mut template_settings = TemplateSettings::default();
241            while let Some(key) = map.next_key::<String>()? {
242                let value = map.next_value::<String>()?;
243                if key == "template_path" {
244                    template_settings.template_path = value.clone();
245                } else {
246                    template_settings.globals.insert(key, value);
247                }
248            }
249            Ok(template_settings)
250        }
251    }
252
253    deserializer.deserialize_any(TemplateSettingsVisitor)
254}
255
256/// Template argument structure for deserialization
257#[derive(Debug, Clone, Deserialize, Default)]
258struct TemplateArgument {
259    pub key: String,
260    pub value: String,
261}
262
263pub fn deserialize_template_args_as_hash_map<'de, D>(
264    deserializer: D,
265) -> Result<Option<HashMap<String, String>>, D::Error>
266where
267    D: serde::Deserializer<'de>,
268{
269    struct TemplateArgsVisitor;
270
271    impl<'de> Visitor<'de> for TemplateArgsVisitor {
272        type Value = Option<HashMap<String, String>>;
273
274        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
275            formatter.write_str("a sequence of template arguments or a map of key-value pairs")
276        }
277
278        fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
279        where
280            S: SeqAccess<'de>,
281        {
282            let mut template_args = HashMap::new();
283
284            while let Some(item) = seq.next_element::<TemplateArgument>()? {
285                template_args.insert(item.key, item.value);
286            }
287
288            if template_args.is_empty() {
289                Ok(None)
290            } else {
291                Ok(Some(template_args))
292            }
293        }
294
295        fn visit_none<E>(self) -> Result<Self::Value, E>
296        where
297            E: serde::de::Error,
298        {
299            Ok(None)
300        }
301
302        fn visit_unit<E>(self) -> Result<Self::Value, E>
303        where
304            E: serde::de::Error,
305        {
306            Ok(None)
307        }
308
309        fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
310        where
311            M: MapAccess<'de>,
312        {
313            let mut template_args = HashMap::new();
314
315            while let Some((key, value)) = map.next_entry::<String, String>()? {
316                template_args.insert(key, value);
317            }
318
319            if template_args.is_empty() {
320                Ok(None)
321            } else {
322                Ok(Some(template_args))
323            }
324        }
325    }
326
327    deserializer.deserialize_any(TemplateArgsVisitor)
328}