vulnera_advisor/sources/
ghsa.rs

1use super::AdvisorySource;
2use crate::error::{AdvisoryError, Result};
3use crate::models::{
4    Advisory, Affected, Event, Package, Range, RangeType, Reference, ReferenceType,
5};
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
9use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
10use serde::Deserialize;
11use serde_json::json;
12use std::time::Duration;
13use tracing::{debug, info, warn};
14
15pub struct GHSASource {
16    token: String,
17    client: ClientWithMiddleware,
18    api_url: String,
19}
20
21impl GHSASource {
22    pub fn new(token: String) -> Self {
23        // Build client with timeout and retry policy
24        let base_client = reqwest::Client::builder()
25            .timeout(Duration::from_secs(300))
26            .connect_timeout(Duration::from_secs(10))
27            .build()
28            .expect("Failed to build HTTP client");
29
30        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
31        let client = ClientBuilder::new(base_client)
32            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
33            .build();
34
35        Self {
36            token,
37            client,
38            api_url: "https://api.github.com/graphql".to_string(),
39        }
40    }
41
42    #[cfg(test)]
43    pub fn with_api_url(mut self, url: String) -> Self {
44        self.api_url = url;
45        self
46    }
47}
48
49#[async_trait]
50impl AdvisorySource for GHSASource {
51    async fn fetch(&self, since: Option<DateTime<Utc>>) -> Result<Vec<Advisory>> {
52        let mut advisories = Vec::new();
53        let mut cursor: Option<String> = None;
54        let mut page_count = 0;
55
56        info!(
57            "Starting GHSA sync{}",
58            since
59                .map(|d| format!(" since {}", d))
60                .unwrap_or_else(|| " (full)".to_string())
61        );
62
63        loop {
64            page_count += 1;
65
66            let query = r#"
67            query($cursor: String, $updatedSince: DateTime) {
68                securityAdvisories(first: 100, after: $cursor, updatedSince: $updatedSince) {
69                    pageInfo {
70                        hasNextPage
71                        endCursor
72                    }
73                    nodes {
74                        ghsaId
75                        summary
76                        description
77                        publishedAt
78                        updatedAt
79                        references {
80                            url
81                        }
82                        identifiers {
83                            type
84                            value
85                        }
86                        vulnerabilities(first: 100) {
87                            nodes {
88                                package {
89                                    name
90                                    ecosystem
91                                }
92                                vulnerableVersionRange
93                                firstPatchedVersion {
94                                    identifier
95                                }
96                            }
97                        }
98                    }
99                }
100            }
101            "#;
102
103            let variables = if let Some(since_dt) = since {
104                json!({
105                    "cursor": cursor,
106                    "updatedSince": since_dt.to_rfc3339(),
107                })
108            } else {
109                json!({
110                    "cursor": cursor,
111                    "updatedSince": serde_json::Value::Null,
112                })
113            };
114
115            let body = serde_json::to_string(&json!({
116                "query": query,
117                "variables": variables
118            }))?;
119
120            let response = self
121                .client
122                .post(&self.api_url)
123                .header("Authorization", format!("Bearer {}", self.token))
124                .header("User-Agent", "vulnera-advisors")
125                .header("Content-Type", "application/json")
126                .body(body)
127                .send()
128                .await?;
129
130            if !response.status().is_success() {
131                let status = response.status();
132                let text = response.text().await?;
133                warn!("GHSA API error {}: {}", status, text);
134                return Err(AdvisoryError::source_fetch(
135                    "GHSA",
136                    format!("API returned {}: {}", status, text),
137                ));
138            }
139
140            let data: GraphQlResponse = response.json().await?;
141
142            if let Some(errors) = data.errors {
143                warn!("GraphQL errors: {:?}", errors);
144                return Err(AdvisoryError::source_fetch(
145                    "GHSA",
146                    format!("GraphQL errors: {:?}", errors),
147                ));
148            }
149
150            if let Some(data) = data.data {
151                for advisory_node in data.security_advisories.nodes {
152                    // Map to canonical Advisory
153                    let mut references: Vec<Reference> = advisory_node
154                        .references
155                        .iter()
156                        .map(|r| Reference {
157                            reference_type: ReferenceType::Web,
158                            url: r.url.clone(),
159                        })
160                        .collect();
161
162                    // Add identifiers as aliases
163                    let mut aliases = Vec::new();
164                    for id in &advisory_node.identifiers {
165                        aliases.push(id.value.clone());
166                    }
167
168                    // Add identifiers as references/aliases
169                    for id in &advisory_node.identifiers {
170                        references.push(Reference {
171                            reference_type: ReferenceType::Other,
172                            url: format!("{}:{}", id.id_type, id.value),
173                        });
174                    }
175
176                    let mut affected = Vec::new();
177                    for vuln in advisory_node.vulnerabilities.nodes {
178                        affected.push(Affected {
179                            package: Package {
180                                ecosystem: vuln.package.ecosystem,
181                                name: vuln.package.name,
182                                purl: None,
183                            },
184                            ranges: vec![Range {
185                                range_type: RangeType::Ecosystem,
186                                events: vec![
187                                    Event::Introduced("0".to_string()),
188                                    Event::Fixed(
189                                        vuln.first_patched_version
190                                            .map(|v| v.identifier)
191                                            .unwrap_or_else(|| "0.0.0".to_string()),
192                                    ),
193                                ],
194                                repo: None,
195                            }],
196                            versions: vec![],
197                            ecosystem_specific: Some(json!({
198                                "vulnerable_range": vuln.vulnerable_version_range
199                            })),
200                            database_specific: None,
201                        });
202                    }
203
204                    advisories.push(Advisory {
205                        id: advisory_node.ghsa_id,
206                        summary: Some(advisory_node.summary),
207                        details: Some(advisory_node.description),
208                        affected,
209                        references,
210                        published: Some(advisory_node.published_at),
211                        modified: Some(advisory_node.updated_at),
212                        aliases: Some(aliases),
213                        database_specific: Some(json!({ "source": "GHSA" })),
214                        enrichment: None,
215                    });
216                }
217
218                if data.security_advisories.page_info.has_next_page {
219                    cursor = data.security_advisories.page_info.end_cursor;
220                    if page_count % 10 == 0 {
221                        info!(
222                            "GHSA sync progress: {} pages, {} advisories so far",
223                            page_count,
224                            advisories.len()
225                        );
226                    }
227                    debug!("Fetching next page of GHSA advisories...");
228                } else {
229                    break;
230                }
231            } else {
232                break;
233            }
234        }
235
236        info!("Fetched {} advisories from GHSA", advisories.len());
237        Ok(advisories)
238    }
239
240    fn name(&self) -> &str {
241        "GHSA"
242    }
243}
244
245#[derive(Deserialize)]
246struct GraphQlResponse {
247    data: Option<Data>,
248    errors: Option<Vec<serde_json::Value>>,
249}
250
251#[derive(Deserialize)]
252struct Data {
253    #[serde(rename = "securityAdvisories")]
254    security_advisories: SecurityAdvisories,
255}
256
257#[derive(Deserialize)]
258struct SecurityAdvisories {
259    #[serde(rename = "pageInfo")]
260    page_info: PageInfo,
261    nodes: Vec<GhsaAdvisoryNode>,
262}
263
264#[derive(Deserialize)]
265struct PageInfo {
266    #[serde(rename = "hasNextPage")]
267    has_next_page: bool,
268    #[serde(rename = "endCursor")]
269    end_cursor: Option<String>,
270}
271
272#[derive(Deserialize)]
273struct GhsaAdvisoryNode {
274    #[serde(rename = "ghsaId")]
275    ghsa_id: String,
276    summary: String,
277    description: String,
278    #[serde(rename = "publishedAt")]
279    published_at: DateTime<Utc>,
280    #[serde(rename = "updatedAt")]
281    updated_at: DateTime<Utc>,
282    references: Vec<GhsaReference>,
283    identifiers: Vec<GhsaIdentifier>,
284    vulnerabilities: GhsaVulnerabilitiesConnection,
285}
286
287#[derive(Deserialize)]
288struct GhsaVulnerabilitiesConnection {
289    nodes: Vec<GhsaVulnerability>,
290}
291
292#[derive(Deserialize)]
293struct GhsaVulnerability {
294    package: GhsaPackage,
295    #[serde(rename = "vulnerableVersionRange")]
296    vulnerable_version_range: String,
297    #[serde(rename = "firstPatchedVersion")]
298    first_patched_version: Option<GhsaVersion>,
299}
300
301#[derive(Deserialize)]
302struct GhsaReference {
303    url: String,
304}
305
306#[derive(Deserialize)]
307struct GhsaIdentifier {
308    #[serde(rename = "type")]
309    id_type: String,
310    value: String,
311}
312
313#[derive(Deserialize)]
314struct GhsaPackage {
315    name: String,
316    ecosystem: String,
317}
318
319#[derive(Deserialize)]
320struct GhsaVersion {
321    identifier: String,
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use wiremock::matchers::{body_string_contains, method, path};
328    use wiremock::{Mock, MockServer, ResponseTemplate};
329
330    #[tokio::test]
331    async fn test_fetch_advisories_full() {
332        let mock_server = MockServer::start().await;
333        let source = GHSASource::new("fake_token".to_string()).with_api_url(mock_server.uri());
334
335        let response_body = json!({
336            "data": {
337                "securityAdvisories": {
338                    "pageInfo": {
339                        "hasNextPage": false,
340                        "endCursor": null
341                    },
342                    "nodes": [
343                        {
344                            "ghsaId": "GHSA-xxxx-yyyy-zzzz",
345                            "summary": "Test Advisory",
346                            "description": "This is a test advisory",
347                            "publishedAt": "2023-01-01T00:00:00Z",
348                            "updatedAt": "2023-01-02T00:00:00Z",
349                            "references": [
350                                { "url": "https://example.com" }
351                            ],
352                            "identifiers": [
353                                { "type": "CVE", "value": "CVE-2023-1234" }
354                            ],
355                            "vulnerabilities": {
356                                "nodes": [
357                                    {
358                                        "package": {
359                                            "name": "test-package",
360                                            "ecosystem": "NPM"
361                                        },
362                                        "vulnerableVersionRange": "< 1.0.0",
363                                        "firstPatchedVersion": {
364                                            "identifier": "1.0.0"
365                                        }
366                                    }
367                                ]
368                            }
369                        }
370                    ]
371                }
372            }
373        });
374
375        Mock::given(method("POST"))
376            .and(path("/"))
377            .and(body_string_contains("securityAdvisories"))
378            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
379            .mount(&mock_server)
380            .await;
381
382        let advisories = source.fetch(None).await.unwrap();
383        assert_eq!(advisories.len(), 1);
384        assert_eq!(advisories[0].id, "GHSA-xxxx-yyyy-zzzz");
385        assert_eq!(advisories[0].affected.len(), 1);
386        assert_eq!(advisories[0].affected[0].package.name, "test-package");
387    }
388
389    #[tokio::test]
390    async fn test_fetch_advisories_since() {
391        let mock_server = MockServer::start().await;
392        let source = GHSASource::new("fake_token".to_string()).with_api_url(mock_server.uri());
393
394        let response_body = json!({
395            "data": {
396                "securityAdvisories": {
397                    "pageInfo": {
398                        "hasNextPage": false,
399                        "endCursor": null
400                    },
401                    "nodes": []
402                }
403            }
404        });
405
406        Mock::given(method("POST"))
407            .and(path("/"))
408            .and(body_string_contains("updatedSince"))
409            .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
410            .mount(&mock_server)
411            .await;
412
413        let since = Utc::now();
414        let advisories = source.fetch(Some(since)).await.unwrap();
415        assert_eq!(advisories.len(), 0);
416    }
417}