Skip to main content

ward/github/
branch_protection.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3
4use super::Client;
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
7pub struct BranchProtectionState {
8    pub required_pull_request_reviews: bool,
9    pub required_approving_review_count: u32,
10    pub dismiss_stale_reviews: bool,
11    pub require_code_owner_reviews: bool,
12    pub required_status_checks: bool,
13    pub strict_status_checks: bool,
14    pub enforce_admins: bool,
15    pub required_linear_history: bool,
16    pub allow_force_pushes: bool,
17    pub allow_deletions: bool,
18}
19
20#[derive(Debug, Deserialize)]
21struct BranchProtectionResponse {
22    #[serde(default)]
23    required_pull_request_reviews: Option<PullRequestReviewConfig>,
24    #[serde(default)]
25    required_status_checks: Option<StatusChecksConfig>,
26    #[serde(default)]
27    enforce_admins: Option<EnforceAdmins>,
28    #[serde(default)]
29    required_linear_history: Option<EnabledFlag>,
30    #[serde(default)]
31    allow_force_pushes: Option<EnabledFlag>,
32    #[serde(default)]
33    allow_deletions: Option<EnabledFlag>,
34}
35
36#[derive(Debug, Deserialize)]
37struct PullRequestReviewConfig {
38    #[serde(default)]
39    required_approving_review_count: u32,
40    #[serde(default)]
41    dismiss_stale_reviews: bool,
42    #[serde(default)]
43    require_code_owner_reviews: bool,
44}
45
46#[derive(Debug, Deserialize)]
47struct StatusChecksConfig {
48    #[serde(default)]
49    strict: bool,
50}
51
52#[derive(Debug, Deserialize)]
53struct EnforceAdmins {
54    enabled: bool,
55}
56
57#[derive(Debug, Deserialize)]
58struct EnabledFlag {
59    enabled: bool,
60}
61
62impl Client {
63    pub async fn get_branch_protection(
64        &self,
65        repo: &str,
66        branch: &str,
67    ) -> Result<Option<BranchProtectionState>> {
68        let resp = self
69            .get(&format!(
70                "/repos/{}/{repo}/branches/{branch}/protection",
71                self.org
72            ))
73            .await?;
74
75        if resp.status().as_u16() == 404 {
76            return Ok(None);
77        }
78
79        let status = resp.status();
80        if !status.is_success() {
81            let body = resp.text().await.unwrap_or_default();
82            anyhow::bail!(
83                "Failed to get branch protection for {repo}/{branch} (HTTP {status}): {body}"
84            );
85        }
86
87        let body: BranchProtectionResponse = resp.json().await?;
88
89        let state = BranchProtectionState {
90            required_pull_request_reviews: body.required_pull_request_reviews.is_some(),
91            required_approving_review_count: body
92                .required_pull_request_reviews
93                .as_ref()
94                .map(|r| r.required_approving_review_count)
95                .unwrap_or(0),
96            dismiss_stale_reviews: body
97                .required_pull_request_reviews
98                .as_ref()
99                .is_some_and(|r| r.dismiss_stale_reviews),
100            require_code_owner_reviews: body
101                .required_pull_request_reviews
102                .as_ref()
103                .is_some_and(|r| r.require_code_owner_reviews),
104            required_status_checks: body.required_status_checks.is_some(),
105            strict_status_checks: body
106                .required_status_checks
107                .as_ref()
108                .is_some_and(|r| r.strict),
109            enforce_admins: body.enforce_admins.as_ref().is_some_and(|e| e.enabled),
110            required_linear_history: body
111                .required_linear_history
112                .as_ref()
113                .is_some_and(|f| f.enabled),
114            allow_force_pushes: body.allow_force_pushes.as_ref().is_some_and(|f| f.enabled),
115            allow_deletions: body.allow_deletions.as_ref().is_some_and(|f| f.enabled),
116        };
117
118        Ok(Some(state))
119    }
120
121    pub async fn update_branch_protection(
122        &self,
123        repo: &str,
124        branch: &str,
125        config: &crate::config::manifest::BranchProtectionConfig,
126    ) -> Result<()> {
127        let required_status_checks = if config.require_status_checks {
128            serde_json::json!({
129                "strict": config.strict_status_checks,
130                "contexts": []
131            })
132        } else {
133            serde_json::Value::Null
134        };
135
136        let required_pull_request_reviews = if config.enabled {
137            serde_json::json!({
138                "required_approving_review_count": config.required_approvals,
139                "dismiss_stale_reviews": config.dismiss_stale_reviews,
140                "require_code_owner_reviews": config.require_code_owner_reviews
141            })
142        } else {
143            serde_json::Value::Null
144        };
145
146        let body = serde_json::json!({
147            "required_status_checks": required_status_checks,
148            "required_pull_request_reviews": required_pull_request_reviews,
149            "enforce_admins": config.enforce_admins,
150            "restrictions": null,
151            "required_linear_history": config.required_linear_history,
152            "allow_force_pushes": config.allow_force_pushes,
153            "allow_deletions": config.allow_deletions
154        });
155
156        let resp = self
157            .put_json(
158                &format!("/repos/{}/{repo}/branches/{branch}/protection", self.org),
159                &body,
160            )
161            .await?;
162
163        let status = resp.status();
164        if status.is_success() {
165            Ok(())
166        } else {
167            let body = resp.text().await.unwrap_or_default();
168            anyhow::bail!(
169                "Failed to update branch protection for {repo}/{branch} (HTTP {status}): {body}"
170            );
171        }
172    }
173}