1use anyhow::Result;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18use uuid::Uuid;
19
20#[derive(Debug, Clone)]
22pub struct GitHubOAuthConfig {
23 pub client_id: String,
24 pub client_secret: String,
25 pub redirect_uri: String,
26 pub scope: String,
27}
28
29impl GitHubOAuthConfig {
30 pub fn from_env() -> Option<Self> {
31 Some(Self {
32 client_id: std::env::var("GITHUB_CLIENT_ID").ok()?,
33 client_secret: std::env::var("GITHUB_CLIENT_SECRET").ok()?,
34 redirect_uri: std::env::var("GITHUB_REDIRECT_URI")
35 .unwrap_or_else(|_| "http://localhost:28428/auth/github/callback".to_string()),
36 scope: std::env::var("GITHUB_SCOPE")
37 .unwrap_or_else(|_| "read:user user:email".to_string()),
38 })
39 }
40
41 pub fn authorization_url(&self, state: &str) -> String {
42 let encode = |s: &str| -> String {
44 s.chars()
45 .map(|c| match c {
46 'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => c.to_string(),
47 ' ' => "+".to_string(),
48 _ => format!("%{:02X}", c as u8),
49 })
50 .collect()
51 };
52 format!(
53 "https://github.com/login/oauth/authorize?client_id={}&redirect_uri={}&scope={}&state={}",
54 self.client_id,
55 encode(&self.redirect_uri),
56 encode(&self.scope),
57 state
58 )
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct GitHubUser {
65 pub id: i64,
66 pub login: String,
67 pub name: Option<String>,
68 pub email: Option<String>,
69 pub avatar_url: String,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct UserSession {
75 pub session_id: String,
76 pub github_user: GitHubUser,
77 pub created_at: chrono::DateTime<chrono::Utc>,
78 pub expires_at: chrono::DateTime<chrono::Utc>,
79 pub participant_id: Option<String>,
81}
82
83impl UserSession {
84 pub fn new(github_user: GitHubUser) -> Self {
85 let now = chrono::Utc::now();
86 Self {
87 session_id: Uuid::new_v4().to_string(),
88 github_user,
89 created_at: now,
90 expires_at: now + chrono::Duration::hours(24),
91 participant_id: None,
92 }
93 }
94
95 pub fn is_expired(&self) -> bool {
96 chrono::Utc::now() > self.expires_at
97 }
98
99 pub fn display_name(&self) -> String {
100 self.github_user
101 .name
102 .clone()
103 .unwrap_or_else(|| self.github_user.login.clone())
104 }
105}
106
107pub struct SessionStore {
109 sessions: HashMap<String, UserSession>,
110 pending_states: HashMap<String, chrono::DateTime<chrono::Utc>>,
112}
113
114impl SessionStore {
115 pub fn new() -> Self {
116 Self {
117 sessions: HashMap::new(),
118 pending_states: HashMap::new(),
119 }
120 }
121
122 pub fn create_oauth_state(&mut self) -> String {
124 let state = Uuid::new_v4().to_string();
125 self.pending_states.insert(state.clone(), chrono::Utc::now());
126 state
127 }
128
129 pub fn validate_oauth_state(&mut self, state: &str) -> bool {
131 if let Some(created) = self.pending_states.remove(state) {
132 chrono::Utc::now() - created < chrono::Duration::minutes(10)
134 } else {
135 false
136 }
137 }
138
139 pub fn create_session(&mut self, github_user: GitHubUser) -> UserSession {
141 let session = UserSession::new(github_user);
142 self.sessions.insert(session.session_id.clone(), session.clone());
143 session
144 }
145
146 pub fn get_session(&self, session_id: &str) -> Option<&UserSession> {
148 self.sessions.get(session_id).filter(|s| !s.is_expired())
149 }
150
151 pub fn get_session_mut(&mut self, session_id: &str) -> Option<&mut UserSession> {
153 self.sessions.get_mut(session_id).filter(|s| !s.is_expired())
154 }
155
156 pub fn remove_session(&mut self, session_id: &str) {
158 self.sessions.remove(session_id);
159 }
160
161 pub fn cleanup_expired(&mut self) {
163 let now = chrono::Utc::now();
164 self.sessions.retain(|_, s| s.expires_at > now);
165 self.pending_states
166 .retain(|_, created| now - *created < chrono::Duration::minutes(10));
167 }
168
169 pub fn session_count(&self) -> usize {
171 self.sessions.len()
172 }
173}
174
175impl Default for SessionStore {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181pub type SharedSessionStore = Arc<RwLock<SessionStore>>;
183
184pub fn create_session_store() -> SharedSessionStore {
186 Arc::new(RwLock::new(SessionStore::new()))
187}
188
189pub async fn exchange_code_for_token(
191 config: &GitHubOAuthConfig,
192 code: &str,
193) -> Result<String> {
194 #[derive(Deserialize)]
195 struct TokenResponse {
196 access_token: String,
197 }
198
199 let client = reqwest::Client::new();
200 let response = client
201 .post("https://github.com/login/oauth/access_token")
202 .header("Accept", "application/json")
203 .form(&[
204 ("client_id", &config.client_id),
205 ("client_secret", &config.client_secret),
206 ("code", &code.to_string()),
207 ("redirect_uri", &config.redirect_uri),
208 ])
209 .send()
210 .await?;
211
212 let token_response: TokenResponse = response.json().await?;
213 Ok(token_response.access_token)
214}
215
216pub async fn fetch_github_user(access_token: &str) -> Result<GitHubUser> {
218 let client = reqwest::Client::new();
219 let response = client
220 .get("https://api.github.com/user")
221 .header("Authorization", format!("Bearer {}", access_token))
222 .header("User-Agent", "Smart-Tree-Daemon")
223 .send()
224 .await?;
225
226 let user: GitHubUser = response.json().await?;
227 Ok(user)
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_session_creation() {
236 let user = GitHubUser {
237 id: 12345,
238 login: "testuser".to_string(),
239 name: Some("Test User".to_string()),
240 email: Some("test@example.com".to_string()),
241 avatar_url: "https://example.com/avatar.png".to_string(),
242 };
243
244 let session = UserSession::new(user);
245 assert!(!session.is_expired());
246 assert_eq!(session.display_name(), "Test User");
247 }
248
249 #[test]
250 fn test_oauth_state() {
251 let mut store = SessionStore::new();
252 let state = store.create_oauth_state();
253 assert!(store.validate_oauth_state(&state));
254 assert!(!store.validate_oauth_state(&state));
256 }
257}