Skip to main content

varpulis_cli/
billing.rs

1//! Stripe billing integration for Varpulis Cloud.
2//!
3//! Provides usage tracking, tier management, and Stripe Checkout/Portal
4//! integration via REST endpoints.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use axum::extract::{Json, State};
10use axum::http::{HeaderMap, StatusCode};
11use axum::response::{IntoResponse, Response};
12use axum::routing::{get, post};
13use axum::Router;
14#[cfg(feature = "saas")]
15use chrono::Datelike;
16use serde::{Deserialize, Serialize};
17use tokio::sync::RwLock;
18use uuid::Uuid;
19
20use crate::audit::{AuditAction, AuditEntry, SharedAuditLogger};
21
22// ---------------------------------------------------------------------------
23// Configuration
24// ---------------------------------------------------------------------------
25
26/// Billing configuration loaded from environment variables.
27#[derive(Debug, Clone)]
28pub struct BillingConfig {
29    pub stripe_secret_key: String,
30    pub stripe_webhook_secret: String,
31    pub pro_price_id: String,
32    pub business_price_id: String,
33    pub frontend_url: String,
34}
35
36impl BillingConfig {
37    /// Build config from environment variables.
38    /// Returns None if Stripe is not configured.
39    pub fn from_env() -> Option<Self> {
40        let secret_key = std::env::var("STRIPE_SECRET_KEY").ok()?;
41        let webhook_secret =
42            std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_else(|_| String::new());
43        let pro_price_id = std::env::var("STRIPE_PRO_PRICE_ID").unwrap_or_else(|_| String::new());
44        let business_price_id =
45            std::env::var("STRIPE_BUSINESS_PRICE_ID").unwrap_or_else(|_| String::new());
46        let frontend_url =
47            std::env::var("FRONTEND_URL").unwrap_or_else(|_| "http://localhost:5173".to_string());
48
49        Some(Self {
50            stripe_secret_key: secret_key,
51            stripe_webhook_secret: webhook_secret,
52            pro_price_id,
53            business_price_id,
54            frontend_url,
55        })
56    }
57
58    /// Get the Stripe price ID for a given tier.
59    pub fn price_id_for_tier(&self, tier: &Tier) -> Option<&str> {
60        match tier {
61            Tier::Pro if !self.pro_price_id.is_empty() => Some(&self.pro_price_id),
62            Tier::Business if !self.business_price_id.is_empty() => Some(&self.business_price_id),
63            _ => None,
64        }
65    }
66
67    /// Determine tier from a Stripe price ID.
68    pub fn tier_for_price_id(&self, price_id: &str) -> Option<Tier> {
69        if !self.pro_price_id.is_empty() && price_id == self.pro_price_id {
70            Some(Tier::Pro)
71        } else if !self.business_price_id.is_empty() && price_id == self.business_price_id {
72            Some(Tier::Business)
73        } else {
74            None
75        }
76    }
77}
78
79// ---------------------------------------------------------------------------
80// Tier
81// ---------------------------------------------------------------------------
82
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
84#[serde(rename_all = "lowercase")]
85pub enum Tier {
86    Free,
87    Pro,
88    Business,
89    Enterprise,
90}
91
92impl Tier {
93    pub const fn event_limit(&self) -> Option<i64> {
94        match self {
95            Self::Free => Some(100_000),
96            Self::Pro => Some(10_000_000),
97            Self::Business => Some(100_000_000),
98            Self::Enterprise => None,
99        }
100    }
101
102    pub const fn display_name(&self) -> &str {
103        match self {
104            Self::Free => "Free",
105            Self::Pro => "Pro ($49/mo)",
106            Self::Business => "Business ($199/mo)",
107            Self::Enterprise => "Enterprise",
108        }
109    }
110}
111
112impl std::fmt::Display for Tier {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            Self::Free => write!(f, "free"),
116            Self::Pro => write!(f, "pro"),
117            Self::Business => write!(f, "business"),
118            Self::Enterprise => write!(f, "enterprise"),
119        }
120    }
121}
122
123impl std::str::FromStr for Tier {
124    type Err = String;
125
126    fn from_str(s: &str) -> Result<Self, Self::Err> {
127        match s {
128            "free" => Ok(Self::Free),
129            "pro" => Ok(Self::Pro),
130            "business" => Ok(Self::Business),
131            "enterprise" => Ok(Self::Enterprise),
132            other => Err(format!("unknown tier: {other}")),
133        }
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Usage tracking
139// ---------------------------------------------------------------------------
140
141/// In-memory buffer for event counts, flushed to DB periodically.
142#[derive(Debug)]
143pub struct UsageTracker {
144    buffer: HashMap<Uuid, i64>,
145    /// Cached DB monthly totals (reloaded every 60s during flush cycle).
146    monthly_totals: HashMap<Uuid, i64>,
147    /// Cached per-org monthly limits (loaded from DB).
148    monthly_limits: HashMap<Uuid, i64>,
149}
150
151impl Default for UsageTracker {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157impl UsageTracker {
158    pub fn new() -> Self {
159        Self {
160            buffer: HashMap::new(),
161            monthly_totals: HashMap::new(),
162            monthly_limits: HashMap::new(),
163        }
164    }
165
166    pub fn record_events(&mut self, org_id: Uuid, count: i64) {
167        *self.buffer.entry(org_id).or_insert(0) += count;
168    }
169
170    /// Drain all buffered counts, returning `(org_id, event_count)` pairs.
171    pub fn drain(&mut self) -> Vec<(Uuid, i64)> {
172        self.buffer.drain().collect()
173    }
174
175    pub fn get(&self, org_id: &Uuid) -> i64 {
176        self.buffer.get(org_id).copied().unwrap_or(0)
177    }
178
179    /// Update the cached monthly total for an org (from DB).
180    pub fn set_monthly_total(&mut self, org_id: Uuid, total: i64) {
181        self.monthly_totals.insert(org_id, total);
182    }
183
184    /// Get the cached monthly total.
185    pub fn get_monthly_total(&self, org_id: &Uuid) -> Option<i64> {
186        self.monthly_totals.get(org_id).copied()
187    }
188
189    /// Update the cached monthly limit for an org.
190    pub fn set_monthly_limit(&mut self, org_id: Uuid, limit: i64) {
191        self.monthly_limits.insert(org_id, limit);
192    }
193
194    /// Get the cached monthly limit.
195    pub fn get_monthly_limit(&self, org_id: &Uuid) -> Option<i64> {
196        self.monthly_limits.get(org_id).copied()
197    }
198
199    /// Fast-path check: can this org process `additional` more events?
200    /// Returns None if cache miss (caller should fall through to DB query).
201    pub fn check_cached_limit(&self, org_id: &Uuid, additional: i64) -> Option<bool> {
202        let total = self.monthly_totals.get(org_id)?;
203        let limit = self.monthly_limits.get(org_id)?;
204        let buffered = self.buffer.get(org_id).copied().unwrap_or(0);
205        Some(total + buffered + additional <= *limit)
206    }
207}
208
209// ---------------------------------------------------------------------------
210// State
211// ---------------------------------------------------------------------------
212
213#[derive(Debug)]
214pub struct BillingState {
215    pub config: BillingConfig,
216    pub usage: RwLock<UsageTracker>,
217    pub http_client: reqwest::Client,
218    #[cfg(feature = "saas")]
219    pub db_pool: Option<varpulis_db::PgPool>,
220    pub audit_logger: Option<SharedAuditLogger>,
221}
222
223impl BillingState {
224    pub fn new(config: BillingConfig) -> Self {
225        Self {
226            config,
227            usage: RwLock::new(UsageTracker::new()),
228            http_client: reqwest::Client::new(),
229            #[cfg(feature = "saas")]
230            db_pool: None,
231            audit_logger: None,
232        }
233    }
234
235    pub fn with_audit_logger(mut self, logger: Option<SharedAuditLogger>) -> Self {
236        self.audit_logger = logger;
237        self
238    }
239
240    #[cfg(feature = "saas")]
241    pub fn with_db_pool(mut self, pool: varpulis_db::PgPool) -> Self {
242        self.db_pool = Some(pool);
243        self
244    }
245}
246
247pub type SharedBillingState = Arc<BillingState>;
248
249// ---------------------------------------------------------------------------
250// Usage limit enforcement
251// ---------------------------------------------------------------------------
252
253/// Error returned when usage exceeds tier limits.
254#[derive(Debug, Serialize)]
255pub struct UsageLimitExceeded {
256    pub tier: Tier,
257    pub limit: i64,
258    pub current_usage: i64,
259    pub message: String,
260}
261
262/// Result of a usage limit check: Ok with optional warning, or Err if exceeded.
263#[derive(Debug)]
264pub enum UsageCheckResult {
265    /// Within limits, no warning.
266    Ok,
267    /// Within limits but approaching (>80%).
268    ApproachingLimit { usage_percent: f64 },
269    /// Over limit.
270    Exceeded(UsageLimitExceeded),
271}
272
273impl BillingState {
274    /// Check if the org can process `additional_events` more events this month.
275    /// Returns `UsageCheckResult::Ok` or `::ApproachingLimit` if within limit,
276    /// `::Exceeded` if over.
277    #[cfg(feature = "saas")]
278    pub async fn check_usage_limit(
279        &self,
280        org_id: Uuid,
281        additional_events: i64,
282    ) -> UsageCheckResult {
283        // 1. Get org tier from DB
284        let tier = if let Some(ref pool) = self.db_pool {
285            if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
286                org.tier.parse().unwrap_or(Tier::Free)
287            } else {
288                Tier::Free
289            }
290        } else {
291            Tier::Free
292        };
293
294        // Enterprise = no limit
295        let limit = match tier.event_limit() {
296            Some(l) => l,
297            None => return UsageCheckResult::Ok,
298        };
299
300        // 2. Get current month usage from DB
301        let db_usage = if let Some(ref pool) = self.db_pool {
302            let today = chrono::Utc::now().date_naive();
303            let start =
304                chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1).unwrap_or(today);
305            if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await {
306                rows.iter().map(|r| r.events_processed).sum::<i64>()
307            } else {
308                0
309            }
310        } else {
311            0
312        };
313
314        // 3. Add in-memory buffer (not yet flushed to DB)
315        let buffered = self.usage.read().await.get(&org_id);
316        let total = db_usage + buffered + additional_events;
317
318        if total > limit {
319            UsageCheckResult::Exceeded(UsageLimitExceeded {
320                tier: tier.clone(),
321                limit,
322                current_usage: db_usage + buffered,
323                message: format!(
324                    "Usage limit exceeded for {} tier ({}/{} events this month). Upgrade to increase your limit.",
325                    tier.display_name(),
326                    db_usage + buffered,
327                    limit,
328                ),
329            })
330        } else {
331            let usage_percent = (total as f64 / limit as f64) * 100.0;
332            if usage_percent >= 80.0 {
333                UsageCheckResult::ApproachingLimit { usage_percent }
334            } else {
335                UsageCheckResult::Ok
336            }
337        }
338    }
339
340    /// Look up org_id for a raw API key from the database.
341    /// Hashes the key with SHA-256 and looks up the hash.
342    #[cfg(feature = "saas")]
343    pub async fn org_id_for_api_key(&self, raw_key: &str) -> Option<Uuid> {
344        use sha2::Digest;
345        let pool = self.db_pool.as_ref()?;
346        let hash = hex::encode(sha2::Sha256::digest(raw_key.as_bytes()));
347        let api_key = varpulis_db::repo::get_api_key_by_hash(pool, &hash)
348            .await
349            .ok()??;
350        Some(api_key.org_id)
351    }
352}
353
354/// Build a 429 Too Many Requests response from a usage limit error.
355pub fn usage_limit_response(err: &UsageLimitExceeded) -> Response {
356    let body = Json(serde_json::json!({
357        "error": "usage_limit_exceeded",
358        "message": err.message,
359        "tier": err.tier,
360        "limit": err.limit,
361        "current_usage": err.current_usage,
362        "upgrade_url": "/billing",
363    }));
364
365    (
366        StatusCode::TOO_MANY_REQUESTS,
367        [("Retry-After", "3600")],
368        body,
369    )
370        .into_response()
371}
372
373// ---------------------------------------------------------------------------
374// Usage flush task
375// ---------------------------------------------------------------------------
376
377/// Spawn a background task that flushes in-memory usage counters to the DB every 60s
378/// and reloads cached monthly totals/limits.
379#[cfg(feature = "saas")]
380pub fn spawn_usage_flush(state: SharedBillingState, pool: varpulis_db::PgPool) {
381    tokio::spawn(async move {
382        let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
383        loop {
384            interval.tick().await;
385            let entries = state.usage.write().await.drain();
386            let today = chrono::Utc::now().date_naive();
387            for (org_id, count) in &entries {
388                if let Err(e) =
389                    varpulis_db::repo::record_usage(&pool, *org_id, today, *count, 0).await
390                {
391                    tracing::error!("Failed to flush usage for org {}: {}", org_id, e);
392                }
393            }
394
395            // Reload monthly totals and limits from DB for cache
396            if let Ok(orgs) = varpulis_db::repo::list_all_organizations(&pool).await {
397                let mut tracker = state.usage.write().await;
398                for org in &orgs {
399                    if let Ok(total) = varpulis_db::repo::get_org_usage_summary(&pool, org.id).await
400                    {
401                        tracker.set_monthly_total(org.id, total);
402                    }
403                    tracker.set_monthly_limit(org.id, org.monthly_event_limit);
404                }
405            }
406
407            if !entries.is_empty() {
408                tracing::debug!("Usage flush complete ({} orgs)", entries.len());
409            }
410        }
411    });
412}
413
414// ---------------------------------------------------------------------------
415// Stripe helpers
416// ---------------------------------------------------------------------------
417
418/// Call the Stripe API with form-encoded body.
419async fn stripe_post(
420    client: &reqwest::Client,
421    secret_key: &str,
422    endpoint: &str,
423    params: &[(&str, &str)],
424) -> Result<serde_json::Value, String> {
425    let resp = client
426        .post(format!("https://api.stripe.com/v1/{endpoint}"))
427        .basic_auth(secret_key, None::<&str>)
428        .form(params)
429        .send()
430        .await
431        .map_err(|e| format!("Stripe request failed: {e}"))?;
432
433    let status = resp.status();
434    let body: serde_json::Value = resp
435        .json()
436        .await
437        .map_err(|e| format!("Stripe response parse failed: {e}"))?;
438
439    if !status.is_success() {
440        let msg = body["error"]["message"]
441            .as_str()
442            .unwrap_or("Unknown Stripe error");
443        return Err(format!("Stripe API error ({status}): {msg}"));
444    }
445
446    Ok(body)
447}
448
449/// Verify Stripe webhook signature (HMAC-SHA256).
450fn verify_stripe_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
451    use hmac::{Hmac, Mac};
452    use sha2::Sha256;
453
454    // Parse signature header: "t=timestamp,v1=signature"
455    let mut timestamp = "";
456    let mut signature = "";
457    for part in sig_header.split(',') {
458        if let Some(t) = part.strip_prefix("t=") {
459            timestamp = t;
460        } else if let Some(s) = part.strip_prefix("v1=") {
461            signature = s;
462        }
463    }
464
465    if timestamp.is_empty() || signature.is_empty() {
466        return false;
467    }
468
469    // Compute expected signature
470    let signed_payload = format!(
471        "{}.{}",
472        timestamp,
473        std::str::from_utf8(payload).unwrap_or("")
474    );
475    let mut mac =
476        Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key size");
477    hmac::Mac::update(&mut mac, signed_payload.as_bytes());
478    let expected = hex::encode(mac.finalize().into_bytes());
479
480    // Constant-time comparison
481    expected == signature
482}
483
484// ---------------------------------------------------------------------------
485// Route handlers
486// ---------------------------------------------------------------------------
487
488/// GET /api/v1/billing/usage — get usage summary.
489async fn handle_usage(
490    State(state): State<Option<SharedBillingState>>,
491    headers: HeaderMap,
492) -> Response {
493    let auth_header = headers
494        .get("authorization")
495        .and_then(|v| v.to_str().ok())
496        .map(|s| s.to_string());
497
498    match state {
499        Some(s) => {
500            // Try DB first when saas is enabled
501            #[cfg(feature = "saas")]
502            if let Some(ref pool) = s.db_pool {
503                if let Some(org_id) = extract_org_id_from_header(&auth_header, &s) {
504                    let today = chrono::Utc::now().date_naive();
505                    let start = chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1)
506                        .unwrap_or(today);
507                    if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await
508                    {
509                        let total: i64 = rows.iter().map(|r| r.events_processed).sum();
510                        return (
511                            StatusCode::OK,
512                            Json(serde_json::json!({
513                                "events_this_month": total,
514                                "daily": rows.iter().map(|r| serde_json::json!({
515                                    "date": r.date.to_string(),
516                                    "events_processed": r.events_processed,
517                                })).collect::<Vec<_>>(),
518                            })),
519                        )
520                            .into_response();
521                    }
522                }
523            }
524
525            // Fallback: in-memory buffer
526            let _ = auth_header;
527            let tracker = s.usage.read().await;
528            let orgs: Vec<serde_json::Value> = tracker
529                .buffer
530                .iter()
531                .map(|(org_id, count)| {
532                    serde_json::json!({
533                        "org_id": org_id.to_string(),
534                        "events_today": count,
535                    })
536                })
537                .collect();
538            (StatusCode::OK, Json(serde_json::json!({ "usage": orgs }))).into_response()
539        }
540        None => (
541            StatusCode::SERVICE_UNAVAILABLE,
542            Json(serde_json::json!({ "error": "Billing not configured" })),
543        )
544            .into_response(),
545    }
546}
547
548/// GET /api/v1/billing/plan — get current plan.
549async fn handle_plan(
550    State(state): State<Option<SharedBillingState>>,
551    headers: HeaderMap,
552) -> Response {
553    let auth_header = headers
554        .get("authorization")
555        .and_then(|v| v.to_str().ok())
556        .map(|s| s.to_string());
557
558    match state {
559        Some(_s) => {
560            // Try DB for real plan when saas enabled
561            #[cfg(feature = "saas")]
562            if let Some(ref pool) = _s.db_pool {
563                if let Some(org_id) = extract_org_id_from_header(&auth_header, &_s) {
564                    if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
565                        let tier: Tier = org.tier.parse().unwrap_or(Tier::Free);
566                        return (
567                            StatusCode::OK,
568                            Json(serde_json::json!({
569                                "tier": org.tier,
570                                "event_limit": tier.event_limit(),
571                                "display_name": tier.display_name(),
572                            })),
573                        )
574                            .into_response();
575                    }
576                }
577            }
578
579            // Fallback: hardcoded free
580            let _ = auth_header;
581            (
582                StatusCode::OK,
583                Json(serde_json::json!({
584                    "tier": "free",
585                    "event_limit": 10_000,
586                    "display_name": "Free",
587                })),
588            )
589                .into_response()
590        }
591        None => (
592            StatusCode::SERVICE_UNAVAILABLE,
593            Json(serde_json::json!({ "error": "Billing not configured" })),
594        )
595            .into_response(),
596    }
597}
598
599#[derive(Debug, Deserialize)]
600struct CheckoutRequest {
601    /// Target tier: "pro" or "business". Defaults to "pro".
602    tier: Option<String>,
603    success_url: Option<String>,
604    cancel_url: Option<String>,
605}
606
607/// POST /api/v1/billing/checkout — create Stripe Checkout session.
608async fn handle_checkout(
609    State(state): State<Option<SharedBillingState>>,
610    headers: HeaderMap,
611    Json(body): Json<CheckoutRequest>,
612) -> Response {
613    let auth_header = headers
614        .get("authorization")
615        .and_then(|v| v.to_str().ok())
616        .map(|s| s.to_string());
617
618    match state {
619        Some(s) => {
620            // Determine target tier from request body
621            let target_tier: Tier = body
622                .tier
623                .as_deref()
624                .unwrap_or("pro")
625                .parse()
626                .unwrap_or(Tier::Pro);
627
628            let price_id = match s.config.price_id_for_tier(&target_tier) {
629                Some(id) => id.to_string(),
630                None => {
631                    return (
632                        StatusCode::BAD_REQUEST,
633                        Json(serde_json::json!({
634                            "error": format!("Stripe Price ID not configured for {} tier", target_tier)
635                        })),
636                    )
637                        .into_response();
638                }
639            };
640
641            let success_url = body
642                .success_url
643                .unwrap_or_else(|| format!("{}/billing?success=true", s.config.frontend_url));
644            let cancel_url = body
645                .cancel_url
646                .unwrap_or_else(|| format!("{}/billing", s.config.frontend_url));
647
648            // Build Stripe Checkout params
649            let org_id_str = extract_org_id_str_from_header(&auth_header, &s).unwrap_or_default();
650
651            let mut params: Vec<(&str, &str)> = vec![
652                ("mode", "subscription"),
653                ("line_items[0][price]", &price_id),
654                ("line_items[0][quantity]", "1"),
655                ("success_url", &success_url),
656                ("cancel_url", &cancel_url),
657            ];
658
659            if !org_id_str.is_empty() {
660                params.push(("client_reference_id", &org_id_str));
661            }
662
663            // Look up existing Stripe customer
664            #[allow(unused_mut)]
665            let mut customer_id = String::new();
666            #[cfg(feature = "saas")]
667            if let Some(ref pool) = s.db_pool {
668                if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
669                    if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
670                    {
671                        if let Some(cid) = org.stripe_customer_id {
672                            customer_id = cid;
673                        }
674                    }
675                }
676            }
677
678            if !customer_id.is_empty() {
679                params.push(("customer", &customer_id));
680            }
681
682            match stripe_post(
683                &s.http_client,
684                &s.config.stripe_secret_key,
685                "checkout/sessions",
686                &params,
687            )
688            .await
689            {
690                Ok(session) => {
691                    let checkout_url = session["url"].as_str().unwrap_or("");
692                    let session_id = session["id"].as_str().unwrap_or("");
693                    // Audit log: checkout started
694                    if let Some(ref logger) = s.audit_logger {
695                        logger
696                            .log(
697                                AuditEntry::new(
698                                    &org_id_str,
699                                    AuditAction::CheckoutStarted,
700                                    "/api/v1/billing/checkout",
701                                )
702                                .with_detail(format!("session: {session_id}")),
703                            )
704                            .await;
705                    }
706                    (
707                        StatusCode::OK,
708                        Json(serde_json::json!({
709                            "checkout_url": checkout_url,
710                            "session_id": session_id,
711                        })),
712                    )
713                        .into_response()
714                }
715                Err(e) => {
716                    tracing::error!("Stripe checkout failed: {}", e);
717                    (
718                        StatusCode::BAD_GATEWAY,
719                        Json(serde_json::json!({"error": e})),
720                    )
721                        .into_response()
722                }
723            }
724        }
725        None => (
726            StatusCode::SERVICE_UNAVAILABLE,
727            Json(serde_json::json!({ "error": "Billing not configured" })),
728        )
729            .into_response(),
730    }
731}
732
733/// POST /api/v1/billing/portal — create Stripe Customer Portal session.
734async fn handle_portal(
735    State(state): State<Option<SharedBillingState>>,
736    headers: HeaderMap,
737) -> Response {
738    let auth_header = headers
739        .get("authorization")
740        .and_then(|v| v.to_str().ok())
741        .map(|s| s.to_string());
742
743    match state {
744        Some(s) => {
745            #[allow(unused_mut)]
746            let mut customer_id = String::new();
747
748            #[cfg(feature = "saas")]
749            if let Some(ref pool) = s.db_pool {
750                if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
751                    if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
752                    {
753                        if let Some(cid) = org.stripe_customer_id {
754                            customer_id = cid;
755                        }
756                    }
757                }
758            }
759            let _ = auth_header;
760
761            if customer_id.is_empty() {
762                return (
763                    StatusCode::BAD_REQUEST,
764                    Json(serde_json::json!({
765                        "error": "No Stripe customer found. Upgrade first."
766                    })),
767                )
768                    .into_response();
769            }
770
771            let return_url = format!("{}/billing", s.config.frontend_url);
772            match stripe_post(
773                &s.http_client,
774                &s.config.stripe_secret_key,
775                "billing_portal/sessions",
776                &[("customer", &customer_id), ("return_url", &return_url)],
777            )
778            .await
779            {
780                Ok(session) => {
781                    let portal_url = session["url"].as_str().unwrap_or("");
782                    (
783                        StatusCode::OK,
784                        Json(serde_json::json!({
785                            "portal_url": portal_url,
786                        })),
787                    )
788                        .into_response()
789                }
790                Err(e) => {
791                    tracing::error!("Stripe portal failed: {}", e);
792                    (
793                        StatusCode::BAD_GATEWAY,
794                        Json(serde_json::json!({"error": e})),
795                    )
796                        .into_response()
797                }
798            }
799        }
800        None => (
801            StatusCode::SERVICE_UNAVAILABLE,
802            Json(serde_json::json!({ "error": "Billing not configured" })),
803        )
804            .into_response(),
805    }
806}
807
808/// POST /api/v1/billing/webhook — handle Stripe webhook events.
809async fn handle_webhook(
810    State(state): State<Option<SharedBillingState>>,
811    headers: HeaderMap,
812    body: bytes::Bytes,
813) -> Response {
814    let sig_header = headers
815        .get("stripe-signature")
816        .and_then(|v| v.to_str().ok())
817        .map(|s| s.to_string());
818
819    let s = match state {
820        Some(s) => s,
821        None => {
822            return (
823                StatusCode::SERVICE_UNAVAILABLE,
824                Json(serde_json::json!({"error": "Billing not configured"})),
825            )
826                .into_response();
827        }
828    };
829
830    // Verify signature
831    if !s.config.stripe_webhook_secret.is_empty() {
832        let sig = sig_header.unwrap_or_default();
833        if !verify_stripe_signature(&body, &sig, &s.config.stripe_webhook_secret) {
834            return (
835                StatusCode::BAD_REQUEST,
836                Json(serde_json::json!({"error": "Invalid signature"})),
837            )
838                .into_response();
839        }
840    }
841
842    // Parse event
843    let event: serde_json::Value = match serde_json::from_slice(&body) {
844        Ok(v) => v,
845        Err(e) => {
846            tracing::error!("Webhook parse error: {}", e);
847            return (
848                StatusCode::BAD_REQUEST,
849                Json(serde_json::json!({"error": "Invalid JSON"})),
850            )
851                .into_response();
852        }
853    };
854
855    let event_type = event["type"].as_str().unwrap_or("");
856    tracing::info!("Stripe webhook: {}", event_type);
857
858    // Audit log: webhook received
859    if let Some(ref logger) = s.audit_logger {
860        logger
861            .log(
862                AuditEntry::new(
863                    "stripe",
864                    AuditAction::WebhookReceived,
865                    "/api/v1/billing/webhook",
866                )
867                .with_detail(event_type.to_string()),
868            )
869            .await;
870    }
871
872    #[cfg(feature = "saas")]
873    if let Some(ref pool) = s.db_pool {
874        match event_type {
875            "checkout.session.completed" => {
876                let obj = &event["data"]["object"];
877                let customer = obj["customer"].as_str().unwrap_or("");
878                let client_ref = obj["client_reference_id"].as_str().unwrap_or("");
879
880                if !client_ref.is_empty() && !customer.is_empty() {
881                    if let Ok(org_id) = client_ref.parse::<uuid::Uuid>() {
882                        if let Err(e) =
883                            varpulis_db::repo::update_org_stripe_customer(pool, org_id, customer)
884                                .await
885                        {
886                            tracing::error!("Failed to save Stripe customer: {}", e);
887                        }
888
889                        // Determine tier from subscription price_id
890                        let subscription_id = obj["subscription"].as_str().unwrap_or("");
891                        let new_tier = if !subscription_id.is_empty() {
892                            // Fetch subscription to get price_id
893                            if let Ok(sub) = stripe_post(
894                                &s.http_client,
895                                &s.config.stripe_secret_key,
896                                &format!("subscriptions/{subscription_id}"),
897                                &[],
898                            )
899                            .await
900                            {
901                                let price_id = sub["items"]["data"][0]["price"]["id"]
902                                    .as_str()
903                                    .unwrap_or("");
904                                s.config.tier_for_price_id(price_id).unwrap_or(Tier::Pro)
905                            } else {
906                                Tier::Pro
907                            }
908                        } else {
909                            Tier::Pro
910                        };
911
912                        let tier_str = new_tier.to_string();
913                        if let Err(e) =
914                            varpulis_db::repo::update_org_tier(pool, org_id, &tier_str).await
915                        {
916                            tracing::error!("Failed to update tier: {}", e);
917                        }
918                        // Clear trial status on paid upgrade
919                        if let Err(e) =
920                            varpulis_db::repo::update_org_status(pool, org_id, "active").await
921                        {
922                            tracing::error!("Failed to update org status: {}", e);
923                        }
924                        tracing::info!(
925                            "Org {} upgraded to {} (customer: {})",
926                            org_id,
927                            tier_str,
928                            customer
929                        );
930                        // Audit log: tier upgrade
931                        if let Some(ref logger) = s.audit_logger {
932                            logger
933                                .log(
934                                    AuditEntry::new(
935                                        org_id.to_string(),
936                                        AuditAction::TierChange,
937                                        "/api/v1/billing/webhook",
938                                    )
939                                    .with_detail(format!("upgraded to {tier_str}")),
940                                )
941                                .await;
942                        }
943                    }
944                }
945            }
946            "customer.subscription.deleted" => {
947                let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
948                if !customer.is_empty() {
949                    if let Ok(Some(org)) =
950                        varpulis_db::repo::get_org_by_stripe_customer(pool, customer).await
951                    {
952                        if let Err(e) =
953                            varpulis_db::repo::update_org_tier(pool, org.id, "free").await
954                        {
955                            tracing::error!("Failed to downgrade org: {}", e);
956                        }
957                        tracing::info!("Org {} downgraded to free", org.id);
958                        // Audit log: tier downgrade
959                        if let Some(ref logger) = s.audit_logger {
960                            logger
961                                .log(
962                                    AuditEntry::new(
963                                        org.id.to_string(),
964                                        AuditAction::TierChange,
965                                        "/api/v1/billing/webhook",
966                                    )
967                                    .with_detail("downgraded to free"),
968                                )
969                                .await;
970                        }
971                    }
972                }
973            }
974            "customer.subscription.updated" => {
975                let obj = &event["data"]["object"];
976                let customer = obj["customer"].as_str().unwrap_or("");
977                let price_id = obj["items"]["data"][0]["price"]["id"]
978                    .as_str()
979                    .unwrap_or("");
980
981                if !customer.is_empty() && !price_id.is_empty() {
982                    if let Some(new_tier) = s.config.tier_for_price_id(price_id) {
983                        if let Ok(Some(org)) =
984                            varpulis_db::repo::get_org_by_stripe_customer(pool, customer).await
985                        {
986                            let tier_str = new_tier.to_string();
987                            if org.tier != tier_str {
988                                if let Err(e) =
989                                    varpulis_db::repo::update_org_tier(pool, org.id, &tier_str)
990                                        .await
991                                {
992                                    tracing::error!(
993                                        "Failed to update tier on subscription change: {}",
994                                        e
995                                    );
996                                } else {
997                                    tracing::info!(
998                                        "Org {} tier changed to {} via subscription update",
999                                        org.id,
1000                                        tier_str
1001                                    );
1002                                }
1003                            }
1004                        }
1005                    }
1006                }
1007            }
1008            "invoice.payment_failed" => {
1009                let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
1010                tracing::warn!("Payment failed for customer {}", customer);
1011            }
1012            _ => {
1013                tracing::debug!("Unhandled webhook event: {}", event_type);
1014            }
1015        }
1016    }
1017
1018    #[cfg(not(feature = "saas"))]
1019    {
1020        let _ = event_type;
1021        tracing::debug!("Webhook received but saas feature not enabled");
1022    }
1023
1024    (StatusCode::OK, Json(serde_json::json!({"received": true}))).into_response()
1025}
1026
1027// ---------------------------------------------------------------------------
1028// JWT claim extraction helpers
1029// ---------------------------------------------------------------------------
1030
1031/// Extract org_id UUID from Authorization header JWT.
1032#[cfg_attr(not(feature = "saas"), allow(dead_code))]
1033fn extract_org_id_from_header(
1034    auth_header: &Option<String>,
1035    state: &BillingState,
1036) -> Option<uuid::Uuid> {
1037    extract_org_id_str_from_header(auth_header, state)?
1038        .parse()
1039        .ok()
1040}
1041
1042/// Extract org_id string from Authorization header JWT.
1043fn extract_org_id_str_from_header(
1044    auth_header: &Option<String>,
1045    _state: &BillingState,
1046) -> Option<String> {
1047    let header = auth_header.as_ref()?;
1048    let token = header.strip_prefix("Bearer ")?.trim();
1049    if token.is_empty() {
1050        return None;
1051    }
1052
1053    // Decode JWT without full verification (billing state doesn't have JWT secret).
1054    // Use jsonwebtoken's dangerous decode to read claims without signature check.
1055    let token_data = jsonwebtoken::dangerous::insecure_decode::<serde_json::Value>(token).ok()?;
1056
1057    let org_id = token_data.claims["org_id"].as_str()?;
1058    if org_id.is_empty() {
1059        return None;
1060    }
1061    Some(org_id.to_string())
1062}
1063
1064// ---------------------------------------------------------------------------
1065// Route assembly
1066// ---------------------------------------------------------------------------
1067
1068/// Build billing routes. When `state` is None, endpoints return 503.
1069pub fn billing_routes(state: Option<SharedBillingState>) -> Router {
1070    Router::new()
1071        .route("/api/v1/billing/usage", get(handle_usage))
1072        .route("/api/v1/billing/plan", get(handle_plan))
1073        .route("/api/v1/billing/checkout", post(handle_checkout))
1074        .route("/api/v1/billing/portal", post(handle_portal))
1075        .route("/api/v1/billing/webhook", post(handle_webhook))
1076        .with_state(state)
1077}
1078
1079// ---------------------------------------------------------------------------
1080// Tests
1081// ---------------------------------------------------------------------------
1082
1083#[cfg(test)]
1084mod tests {
1085    use axum::body::Body;
1086    use axum::http::Request;
1087    use tower::ServiceExt;
1088
1089    use super::*;
1090
1091    fn get_req(uri: &str) -> Request<Body> {
1092        Request::builder()
1093            .method("GET")
1094            .uri(uri)
1095            .body(Body::empty())
1096            .unwrap()
1097    }
1098
1099    #[test]
1100    fn test_tier_event_limits() {
1101        assert_eq!(Tier::Free.event_limit(), Some(100_000));
1102        assert_eq!(Tier::Pro.event_limit(), Some(10_000_000));
1103        assert_eq!(Tier::Business.event_limit(), Some(100_000_000));
1104        assert_eq!(Tier::Enterprise.event_limit(), None);
1105    }
1106
1107    #[test]
1108    fn test_tier_display_name() {
1109        assert_eq!(Tier::Free.display_name(), "Free");
1110        assert_eq!(Tier::Pro.display_name(), "Pro ($49/mo)");
1111        assert_eq!(Tier::Business.display_name(), "Business ($199/mo)");
1112        assert_eq!(Tier::Enterprise.display_name(), "Enterprise");
1113    }
1114
1115    #[test]
1116    fn test_tier_from_str() {
1117        assert_eq!("free".parse::<Tier>(), Ok(Tier::Free));
1118        assert_eq!("pro".parse::<Tier>(), Ok(Tier::Pro));
1119        assert_eq!("business".parse::<Tier>(), Ok(Tier::Business));
1120        assert_eq!("enterprise".parse::<Tier>(), Ok(Tier::Enterprise));
1121        assert!("invalid".parse::<Tier>().is_err());
1122    }
1123
1124    #[test]
1125    fn test_tier_serialization() {
1126        let json = serde_json::to_string(&Tier::Business).unwrap();
1127        assert_eq!(json, "\"business\"");
1128        let deserialized: Tier = serde_json::from_str(&json).unwrap();
1129        assert_eq!(deserialized, Tier::Business);
1130    }
1131
1132    #[test]
1133    fn test_usage_tracker_record_and_drain() {
1134        let mut tracker = UsageTracker::new();
1135        let org = Uuid::new_v4();
1136
1137        tracker.record_events(org, 100);
1138        tracker.record_events(org, 50);
1139        assert_eq!(tracker.get(&org), 150);
1140
1141        let drained = tracker.drain();
1142        assert_eq!(drained.len(), 1);
1143        assert_eq!(drained[0], (org, 150));
1144
1145        // Buffer should be empty after drain
1146        assert_eq!(tracker.get(&org), 0);
1147    }
1148
1149    #[test]
1150    fn test_usage_tracker_multiple_orgs() {
1151        let mut tracker = UsageTracker::new();
1152        let org1 = Uuid::new_v4();
1153        let org2 = Uuid::new_v4();
1154
1155        tracker.record_events(org1, 100);
1156        tracker.record_events(org2, 200);
1157        tracker.record_events(org1, 50);
1158
1159        assert_eq!(tracker.get(&org1), 150);
1160        assert_eq!(tracker.get(&org2), 200);
1161    }
1162
1163    #[test]
1164    fn test_verify_stripe_signature() {
1165        use hmac::{Hmac, Mac};
1166        use sha2::Sha256;
1167
1168        let secret = "whsec_test123";
1169        let payload = b"{\"type\":\"test\"}";
1170        let timestamp = "1234567890";
1171
1172        // Compute expected signature
1173        let signed = format!("{}.{}", timestamp, std::str::from_utf8(payload).unwrap());
1174        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
1175        hmac::Mac::update(&mut mac, signed.as_bytes());
1176        let sig = hex::encode(mac.finalize().into_bytes());
1177
1178        let header = format!("t={timestamp},v1={sig}");
1179
1180        assert!(verify_stripe_signature(payload, &header, secret));
1181        assert!(!verify_stripe_signature(payload, &header, "wrong_secret"));
1182        assert!(!verify_stripe_signature(b"tampered", &header, secret));
1183    }
1184
1185    #[tokio::test]
1186    async fn test_billing_routes_not_configured() {
1187        let app = billing_routes(None);
1188
1189        let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1190
1191        assert_eq!(res.status(), 503);
1192    }
1193
1194    #[tokio::test]
1195    async fn test_billing_routes_usage() {
1196        let config = BillingConfig {
1197            stripe_secret_key: "sk_test_xxx".to_string(),
1198            stripe_webhook_secret: "whsec_xxx".to_string(),
1199            pro_price_id: "price_xxx".to_string(),
1200            business_price_id: "price_biz_xxx".to_string(),
1201            frontend_url: "http://localhost:5173".to_string(),
1202        };
1203        let state = Arc::new(BillingState::new(config));
1204        let app = billing_routes(Some(state));
1205
1206        let res = app.oneshot(get_req("/api/v1/billing/usage")).await.unwrap();
1207
1208        assert_eq!(res.status(), 200);
1209    }
1210
1211    #[tokio::test]
1212    async fn test_billing_routes_plan() {
1213        let config = BillingConfig {
1214            stripe_secret_key: "sk_test_xxx".to_string(),
1215            stripe_webhook_secret: "whsec_xxx".to_string(),
1216            pro_price_id: "price_xxx".to_string(),
1217            business_price_id: "price_biz_xxx".to_string(),
1218            frontend_url: "http://localhost:5173".to_string(),
1219        };
1220        let state = Arc::new(BillingState::new(config));
1221        let app = billing_routes(Some(state));
1222
1223        let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1224
1225        assert_eq!(res.status(), 200);
1226        let body = axum::body::to_bytes(res.into_body(), usize::MAX)
1227            .await
1228            .unwrap();
1229        let body: serde_json::Value = serde_json::from_slice(&body).unwrap();
1230        assert_eq!(body["tier"], "free");
1231    }
1232
1233    #[tokio::test]
1234    async fn test_webhook_invalid_signature() {
1235        let config = BillingConfig {
1236            stripe_secret_key: "sk_test_xxx".to_string(),
1237            stripe_webhook_secret: "whsec_real_secret".to_string(),
1238            pro_price_id: "price_xxx".to_string(),
1239            business_price_id: "price_biz_xxx".to_string(),
1240            frontend_url: "http://localhost:5173".to_string(),
1241        };
1242        let state = Arc::new(BillingState::new(config));
1243        let app = billing_routes(Some(state));
1244
1245        let req: Request<Body> = Request::builder()
1246            .method("POST")
1247            .uri("/api/v1/billing/webhook")
1248            .header("stripe-signature", "t=123,v1=bad")
1249            .header("content-type", "application/octet-stream")
1250            .body(Body::from("{\"type\":\"test\"}"))
1251            .unwrap();
1252        let res = app.oneshot(req).await.unwrap();
1253
1254        assert_eq!(res.status(), 400);
1255    }
1256
1257    #[test]
1258    fn test_usage_limit_exceeded_serialization() {
1259        let err = UsageLimitExceeded {
1260            tier: Tier::Free,
1261            limit: 100_000,
1262            current_usage: 100_500,
1263            message: "Usage limit exceeded for Free tier".to_string(),
1264        };
1265        let json = serde_json::to_value(&err).unwrap();
1266        assert_eq!(json["tier"], "free");
1267        assert_eq!(json["limit"], 100_000);
1268        assert_eq!(json["current_usage"], 100_500);
1269    }
1270
1271    #[test]
1272    fn test_usage_limit_response_status() {
1273        let err = UsageLimitExceeded {
1274            tier: Tier::Free,
1275            limit: 100_000,
1276            current_usage: 110_000,
1277            message: "Limit exceeded".to_string(),
1278        };
1279        let resp = usage_limit_response(&err);
1280        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
1281    }
1282
1283    #[test]
1284    fn test_tier_enterprise_no_limit() {
1285        // Enterprise has no event limit
1286        assert_eq!(Tier::Enterprise.event_limit(), None);
1287    }
1288
1289    #[test]
1290    fn test_usage_tracker_tracks_independently() {
1291        let mut tracker = UsageTracker::new();
1292        let org = Uuid::new_v4();
1293
1294        // Record events in multiple increments
1295        tracker.record_events(org, 5_000);
1296        tracker.record_events(org, 3_000);
1297        tracker.record_events(org, 2_001);
1298
1299        // Should be cumulative
1300        assert_eq!(tracker.get(&org), 10_001);
1301    }
1302}