libsubconverter/settings/
toml_deserializer.rs

1use serde::de::{MapAccess, SeqAccess, Visitor};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt;
5
6use crate::models::{
7    cron::CronTaskConfig, BalanceStrategy, ProxyGroupConfig, ProxyGroupType, RegexMatchConfig,
8    RulesetConfig,
9};
10use crate::settings::settings::toml_settings::TemplateSettings;
11
12pub trait ImportableInToml: serde::de::DeserializeOwned + Clone {
13    fn is_import_node(&self) -> bool;
14    fn get_import_path(&self) -> Option<String>;
15    fn try_from_toml_value(value: &toml::Value) -> Result<Self, Box<dyn std::error::Error>> {
16        Ok(value.clone().try_into()?)
17    }
18}
19
20/// Stream rule configuration
21#[derive(Debug, Clone, Serialize, Deserialize, Default)]
22#[serde(default)]
23pub struct RegexMatchRuleInToml {
24    #[serde(rename = "match")]
25    pub match_str: Option<String>,
26
27    #[serde(alias = "emoji")]
28    pub replace: Option<String>,
29    pub script: Option<String>,
30    pub import: Option<String>,
31}
32
33impl Into<RegexMatchConfig> for RegexMatchRuleInToml {
34    fn into(self) -> RegexMatchConfig {
35        let mut config = RegexMatchConfig::new(
36            self.match_str.unwrap_or_default(),
37            self.replace.unwrap_or_default(),
38            self.script.unwrap_or_default(),
39        );
40        config.compile();
41        config
42    }
43}
44
45impl ImportableInToml for RegexMatchRuleInToml {
46    fn is_import_node(&self) -> bool {
47        self.import.is_some()
48    }
49
50    fn get_import_path(&self) -> Option<String> {
51        self.import.clone()
52    }
53}
54
55/// Ruleset configuration
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
57#[serde(default)]
58pub struct RulesetConfigInToml {
59    pub group: String,
60    pub ruleset: Option<String>,
61    #[serde(rename = "type")]
62    pub ruleset_type: Option<String>,
63    pub interval: Option<u32>,
64    pub import: Option<String>,
65}
66
67impl ImportableInToml for RulesetConfigInToml {
68    fn is_import_node(&self) -> bool {
69        self.import.is_some()
70    }
71
72    fn get_import_path(&self) -> Option<String> {
73        self.import.clone()
74    }
75}
76
77impl Into<RulesetConfig> for RulesetConfigInToml {
78    fn into(self) -> RulesetConfig {
79        RulesetConfig {
80            url: self.ruleset.unwrap_or_default(),
81            group: self.group,
82            interval: self.interval.unwrap_or(300),
83        }
84    }
85}
86
87fn default_test_url() -> Option<String> {
88    Some("http://www.gstatic.com/generate_204".to_string())
89}
90
91fn default_interval() -> Option<u32> {
92    Some(300)
93}
94
95/// Proxy group configuration
96#[derive(Debug, Clone, Serialize, Deserialize, Default)]
97#[serde(default)]
98pub struct ProxyGroupConfigInToml {
99    pub name: String,
100    #[serde(rename = "type")]
101    pub group_type: String,
102    pub strategy: Option<String>,
103    pub rule: Vec<String>,
104    #[serde(default = "default_test_url")]
105    pub url: Option<String>,
106    #[serde(default = "default_interval")]
107    pub interval: Option<u32>,
108    pub lazy: Option<bool>,
109    pub tolerance: Option<u32>,
110    pub timeout: Option<u32>,
111    pub disable_udp: Option<bool>,
112    pub import: Option<String>,
113}
114
115impl ImportableInToml for ProxyGroupConfigInToml {
116    fn is_import_node(&self) -> bool {
117        self.import.is_some()
118    }
119
120    fn get_import_path(&self) -> Option<String> {
121        self.import.clone()
122    }
123}
124
125impl Into<ProxyGroupConfig> for ProxyGroupConfigInToml {
126    fn into(self) -> ProxyGroupConfig {
127        let group_type = match self.group_type.as_str() {
128            "select" => ProxyGroupType::Select,
129            "url-test" => ProxyGroupType::URLTest,
130            "load-balance" => ProxyGroupType::LoadBalance,
131            "fallback" => ProxyGroupType::Fallback,
132            "relay" => ProxyGroupType::Relay,
133            "ssid" => ProxyGroupType::SSID,
134            "smart" => ProxyGroupType::Smart,
135            _ => ProxyGroupType::Select, // 默认为 Select
136        };
137
138        // 处理 strategy 字段
139        let strategy = match self.strategy.as_deref() {
140            Some("consistent-hashing") => BalanceStrategy::ConsistentHashing,
141            Some("round-robin") => BalanceStrategy::RoundRobin,
142            _ => BalanceStrategy::ConsistentHashing,
143        };
144
145        // 创建基本的 ProxyGroupConfig
146        let mut config = ProxyGroupConfig {
147            name: self.name,
148            group_type,
149            proxies: self.rule,
150            url: self.url.unwrap_or_default(),
151            interval: self.interval.unwrap_or(300),
152            tolerance: self.tolerance.unwrap_or(0),
153            timeout: self.timeout.unwrap_or(5),
154            lazy: self.lazy.unwrap_or(false),
155            disable_udp: self.disable_udp.unwrap_or(false),
156            strategy,
157            // 添加缺失的字段
158            persistent: false,
159            evaluate_before_use: false,
160            using_provider: Vec::new(),
161        };
162
163        // 根据不同的代理组类型设置特定属性
164        match config.group_type {
165            ProxyGroupType::URLTest | ProxyGroupType::Smart => {
166                // 这些类型需要 URL 和 interval
167                if config.url.is_empty() {
168                    config.url = "http://www.gstatic.com/generate_204".to_string();
169                }
170            }
171            ProxyGroupType::LoadBalance => {
172                // 负载均衡需要 URL、interval 和 strategy
173                if config.url.is_empty() {
174                    config.url = "http://www.gstatic.com/generate_204".to_string();
175                }
176            }
177            ProxyGroupType::Fallback => {
178                // 故障转移需要 URL 和 interval
179                if config.url.is_empty() {
180                    config.url = "http://www.gstatic.com/generate_204".to_string();
181                }
182            }
183            _ => {}
184        }
185
186        config
187    }
188}
189
190/// Task configuration
191#[derive(Debug, Clone, Serialize, Deserialize, Default)]
192#[serde(default)]
193pub struct TaskConfigInToml {
194    pub name: String,
195    pub cronexp: String,
196    pub path: String,
197    pub timeout: u32,
198    pub import: Option<String>,
199}
200
201impl ImportableInToml for TaskConfigInToml {
202    fn is_import_node(&self) -> bool {
203        self.import.is_some()
204    }
205
206    fn get_import_path(&self) -> Option<String> {
207        self.import.clone()
208    }
209}
210
211impl Into<CronTaskConfig> for TaskConfigInToml {
212    fn into(self) -> CronTaskConfig {
213        CronTaskConfig {
214            name: self.name,
215            cron_exp: self.cronexp,
216            path: self.path,
217            timeout: self.timeout,
218        }
219    }
220}
221
222pub fn deserialize_template_as_template_settings<'de, D>(
223    deserializer: D,
224) -> Result<TemplateSettings, D::Error>
225where
226    D: serde::Deserializer<'de>,
227{
228    struct TemplateSettingsVisitor;
229
230    impl<'de> Visitor<'de> for TemplateSettingsVisitor {
231        type Value = TemplateSettings;
232
233        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
234            formatter.write_str("a TemplateSettings struct")
235        }
236
237        fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
238        where
239            V: MapAccess<'de>,
240        {
241            let mut template_settings = TemplateSettings::default();
242            while let Some(key) = map.next_key::<String>()? {
243                let value = map.next_value::<String>()?;
244                if key == "template_path" {
245                    template_settings.template_path = value.clone();
246                } else {
247                    template_settings.globals.insert(key, value);
248                }
249            }
250            Ok(template_settings)
251        }
252    }
253
254    deserializer.deserialize_any(TemplateSettingsVisitor)
255}
256
257/// Template argument structure for deserialization
258#[derive(Debug, Clone, Deserialize, Default)]
259struct TemplateArgument {
260    pub key: String,
261    pub value: String,
262}
263
264pub fn deserialize_template_args_as_hash_map<'de, D>(
265    deserializer: D,
266) -> Result<Option<HashMap<String, String>>, D::Error>
267where
268    D: serde::Deserializer<'de>,
269{
270    struct TemplateArgsVisitor;
271
272    impl<'de> Visitor<'de> for TemplateArgsVisitor {
273        type Value = Option<HashMap<String, String>>;
274
275        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
276            formatter.write_str("a sequence of template arguments or a map of key-value pairs")
277        }
278
279        fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
280        where
281            S: SeqAccess<'de>,
282        {
283            let mut template_args = HashMap::new();
284
285            while let Some(item) = seq.next_element::<TemplateArgument>()? {
286                template_args.insert(item.key, item.value);
287            }
288
289            if template_args.is_empty() {
290                Ok(None)
291            } else {
292                Ok(Some(template_args))
293            }
294        }
295
296        fn visit_none<E>(self) -> Result<Self::Value, E>
297        where
298            E: serde::de::Error,
299        {
300            Ok(None)
301        }
302
303        fn visit_unit<E>(self) -> Result<Self::Value, E>
304        where
305            E: serde::de::Error,
306        {
307            Ok(None)
308        }
309
310        fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
311        where
312            M: MapAccess<'de>,
313        {
314            let mut template_args = HashMap::new();
315
316            while let Some((key, value)) = map.next_entry::<String, String>()? {
317                template_args.insert(key, value);
318            }
319
320            if template_args.is_empty() {
321                Ok(None)
322            } else {
323                Ok(Some(template_args))
324            }
325        }
326    }
327
328    deserializer.deserialize_any(TemplateArgsVisitor)
329}