taplo_common/
util.rs

1use globset::{Glob, GlobSetBuilder};
2use percent_encoding::percent_decode_str;
3use serde_json::Value;
4use std::{
5    borrow::Cow,
6    hash::{Hash, Hasher},
7    path::{Path, PathBuf},
8    sync::Arc,
9};
10
11#[derive(Debug, Clone)]
12pub struct GlobRule {
13    include: globset::GlobSet,
14    exclude: globset::GlobSet,
15}
16
17impl GlobRule {
18    pub fn new(
19        include: impl IntoIterator<Item = impl AsRef<str>>,
20        exclude: impl IntoIterator<Item = impl AsRef<str>>,
21    ) -> Result<Self, anyhow::Error> {
22        let mut inc = GlobSetBuilder::new();
23        for glob in include {
24            inc.add(Glob::new(glob.as_ref())?);
25        }
26
27        let mut exc = GlobSetBuilder::new();
28        for glob in exclude {
29            exc.add(Glob::new(glob.as_ref())?);
30        }
31
32        Ok(Self {
33            include: inc.build()?,
34            exclude: exc.build()?,
35        })
36    }
37
38    pub fn is_match(&self, text: impl AsRef<Path>) -> bool {
39        if !self.include.is_match(text.as_ref()) {
40            return false;
41        }
42
43        !self.exclude.is_match(text.as_ref())
44    }
45}
46
47#[derive(Eq)]
48pub struct ArcHashValue(pub Arc<Value>);
49
50impl Hash for ArcHashValue {
51    fn hash<H: Hasher>(&self, state: &mut H) {
52        HashValue(&self.0).hash(state);
53    }
54}
55
56impl PartialEq for ArcHashValue {
57    fn eq(&self, other: &Self) -> bool {
58        self.0 == other.0
59    }
60}
61
62#[derive(Eq)]
63pub struct HashValue<'v>(pub &'v Value);
64
65impl PartialEq for HashValue<'_> {
66    fn eq(&self, other: &Self) -> bool {
67        self.0 == other.0
68    }
69}
70
71impl Hash for HashValue<'_> {
72    fn hash<H: Hasher>(&self, state: &mut H) {
73        match &self.0 {
74            Value::Null => 0.hash(state),
75            Value::Bool(v) => v.hash(state),
76            Value::Number(v) => v.hash(state),
77            Value::String(v) => v.hash(state),
78            Value::Array(v) => {
79                for v in v {
80                    HashValue(v).hash(state);
81                }
82            }
83            Value::Object(v) => {
84                for (k, v) in v {
85                    k.hash(state);
86                    HashValue(v).hash(state);
87                }
88            }
89        }
90    }
91}
92
93pub trait Normalize {
94    /// Normalizing in the context of Taplo the following:
95    ///
96    /// - replaces `\` with `/` on windows
97    /// - decodes all percent-encoded characters
98    #[must_use]
99    fn normalize(self) -> Self;
100}
101
102impl Normalize for PathBuf {
103    fn normalize(self) -> Self {
104        match self.to_str() {
105            Some(s) => (*normalize_str(s)).into(),
106            None => self,
107        }
108    }
109}
110
111pub(crate) fn normalize_str(s: &str) -> Cow<str> {
112    let Some(percent_decoded) = percent_decode_str(s).decode_utf8().ok() else {
113        return s.into();
114    };
115
116    if cfg!(windows) {
117        percent_decoded.replace('\\', "/").into()
118    } else {
119        percent_decoded
120    }
121}
122
123#[cfg(all(not(target_arch = "wasm32"), feature = "reqwest"))]
124#[tracing::instrument]
125pub fn get_reqwest_client(timeout: std::time::Duration) -> Result<reqwest::Client, reqwest::Error> {
126    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
127    fn get_certs(
128        mut builder: reqwest::ClientBuilder,
129        path: &std::ffi::OsString,
130    ) -> reqwest::ClientBuilder {
131        fn get_cert(path: &Path) -> Result<reqwest::Certificate, anyhow::Error> {
132            let is_der = path.extension().is_some_and(|ext| ext == "der");
133            let buf = std::fs::read(path)?;
134            tracing::info!(
135                "Found a custom CA {}. Reading the CA...",
136                path.to_string_lossy()
137            );
138            if is_der {
139                Ok(reqwest::Certificate::from_der(&buf)?)
140            } else {
141                Ok(reqwest::Certificate::from_pem(&buf)?)
142            }
143        }
144
145        match get_cert(path.as_ref()) {
146            Ok(cert) => {
147                builder = builder.add_root_certificate(cert);
148                tracing::info!(?path, "Added the custom CA");
149            }
150            Err(err) => {
151                tracing::error!(error = %err, "Could not parse the custom CA");
152            }
153        }
154        builder
155    }
156    #[cfg(not(any(feature = "native-tls", feature = "rustls-tls")))]
157    fn get_certs(
158        builder: reqwest::ClientBuilder,
159        path: std::ffi::OsString,
160    ) -> reqwest::ClientBuilder {
161        tracing::error!(?path, "Could not load certs, taplo was built without TLS");
162        builder
163    }
164
165    let mut builder = reqwest::Client::builder().timeout(timeout);
166    if let Some(path) = std::env::var_os("TAPLO_EXTRA_CA_CERTS") {
167        builder = get_certs(builder, &path);
168    }
169    builder.build()
170}