1use super::AdvisorySource;
2use crate::error::{AdvisoryError, Result};
3use crate::models::{
4 Advisory, Affected, Event, Package, Range, RangeType, Reference, ReferenceType,
5};
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
9use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
10use serde::Deserialize;
11use serde_json::json;
12use std::time::Duration;
13use tracing::{debug, info, warn};
14
15pub struct GHSASource {
16 token: String,
17 client: ClientWithMiddleware,
18 api_url: String,
19}
20
21impl GHSASource {
22 pub fn new(token: String) -> Self {
23 let base_client = reqwest::Client::builder()
25 .timeout(Duration::from_secs(300))
26 .connect_timeout(Duration::from_secs(10))
27 .build()
28 .expect("Failed to build HTTP client");
29
30 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
31 let client = ClientBuilder::new(base_client)
32 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
33 .build();
34
35 Self {
36 token,
37 client,
38 api_url: "https://api.github.com/graphql".to_string(),
39 }
40 }
41
42 #[cfg(test)]
43 pub fn with_api_url(mut self, url: String) -> Self {
44 self.api_url = url;
45 self
46 }
47}
48
49#[async_trait]
50impl AdvisorySource for GHSASource {
51 async fn fetch(&self, since: Option<DateTime<Utc>>) -> Result<Vec<Advisory>> {
52 let mut advisories = Vec::new();
53 let mut cursor: Option<String> = None;
54 let mut page_count = 0;
55
56 info!(
57 "Starting GHSA sync{}",
58 since
59 .map(|d| format!(" since {}", d))
60 .unwrap_or_else(|| " (full)".to_string())
61 );
62
63 loop {
64 page_count += 1;
65
66 let query = r#"
67 query($cursor: String, $updatedSince: DateTime) {
68 securityAdvisories(first: 100, after: $cursor, updatedSince: $updatedSince) {
69 pageInfo {
70 hasNextPage
71 endCursor
72 }
73 nodes {
74 ghsaId
75 summary
76 description
77 publishedAt
78 updatedAt
79 references {
80 url
81 }
82 identifiers {
83 type
84 value
85 }
86 vulnerabilities(first: 100) {
87 nodes {
88 package {
89 name
90 ecosystem
91 }
92 vulnerableVersionRange
93 firstPatchedVersion {
94 identifier
95 }
96 }
97 }
98 }
99 }
100 }
101 "#;
102
103 let variables = if let Some(since_dt) = since {
104 json!({
105 "cursor": cursor,
106 "updatedSince": since_dt.to_rfc3339(),
107 })
108 } else {
109 json!({
110 "cursor": cursor,
111 "updatedSince": serde_json::Value::Null,
112 })
113 };
114
115 let body = serde_json::to_string(&json!({
116 "query": query,
117 "variables": variables
118 }))?;
119
120 let response = self
121 .client
122 .post(&self.api_url)
123 .header("Authorization", format!("Bearer {}", self.token))
124 .header("User-Agent", "vulnera-advisors")
125 .header("Content-Type", "application/json")
126 .body(body)
127 .send()
128 .await?;
129
130 if !response.status().is_success() {
131 let status = response.status();
132 let text = response.text().await?;
133 warn!("GHSA API error {}: {}", status, text);
134 return Err(AdvisoryError::source_fetch(
135 "GHSA",
136 format!("API returned {}: {}", status, text),
137 ));
138 }
139
140 let data: GraphQlResponse = response.json().await?;
141
142 if let Some(errors) = data.errors {
143 warn!("GraphQL errors: {:?}", errors);
144 return Err(AdvisoryError::source_fetch(
145 "GHSA",
146 format!("GraphQL errors: {:?}", errors),
147 ));
148 }
149
150 if let Some(data) = data.data {
151 for advisory_node in data.security_advisories.nodes {
152 let mut references: Vec<Reference> = advisory_node
154 .references
155 .iter()
156 .map(|r| Reference {
157 reference_type: ReferenceType::Web,
158 url: r.url.clone(),
159 })
160 .collect();
161
162 let mut aliases = Vec::new();
164 for id in &advisory_node.identifiers {
165 aliases.push(id.value.clone());
166 }
167
168 for id in &advisory_node.identifiers {
170 references.push(Reference {
171 reference_type: ReferenceType::Other,
172 url: format!("{}:{}", id.id_type, id.value),
173 });
174 }
175
176 let mut affected = Vec::new();
177 for vuln in advisory_node.vulnerabilities.nodes {
178 affected.push(Affected {
179 package: Package {
180 ecosystem: vuln.package.ecosystem,
181 name: vuln.package.name,
182 purl: None,
183 },
184 ranges: vec![Range {
185 range_type: RangeType::Ecosystem,
186 events: vec![
187 Event::Introduced("0".to_string()),
188 Event::Fixed(
189 vuln.first_patched_version
190 .map(|v| v.identifier)
191 .unwrap_or_else(|| "0.0.0".to_string()),
192 ),
193 ],
194 repo: None,
195 }],
196 versions: vec![],
197 ecosystem_specific: Some(json!({
198 "vulnerable_range": vuln.vulnerable_version_range
199 })),
200 database_specific: None,
201 });
202 }
203
204 advisories.push(Advisory {
205 id: advisory_node.ghsa_id,
206 summary: Some(advisory_node.summary),
207 details: Some(advisory_node.description),
208 affected,
209 references,
210 published: Some(advisory_node.published_at),
211 modified: Some(advisory_node.updated_at),
212 aliases: Some(aliases),
213 database_specific: Some(json!({ "source": "GHSA" })),
214 enrichment: None,
215 });
216 }
217
218 if data.security_advisories.page_info.has_next_page {
219 cursor = data.security_advisories.page_info.end_cursor;
220 if page_count % 10 == 0 {
221 info!(
222 "GHSA sync progress: {} pages, {} advisories so far",
223 page_count,
224 advisories.len()
225 );
226 }
227 debug!("Fetching next page of GHSA advisories...");
228 } else {
229 break;
230 }
231 } else {
232 break;
233 }
234 }
235
236 info!("Fetched {} advisories from GHSA", advisories.len());
237 Ok(advisories)
238 }
239
240 fn name(&self) -> &str {
241 "GHSA"
242 }
243}
244
245#[derive(Deserialize)]
246struct GraphQlResponse {
247 data: Option<Data>,
248 errors: Option<Vec<serde_json::Value>>,
249}
250
251#[derive(Deserialize)]
252struct Data {
253 #[serde(rename = "securityAdvisories")]
254 security_advisories: SecurityAdvisories,
255}
256
257#[derive(Deserialize)]
258struct SecurityAdvisories {
259 #[serde(rename = "pageInfo")]
260 page_info: PageInfo,
261 nodes: Vec<GhsaAdvisoryNode>,
262}
263
264#[derive(Deserialize)]
265struct PageInfo {
266 #[serde(rename = "hasNextPage")]
267 has_next_page: bool,
268 #[serde(rename = "endCursor")]
269 end_cursor: Option<String>,
270}
271
272#[derive(Deserialize)]
273struct GhsaAdvisoryNode {
274 #[serde(rename = "ghsaId")]
275 ghsa_id: String,
276 summary: String,
277 description: String,
278 #[serde(rename = "publishedAt")]
279 published_at: DateTime<Utc>,
280 #[serde(rename = "updatedAt")]
281 updated_at: DateTime<Utc>,
282 references: Vec<GhsaReference>,
283 identifiers: Vec<GhsaIdentifier>,
284 vulnerabilities: GhsaVulnerabilitiesConnection,
285}
286
287#[derive(Deserialize)]
288struct GhsaVulnerabilitiesConnection {
289 nodes: Vec<GhsaVulnerability>,
290}
291
292#[derive(Deserialize)]
293struct GhsaVulnerability {
294 package: GhsaPackage,
295 #[serde(rename = "vulnerableVersionRange")]
296 vulnerable_version_range: String,
297 #[serde(rename = "firstPatchedVersion")]
298 first_patched_version: Option<GhsaVersion>,
299}
300
301#[derive(Deserialize)]
302struct GhsaReference {
303 url: String,
304}
305
306#[derive(Deserialize)]
307struct GhsaIdentifier {
308 #[serde(rename = "type")]
309 id_type: String,
310 value: String,
311}
312
313#[derive(Deserialize)]
314struct GhsaPackage {
315 name: String,
316 ecosystem: String,
317}
318
319#[derive(Deserialize)]
320struct GhsaVersion {
321 identifier: String,
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use wiremock::matchers::{body_string_contains, method, path};
328 use wiremock::{Mock, MockServer, ResponseTemplate};
329
330 #[tokio::test]
331 async fn test_fetch_advisories_full() {
332 let mock_server = MockServer::start().await;
333 let source = GHSASource::new("fake_token".to_string()).with_api_url(mock_server.uri());
334
335 let response_body = json!({
336 "data": {
337 "securityAdvisories": {
338 "pageInfo": {
339 "hasNextPage": false,
340 "endCursor": null
341 },
342 "nodes": [
343 {
344 "ghsaId": "GHSA-xxxx-yyyy-zzzz",
345 "summary": "Test Advisory",
346 "description": "This is a test advisory",
347 "publishedAt": "2023-01-01T00:00:00Z",
348 "updatedAt": "2023-01-02T00:00:00Z",
349 "references": [
350 { "url": "https://example.com" }
351 ],
352 "identifiers": [
353 { "type": "CVE", "value": "CVE-2023-1234" }
354 ],
355 "vulnerabilities": {
356 "nodes": [
357 {
358 "package": {
359 "name": "test-package",
360 "ecosystem": "NPM"
361 },
362 "vulnerableVersionRange": "< 1.0.0",
363 "firstPatchedVersion": {
364 "identifier": "1.0.0"
365 }
366 }
367 ]
368 }
369 }
370 ]
371 }
372 }
373 });
374
375 Mock::given(method("POST"))
376 .and(path("/"))
377 .and(body_string_contains("securityAdvisories"))
378 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
379 .mount(&mock_server)
380 .await;
381
382 let advisories = source.fetch(None).await.unwrap();
383 assert_eq!(advisories.len(), 1);
384 assert_eq!(advisories[0].id, "GHSA-xxxx-yyyy-zzzz");
385 assert_eq!(advisories[0].affected.len(), 1);
386 assert_eq!(advisories[0].affected[0].package.name, "test-package");
387 }
388
389 #[tokio::test]
390 async fn test_fetch_advisories_since() {
391 let mock_server = MockServer::start().await;
392 let source = GHSASource::new("fake_token".to_string()).with_api_url(mock_server.uri());
393
394 let response_body = json!({
395 "data": {
396 "securityAdvisories": {
397 "pageInfo": {
398 "hasNextPage": false,
399 "endCursor": null
400 },
401 "nodes": []
402 }
403 }
404 });
405
406 Mock::given(method("POST"))
407 .and(path("/"))
408 .and(body_string_contains("updatedSince"))
409 .respond_with(ResponseTemplate::new(200).set_body_json(response_body))
410 .mount(&mock_server)
411 .await;
412
413 let since = Utc::now();
414 let advisories = source.fetch(Some(since)).await.unwrap();
415 assert_eq!(advisories.len(), 0);
416 }
417}