1use std::collections::HashMap;
11
12use aes_gcm::{
13 aead::{Aead, KeyInit},
14 Aes256Gcm, Key, Nonce,
15};
16use anyhow::{bail, Context, Result};
17use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20
21use crate::oauth::OAuthCredential;
22
23#[derive(Debug, Serialize, Deserialize)]
28pub struct SyncBundle {
29 pub config_toml: String,
30 pub accounts: HashMap<String, OAuthCredential>,
31}
32
33pub fn generate_code() -> String {
39 let bytes = crate::oauth::rand_bytes::<9>();
40 format!("SH-{}", hex::encode(bytes))
41}
42
43pub fn validate_code(code: &str) -> Result<()> {
45 if !code.starts_with("SH-") || code.len() != 21 {
46 bail!("Invalid transfer code format. Expected SH-<18 hex chars> (e.g. SH-a3f2b1c4d5e6f7a8b9).");
47 }
48 if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
49 bail!("Invalid transfer code — must be hex characters after 'SH-'.");
50 }
51 Ok(())
52}
53
54fn derive_key(code: &str) -> [u8; 32] {
59 let hash = Sha256::digest(code.as_bytes());
60 hash.into()
61}
62
63pub fn encrypt_bundle(bundle: &SyncBundle, code: &str) -> Result<String> {
65 let json = serde_json::to_vec(bundle).context("failed to serialize bundle")?;
66
67 let key_bytes = derive_key(code);
68 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
69 let cipher = Aes256Gcm::new(key);
70
71 let nonce_bytes = crate::oauth::rand_bytes::<12>();
72 let nonce = Nonce::from_slice(&nonce_bytes);
73
74 let ciphertext = cipher
75 .encrypt(nonce, json.as_slice())
76 .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
77
78 let mut wire = Vec::with_capacity(12 + ciphertext.len());
80 wire.extend_from_slice(&nonce_bytes);
81 wire.extend_from_slice(&ciphertext);
82
83 Ok(B64.encode(wire))
84}
85
86pub fn decrypt_bundle(payload_b64: &str, code: &str) -> Result<SyncBundle> {
88 let wire = B64
89 .decode(payload_b64)
90 .context("invalid base64 in payload")?;
91
92 if wire.len() < 12 {
93 bail!("payload too short");
94 }
95
96 let (nonce_bytes, ciphertext) = wire.split_at(12);
97
98 let key_bytes = derive_key(code);
99 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
100 let cipher = Aes256Gcm::new(key);
101 let nonce = Nonce::from_slice(nonce_bytes);
102
103 let plaintext = cipher
104 .decrypt(nonce, ciphertext)
105 .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))?;
106
107 serde_json::from_slice::<SyncBundle>(&plaintext).context("failed to deserialize bundle")
108}
109
110pub async fn push_to_relay(code: &str, payload: &str, relay_url: &str) -> Result<()> {
116 let client = reqwest::Client::builder()
117 .timeout(std::time::Duration::from_secs(15))
118 .build()?;
119
120 let body = serde_json::json!({ "code": code, "payload": payload });
121
122 let resp = client
123 .post(format!("{relay_url}/bundle"))
124 .json(&body)
125 .send()
126 .await
127 .context("failed to reach relay")?;
128
129 if !resp.status().is_success() {
130 let status = resp.status();
131 let text = resp.text().await.unwrap_or_default();
132 bail!("relay returned {status}: {text}");
133 }
134
135 Ok(())
136}
137
138pub async fn pull_from_relay(code: &str, relay_url: &str) -> Result<String> {
141 let client = reqwest::Client::builder()
142 .timeout(std::time::Duration::from_secs(15))
143 .build()?;
144
145 let resp = client
146 .get(format!("{relay_url}/bundle/{code}"))
147 .send()
148 .await
149 .context("failed to reach relay")?;
150
151 if resp.status() == reqwest::StatusCode::NOT_FOUND {
152 bail!("Code not found or already used. Codes are one-time use — run `shunt push` again to get a new one.");
153 }
154
155 if !resp.status().is_success() {
156 let status = resp.status();
157 let text = resp.text().await.unwrap_or_default();
158 bail!("relay returned {status}: {text}");
159 }
160
161 let json: serde_json::Value = resp.json().await.context("invalid response from relay")?;
162 json["payload"]
163 .as_str()
164 .map(|s| s.to_owned())
165 .context("relay response missing 'payload' field")
166}