update_pypi_deps/
lib.rs

1#![warn(clippy::pedantic)]
2
3use Error as Err;
4
5use std::collections::HashMap;
6use std::convert::Infallible;
7use std::fmt::Display;
8use std::io::{self, Write};
9use std::ops::{Deref, DerefMut};
10use std::result;
11use std::sync::Arc;
12
13use clap::Parser;
14use thiserror;
15use tokio::sync::Semaphore;
16use tracing::{debug, warn};
17
18mod pypi;
19
20pub type Result<T> = result::Result<T, Error>;
21
22#[derive(thiserror::Error, Debug)]
23pub enum Error {
24    #[error(transparent)]
25    Io(#[from] io::Error),
26
27    #[error(transparent)]
28    Join(#[from] tokio::task::JoinError),
29
30    #[error("error parsing dependencies: {}", .0)]
31    ParseDeps(&'static str),
32
33    #[error(transparent)]
34    ParseFilter(#[from] tracing_subscriber::filter::ParseError),
35
36    #[error("parsing input as toml failed")]
37    ParseToml(#[from] toml::de::Error),
38
39    #[error(transparent)]
40    Reqwest(#[from] reqwest::Error),
41
42    #[error("error serializing toml: {}", .0)]
43    SerializeToml(&'static str),
44
45    #[error("unknown error: {}", .0)]
46    Unknown(String),
47}
48
49impl From<Infallible> for Error {
50    fn from(_: Infallible) -> Self {
51        unreachable!()
52    }
53}
54
55#[derive(Parser, Debug)]
56#[command(author, version, about)]
57pub struct Config {
58    /// File from which to parse dependencies
59    #[arg(short, long, default_value = "pyproject.toml")]
60    input: String,
61
62    /// Number of PyPI requests to send in parallel
63    #[arg(short, long, default_value = "10")]
64    requests: usize,
65}
66
67#[derive(Clone, Debug)]
68struct PypiDeps {
69    dependencies: Dependencies,
70    optional_dependencies: HashMap<String, Dependencies>,
71}
72
73impl TryFrom<TomlTable> for PypiDeps {
74    type Error = Error;
75
76    fn try_from(toml: TomlTable) -> Result<Self> {
77        let project = toml
78            .get("project")
79            .ok_or_else(|| Err::ParseDeps("no such section: project"))?;
80        let dependencies: Dependencies = project
81            .get("dependencies")
82            .ok_or_else(|| {
83                Err::ParseDeps("no such section: project.dependencies")
84            })?
85            .try_into()?;
86
87        let optional_dependencies: HashMap<_, _> = project
88            .get("optional-dependencies")
89            .and_then(|v| v.as_table())
90            .map_or(Ok(HashMap::new()), |t| {
91                t.iter()
92                    .map(|(k, v)| {
93                        let deps: Dependencies = v.try_into()?;
94                        Ok((k.clone(), deps))
95                    })
96                    .collect::<Result<_>>()
97            })?;
98
99        Ok(Self {
100            dependencies,
101            optional_dependencies,
102        })
103    }
104}
105
106#[derive(Clone, Debug)]
107struct Dependencies(Vec<(String, Option<(String, String)>)>);
108
109impl TryFrom<&toml::Value> for Dependencies {
110    type Error = Error;
111
112    fn try_from(value: &toml::Value) -> Result<Self> {
113        use Err::ParseDeps;
114        let Some(vals) = value.as_array() else {
115            return Err(ParseDeps("value was not array"));
116        };
117        let vec: Vec<_> = vals
118            .iter()
119            .map(|line| {
120                let line = line.to_string();
121                let trimmed = line.trim_matches('"');
122                debug!("processing line: {trimmed}");
123
124                let pat = {
125                    // https://packaging.python.org/en/latest/specifications/version-specifiers/#id4
126                    // Ordering is important as we return the first match
127                    let pats = ["===", "~=", "==", "!=", "<=", ">=", "<", ">"];
128                    pats.into_iter().find(|&pat| trimmed.contains(pat) )
129                };
130
131                // No version, name only
132                let Some(pat) = pat else {
133                    return Ok((trimmed.to_owned(), None));
134                };
135                let mut splitter = trimmed.split(pat);
136
137                let Some(name) = splitter.next().map(str::to_string) else {
138                    return Err(ParseDeps("line without valid dependency name"));
139                };
140                let Some(version) = splitter.next().map(str::to_string) else {
141                    return Err(ParseDeps(
142                        "dependency with constraint (e.g. `==`) but no version",
143                    ));
144                };
145                Ok((name, Some((pat.to_owned(), version))))
146            })
147            .collect::<Result<_>>()?;
148        Ok(Self(vec))
149    }
150}
151
152async fn fetch_latest_versions(
153    deps: &PypiDeps,
154    concurrency: usize,
155) -> Result<HashMap<String, String>> {
156    let semaphore = Arc::new(Semaphore::new(concurrency));
157
158    let mut handles = Vec::new();
159    for (name, _) in &deps.dependencies.0 {
160        let semaphore = semaphore.clone();
161        let name = name.clone();
162
163        let handle = tokio::spawn(async move {
164            let _permit = semaphore.acquire().await.unwrap();
165            let version = pypi::find_latest(&name).await;
166            drop(_permit);
167            (name, version)
168        });
169        handles.push(handle);
170    }
171
172    for opt_deps in deps.optional_dependencies.values() {
173        for (name, _) in &opt_deps.0 {
174            let semaphore = semaphore.clone();
175            let name = name.clone();
176
177            let handle = tokio::spawn(async move {
178                let _permit = semaphore.acquire().await.unwrap();
179                let version = pypi::find_latest(&name).await;
180                drop(_permit);
181                (name, version)
182            });
183            handles.push(handle);
184        }
185    }
186
187    let mut latest_versions = HashMap::new();
188    for handle in handles {
189        match handle.await? {
190            (name, Ok(version)) => {
191                latest_versions.insert(name, version);
192            }
193            (name, Err(e)) => {
194                warn!("unable to find latest version of {name} due to {e}");
195                continue;
196            }
197        }
198    }
199    Ok(latest_versions)
200}
201
202fn update_versions(
203    deps: &mut PypiDeps,
204    latest_versions: &HashMap<String, String>,
205) -> Result<()> {
206    for (name, constraints) in deps.dependencies.0.iter_mut() {
207        let constraint = constraints
208            .as_ref()
209            .map(|(c, _)| c.to_string())
210            .unwrap_or("==".to_string());
211        let latest = latest_versions.get(name).ok_or_else(|| Error::Unknown(
212            format!("dependency {} should already be in map of latest dependencies but not found:\n{:?}", name, latest_versions)))?;
213        *constraints = Some((constraint, latest.clone()));
214    }
215
216    for (_, opt_deps) in deps.optional_dependencies.iter_mut() {
217        for (name, constraints) in opt_deps.0.iter_mut() {
218            let constraint = constraints
219                .as_ref()
220                .map(|(c, _)| c.to_string())
221                .unwrap_or("==".to_string());
222            let latest = latest_versions.get(name).ok_or_else(|| Error::Unknown(
223            format!("dependency {} should already be in map of latest dependencies but not found:\n{:?}", name, latest_versions)))?;
224            *constraints = Some((constraint, latest.clone()));
225        }
226    }
227    Ok(())
228}
229
230#[derive(Clone)]
231struct TomlTable(toml::Table);
232
233impl Deref for TomlTable {
234    type Target = toml::Table;
235
236    fn deref(&self) -> &Self::Target {
237        &self.0
238    }
239}
240
241impl DerefMut for TomlTable {
242    fn deref_mut(&mut self) -> &mut Self::Target {
243        &mut self.0
244    }
245}
246
247impl Display for PypiDeps {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        if !self.dependencies.0.is_empty() {
250            writeln!(f, "dependencies = [")?;
251            for dep in &self.dependencies.0 {
252                match dep {
253                    (ref name, Some((constraint, version))) => {
254                        writeln!(f, "    \"{name}{constraint}{version}\"")?;
255                    }
256                    (ref name, None) => {
257                        writeln!(f, "    \"{name}\"")?;
258                    }
259                }
260            }
261            writeln!(f, "]")?;
262        }
263
264        if self.optional_dependencies.is_empty() {
265            return Ok(());
266        }
267
268        for (opt_name, deps) in &self.optional_dependencies {
269            writeln!(f, "\n{opt_name} = [")?;
270            for dep in &deps.0 {
271                match dep {
272                    (ref name, Some((constraint, version))) => {
273                        writeln!(f, "    \"{name}{constraint}{version}\"")?;
274                    }
275                    (ref name, None) => {
276                        writeln!(f, "    \"{name}\"")?;
277                    }
278                }
279            }
280            writeln!(f, "]")?;
281        }
282        Ok(())
283    }
284}
285
286#[tracing::instrument]
287#[tokio::main]
288pub async fn run() -> Result<()> {
289    let config = Config::parse();
290
291    let input = std::fs::read_to_string(config.input)?;
292    let toml = TomlTable(input.parse()?);
293    let mut deps: PypiDeps = toml.try_into()?;
294    let latest_versions =
295        fetch_latest_versions(&deps, config.requests).await?;
296
297    update_versions(&mut deps, &latest_versions)?;
298
299    write!(io::stdout(), "Newest versions:\n\n{deps}")?;
300
301    Ok(())
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_parse_deps() {
310        let deps: PypiDeps = TomlTable(
311            std::fs::read_to_string("tests/files/pyproject.toml")
312                .unwrap()
313                .parse()
314                .unwrap(),
315        )
316        .try_into()
317        .unwrap();
318
319        assert!(deps.dependencies.0.contains(&(
320            "cryptography".to_string(),
321            Some(("~=".to_string(), "41.0".to_string()))
322        )));
323
324        let test_deps = deps.optional_dependencies.get("test").unwrap();
325        assert!(test_deps.0.contains(&(
326            "black".to_string(),
327            Some(("==".to_string(), "22.12.0".to_string()))
328        )));
329    }
330}