uv_workspace/
pyproject_mut.rs

1use std::fmt::{Display, Formatter};
2use std::path::Path;
3use std::str::FromStr;
4use std::{fmt, iter, mem};
5
6use itertools::Itertools;
7use serde::{Deserialize, Serialize};
8use thiserror::Error;
9use toml_edit::{
10    Array, ArrayOfTables, DocumentMut, Formatted, Item, RawString, Table, TomlError, Value,
11};
12
13use uv_cache_key::CanonicalUrl;
14use uv_distribution_types::Index;
15use uv_fs::PortablePath;
16use uv_normalize::{ExtraName, GroupName, PackageName};
17use uv_pep440::{Version, VersionParseError, VersionSpecifier, VersionSpecifiers};
18use uv_pep508::{MarkerTree, Requirement, VersionOrUrl};
19use uv_redacted::DisplaySafeUrl;
20
21use crate::pyproject::{DependencyType, Source};
22
23/// Raw and mutable representation of a `pyproject.toml`.
24///
25/// This is useful for operations that require editing an existing `pyproject.toml` while
26/// preserving comments and other structure, such as `uv add` and `uv remove`.
27pub struct PyProjectTomlMut {
28    doc: DocumentMut,
29    target: DependencyTarget,
30}
31
32#[derive(Error, Debug)]
33pub enum Error {
34    #[error("Failed to parse `pyproject.toml`")]
35    Parse(#[from] Box<TomlError>),
36    #[error("Failed to serialize `pyproject.toml`")]
37    Serialize(#[from] Box<toml::ser::Error>),
38    #[error("Failed to deserialize `pyproject.toml`")]
39    Deserialize(#[from] Box<toml::de::Error>),
40    #[error("Dependencies in `pyproject.toml` are malformed")]
41    MalformedDependencies,
42    #[error("Sources in `pyproject.toml` are malformed")]
43    MalformedSources,
44    #[error("Workspace in `pyproject.toml` is malformed")]
45    MalformedWorkspace,
46    #[error("Expected a dependency at index {0}")]
47    MissingDependency(usize),
48    #[error("Failed to parse `version` field of `pyproject.toml`")]
49    VersionParse(#[from] VersionParseError),
50    #[error("Cannot perform ambiguous update; found multiple entries for `{}`:\n{}", package_name, requirements.iter().map(|requirement| format!("- `{requirement}`")).join("\n"))]
51    Ambiguous {
52        package_name: PackageName,
53        requirements: Vec<Requirement>,
54    },
55    #[error("Unknown bound king {0}")]
56    UnknownBoundKind(String),
57}
58
59/// The result of editing an array in a TOML document.
60#[derive(Debug, Copy, Clone, PartialEq, Eq)]
61pub enum ArrayEdit {
62    /// An existing entry (at the given index) was updated.
63    Update(usize),
64    /// A new entry was added at the given index (typically, the end of the array).
65    Add(usize),
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
69enum CommentType {
70    /// A comment that appears on its own line.
71    OwnLine,
72    /// A comment that appears at the end of a line.
73    EndOfLine { leading_whitespace: String },
74}
75
76#[derive(Debug, Clone)]
77struct Comment {
78    text: String,
79    kind: CommentType,
80}
81
82impl ArrayEdit {
83    pub fn index(&self) -> usize {
84        match self {
85            Self::Update(i) | Self::Add(i) => *i,
86        }
87    }
88}
89
90/// The default version specifier when adding a dependency.
91// While PEP 440 allows an arbitrary number of version digits, the `major` and `minor` build on
92// most projects sticking to two or three components and a SemVer-ish versioning system, so can
93// bump the major or minor version of a major.minor or major.minor.patch input version.
94#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq, Serialize)]
95#[serde(rename_all = "kebab-case")]
96#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
97#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
98pub enum AddBoundsKind {
99    /// Only a lower bound, e.g., `>=1.2.3`.
100    #[default]
101    Lower,
102    /// Allow the same major version, similar to the semver caret, e.g., `>=1.2.3, <2.0.0`.
103    ///
104    /// Leading zeroes are skipped, e.g. `>=0.1.2, <0.2.0`.
105    Major,
106    /// Allow the same minor version, similar to the semver tilde, e.g., `>=1.2.3, <1.3.0`.
107    ///
108    /// Leading zeroes are skipped, e.g. `>=0.1.2, <0.1.3`.
109    Minor,
110    /// Pin the exact version, e.g., `==1.2.3`.
111    ///
112    /// This option is not recommended, as versions are already pinned in the uv lockfile.
113    Exact,
114}
115
116impl Display for AddBoundsKind {
117    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
118        match self {
119            Self::Lower => write!(f, "lower"),
120            Self::Major => write!(f, "major"),
121            Self::Minor => write!(f, "minor"),
122            Self::Exact => write!(f, "exact"),
123        }
124    }
125}
126
127impl AddBoundsKind {
128    fn specifiers(self, version: Version) -> VersionSpecifiers {
129        // Nomenclature: "major" is the most significant component of the version, "minor" is the
130        // second most significant component, so most versions are either major.minor.patch or
131        // 0.major.minor.
132        match self {
133            Self::Lower => {
134                VersionSpecifiers::from(VersionSpecifier::greater_than_equal_version(version))
135            }
136            Self::Major => {
137                let leading_zeroes = version
138                    .release()
139                    .iter()
140                    .take_while(|digit| **digit == 0)
141                    .count();
142
143                // Special case: The version is 0.
144                if leading_zeroes == version.release().len() {
145                    let upper_bound = Version::new(
146                        [0, 1]
147                            .into_iter()
148                            .chain(iter::repeat_n(0, version.release().iter().skip(2).len())),
149                    );
150                    return VersionSpecifiers::from_iter([
151                        VersionSpecifier::greater_than_equal_version(version),
152                        VersionSpecifier::less_than_version(upper_bound),
153                    ]);
154                }
155
156                // Compute the new major version and pad it to the same length:
157                // 1.2.3 -> 2.0.0
158                // 1.2 -> 2.0
159                // 1 -> 2
160                // We ignore leading zeroes, adding Semver-style semantics to 0.x versions, too:
161                // 0.1.2 -> 0.2.0
162                // 0.0.1 -> 0.0.2
163                let major = version.release().get(leading_zeroes).copied().unwrap_or(0);
164                // The length of the lower bound minus the leading zero and bumped component.
165                let trailing_zeros = version.release().iter().skip(leading_zeroes + 1).len();
166                let upper_bound = Version::new(
167                    iter::repeat_n(0, leading_zeroes)
168                        .chain(iter::once(major + 1))
169                        .chain(iter::repeat_n(0, trailing_zeros)),
170                );
171
172                VersionSpecifiers::from_iter([
173                    VersionSpecifier::greater_than_equal_version(version),
174                    VersionSpecifier::less_than_version(upper_bound),
175                ])
176            }
177            Self::Minor => {
178                let leading_zeroes = version
179                    .release()
180                    .iter()
181                    .take_while(|digit| **digit == 0)
182                    .count();
183
184                // Special case: The version is 0.
185                if leading_zeroes == version.release().len() {
186                    let upper_bound = [0, 0, 1]
187                        .into_iter()
188                        .chain(iter::repeat_n(0, version.release().iter().skip(3).len()));
189                    return VersionSpecifiers::from_iter([
190                        VersionSpecifier::greater_than_equal_version(version),
191                        VersionSpecifier::less_than_version(Version::new(upper_bound)),
192                    ]);
193                }
194
195                // If both major and minor version are 0, the concept of bumping the minor version
196                // instead of the major version is not useful. Instead, we bump the next
197                // non-zero part of the version. This avoids extending the three components of 0.0.1
198                // to the four components of 0.0.1.1.
199                if leading_zeroes >= 2 {
200                    let most_significant =
201                        version.release().get(leading_zeroes).copied().unwrap_or(0);
202                    // The length of the lower bound minus the leading zero and bumped component.
203                    let trailing_zeros = version.release().iter().skip(leading_zeroes + 1).len();
204                    let upper_bound = Version::new(
205                        iter::repeat_n(0, leading_zeroes)
206                            .chain(iter::once(most_significant + 1))
207                            .chain(iter::repeat_n(0, trailing_zeros)),
208                    );
209                    return VersionSpecifiers::from_iter([
210                        VersionSpecifier::greater_than_equal_version(version),
211                        VersionSpecifier::less_than_version(upper_bound),
212                    ]);
213                }
214
215                // Compute the new minor version and pad it to the same length where possible:
216                // 1.2.3 -> 1.3.0
217                // 1.2 -> 1.3
218                // 1 -> 1.1
219                // We ignore leading zero, adding Semver-style semantics to 0.x versions, too:
220                // 0.1.2 -> 0.1.3
221                // 0.0.1 -> 0.0.2
222
223                // If the version has only one digit, say `1`, or if there are only leading zeroes,
224                // pad with zeroes.
225                let major = version.release().get(leading_zeroes).copied().unwrap_or(0);
226                let minor = version
227                    .release()
228                    .get(leading_zeroes + 1)
229                    .copied()
230                    .unwrap_or(0);
231                let upper_bound = Version::new(
232                    iter::repeat_n(0, leading_zeroes)
233                        .chain(iter::once(major))
234                        .chain(iter::once(minor + 1))
235                        .chain(iter::repeat_n(
236                            0,
237                            version.release().iter().skip(leading_zeroes + 2).len(),
238                        )),
239                );
240
241                VersionSpecifiers::from_iter([
242                    VersionSpecifier::greater_than_equal_version(version),
243                    VersionSpecifier::less_than_version(upper_bound),
244                ])
245            }
246            Self::Exact => {
247                VersionSpecifiers::from_iter([VersionSpecifier::equals_version(version)])
248            }
249        }
250    }
251}
252
253/// Specifies whether dependencies are added to a script file or a `pyproject.toml` file.
254#[derive(Debug, Copy, Clone, PartialEq, Eq)]
255pub enum DependencyTarget {
256    /// A PEP 723 script, with inline metadata.
257    Script,
258    /// A project with a `pyproject.toml`.
259    PyProjectToml,
260}
261
262impl PyProjectTomlMut {
263    /// Initialize a [`PyProjectTomlMut`] from a [`str`].
264    pub fn from_toml(raw: &str, target: DependencyTarget) -> Result<Self, Error> {
265        Ok(Self {
266            doc: raw.parse().map_err(Box::new)?,
267            target,
268        })
269    }
270
271    /// Adds a project to the workspace.
272    pub fn add_workspace(&mut self, path: impl AsRef<Path>) -> Result<(), Error> {
273        // Get or create `tool.uv.workspace.members`.
274        let members = self
275            .doc
276            .entry("tool")
277            .or_insert(implicit())
278            .as_table_mut()
279            .ok_or(Error::MalformedWorkspace)?
280            .entry("uv")
281            .or_insert(implicit())
282            .as_table_mut()
283            .ok_or(Error::MalformedWorkspace)?
284            .entry("workspace")
285            .or_insert(Item::Table(Table::new()))
286            .as_table_mut()
287            .ok_or(Error::MalformedWorkspace)?
288            .entry("members")
289            .or_insert(Item::Value(Value::Array(Array::new())))
290            .as_array_mut()
291            .ok_or(Error::MalformedWorkspace)?;
292
293        // Add the path to the workspace.
294        members.push(PortablePath::from(path.as_ref()).to_string());
295
296        reformat_array_multiline(members);
297
298        Ok(())
299    }
300
301    /// Retrieves a mutable reference to the `project` [`Table`] of the TOML document, creating the
302    /// table if necessary.
303    ///
304    /// For a script, this returns the root table.
305    fn project(&mut self) -> Result<&mut Table, Error> {
306        let doc = match self.target {
307            DependencyTarget::Script => self.doc.as_table_mut(),
308            DependencyTarget::PyProjectToml => self
309                .doc
310                .entry("project")
311                .or_insert(Item::Table(Table::new()))
312                .as_table_mut()
313                .ok_or(Error::MalformedDependencies)?,
314        };
315        Ok(doc)
316    }
317
318    /// Retrieves an optional mutable reference to the `project` [`Table`], returning `None` if it
319    /// doesn't exist.
320    ///
321    /// For a script, this returns the root table.
322    fn project_mut(&mut self) -> Result<Option<&mut Table>, Error> {
323        let doc = match self.target {
324            DependencyTarget::Script => Some(self.doc.as_table_mut()),
325            DependencyTarget::PyProjectToml => self
326                .doc
327                .get_mut("project")
328                .map(|project| project.as_table_mut().ok_or(Error::MalformedSources))
329                .transpose()?,
330        };
331        Ok(doc)
332    }
333
334    /// Adds a dependency to `project.dependencies`.
335    ///
336    /// Returns `true` if the dependency was added, `false` if it was updated.
337    pub fn add_dependency(
338        &mut self,
339        req: &Requirement,
340        source: Option<&Source>,
341        raw: bool,
342    ) -> Result<ArrayEdit, Error> {
343        // Get or create `project.dependencies`.
344        let dependencies = self
345            .project()?
346            .entry("dependencies")
347            .or_insert(Item::Value(Value::Array(Array::new())))
348            .as_array_mut()
349            .ok_or(Error::MalformedDependencies)?;
350
351        let edit = add_dependency(req, dependencies, source.is_some(), raw)?;
352
353        if let Some(source) = source {
354            self.add_source(&req.name, source)?;
355        }
356
357        Ok(edit)
358    }
359
360    /// Adds a development dependency to `tool.uv.dev-dependencies`.
361    ///
362    /// Returns `true` if the dependency was added, `false` if it was updated.
363    pub fn add_dev_dependency(
364        &mut self,
365        req: &Requirement,
366        source: Option<&Source>,
367        raw: bool,
368    ) -> Result<ArrayEdit, Error> {
369        // Get or create `tool.uv.dev-dependencies`.
370        let dev_dependencies = self
371            .doc
372            .entry("tool")
373            .or_insert(implicit())
374            .as_table_mut()
375            .ok_or(Error::MalformedSources)?
376            .entry("uv")
377            .or_insert(Item::Table(Table::new()))
378            .as_table_mut()
379            .ok_or(Error::MalformedSources)?
380            .entry("dev-dependencies")
381            .or_insert(Item::Value(Value::Array(Array::new())))
382            .as_array_mut()
383            .ok_or(Error::MalformedDependencies)?;
384
385        let edit = add_dependency(req, dev_dependencies, source.is_some(), raw)?;
386
387        if let Some(source) = source {
388            self.add_source(&req.name, source)?;
389        }
390
391        Ok(edit)
392    }
393
394    /// Add an [`Index`] to `tool.uv.index`.
395    pub fn add_index(&mut self, index: &Index) -> Result<(), Error> {
396        let size = self.doc.len();
397        let existing = self
398            .doc
399            .entry("tool")
400            .or_insert(implicit())
401            .as_table_mut()
402            .ok_or(Error::MalformedSources)?
403            .entry("uv")
404            .or_insert(implicit())
405            .as_table_mut()
406            .ok_or(Error::MalformedSources)?
407            .entry("index")
408            .or_insert(Item::ArrayOfTables(ArrayOfTables::new()))
409            .as_array_of_tables_mut()
410            .ok_or(Error::MalformedSources)?;
411
412        // If there's already an index with the same name or URL, update it (and move it to the top).
413        let mut table = existing
414            .iter()
415            .find(|table| {
416                // If the index has the same name, reuse it.
417                if let Some(index) = index.name.as_deref() {
418                    if table
419                        .get("name")
420                        .and_then(|name| name.as_str())
421                        .is_some_and(|name| name == index)
422                    {
423                        return true;
424                    }
425                }
426
427                // If the index is the default, and there's another default index, reuse it.
428                if index.default
429                    && table
430                        .get("default")
431                        .is_some_and(|default| default.as_bool() == Some(true))
432                {
433                    return true;
434                }
435
436                // If there's another index with the same URL, reuse it.
437                if table
438                    .get("url")
439                    .and_then(|item| item.as_str())
440                    .and_then(|url| DisplaySafeUrl::parse(url).ok())
441                    .is_some_and(|url| {
442                        CanonicalUrl::new(&url) == CanonicalUrl::new(index.url.url())
443                    })
444                {
445                    return true;
446                }
447
448                false
449            })
450            .cloned()
451            .unwrap_or_default();
452
453        // If necessary, update the name.
454        if let Some(index) = index.name.as_deref() {
455            if table
456                .get("name")
457                .and_then(|name| name.as_str())
458                .is_none_or(|name| name != index)
459            {
460                let mut formatted = Formatted::new(index.to_string());
461                if let Some(value) = table.get("name").and_then(Item::as_value) {
462                    if let Some(prefix) = value.decor().prefix() {
463                        formatted.decor_mut().set_prefix(prefix.clone());
464                    }
465                    if let Some(suffix) = value.decor().suffix() {
466                        formatted.decor_mut().set_suffix(suffix.clone());
467                    }
468                }
469                table.insert("name", Value::String(formatted).into());
470            }
471        }
472
473        // If necessary, update the URL.
474        if table
475            .get("url")
476            .and_then(|item| item.as_str())
477            .is_none_or(|url| url != index.url.without_credentials().as_str())
478        {
479            let mut formatted = Formatted::new(index.url.without_credentials().to_string());
480            if let Some(value) = table.get("url").and_then(Item::as_value) {
481                if let Some(prefix) = value.decor().prefix() {
482                    formatted.decor_mut().set_prefix(prefix.clone());
483                }
484                if let Some(suffix) = value.decor().suffix() {
485                    formatted.decor_mut().set_suffix(suffix.clone());
486                }
487            }
488            table.insert("url", Value::String(formatted).into());
489        }
490
491        // If necessary, update the default.
492        if index.default {
493            if !table
494                .get("default")
495                .and_then(Item::as_bool)
496                .is_some_and(|default| default)
497            {
498                let mut formatted = Formatted::new(true);
499                if let Some(value) = table.get("default").and_then(Item::as_value) {
500                    if let Some(prefix) = value.decor().prefix() {
501                        formatted.decor_mut().set_prefix(prefix.clone());
502                    }
503                    if let Some(suffix) = value.decor().suffix() {
504                        formatted.decor_mut().set_suffix(suffix.clone());
505                    }
506                }
507                table.insert("default", Value::Boolean(formatted).into());
508            }
509        }
510
511        // Remove any replaced tables.
512        existing.retain(|table| {
513            // If the index has the same name, skip it.
514            if let Some(index) = index.name.as_deref() {
515                if table
516                    .get("name")
517                    .and_then(|name| name.as_str())
518                    .is_some_and(|name| name == index)
519                {
520                    return false;
521                }
522            }
523
524            // If there's another default index, skip it.
525            if index.default
526                && table
527                    .get("default")
528                    .is_some_and(|default| default.as_bool() == Some(true))
529            {
530                return false;
531            }
532
533            // If there's another index with the same URL, skip it.
534            if table
535                .get("url")
536                .and_then(|item| item.as_str())
537                .and_then(|url| DisplaySafeUrl::parse(url).ok())
538                .is_some_and(|url| CanonicalUrl::new(&url) == CanonicalUrl::new(index.url.url()))
539            {
540                return false;
541            }
542
543            true
544        });
545
546        // Set the position to the minimum, if it's not already the first element.
547        if let Some(min) = existing.iter().filter_map(Table::position).min() {
548            table.set_position(min);
549
550            // Increment the position of all existing elements.
551            for table in existing.iter_mut() {
552                if let Some(position) = table.position() {
553                    table.set_position(position + 1);
554                }
555            }
556        } else {
557            let position = isize::try_from(size).expect("TOML table size fits in `isize`");
558            table.set_position(position);
559        }
560
561        // Push the item to the table.
562        existing.push(table);
563
564        Ok(())
565    }
566
567    /// Adds a dependency to `project.optional-dependencies`.
568    ///
569    /// Returns `true` if the dependency was added, `false` if it was updated.
570    pub fn add_optional_dependency(
571        &mut self,
572        group: &ExtraName,
573        req: &Requirement,
574        source: Option<&Source>,
575        raw: bool,
576    ) -> Result<ArrayEdit, Error> {
577        // Get or create `project.optional-dependencies`.
578        let optional_dependencies = self
579            .project()?
580            .entry("optional-dependencies")
581            .or_insert(Item::Table(Table::new()))
582            .as_table_like_mut()
583            .ok_or(Error::MalformedDependencies)?;
584
585        // Try to find the existing group.
586        let existing_group = optional_dependencies.iter_mut().find_map(|(key, value)| {
587            if ExtraName::from_str(key.get()).is_ok_and(|g| g == *group) {
588                Some(value)
589            } else {
590                None
591            }
592        });
593
594        // If the group doesn't exist, create it.
595        let group = match existing_group {
596            Some(value) => value,
597            None => optional_dependencies
598                .entry(group.as_ref())
599                .or_insert(Item::Value(Value::Array(Array::new()))),
600        }
601        .as_array_mut()
602        .ok_or(Error::MalformedDependencies)?;
603
604        let added = add_dependency(req, group, source.is_some(), raw)?;
605
606        // If `project.optional-dependencies` is an inline table, reformat it.
607        //
608        // Reformatting can drop comments between keys, but you can't put comments
609        // between items in an inline table anyway.
610        if let Some(optional_dependencies) = self
611            .project()?
612            .get_mut("optional-dependencies")
613            .and_then(Item::as_inline_table_mut)
614        {
615            optional_dependencies.fmt();
616        }
617
618        if let Some(source) = source {
619            self.add_source(&req.name, source)?;
620        }
621
622        Ok(added)
623    }
624
625    /// Ensure that an optional dependency group exists, creating an empty group if it doesn't.
626    pub fn ensure_optional_dependency(&mut self, extra: &ExtraName) -> Result<(), Error> {
627        // Get or create `project.optional-dependencies`.
628        let optional_dependencies = self
629            .project()?
630            .entry("optional-dependencies")
631            .or_insert(Item::Table(Table::new()))
632            .as_table_like_mut()
633            .ok_or(Error::MalformedDependencies)?;
634
635        // Check if the extra already exists.
636        let extra_exists = optional_dependencies
637            .iter()
638            .any(|(key, _value)| ExtraName::from_str(key).is_ok_and(|e| e == *extra));
639
640        // If the extra doesn't exist, create it.
641        if !extra_exists {
642            optional_dependencies.insert(extra.as_ref(), Item::Value(Value::Array(Array::new())));
643        }
644
645        // If `project.optional-dependencies` is an inline table, reformat it.
646        //
647        // Reformatting can drop comments between keys, but you can't put comments
648        // between items in an inline table anyway.
649        if let Some(optional_dependencies) = self
650            .project()?
651            .get_mut("optional-dependencies")
652            .and_then(Item::as_inline_table_mut)
653        {
654            optional_dependencies.fmt();
655        }
656
657        Ok(())
658    }
659
660    /// Adds a dependency to `dependency-groups`.
661    ///
662    /// Returns `true` if the dependency was added, `false` if it was updated.
663    pub fn add_dependency_group_requirement(
664        &mut self,
665        group: &GroupName,
666        req: &Requirement,
667        source: Option<&Source>,
668        raw: bool,
669    ) -> Result<ArrayEdit, Error> {
670        // Get or create `dependency-groups`.
671        let dependency_groups = self
672            .doc
673            .entry("dependency-groups")
674            .or_insert(Item::Table(Table::new()))
675            .as_table_like_mut()
676            .ok_or(Error::MalformedDependencies)?;
677
678        let was_sorted = dependency_groups
679            .get_values()
680            .iter()
681            .filter_map(|(dotted_ks, _)| dotted_ks.first())
682            .map(|k| k.get())
683            .is_sorted();
684
685        // Try to find the existing group.
686        let existing_group = dependency_groups.iter_mut().find_map(|(key, value)| {
687            if GroupName::from_str(key.get()).is_ok_and(|g| g == *group) {
688                Some(value)
689            } else {
690                None
691            }
692        });
693
694        // If the group doesn't exist, create it.
695        let group = match existing_group {
696            Some(value) => value,
697            None => dependency_groups
698                .entry(group.as_ref())
699                .or_insert(Item::Value(Value::Array(Array::new()))),
700        }
701        .as_array_mut()
702        .ok_or(Error::MalformedDependencies)?;
703
704        let added = add_dependency(req, group, source.is_some(), raw)?;
705
706        // To avoid churn in pyproject.toml, we only sort new group keys if the
707        // existing keys were sorted.
708        if was_sorted {
709            dependency_groups.sort_values();
710        }
711
712        // If `dependency-groups` is an inline table, reformat it.
713        //
714        // Reformatting can drop comments between keys, but you can't put comments
715        // between items in an inline table anyway.
716        if let Some(dependency_groups) = self
717            .doc
718            .get_mut("dependency-groups")
719            .and_then(Item::as_inline_table_mut)
720        {
721            dependency_groups.fmt();
722        }
723
724        if let Some(source) = source {
725            self.add_source(&req.name, source)?;
726        }
727
728        Ok(added)
729    }
730
731    /// Ensure that a dependency group exists, creating an empty group if it doesn't.
732    pub fn ensure_dependency_group(&mut self, group: &GroupName) -> Result<(), Error> {
733        // Get or create `dependency-groups`.
734        let dependency_groups = self
735            .doc
736            .entry("dependency-groups")
737            .or_insert(Item::Table(Table::new()))
738            .as_table_like_mut()
739            .ok_or(Error::MalformedDependencies)?;
740
741        let was_sorted = dependency_groups
742            .get_values()
743            .iter()
744            .filter_map(|(dotted_ks, _)| dotted_ks.first())
745            .map(|k| k.get())
746            .is_sorted();
747
748        // Check if the group already exists.
749        let group_exists = dependency_groups
750            .iter()
751            .any(|(key, _value)| GroupName::from_str(key).is_ok_and(|g| g == *group));
752
753        // If the group doesn't exist, create it.
754        if !group_exists {
755            dependency_groups.insert(group.as_ref(), Item::Value(Value::Array(Array::new())));
756
757            // To avoid churn in pyproject.toml, we only sort new group keys if the
758            // existing keys were sorted.
759            if was_sorted {
760                dependency_groups.sort_values();
761            }
762        }
763
764        // If `dependency-groups` is an inline table, reformat it.
765        //
766        // Reformatting can drop comments between keys, but you can't put comments
767        // between items in an inline table anyway.
768        if let Some(dependency_groups) = self
769            .doc
770            .get_mut("dependency-groups")
771            .and_then(Item::as_inline_table_mut)
772        {
773            dependency_groups.fmt();
774        }
775
776        Ok(())
777    }
778
779    /// Set the constraint for a requirement for an existing dependency.
780    pub fn set_dependency_bound(
781        &mut self,
782        dependency_type: &DependencyType,
783        index: usize,
784        version: Version,
785        bound_kind: AddBoundsKind,
786    ) -> Result<(), Error> {
787        let group = match dependency_type {
788            DependencyType::Production => self.dependencies_array()?,
789            DependencyType::Dev => self.dev_dependencies_array()?,
790            DependencyType::Optional(extra) => self.optional_dependencies_array(extra)?,
791            DependencyType::Group(group) => self.dependency_groups_array(group)?,
792        };
793
794        let Some(req) = group.get(index) else {
795            return Err(Error::MissingDependency(index));
796        };
797
798        let mut req = req
799            .as_str()
800            .and_then(try_parse_requirement)
801            .ok_or(Error::MalformedDependencies)?;
802        req.version_or_url = Some(VersionOrUrl::VersionSpecifier(
803            bound_kind.specifiers(version),
804        ));
805        group.replace(index, req.to_string());
806
807        Ok(())
808    }
809
810    /// Get the TOML array for `project.dependencies`.
811    fn dependencies_array(&mut self) -> Result<&mut Array, Error> {
812        // Get or create `project.dependencies`.
813        let dependencies = self
814            .project()?
815            .entry("dependencies")
816            .or_insert(Item::Value(Value::Array(Array::new())))
817            .as_array_mut()
818            .ok_or(Error::MalformedDependencies)?;
819
820        Ok(dependencies)
821    }
822
823    /// Get the TOML array for `tool.uv.dev-dependencies`.
824    fn dev_dependencies_array(&mut self) -> Result<&mut Array, Error> {
825        // Get or create `tool.uv.dev-dependencies`.
826        let dev_dependencies = self
827            .doc
828            .entry("tool")
829            .or_insert(implicit())
830            .as_table_mut()
831            .ok_or(Error::MalformedSources)?
832            .entry("uv")
833            .or_insert(Item::Table(Table::new()))
834            .as_table_mut()
835            .ok_or(Error::MalformedSources)?
836            .entry("dev-dependencies")
837            .or_insert(Item::Value(Value::Array(Array::new())))
838            .as_array_mut()
839            .ok_or(Error::MalformedDependencies)?;
840
841        Ok(dev_dependencies)
842    }
843
844    /// Get the TOML array for a `project.optional-dependencies` entry.
845    fn optional_dependencies_array(&mut self, group: &ExtraName) -> Result<&mut Array, Error> {
846        // Get or create `project.optional-dependencies`.
847        let optional_dependencies = self
848            .project()?
849            .entry("optional-dependencies")
850            .or_insert(Item::Table(Table::new()))
851            .as_table_like_mut()
852            .ok_or(Error::MalformedDependencies)?;
853
854        // Try to find the existing extra.
855        let existing_key = optional_dependencies.iter().find_map(|(key, _value)| {
856            if ExtraName::from_str(key).is_ok_and(|g| g == *group) {
857                Some(key.to_string())
858            } else {
859                None
860            }
861        });
862
863        // If the group doesn't exist, create it.
864        let group = optional_dependencies
865            .entry(existing_key.as_deref().unwrap_or(group.as_ref()))
866            .or_insert(Item::Value(Value::Array(Array::new())))
867            .as_array_mut()
868            .ok_or(Error::MalformedDependencies)?;
869
870        Ok(group)
871    }
872
873    /// Get the TOML array for a `dependency-groups` entry.
874    fn dependency_groups_array(&mut self, group: &GroupName) -> Result<&mut Array, Error> {
875        // Get or create `dependency-groups`.
876        let dependency_groups = self
877            .doc
878            .entry("dependency-groups")
879            .or_insert(Item::Table(Table::new()))
880            .as_table_like_mut()
881            .ok_or(Error::MalformedDependencies)?;
882
883        // Try to find the existing group.
884        let existing_key = dependency_groups.iter().find_map(|(key, _value)| {
885            if GroupName::from_str(key).is_ok_and(|g| g == *group) {
886                Some(key.to_string())
887            } else {
888                None
889            }
890        });
891
892        // If the group doesn't exist, create it.
893        let group = dependency_groups
894            .entry(existing_key.as_deref().unwrap_or(group.as_ref()))
895            .or_insert(Item::Value(Value::Array(Array::new())))
896            .as_array_mut()
897            .ok_or(Error::MalformedDependencies)?;
898
899        Ok(group)
900    }
901
902    /// Adds a source to `tool.uv.sources`.
903    fn add_source(&mut self, name: &PackageName, source: &Source) -> Result<(), Error> {
904        // Get or create `tool.uv.sources`.
905        let sources = self
906            .doc
907            .entry("tool")
908            .or_insert(implicit())
909            .as_table_mut()
910            .ok_or(Error::MalformedSources)?
911            .entry("uv")
912            .or_insert(implicit())
913            .as_table_mut()
914            .ok_or(Error::MalformedSources)?
915            .entry("sources")
916            .or_insert(Item::Table(Table::new()))
917            .as_table_mut()
918            .ok_or(Error::MalformedSources)?;
919
920        if let Some(key) = find_source(name, sources) {
921            sources.remove(&key);
922        }
923        add_source(name, source, sources)?;
924
925        Ok(())
926    }
927
928    /// Removes all occurrences of dependencies with the given name.
929    pub fn remove_dependency(&mut self, name: &PackageName) -> Result<Vec<Requirement>, Error> {
930        // Try to get `project.dependencies`.
931        let Some(dependencies) = self
932            .project_mut()?
933            .and_then(|project| project.get_mut("dependencies"))
934            .map(|dependencies| {
935                dependencies
936                    .as_array_mut()
937                    .ok_or(Error::MalformedDependencies)
938            })
939            .transpose()?
940        else {
941            return Ok(Vec::new());
942        };
943
944        let requirements = remove_dependency(name, dependencies);
945        self.remove_source(name)?;
946
947        Ok(requirements)
948    }
949
950    /// Removes all occurrences of development dependencies with the given name.
951    pub fn remove_dev_dependency(&mut self, name: &PackageName) -> Result<Vec<Requirement>, Error> {
952        // Try to get `tool.uv.dev-dependencies`.
953        let Some(dev_dependencies) = self
954            .doc
955            .get_mut("tool")
956            .map(|tool| tool.as_table_mut().ok_or(Error::MalformedDependencies))
957            .transpose()?
958            .and_then(|tool| tool.get_mut("uv"))
959            .map(|tool_uv| tool_uv.as_table_mut().ok_or(Error::MalformedDependencies))
960            .transpose()?
961            .and_then(|tool_uv| tool_uv.get_mut("dev-dependencies"))
962            .map(|dependencies| {
963                dependencies
964                    .as_array_mut()
965                    .ok_or(Error::MalformedDependencies)
966            })
967            .transpose()?
968        else {
969            return Ok(Vec::new());
970        };
971
972        let requirements = remove_dependency(name, dev_dependencies);
973        self.remove_source(name)?;
974
975        Ok(requirements)
976    }
977
978    /// Removes all occurrences of optional dependencies in the group with the given name.
979    pub fn remove_optional_dependency(
980        &mut self,
981        name: &PackageName,
982        group: &ExtraName,
983    ) -> Result<Vec<Requirement>, Error> {
984        // Try to get `project.optional-dependencies.<group>`.
985        let Some(optional_dependencies) = self
986            .project_mut()?
987            .and_then(|project| project.get_mut("optional-dependencies"))
988            .map(|extras| {
989                extras
990                    .as_table_like_mut()
991                    .ok_or(Error::MalformedDependencies)
992            })
993            .transpose()?
994            .and_then(|extras| {
995                extras.iter_mut().find_map(|(key, value)| {
996                    if ExtraName::from_str(key.get()).is_ok_and(|g| g == *group) {
997                        Some(value)
998                    } else {
999                        None
1000                    }
1001                })
1002            })
1003            .map(|dependencies| {
1004                dependencies
1005                    .as_array_mut()
1006                    .ok_or(Error::MalformedDependencies)
1007            })
1008            .transpose()?
1009        else {
1010            return Ok(Vec::new());
1011        };
1012
1013        let requirements = remove_dependency(name, optional_dependencies);
1014        self.remove_source(name)?;
1015
1016        Ok(requirements)
1017    }
1018
1019    /// Removes all occurrences of the dependency in the group with the given name.
1020    pub fn remove_dependency_group_requirement(
1021        &mut self,
1022        name: &PackageName,
1023        group: &GroupName,
1024    ) -> Result<Vec<Requirement>, Error> {
1025        // Try to get `project.optional-dependencies.<group>`.
1026        let Some(group_dependencies) = self
1027            .doc
1028            .get_mut("dependency-groups")
1029            .map(|groups| {
1030                groups
1031                    .as_table_like_mut()
1032                    .ok_or(Error::MalformedDependencies)
1033            })
1034            .transpose()?
1035            .and_then(|groups| {
1036                groups.iter_mut().find_map(|(key, value)| {
1037                    if GroupName::from_str(key.get()).is_ok_and(|g| g == *group) {
1038                        Some(value)
1039                    } else {
1040                        None
1041                    }
1042                })
1043            })
1044            .map(|dependencies| {
1045                dependencies
1046                    .as_array_mut()
1047                    .ok_or(Error::MalformedDependencies)
1048            })
1049            .transpose()?
1050        else {
1051            return Ok(Vec::new());
1052        };
1053
1054        let requirements = remove_dependency(name, group_dependencies);
1055        self.remove_source(name)?;
1056
1057        Ok(requirements)
1058    }
1059
1060    /// Remove a matching source from `tool.uv.sources`, if it exists.
1061    fn remove_source(&mut self, name: &PackageName) -> Result<(), Error> {
1062        // If the dependency is still in use, don't remove the source.
1063        if !self.find_dependency(name, None).is_empty() {
1064            return Ok(());
1065        }
1066
1067        if let Some(sources) = self
1068            .doc
1069            .get_mut("tool")
1070            .map(|tool| tool.as_table_mut().ok_or(Error::MalformedSources))
1071            .transpose()?
1072            .and_then(|tool| tool.get_mut("uv"))
1073            .map(|tool_uv| tool_uv.as_table_mut().ok_or(Error::MalformedSources))
1074            .transpose()?
1075            .and_then(|tool_uv| tool_uv.get_mut("sources"))
1076            .map(|sources| sources.as_table_mut().ok_or(Error::MalformedSources))
1077            .transpose()?
1078        {
1079            if let Some(key) = find_source(name, sources) {
1080                sources.remove(&key);
1081
1082                // Remove the `tool.uv.sources` table if it is empty.
1083                if sources.is_empty() {
1084                    self.doc
1085                        .entry("tool")
1086                        .or_insert(implicit())
1087                        .as_table_mut()
1088                        .ok_or(Error::MalformedSources)?
1089                        .entry("uv")
1090                        .or_insert(implicit())
1091                        .as_table_mut()
1092                        .ok_or(Error::MalformedSources)?
1093                        .remove("sources");
1094                }
1095            }
1096        }
1097
1098        Ok(())
1099    }
1100
1101    /// Returns `true` if the `tool.uv.dev-dependencies` table is present.
1102    pub fn has_dev_dependencies(&self) -> bool {
1103        self.doc
1104            .get("tool")
1105            .and_then(Item::as_table)
1106            .and_then(|tool| tool.get("uv"))
1107            .and_then(Item::as_table)
1108            .and_then(|uv| uv.get("dev-dependencies"))
1109            .is_some()
1110    }
1111
1112    /// Returns `true` if the `dependency-groups` table is present and contains the given group.
1113    pub fn has_dependency_group(&self, group: &GroupName) -> bool {
1114        self.doc
1115            .get("dependency-groups")
1116            .and_then(Item::as_table)
1117            .and_then(|groups| groups.get(group.as_ref()))
1118            .is_some()
1119    }
1120
1121    /// Returns all the places in this `pyproject.toml` that contain a dependency with the given
1122    /// name.
1123    ///
1124    /// This method searches `project.dependencies`, `tool.uv.dev-dependencies`, and
1125    /// `tool.uv.optional-dependencies`.
1126    pub fn find_dependency(
1127        &self,
1128        name: &PackageName,
1129        marker: Option<&MarkerTree>,
1130    ) -> Vec<DependencyType> {
1131        let mut types = Vec::new();
1132
1133        if let Some(project) = self.doc.get("project").and_then(Item::as_table) {
1134            // Check `project.dependencies`.
1135            if let Some(dependencies) = project.get("dependencies").and_then(Item::as_array) {
1136                if !find_dependencies(name, marker, dependencies).is_empty() {
1137                    types.push(DependencyType::Production);
1138                }
1139            }
1140
1141            // Check `project.optional-dependencies`.
1142            if let Some(extras) = project
1143                .get("optional-dependencies")
1144                .and_then(Item::as_table)
1145            {
1146                for (extra, dependencies) in extras {
1147                    let Some(dependencies) = dependencies.as_array() else {
1148                        continue;
1149                    };
1150                    let Ok(extra) = ExtraName::from_str(extra) else {
1151                        continue;
1152                    };
1153
1154                    if !find_dependencies(name, marker, dependencies).is_empty() {
1155                        types.push(DependencyType::Optional(extra));
1156                    }
1157                }
1158            }
1159        }
1160
1161        // Check `dependency-groups`.
1162        if let Some(groups) = self.doc.get("dependency-groups").and_then(Item::as_table) {
1163            for (group, dependencies) in groups {
1164                let Some(dependencies) = dependencies.as_array() else {
1165                    continue;
1166                };
1167                let Ok(group) = GroupName::from_str(group) else {
1168                    continue;
1169                };
1170
1171                if !find_dependencies(name, marker, dependencies).is_empty() {
1172                    types.push(DependencyType::Group(group));
1173                }
1174            }
1175        }
1176
1177        // Check `tool.uv.dev-dependencies`.
1178        if let Some(dev_dependencies) = self
1179            .doc
1180            .get("tool")
1181            .and_then(Item::as_table)
1182            .and_then(|tool| tool.get("uv"))
1183            .and_then(Item::as_table)
1184            .and_then(|uv| uv.get("dev-dependencies"))
1185            .and_then(Item::as_array)
1186        {
1187            if !find_dependencies(name, marker, dev_dependencies).is_empty() {
1188                types.push(DependencyType::Dev);
1189            }
1190        }
1191
1192        types
1193    }
1194
1195    pub fn version(&mut self) -> Result<Version, Error> {
1196        let version = self
1197            .doc
1198            .get("project")
1199            .and_then(Item::as_table)
1200            .and_then(|project| project.get("version"))
1201            .and_then(Item::as_str)
1202            .ok_or(Error::MalformedWorkspace)?;
1203
1204        Ok(Version::from_str(version)?)
1205    }
1206
1207    pub fn has_dynamic_version(&mut self) -> bool {
1208        let Some(dynamic) = self
1209            .doc
1210            .get("project")
1211            .and_then(Item::as_table)
1212            .and_then(|project| project.get("dynamic"))
1213            .and_then(Item::as_array)
1214        else {
1215            return false;
1216        };
1217
1218        dynamic.iter().any(|val| val.as_str() == Some("version"))
1219    }
1220
1221    pub fn set_version(&mut self, version: &Version) -> Result<(), Error> {
1222        let project = self
1223            .doc
1224            .get_mut("project")
1225            .and_then(Item::as_table_mut)
1226            .ok_or(Error::MalformedWorkspace)?;
1227
1228        if let Some(existing) = project.get_mut("version") {
1229            if let Some(value) = existing.as_value_mut() {
1230                let mut formatted = Value::from(version.to_string());
1231                *formatted.decor_mut() = value.decor().clone();
1232                *value = formatted;
1233            } else {
1234                *existing = Item::Value(Value::from(version.to_string()));
1235            }
1236        } else {
1237            project.insert("version", Item::Value(Value::from(version.to_string())));
1238        }
1239
1240        Ok(())
1241    }
1242}
1243
1244/// Returns an implicit table.
1245fn implicit() -> Item {
1246    let mut table = Table::new();
1247    table.set_implicit(true);
1248    Item::Table(table)
1249}
1250
1251/// Adds a dependency to the given `deps` array.
1252///
1253/// Returns `true` if the dependency was added, `false` if it was updated.
1254pub fn add_dependency(
1255    req: &Requirement,
1256    deps: &mut Array,
1257    has_source: bool,
1258    raw: bool,
1259) -> Result<ArrayEdit, Error> {
1260    let mut to_replace = find_dependencies(&req.name, Some(&req.marker), deps);
1261
1262    match to_replace.as_slice() {
1263        [] => {
1264            #[derive(Debug, Copy, Clone)]
1265            enum Sort {
1266                /// The list is sorted in a case-insensitive manner.
1267                CaseInsensitive,
1268                /// The list is sorted naively in a case-insensitive manner.
1269                CaseInsensitiveNaive,
1270                /// The list is sorted in a case-sensitive manner.
1271                CaseSensitive,
1272                /// The list is sorted naively in a case-sensitive manner.
1273                CaseSensitiveNaive,
1274                /// The list is unsorted.
1275                Unsorted,
1276            }
1277
1278            fn is_sorted<T, I>(items: I) -> bool
1279            where
1280                I: IntoIterator<Item = T>,
1281                T: PartialOrd + Copy,
1282            {
1283                items.into_iter().tuple_windows().all(|(a, b)| a <= b)
1284            }
1285
1286            // `deps` are either requirements (strings) or include groups (inline tables).
1287            // Here we pull out just the requirements for determining the sort.
1288            let reqs: Vec<_> = deps.iter().filter_map(Value::as_str).collect();
1289            let reqs_lowercase: Vec<_> = reqs.iter().copied().map(str::to_lowercase).collect();
1290
1291            // Determine if the dependency list is sorted prior to
1292            // adding the new dependency; the new dependency list
1293            // will be sorted only when the original list is sorted
1294            // so that user's custom dependency ordering is preserved.
1295            //
1296            // Any items which aren't strings are ignored, e.g.
1297            // `{ include-group = "..." }` in dependency-groups.
1298            //
1299            // We account for both case-sensitive and case-insensitive sorting.
1300            let sort = if is_sorted(
1301                reqs_lowercase
1302                    .iter()
1303                    .map(String::as_str)
1304                    .map(split_specifiers),
1305            ) {
1306                Sort::CaseInsensitive
1307            } else if is_sorted(reqs.iter().copied().map(split_specifiers)) {
1308                Sort::CaseSensitive
1309            } else if is_sorted(reqs_lowercase.iter().map(String::as_str)) {
1310                Sort::CaseInsensitiveNaive
1311            } else if is_sorted(reqs) {
1312                Sort::CaseSensitiveNaive
1313            } else {
1314                Sort::Unsorted
1315            };
1316
1317            let req_string = if raw {
1318                req.displayable_with_credentials().to_string()
1319            } else {
1320                req.to_string()
1321            };
1322            let index = match sort {
1323                Sort::CaseInsensitive => deps.iter().position(|dep| {
1324                    dep.as_str().is_some_and(|dep| {
1325                        split_specifiers(&dep.to_lowercase())
1326                            > split_specifiers(&req_string.to_lowercase())
1327                    })
1328                }),
1329                Sort::CaseInsensitiveNaive => deps.iter().position(|dep| {
1330                    dep.as_str()
1331                        .is_some_and(|dep| dep.to_lowercase() > req_string.to_lowercase())
1332                }),
1333                Sort::CaseSensitive => deps.iter().position(|dep| {
1334                    dep.as_str()
1335                        .is_some_and(|dep| split_specifiers(dep) > split_specifiers(&req_string))
1336                }),
1337                Sort::CaseSensitiveNaive => deps
1338                    .iter()
1339                    .position(|dep| dep.as_str().is_some_and(|dep| *dep > *req_string)),
1340                Sort::Unsorted => None,
1341            };
1342            let index = index.unwrap_or_else(|| {
1343                // The dependency should be added to the end, ignoring any
1344                // `include-group` items. This preserves the order for users who
1345                // keep their `include-groups` at the bottom.
1346                deps.iter()
1347                    .enumerate()
1348                    .filter_map(|(i, dep)| if dep.is_str() { Some(i + 1) } else { None })
1349                    .last()
1350                    .unwrap_or(deps.len())
1351            });
1352
1353            let mut value = Value::from(req_string.as_str());
1354
1355            let decor = value.decor_mut();
1356
1357            // Ensure comments remain on the correct line, post-insertion
1358            match index {
1359                val if val == deps.len() => {
1360                    // If we're adding to the end of the list, treat trailing comments as leading comments
1361                    // on the added dependency.
1362                    //
1363                    // For example, given:
1364                    // ```toml
1365                    // dependencies = [
1366                    //     "anyio", # trailing comment
1367                    // ]
1368                    // ```
1369                    //
1370                    // If we add `flask` to the end, we want to retain the comment on `anyio`:
1371                    // ```toml
1372                    // dependencies = [
1373                    //     "anyio", # trailing comment
1374                    //     "flask",
1375                    // ]
1376                    // ```
1377                    decor.set_prefix(deps.trailing().clone());
1378                    deps.set_trailing("");
1379                }
1380                0 => {
1381                    // If the dependency is prepended to a non-empty list, do nothing
1382                }
1383                val => {
1384                    // Retain position of end-of-line comments when a dependency is inserted right below it.
1385                    //
1386                    // For example, given:
1387                    // ```toml
1388                    // dependencies = [
1389                    //     "anyio", # end-of-line comment
1390                    //     "flask",
1391                    // ]
1392                    // ```
1393                    //
1394                    // If we add `pydantic` (between `anyio` and `flask`), we want to retain the comment on `anyio`:
1395                    // ```toml
1396                    // dependencies = [
1397                    //     "anyio", # end-of-line comment
1398                    //     "pydantic",
1399                    //     "flask",
1400                    // ]
1401                    // ```
1402                    let targeted_decor = deps.get_mut(val).unwrap().decor_mut();
1403                    decor.set_prefix(targeted_decor.prefix().unwrap().clone());
1404                    targeted_decor.set_prefix(""); // Re-formatted later by `reformat_array_multiline`
1405                }
1406            }
1407
1408            deps.insert_formatted(index, value);
1409
1410            // `reformat_array_multiline` uses the indentation of the first dependency entry.
1411            // Therefore, we retrieve the indentation of the first dependency entry and apply it to
1412            // the new entry. Note that it is only necessary if the newly added dependency is going
1413            // to be the first in the list _and_ the dependency list was not empty prior to adding
1414            // the new dependency.
1415            if deps.len() > 1 && index == 0 {
1416                let prefix = deps
1417                    .clone()
1418                    .get(index + 1)
1419                    .unwrap()
1420                    .decor()
1421                    .prefix()
1422                    .unwrap()
1423                    .clone();
1424
1425                // However, if the prefix includes a comment, we don't want to duplicate it.
1426                // Depending on the location of the comment, we either want to leave it as-is, or
1427                // attach it to the entry that's being moved to the next line.
1428                //
1429                // For example, given:
1430                // ```toml
1431                // dependencies = [ # comment
1432                //     "flask",
1433                // ]
1434                // ```
1435                //
1436                // If we add `anyio` to the beginning, we want to retain the comment on the open
1437                // bracket:
1438                // ```toml
1439                // dependencies = [ # comment
1440                //     "anyio",
1441                //     "flask",
1442                // ]
1443                // ```
1444                //
1445                // However, given:
1446                // ```toml
1447                // dependencies = [
1448                //     # comment
1449                //     "flask",
1450                // ]
1451                // ```
1452                //
1453                // If we add `anyio` to the beginning, we want the comment to move down with the
1454                // existing entry:
1455                // entry:
1456                // ```toml
1457                // dependencies = [
1458                //     "anyio",
1459                //     # comment
1460                //     "flask",
1461                // ]
1462                if let Some(prefix) = prefix.as_str() {
1463                    // Treat anything before the first own-line comment as a prefix on the new
1464                    // entry; anything after the first own-line comment is a prefix on the existing
1465                    // entry.
1466                    //
1467                    // This is equivalent to using the first and last line content as the prefix for
1468                    // the new entry, and the rest as the prefix for the existing entry.
1469                    if let Some((first_line, rest)) = prefix.split_once(['\r', '\n']) {
1470                        // Determine the appropriate newline character.
1471                        let newline = {
1472                            let mut chars = prefix[first_line.len()..].chars();
1473                            match (chars.next(), chars.next()) {
1474                                (Some('\r'), Some('\n')) => "\r\n",
1475                                (Some('\r'), _) => "\r",
1476                                (Some('\n'), _) => "\n",
1477                                _ => "\n",
1478                            }
1479                        };
1480                        let last_line = rest.lines().last().unwrap_or_default();
1481                        let prefix = format!("{first_line}{newline}{last_line}");
1482                        deps.get_mut(index).unwrap().decor_mut().set_prefix(prefix);
1483
1484                        let prefix = format!("{newline}{rest}");
1485                        deps.get_mut(index + 1)
1486                            .unwrap()
1487                            .decor_mut()
1488                            .set_prefix(prefix);
1489                    } else {
1490                        deps.get_mut(index).unwrap().decor_mut().set_prefix(prefix);
1491                    }
1492                } else {
1493                    deps.get_mut(index).unwrap().decor_mut().set_prefix(prefix);
1494                }
1495            }
1496
1497            reformat_array_multiline(deps);
1498
1499            Ok(ArrayEdit::Add(index))
1500        }
1501        [_] => {
1502            let (i, mut old_req) = to_replace.remove(0);
1503            update_requirement(&mut old_req, req, has_source);
1504            deps.replace(i, old_req.to_string());
1505            reformat_array_multiline(deps);
1506            Ok(ArrayEdit::Update(i))
1507        }
1508        // Cannot perform ambiguous updates.
1509        _ => Err(Error::Ambiguous {
1510            package_name: req.name.clone(),
1511            requirements: to_replace
1512                .into_iter()
1513                .map(|(_, requirement)| requirement)
1514                .collect(),
1515        }),
1516    }
1517}
1518
1519/// Update an existing requirement.
1520fn update_requirement(old: &mut Requirement, new: &Requirement, has_source: bool) {
1521    // Add any new extras.
1522    let mut extras = old.extras.to_vec();
1523    extras.extend(new.extras.iter().cloned());
1524    extras.sort_unstable();
1525    extras.dedup();
1526    old.extras = extras.into_boxed_slice();
1527
1528    // Clear the requirement source if we are going to add to `tool.uv.sources`.
1529    if has_source {
1530        old.clear_url();
1531    }
1532
1533    // Update the source if a new one was specified.
1534    match &new.version_or_url {
1535        None => {}
1536        Some(VersionOrUrl::VersionSpecifier(specifier)) if specifier.is_empty() => {}
1537        Some(version_or_url) => old.version_or_url = Some(version_or_url.clone()),
1538    }
1539
1540    // Update the marker expression.
1541    if new.marker.contents().is_some() {
1542        old.marker = new.marker;
1543    }
1544}
1545
1546/// Removes all occurrences of dependencies with the given name from the given `deps` array.
1547fn remove_dependency(name: &PackageName, deps: &mut Array) -> Vec<Requirement> {
1548    // Remove matching dependencies.
1549    let removed = find_dependencies(name, None, deps)
1550        .into_iter()
1551        .rev() // Reverse to preserve indices as we remove them.
1552        .filter_map(|(i, _)| {
1553            deps.remove(i)
1554                .as_str()
1555                .and_then(|req| Requirement::from_str(req).ok())
1556        })
1557        .collect::<Vec<_>>();
1558
1559    if !removed.is_empty() {
1560        reformat_array_multiline(deps);
1561    }
1562
1563    removed
1564}
1565
1566/// Returns a `Vec` containing the all dependencies with the given name, along with their positions
1567/// in the array.
1568fn find_dependencies(
1569    name: &PackageName,
1570    marker: Option<&MarkerTree>,
1571    deps: &Array,
1572) -> Vec<(usize, Requirement)> {
1573    let mut to_replace = Vec::new();
1574    for (i, dep) in deps.iter().enumerate() {
1575        if let Some(req) = dep.as_str().and_then(try_parse_requirement) {
1576            if marker.is_none_or(|m| *m == req.marker) && *name == req.name {
1577                to_replace.push((i, req));
1578            }
1579        }
1580    }
1581    to_replace
1582}
1583
1584/// Returns the key in `tool.uv.sources` that matches the given package name.
1585fn find_source(name: &PackageName, sources: &Table) -> Option<String> {
1586    for (key, _) in sources {
1587        if PackageName::from_str(key).is_ok_and(|ref key| key == name) {
1588            return Some(key.to_string());
1589        }
1590    }
1591    None
1592}
1593
1594// Add a source to `tool.uv.sources`.
1595fn add_source(req: &PackageName, source: &Source, sources: &mut Table) -> Result<(), Error> {
1596    // Serialize as an inline table.
1597    let mut doc = toml::to_string(&source)
1598        .map_err(Box::new)?
1599        .parse::<DocumentMut>()
1600        .unwrap();
1601    let table = mem::take(doc.as_table_mut()).into_inline_table();
1602
1603    sources.insert(req.as_ref(), Item::Value(Value::InlineTable(table)));
1604
1605    Ok(())
1606}
1607
1608impl fmt::Display for PyProjectTomlMut {
1609    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1610        self.doc.fmt(f)
1611    }
1612}
1613
1614fn try_parse_requirement(req: &str) -> Option<Requirement> {
1615    Requirement::from_str(req).ok()
1616}
1617
1618/// Reformats a TOML array to multi line while trying to preserve all comments
1619/// and move them around. This also formats the array to have a trailing comma.
1620fn reformat_array_multiline(deps: &mut Array) {
1621    fn find_comments(s: Option<&RawString>) -> Box<dyn Iterator<Item = Comment> + '_> {
1622        let iter = s
1623            .and_then(|x| x.as_str())
1624            .unwrap_or("")
1625            .lines()
1626            .scan(
1627                (false, false),
1628                |(prev_line_was_empty, prev_line_was_comment), line| {
1629                    let trimmed_line = line.trim();
1630
1631                    if let Some((before, comment)) = line.split_once('#') {
1632                        let comment_text = format!("#{}", comment.trim_end());
1633
1634                        let comment_kind = if (*prev_line_was_empty) || (*prev_line_was_comment) {
1635                            CommentType::OwnLine
1636                        } else {
1637                            CommentType::EndOfLine {
1638                                leading_whitespace: before
1639                                    .chars()
1640                                    .rev()
1641                                    .take_while(|c| c.is_whitespace())
1642                                    .collect::<String>()
1643                                    .chars()
1644                                    .rev()
1645                                    .collect(),
1646                            }
1647                        };
1648
1649                        *prev_line_was_empty = trimmed_line.is_empty();
1650                        *prev_line_was_comment = true;
1651
1652                        Some(Some(Comment {
1653                            text: comment_text,
1654                            kind: comment_kind,
1655                        }))
1656                    } else {
1657                        *prev_line_was_empty = trimmed_line.is_empty();
1658                        *prev_line_was_comment = false;
1659                        Some(None)
1660                    }
1661                },
1662            )
1663            .flatten();
1664
1665        Box::new(iter)
1666    }
1667
1668    let mut indentation_prefix = None;
1669
1670    // Calculate the indentation prefix based on the indentation of the first dependency entry.
1671    if let Some(first_item) = deps.iter().next() {
1672        let decor_prefix = first_item
1673            .decor()
1674            .prefix()
1675            .and_then(|s| s.as_str())
1676            .and_then(|s| s.lines().last())
1677            .unwrap_or_default();
1678
1679        let decor_prefix = decor_prefix
1680            .split_once('#')
1681            .map(|(s, _)| s)
1682            .unwrap_or(decor_prefix);
1683
1684        indentation_prefix = (!decor_prefix.is_empty()).then_some(decor_prefix.to_string());
1685    }
1686
1687    let indentation_prefix_str = format!("\n{}", indentation_prefix.as_deref().unwrap_or("    "));
1688
1689    for item in deps.iter_mut() {
1690        let decor = item.decor_mut();
1691        let mut prefix = String::new();
1692
1693        for comment in find_comments(decor.prefix()).chain(find_comments(decor.suffix())) {
1694            match &comment.kind {
1695                CommentType::OwnLine => {
1696                    prefix.push_str(&indentation_prefix_str);
1697                }
1698                CommentType::EndOfLine { leading_whitespace } => {
1699                    prefix.push_str(leading_whitespace);
1700                }
1701            }
1702            prefix.push_str(&comment.text);
1703        }
1704        prefix.push_str(&indentation_prefix_str);
1705        decor.set_prefix(prefix);
1706        decor.set_suffix("");
1707    }
1708
1709    deps.set_trailing(&{
1710        let mut comments = find_comments(Some(deps.trailing())).peekable();
1711        let mut rv = String::new();
1712        if comments.peek().is_some() {
1713            for comment in comments {
1714                match &comment.kind {
1715                    CommentType::OwnLine => {
1716                        let indentation_prefix_str =
1717                            format!("\n{}", indentation_prefix.as_deref().unwrap_or("    "));
1718                        rv.push_str(&indentation_prefix_str);
1719                    }
1720                    CommentType::EndOfLine { leading_whitespace } => {
1721                        rv.push_str(leading_whitespace);
1722                    }
1723                }
1724                rv.push_str(&comment.text);
1725            }
1726        }
1727        if !rv.is_empty() || !deps.is_empty() {
1728            rv.push('\n');
1729        }
1730        rv
1731    });
1732    deps.set_trailing_comma(true);
1733}
1734
1735/// Split a requirement into the package name and its dependency specifiers.
1736///
1737/// E.g., given `flask>=1.0`, this function returns `("flask", ">=1.0")`. But given
1738/// `Flask>=1.0`, this function returns `("Flask", ">=1.0")`.
1739///
1740/// Extras are retained, such that `flask[dotenv]>=1.0` returns `("flask[dotenv]", ">=1.0")`.
1741fn split_specifiers(req: &str) -> (&str, &str) {
1742    let (name, specifiers) = req
1743        .find(['>', '<', '=', '~', '!', '@'])
1744        .map_or((req, ""), |pos| {
1745            let (name, specifiers) = req.split_at(pos);
1746            (name, specifiers)
1747        });
1748    (name.trim(), specifiers.trim())
1749}
1750
1751#[cfg(test)]
1752mod test {
1753    use super::{AddBoundsKind, reformat_array_multiline, split_specifiers};
1754    use std::str::FromStr;
1755    use toml_edit::DocumentMut;
1756    use uv_pep440::Version;
1757
1758    #[test]
1759    fn split() {
1760        assert_eq!(split_specifiers("flask>=1.0"), ("flask", ">=1.0"));
1761        assert_eq!(split_specifiers("Flask>=1.0"), ("Flask", ">=1.0"));
1762        assert_eq!(
1763            split_specifiers("flask[dotenv]>=1.0"),
1764            ("flask[dotenv]", ">=1.0")
1765        );
1766        assert_eq!(split_specifiers("flask[dotenv]"), ("flask[dotenv]", ""));
1767        assert_eq!(
1768            split_specifiers(
1769                "flask @ https://files.pythonhosted.org/packages/af/47/93213ee66ef8fae3b93b3e29206f6b251e65c97bd91d8e1c5596ef15af0a/flask-3.1.0-py3-none-any.whl"
1770            ),
1771            (
1772                "flask",
1773                "@ https://files.pythonhosted.org/packages/af/47/93213ee66ef8fae3b93b3e29206f6b251e65c97bd91d8e1c5596ef15af0a/flask-3.1.0-py3-none-any.whl"
1774            )
1775        );
1776    }
1777
1778    #[test]
1779    fn reformat_preserves_inline_comment_spacing() {
1780        let mut doc: DocumentMut = r#"
1781[project]
1782dependencies = [
1783    "attrs>=25.4.0",     # comment
1784]
1785"#
1786        .parse()
1787        .unwrap();
1788
1789        reformat_array_multiline(
1790            doc["project"]["dependencies"]
1791                .as_array_mut()
1792                .expect("dependencies array"),
1793        );
1794
1795        let serialized = doc.to_string();
1796
1797        assert!(
1798            serialized.contains("\"attrs>=25.4.0\",     # comment"),
1799            "inline comment spacing should be preserved:\n{serialized}"
1800        );
1801    }
1802
1803    #[test]
1804    fn reformat_preserves_inline_comment_without_padding() {
1805        let mut doc: DocumentMut = r#"
1806[project]
1807dependencies = [
1808    "attrs>=25.4.0",#comment
1809]
1810"#
1811        .parse()
1812        .unwrap();
1813
1814        reformat_array_multiline(
1815            doc["project"]["dependencies"]
1816                .as_array_mut()
1817                .expect("dependencies array"),
1818        );
1819
1820        let serialized = doc.to_string();
1821
1822        assert!(
1823            serialized.contains("\"attrs>=25.4.0\",#comment"),
1824            "inline comment spacing without padding should be preserved:\n{serialized}"
1825        );
1826    }
1827
1828    #[test]
1829    fn bound_kind_to_specifiers_exact() {
1830        let tests = [
1831            ("0", "==0"),
1832            ("0.0", "==0.0"),
1833            ("0.0.0", "==0.0.0"),
1834            ("0.1", "==0.1"),
1835            ("0.0.1", "==0.0.1"),
1836            ("0.0.0.1", "==0.0.0.1"),
1837            ("1.0.0", "==1.0.0"),
1838            ("1.2", "==1.2"),
1839            ("1.2.3", "==1.2.3"),
1840            ("1.2.3.4", "==1.2.3.4"),
1841            ("1.2.3.4a1.post1", "==1.2.3.4a1.post1"),
1842        ];
1843
1844        for (version, expected) in tests {
1845            let actual = AddBoundsKind::Exact
1846                .specifiers(Version::from_str(version).unwrap())
1847                .to_string();
1848            assert_eq!(actual, expected, "{version}");
1849        }
1850    }
1851
1852    #[test]
1853    fn bound_kind_to_specifiers_lower() {
1854        let tests = [
1855            ("0", ">=0"),
1856            ("0.0", ">=0.0"),
1857            ("0.0.0", ">=0.0.0"),
1858            ("0.1", ">=0.1"),
1859            ("0.0.1", ">=0.0.1"),
1860            ("0.0.0.1", ">=0.0.0.1"),
1861            ("1", ">=1"),
1862            ("1.0.0", ">=1.0.0"),
1863            ("1.2", ">=1.2"),
1864            ("1.2.3", ">=1.2.3"),
1865            ("1.2.3.4", ">=1.2.3.4"),
1866            ("1.2.3.4a1.post1", ">=1.2.3.4a1.post1"),
1867        ];
1868
1869        for (version, expected) in tests {
1870            let actual = AddBoundsKind::Lower
1871                .specifiers(Version::from_str(version).unwrap())
1872                .to_string();
1873            assert_eq!(actual, expected, "{version}");
1874        }
1875    }
1876
1877    #[test]
1878    fn bound_kind_to_specifiers_major() {
1879        let tests = [
1880            ("0", ">=0, <0.1"),
1881            ("0.0", ">=0.0, <0.1"),
1882            ("0.0.0", ">=0.0.0, <0.1.0"),
1883            ("0.0.0.0", ">=0.0.0.0, <0.1.0.0"),
1884            ("0.1", ">=0.1, <0.2"),
1885            ("0.0.1", ">=0.0.1, <0.0.2"),
1886            ("0.0.1.1", ">=0.0.1.1, <0.0.2.0"),
1887            ("0.0.0.1", ">=0.0.0.1, <0.0.0.2"),
1888            ("1", ">=1, <2"),
1889            ("1.0.0", ">=1.0.0, <2.0.0"),
1890            ("1.2", ">=1.2, <2.0"),
1891            ("1.2.3", ">=1.2.3, <2.0.0"),
1892            ("1.2.3.4", ">=1.2.3.4, <2.0.0.0"),
1893            ("1.2.3.4a1.post1", ">=1.2.3.4a1.post1, <2.0.0.0"),
1894        ];
1895
1896        for (version, expected) in tests {
1897            let actual = AddBoundsKind::Major
1898                .specifiers(Version::from_str(version).unwrap())
1899                .to_string();
1900            assert_eq!(actual, expected, "{version}");
1901        }
1902    }
1903
1904    #[test]
1905    fn bound_kind_to_specifiers_minor() {
1906        let tests = [
1907            ("0", ">=0, <0.0.1"),
1908            ("0.0", ">=0.0, <0.0.1"),
1909            ("0.0.0", ">=0.0.0, <0.0.1"),
1910            ("0.0.0.0", ">=0.0.0.0, <0.0.1.0"),
1911            ("0.1", ">=0.1, <0.1.1"),
1912            ("0.0.1", ">=0.0.1, <0.0.2"),
1913            ("0.0.1.1", ">=0.0.1.1, <0.0.2.0"),
1914            ("0.0.0.1", ">=0.0.0.1, <0.0.0.2"),
1915            ("1", ">=1, <1.1"),
1916            ("1.0.0", ">=1.0.0, <1.1.0"),
1917            ("1.2", ">=1.2, <1.3"),
1918            ("1.2.3", ">=1.2.3, <1.3.0"),
1919            ("1.2.3.4", ">=1.2.3.4, <1.3.0.0"),
1920            ("1.2.3.4a1.post1", ">=1.2.3.4a1.post1, <1.3.0.0"),
1921        ];
1922
1923        for (version, expected) in tests {
1924            let actual = AddBoundsKind::Minor
1925                .specifiers(Version::from_str(version).unwrap())
1926                .to_string();
1927            assert_eq!(actual, expected, "{version}");
1928        }
1929    }
1930}