1use std::collections::HashMap;
11
12use aes_gcm::{
13 aead::{Aead, KeyInit},
14 Aes256Gcm, Key, Nonce,
15};
16use anyhow::{bail, Context, Result};
17use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20
21use crate::oauth::OAuthCredential;
22
23#[derive(Debug, Serialize, Deserialize)]
28pub struct SyncBundle {
29 pub config_toml: String,
30 pub accounts: HashMap<String, OAuthCredential>,
31}
32
33pub fn generate_code() -> String {
39 let bytes = crate::oauth::rand_bytes::<9>();
40 format!("SH-{}", hex::encode(bytes))
41}
42
43pub fn generate_remote_code() -> String {
45 let bytes = crate::oauth::rand_bytes::<9>();
46 format!("RM-{}", hex::encode(bytes))
47}
48
49pub fn validate_remote_code(code: &str) -> Result<()> {
51 if !code.starts_with("RM-") || code.len() != 21 {
52 anyhow::bail!("Invalid remote code format. Expected RM-<18 hex chars>.");
53 }
54 if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
55 anyhow::bail!("Invalid remote code — must be hex characters after 'RM-'.");
56 }
57 Ok(())
58}
59
60pub fn validate_code(code: &str) -> Result<()> {
62 if !code.starts_with("SH-") || code.len() != 21 {
63 bail!("Invalid transfer code format. Expected SH-<18 hex chars> (e.g. SH-a3f2b1c4d5e6f7a8b9).");
64 }
65 if !code[3..].chars().all(|c| c.is_ascii_hexdigit()) {
66 bail!("Invalid transfer code — must be hex characters after 'SH-'.");
67 }
68 Ok(())
69}
70
71fn derive_key(code: &str) -> [u8; 32] {
76 let hash = Sha256::digest(code.as_bytes());
77 hash.into()
78}
79
80pub fn encrypt_bundle(bundle: &SyncBundle, code: &str) -> Result<String> {
82 let json = serde_json::to_vec(bundle).context("failed to serialize bundle")?;
83
84 let key_bytes = derive_key(code);
85 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
86 let cipher = Aes256Gcm::new(key);
87
88 let nonce_bytes = crate::oauth::rand_bytes::<12>();
89 let nonce = Nonce::from_slice(&nonce_bytes);
90
91 let ciphertext = cipher
92 .encrypt(nonce, json.as_slice())
93 .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
94
95 let mut wire = Vec::with_capacity(12 + ciphertext.len());
97 wire.extend_from_slice(&nonce_bytes);
98 wire.extend_from_slice(&ciphertext);
99
100 Ok(B64.encode(wire))
101}
102
103pub fn decrypt_bundle(payload_b64: &str, code: &str) -> Result<SyncBundle> {
105 let wire = B64
106 .decode(payload_b64)
107 .context("invalid base64 in payload")?;
108
109 if wire.len() < 12 {
110 bail!("payload too short");
111 }
112
113 let (nonce_bytes, ciphertext) = wire.split_at(12);
114
115 let key_bytes = derive_key(code);
116 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
117 let cipher = Aes256Gcm::new(key);
118 let nonce = Nonce::from_slice(nonce_bytes);
119
120 let plaintext = cipher
121 .decrypt(nonce, ciphertext)
122 .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))?;
123
124 serde_json::from_slice::<SyncBundle>(&plaintext).context("failed to deserialize bundle")
125}
126
127pub async fn push_to_relay(code: &str, payload: &str, relay_url: &str) -> Result<()> {
133 let client = reqwest::Client::builder()
134 .timeout(std::time::Duration::from_secs(15))
135 .build()?;
136
137 let body = serde_json::json!({ "code": code, "payload": payload });
138
139 let resp = client
140 .post(format!("{relay_url}/bundle"))
141 .json(&body)
142 .send()
143 .await
144 .context("failed to reach relay")?;
145
146 if !resp.status().is_success() {
147 let status = resp.status();
148 let text = resp.text().await.unwrap_or_default();
149 bail!("relay returned {status}: {text}");
150 }
151
152 Ok(())
153}
154
155pub fn encrypt_bytes(data: &[u8], code: &str) -> Result<String> {
158 let key_bytes = derive_key(code);
159 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
160 let cipher = Aes256Gcm::new(key);
161 let nonce_bytes = crate::oauth::rand_bytes::<12>();
162 let nonce = Nonce::from_slice(&nonce_bytes);
163 let ciphertext = cipher
164 .encrypt(nonce, data)
165 .map_err(|e| anyhow::anyhow!("encryption failed: {e}"))?;
166 let mut wire = Vec::with_capacity(12 + ciphertext.len());
167 wire.extend_from_slice(&nonce_bytes);
168 wire.extend_from_slice(&ciphertext);
169 Ok(B64.encode(wire))
170}
171
172pub fn decrypt_bytes(payload_b64: &str, code: &str) -> Result<Vec<u8>> {
174 let wire = B64.decode(payload_b64).context("invalid base64 in payload")?;
175 if wire.len() < 12 { anyhow::bail!("payload too short"); }
176 let (nonce_bytes, ciphertext) = wire.split_at(12);
177 let key_bytes = derive_key(code);
178 let key = Key::<Aes256Gcm>::from_slice(&key_bytes);
179 let cipher = Aes256Gcm::new(key);
180 let nonce = Nonce::from_slice(nonce_bytes);
181 cipher
182 .decrypt(nonce, ciphertext)
183 .map_err(|_| anyhow::anyhow!("decryption failed — wrong code or corrupted payload"))
184}
185
186pub async fn pull_from_relay(code: &str, relay_url: &str) -> Result<String> {
189 let client = reqwest::Client::builder()
190 .timeout(std::time::Duration::from_secs(15))
191 .build()?;
192
193 let resp = client
194 .get(format!("{relay_url}/bundle/{code}"))
195 .send()
196 .await
197 .context("failed to reach relay")?;
198
199 if resp.status() == reqwest::StatusCode::NOT_FOUND {
200 bail!("Code not found or already used. Codes are one-time use — run `shunt push` again to get a new one.");
201 }
202
203 if !resp.status().is_success() {
204 let status = resp.status();
205 let text = resp.text().await.unwrap_or_default();
206 bail!("relay returned {status}: {text}");
207 }
208
209 let json: serde_json::Value = resp.json().await.context("invalid response from relay")?;
210 json["payload"]
211 .as_str()
212 .map(|s| s.to_owned())
213 .context("relay response missing 'payload' field")
214}