Skip to main content

varpulis_cli/
billing.rs

1//! Stripe billing integration for Varpulis Cloud.
2//!
3//! Provides usage tracking, tier management, and Stripe Checkout/Portal
4//! integration via REST endpoints.
5
6use axum::extract::{Json, State};
7use axum::http::{HeaderMap, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::routing::{get, post};
10use axum::Router;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use uuid::Uuid;
16
17use crate::audit::{AuditAction, AuditEntry, SharedAuditLogger};
18
19#[cfg(feature = "saas")]
20use chrono::Datelike;
21
22// ---------------------------------------------------------------------------
23// Configuration
24// ---------------------------------------------------------------------------
25
26/// Billing configuration loaded from environment variables.
27#[derive(Debug, Clone)]
28pub struct BillingConfig {
29    pub stripe_secret_key: String,
30    pub stripe_webhook_secret: String,
31    pub pro_price_id: String,
32    pub frontend_url: String,
33}
34
35impl BillingConfig {
36    /// Build config from environment variables.
37    /// Returns None if Stripe is not configured.
38    pub fn from_env() -> Option<Self> {
39        let secret_key = std::env::var("STRIPE_SECRET_KEY").ok()?;
40        let webhook_secret =
41            std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_else(|_| String::new());
42        let pro_price_id = std::env::var("STRIPE_PRO_PRICE_ID").unwrap_or_else(|_| String::new());
43        let frontend_url =
44            std::env::var("FRONTEND_URL").unwrap_or_else(|_| "http://localhost:5173".to_string());
45
46        Some(Self {
47            stripe_secret_key: secret_key,
48            stripe_webhook_secret: webhook_secret,
49            pro_price_id,
50            frontend_url,
51        })
52    }
53}
54
55// ---------------------------------------------------------------------------
56// Tier
57// ---------------------------------------------------------------------------
58
59#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
60#[serde(rename_all = "lowercase")]
61pub enum Tier {
62    Free,
63    Pro,
64    Enterprise,
65}
66
67impl Tier {
68    pub const fn event_limit(&self) -> Option<i64> {
69        match self {
70            Self::Free => Some(10_000),
71            Self::Pro => Some(10_000_000),
72            Self::Enterprise => None,
73        }
74    }
75
76    pub const fn display_name(&self) -> &str {
77        match self {
78            Self::Free => "Free",
79            Self::Pro => "Pro ($49/mo)",
80            Self::Enterprise => "Enterprise",
81        }
82    }
83}
84
85impl std::fmt::Display for Tier {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            Self::Free => write!(f, "free"),
89            Self::Pro => write!(f, "pro"),
90            Self::Enterprise => write!(f, "enterprise"),
91        }
92    }
93}
94
95impl std::str::FromStr for Tier {
96    type Err = String;
97
98    fn from_str(s: &str) -> Result<Self, Self::Err> {
99        match s {
100            "free" => Ok(Self::Free),
101            "pro" => Ok(Self::Pro),
102            "enterprise" => Ok(Self::Enterprise),
103            other => Err(format!("unknown tier: {other}")),
104        }
105    }
106}
107
108// ---------------------------------------------------------------------------
109// Usage tracking
110// ---------------------------------------------------------------------------
111
112/// In-memory buffer for event counts, flushed to DB periodically.
113#[derive(Debug)]
114pub struct UsageTracker {
115    buffer: HashMap<Uuid, i64>,
116}
117
118impl Default for UsageTracker {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl UsageTracker {
125    pub fn new() -> Self {
126        Self {
127            buffer: HashMap::new(),
128        }
129    }
130
131    pub fn record_events(&mut self, org_id: Uuid, count: i64) {
132        *self.buffer.entry(org_id).or_insert(0) += count;
133    }
134
135    /// Drain all buffered counts, returning `(org_id, event_count)` pairs.
136    pub fn drain(&mut self) -> Vec<(Uuid, i64)> {
137        self.buffer.drain().collect()
138    }
139
140    pub fn get(&self, org_id: &Uuid) -> i64 {
141        self.buffer.get(org_id).copied().unwrap_or(0)
142    }
143}
144
145// ---------------------------------------------------------------------------
146// State
147// ---------------------------------------------------------------------------
148
149#[derive(Debug)]
150pub struct BillingState {
151    pub config: BillingConfig,
152    pub usage: RwLock<UsageTracker>,
153    pub http_client: reqwest::Client,
154    #[cfg(feature = "saas")]
155    pub db_pool: Option<varpulis_db::PgPool>,
156    pub audit_logger: Option<SharedAuditLogger>,
157}
158
159impl BillingState {
160    pub fn new(config: BillingConfig) -> Self {
161        Self {
162            config,
163            usage: RwLock::new(UsageTracker::new()),
164            http_client: reqwest::Client::new(),
165            #[cfg(feature = "saas")]
166            db_pool: None,
167            audit_logger: None,
168        }
169    }
170
171    pub fn with_audit_logger(mut self, logger: Option<SharedAuditLogger>) -> Self {
172        self.audit_logger = logger;
173        self
174    }
175
176    #[cfg(feature = "saas")]
177    pub fn with_db_pool(mut self, pool: varpulis_db::PgPool) -> Self {
178        self.db_pool = Some(pool);
179        self
180    }
181}
182
183pub type SharedBillingState = Arc<BillingState>;
184
185// ---------------------------------------------------------------------------
186// Usage limit enforcement
187// ---------------------------------------------------------------------------
188
189/// Error returned when usage exceeds tier limits.
190#[derive(Debug, Serialize)]
191pub struct UsageLimitExceeded {
192    pub tier: Tier,
193    pub limit: i64,
194    pub current_usage: i64,
195    pub message: String,
196}
197
198impl BillingState {
199    /// Check if the org can process `additional_events` more events this month.
200    /// Returns `Ok(())` if within limit, `Err(UsageLimitExceeded)` if over.
201    #[cfg(feature = "saas")]
202    pub async fn check_usage_limit(
203        &self,
204        org_id: Uuid,
205        additional_events: i64,
206    ) -> Result<(), UsageLimitExceeded> {
207        // 1. Get org tier from DB
208        let tier = if let Some(ref pool) = self.db_pool {
209            if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
210                org.tier.parse().unwrap_or(Tier::Free)
211            } else {
212                Tier::Free
213            }
214        } else {
215            Tier::Free
216        };
217
218        // Enterprise = no limit
219        let limit = match tier.event_limit() {
220            Some(l) => l,
221            None => return Ok(()),
222        };
223
224        // 2. Get current month usage from DB
225        let db_usage = if let Some(ref pool) = self.db_pool {
226            let today = chrono::Utc::now().date_naive();
227            let start =
228                chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1).unwrap_or(today);
229            if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await {
230                rows.iter().map(|r| r.events_processed).sum::<i64>()
231            } else {
232                0
233            }
234        } else {
235            0
236        };
237
238        // 3. Add in-memory buffer (not yet flushed to DB)
239        let buffered = self.usage.read().await.get(&org_id);
240        let total = db_usage + buffered + additional_events;
241
242        if total > limit {
243            Err(UsageLimitExceeded {
244                tier: tier.clone(),
245                limit,
246                current_usage: db_usage + buffered,
247                message: format!(
248                    "Usage limit exceeded for {} tier ({}/{} events this month). Upgrade to increase your limit.",
249                    tier.display_name(),
250                    db_usage + buffered,
251                    limit,
252                ),
253            })
254        } else {
255            Ok(())
256        }
257    }
258
259    /// Look up org_id for a raw API key from the database.
260    /// Hashes the key with SHA-256 and looks up the hash.
261    #[cfg(feature = "saas")]
262    pub async fn org_id_for_api_key(&self, raw_key: &str) -> Option<Uuid> {
263        use sha2::Digest;
264        let pool = self.db_pool.as_ref()?;
265        let hash = hex::encode(sha2::Sha256::digest(raw_key.as_bytes()));
266        let api_key = varpulis_db::repo::get_api_key_by_hash(pool, &hash)
267            .await
268            .ok()??;
269        Some(api_key.org_id)
270    }
271}
272
273/// Build a 429 Too Many Requests response from a usage limit error.
274pub fn usage_limit_response(err: &UsageLimitExceeded) -> Response {
275    let body = Json(serde_json::json!({
276        "error": "usage_limit_exceeded",
277        "message": err.message,
278        "tier": err.tier,
279        "limit": err.limit,
280        "current_usage": err.current_usage,
281        "upgrade_url": "/billing",
282    }));
283
284    (
285        StatusCode::TOO_MANY_REQUESTS,
286        [("Retry-After", "3600")],
287        body,
288    )
289        .into_response()
290}
291
292// ---------------------------------------------------------------------------
293// Usage flush task
294// ---------------------------------------------------------------------------
295
296/// Spawn a background task that flushes in-memory usage counters to the DB every 60s.
297#[cfg(feature = "saas")]
298pub fn spawn_usage_flush(state: SharedBillingState, pool: varpulis_db::PgPool) {
299    tokio::spawn(async move {
300        let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
301        loop {
302            interval.tick().await;
303            let entries = state.usage.write().await.drain();
304            if entries.is_empty() {
305                continue;
306            }
307            let today = chrono::Utc::now().date_naive();
308            for (org_id, count) in entries {
309                if let Err(e) =
310                    varpulis_db::repo::record_usage(&pool, org_id, today, count, 0).await
311                {
312                    tracing::error!("Failed to flush usage for org {}: {}", org_id, e);
313                }
314            }
315            tracing::debug!("Usage flush complete");
316        }
317    });
318}
319
320// ---------------------------------------------------------------------------
321// Stripe helpers
322// ---------------------------------------------------------------------------
323
324/// Call the Stripe API with form-encoded body.
325async fn stripe_post(
326    client: &reqwest::Client,
327    secret_key: &str,
328    endpoint: &str,
329    params: &[(&str, &str)],
330) -> Result<serde_json::Value, String> {
331    let resp = client
332        .post(format!("https://api.stripe.com/v1/{endpoint}"))
333        .basic_auth(secret_key, None::<&str>)
334        .form(params)
335        .send()
336        .await
337        .map_err(|e| format!("Stripe request failed: {e}"))?;
338
339    let status = resp.status();
340    let body: serde_json::Value = resp
341        .json()
342        .await
343        .map_err(|e| format!("Stripe response parse failed: {e}"))?;
344
345    if !status.is_success() {
346        let msg = body["error"]["message"]
347            .as_str()
348            .unwrap_or("Unknown Stripe error");
349        return Err(format!("Stripe API error ({status}): {msg}"));
350    }
351
352    Ok(body)
353}
354
355/// Verify Stripe webhook signature (HMAC-SHA256).
356fn verify_stripe_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
357    use hmac::{Hmac, Mac};
358    use sha2::Sha256;
359
360    // Parse signature header: "t=timestamp,v1=signature"
361    let mut timestamp = "";
362    let mut signature = "";
363    for part in sig_header.split(',') {
364        if let Some(t) = part.strip_prefix("t=") {
365            timestamp = t;
366        } else if let Some(s) = part.strip_prefix("v1=") {
367            signature = s;
368        }
369    }
370
371    if timestamp.is_empty() || signature.is_empty() {
372        return false;
373    }
374
375    // Compute expected signature
376    let signed_payload = format!(
377        "{}.{}",
378        timestamp,
379        std::str::from_utf8(payload).unwrap_or("")
380    );
381    let mut mac =
382        Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key size");
383    hmac::Mac::update(&mut mac, signed_payload.as_bytes());
384    let expected = hex::encode(mac.finalize().into_bytes());
385
386    // Constant-time comparison
387    expected == signature
388}
389
390// ---------------------------------------------------------------------------
391// Route handlers
392// ---------------------------------------------------------------------------
393
394/// GET /api/v1/billing/usage — get usage summary.
395async fn handle_usage(
396    State(state): State<Option<SharedBillingState>>,
397    headers: HeaderMap,
398) -> Response {
399    let auth_header = headers
400        .get("authorization")
401        .and_then(|v| v.to_str().ok())
402        .map(|s| s.to_string());
403
404    match state {
405        Some(s) => {
406            // Try DB first when saas is enabled
407            #[cfg(feature = "saas")]
408            if let Some(ref pool) = s.db_pool {
409                if let Some(org_id) = extract_org_id_from_header(&auth_header, &s) {
410                    let today = chrono::Utc::now().date_naive();
411                    let start = chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1)
412                        .unwrap_or(today);
413                    if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await
414                    {
415                        let total: i64 = rows.iter().map(|r| r.events_processed).sum();
416                        return (
417                            StatusCode::OK,
418                            Json(serde_json::json!({
419                                "events_this_month": total,
420                                "daily": rows.iter().map(|r| serde_json::json!({
421                                    "date": r.date.to_string(),
422                                    "events_processed": r.events_processed,
423                                })).collect::<Vec<_>>(),
424                            })),
425                        )
426                            .into_response();
427                    }
428                }
429            }
430
431            // Fallback: in-memory buffer
432            let _ = auth_header;
433            let tracker = s.usage.read().await;
434            let orgs: Vec<serde_json::Value> = tracker
435                .buffer
436                .iter()
437                .map(|(org_id, count)| {
438                    serde_json::json!({
439                        "org_id": org_id.to_string(),
440                        "events_today": count,
441                    })
442                })
443                .collect();
444            (StatusCode::OK, Json(serde_json::json!({ "usage": orgs }))).into_response()
445        }
446        None => (
447            StatusCode::SERVICE_UNAVAILABLE,
448            Json(serde_json::json!({ "error": "Billing not configured" })),
449        )
450            .into_response(),
451    }
452}
453
454/// GET /api/v1/billing/plan — get current plan.
455async fn handle_plan(
456    State(state): State<Option<SharedBillingState>>,
457    headers: HeaderMap,
458) -> Response {
459    let auth_header = headers
460        .get("authorization")
461        .and_then(|v| v.to_str().ok())
462        .map(|s| s.to_string());
463
464    match state {
465        Some(_s) => {
466            // Try DB for real plan when saas enabled
467            #[cfg(feature = "saas")]
468            if let Some(ref pool) = _s.db_pool {
469                if let Some(org_id) = extract_org_id_from_header(&auth_header, &_s) {
470                    if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
471                        let tier: Tier = org.tier.parse().unwrap_or(Tier::Free);
472                        return (
473                            StatusCode::OK,
474                            Json(serde_json::json!({
475                                "tier": org.tier,
476                                "event_limit": tier.event_limit(),
477                                "display_name": tier.display_name(),
478                            })),
479                        )
480                            .into_response();
481                    }
482                }
483            }
484
485            // Fallback: hardcoded free
486            let _ = auth_header;
487            (
488                StatusCode::OK,
489                Json(serde_json::json!({
490                    "tier": "free",
491                    "event_limit": 10_000,
492                    "display_name": "Free",
493                })),
494            )
495                .into_response()
496        }
497        None => (
498            StatusCode::SERVICE_UNAVAILABLE,
499            Json(serde_json::json!({ "error": "Billing not configured" })),
500        )
501            .into_response(),
502    }
503}
504
505#[derive(Debug, Deserialize)]
506struct CheckoutRequest {
507    success_url: Option<String>,
508    cancel_url: Option<String>,
509}
510
511/// POST /api/v1/billing/checkout — create Stripe Checkout session.
512async fn handle_checkout(
513    State(state): State<Option<SharedBillingState>>,
514    headers: HeaderMap,
515    Json(body): Json<CheckoutRequest>,
516) -> Response {
517    let auth_header = headers
518        .get("authorization")
519        .and_then(|v| v.to_str().ok())
520        .map(|s| s.to_string());
521
522    match state {
523        Some(s) => {
524            if s.config.pro_price_id.is_empty() {
525                return (
526                    StatusCode::BAD_REQUEST,
527                    Json(serde_json::json!({
528                        "error": "Stripe Price ID not configured"
529                    })),
530                )
531                    .into_response();
532            }
533
534            let success_url = body
535                .success_url
536                .unwrap_or_else(|| format!("{}/billing?success=true", s.config.frontend_url));
537            let cancel_url = body
538                .cancel_url
539                .unwrap_or_else(|| format!("{}/billing", s.config.frontend_url));
540
541            // Build Stripe Checkout params
542            let org_id_str = extract_org_id_str_from_header(&auth_header, &s).unwrap_or_default();
543
544            let mut params: Vec<(&str, &str)> = vec![
545                ("mode", "subscription"),
546                ("line_items[0][price]", &s.config.pro_price_id),
547                ("line_items[0][quantity]", "1"),
548                ("success_url", &success_url),
549                ("cancel_url", &cancel_url),
550            ];
551
552            if !org_id_str.is_empty() {
553                params.push(("client_reference_id", &org_id_str));
554            }
555
556            // Look up existing Stripe customer
557            #[allow(unused_mut)]
558            let mut customer_id = String::new();
559            #[cfg(feature = "saas")]
560            if let Some(ref pool) = s.db_pool {
561                if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
562                    if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
563                    {
564                        if let Some(cid) = org.stripe_customer_id {
565                            customer_id = cid;
566                        }
567                    }
568                }
569            }
570
571            if !customer_id.is_empty() {
572                params.push(("customer", &customer_id));
573            }
574
575            match stripe_post(
576                &s.http_client,
577                &s.config.stripe_secret_key,
578                "checkout/sessions",
579                &params,
580            )
581            .await
582            {
583                Ok(session) => {
584                    let checkout_url = session["url"].as_str().unwrap_or("");
585                    let session_id = session["id"].as_str().unwrap_or("");
586                    // Audit log: checkout started
587                    if let Some(ref logger) = s.audit_logger {
588                        logger
589                            .log(
590                                AuditEntry::new(
591                                    &org_id_str,
592                                    AuditAction::CheckoutStarted,
593                                    "/api/v1/billing/checkout",
594                                )
595                                .with_detail(format!("session: {session_id}")),
596                            )
597                            .await;
598                    }
599                    (
600                        StatusCode::OK,
601                        Json(serde_json::json!({
602                            "checkout_url": checkout_url,
603                            "session_id": session_id,
604                        })),
605                    )
606                        .into_response()
607                }
608                Err(e) => {
609                    tracing::error!("Stripe checkout failed: {}", e);
610                    (
611                        StatusCode::BAD_GATEWAY,
612                        Json(serde_json::json!({"error": e})),
613                    )
614                        .into_response()
615                }
616            }
617        }
618        None => (
619            StatusCode::SERVICE_UNAVAILABLE,
620            Json(serde_json::json!({ "error": "Billing not configured" })),
621        )
622            .into_response(),
623    }
624}
625
626/// POST /api/v1/billing/portal — create Stripe Customer Portal session.
627async fn handle_portal(
628    State(state): State<Option<SharedBillingState>>,
629    headers: HeaderMap,
630) -> Response {
631    let auth_header = headers
632        .get("authorization")
633        .and_then(|v| v.to_str().ok())
634        .map(|s| s.to_string());
635
636    match state {
637        Some(s) => {
638            #[allow(unused_mut)]
639            let mut customer_id = String::new();
640
641            #[cfg(feature = "saas")]
642            if let Some(ref pool) = s.db_pool {
643                if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
644                    if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
645                    {
646                        if let Some(cid) = org.stripe_customer_id {
647                            customer_id = cid;
648                        }
649                    }
650                }
651            }
652            let _ = auth_header;
653
654            if customer_id.is_empty() {
655                return (
656                    StatusCode::BAD_REQUEST,
657                    Json(serde_json::json!({
658                        "error": "No Stripe customer found. Upgrade first."
659                    })),
660                )
661                    .into_response();
662            }
663
664            let return_url = format!("{}/billing", s.config.frontend_url);
665            match stripe_post(
666                &s.http_client,
667                &s.config.stripe_secret_key,
668                "billing_portal/sessions",
669                &[("customer", &customer_id), ("return_url", &return_url)],
670            )
671            .await
672            {
673                Ok(session) => {
674                    let portal_url = session["url"].as_str().unwrap_or("");
675                    (
676                        StatusCode::OK,
677                        Json(serde_json::json!({
678                            "portal_url": portal_url,
679                        })),
680                    )
681                        .into_response()
682                }
683                Err(e) => {
684                    tracing::error!("Stripe portal failed: {}", e);
685                    (
686                        StatusCode::BAD_GATEWAY,
687                        Json(serde_json::json!({"error": e})),
688                    )
689                        .into_response()
690                }
691            }
692        }
693        None => (
694            StatusCode::SERVICE_UNAVAILABLE,
695            Json(serde_json::json!({ "error": "Billing not configured" })),
696        )
697            .into_response(),
698    }
699}
700
701/// POST /api/v1/billing/webhook — handle Stripe webhook events.
702async fn handle_webhook(
703    State(state): State<Option<SharedBillingState>>,
704    headers: HeaderMap,
705    body: bytes::Bytes,
706) -> Response {
707    let sig_header = headers
708        .get("stripe-signature")
709        .and_then(|v| v.to_str().ok())
710        .map(|s| s.to_string());
711
712    let s = match state {
713        Some(s) => s,
714        None => {
715            return (
716                StatusCode::SERVICE_UNAVAILABLE,
717                Json(serde_json::json!({"error": "Billing not configured"})),
718            )
719                .into_response();
720        }
721    };
722
723    // Verify signature
724    if !s.config.stripe_webhook_secret.is_empty() {
725        let sig = sig_header.unwrap_or_default();
726        if !verify_stripe_signature(&body, &sig, &s.config.stripe_webhook_secret) {
727            return (
728                StatusCode::BAD_REQUEST,
729                Json(serde_json::json!({"error": "Invalid signature"})),
730            )
731                .into_response();
732        }
733    }
734
735    // Parse event
736    let event: serde_json::Value = match serde_json::from_slice(&body) {
737        Ok(v) => v,
738        Err(e) => {
739            tracing::error!("Webhook parse error: {}", e);
740            return (
741                StatusCode::BAD_REQUEST,
742                Json(serde_json::json!({"error": "Invalid JSON"})),
743            )
744                .into_response();
745        }
746    };
747
748    let event_type = event["type"].as_str().unwrap_or("");
749    tracing::info!("Stripe webhook: {}", event_type);
750
751    // Audit log: webhook received
752    if let Some(ref logger) = s.audit_logger {
753        logger
754            .log(
755                AuditEntry::new(
756                    "stripe",
757                    AuditAction::WebhookReceived,
758                    "/api/v1/billing/webhook",
759                )
760                .with_detail(event_type.to_string()),
761            )
762            .await;
763    }
764
765    #[cfg(feature = "saas")]
766    if let Some(ref pool) = s.db_pool {
767        match event_type {
768            "checkout.session.completed" => {
769                let obj = &event["data"]["object"];
770                let customer = obj["customer"].as_str().unwrap_or("");
771                let client_ref = obj["client_reference_id"].as_str().unwrap_or("");
772
773                if !client_ref.is_empty() && !customer.is_empty() {
774                    if let Ok(org_id) = client_ref.parse::<uuid::Uuid>() {
775                        if let Err(e) =
776                            varpulis_db::repo::update_org_stripe_customer(pool, org_id, customer)
777                                .await
778                        {
779                            tracing::error!("Failed to save Stripe customer: {}", e);
780                        }
781                        if let Err(e) =
782                            varpulis_db::repo::update_org_tier(pool, org_id, "pro").await
783                        {
784                            tracing::error!("Failed to update tier: {}", e);
785                        }
786                        tracing::info!("Org {} upgraded to pro (customer: {})", org_id, customer);
787                        // Audit log: tier upgrade
788                        if let Some(ref logger) = s.audit_logger {
789                            logger
790                                .log(
791                                    AuditEntry::new(
792                                        org_id.to_string(),
793                                        AuditAction::TierChange,
794                                        "/api/v1/billing/webhook",
795                                    )
796                                    .with_detail("upgraded to pro"),
797                                )
798                                .await;
799                        }
800                    }
801                }
802            }
803            "customer.subscription.deleted" => {
804                let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
805                if !customer.is_empty() {
806                    if let Ok(Some(org)) =
807                        varpulis_db::repo::get_org_by_stripe_customer(pool, customer).await
808                    {
809                        if let Err(e) =
810                            varpulis_db::repo::update_org_tier(pool, org.id, "free").await
811                        {
812                            tracing::error!("Failed to downgrade org: {}", e);
813                        }
814                        tracing::info!("Org {} downgraded to free", org.id);
815                        // Audit log: tier downgrade
816                        if let Some(ref logger) = s.audit_logger {
817                            logger
818                                .log(
819                                    AuditEntry::new(
820                                        org.id.to_string(),
821                                        AuditAction::TierChange,
822                                        "/api/v1/billing/webhook",
823                                    )
824                                    .with_detail("downgraded to free"),
825                                )
826                                .await;
827                        }
828                    }
829                }
830            }
831            "customer.subscription.updated" => {
832                // Could check for plan changes; for now just log
833                tracing::info!("Subscription updated (no-op)");
834            }
835            "invoice.payment_failed" => {
836                let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
837                tracing::warn!("Payment failed for customer {}", customer);
838            }
839            _ => {
840                tracing::debug!("Unhandled webhook event: {}", event_type);
841            }
842        }
843    }
844
845    #[cfg(not(feature = "saas"))]
846    {
847        let _ = event_type;
848        tracing::debug!("Webhook received but saas feature not enabled");
849    }
850
851    (StatusCode::OK, Json(serde_json::json!({"received": true}))).into_response()
852}
853
854// ---------------------------------------------------------------------------
855// JWT claim extraction helpers
856// ---------------------------------------------------------------------------
857
858/// Extract org_id UUID from Authorization header JWT.
859#[cfg_attr(not(feature = "saas"), allow(dead_code))]
860fn extract_org_id_from_header(
861    auth_header: &Option<String>,
862    state: &BillingState,
863) -> Option<uuid::Uuid> {
864    extract_org_id_str_from_header(auth_header, state)?
865        .parse()
866        .ok()
867}
868
869/// Extract org_id string from Authorization header JWT.
870fn extract_org_id_str_from_header(
871    auth_header: &Option<String>,
872    _state: &BillingState,
873) -> Option<String> {
874    let header = auth_header.as_ref()?;
875    let token = header.strip_prefix("Bearer ")?.trim();
876    if token.is_empty() {
877        return None;
878    }
879
880    // Decode JWT without full verification (billing state doesn't have JWT secret).
881    // Use jsonwebtoken's dangerous_insecure_decode to read claims.
882    use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
883    let mut validation = Validation::new(Algorithm::HS256);
884    validation.insecure_disable_signature_validation();
885    validation.validate_exp = false;
886    let token_data =
887        decode::<serde_json::Value>(token, &DecodingKey::from_secret(b""), &validation).ok()?;
888
889    let org_id = token_data.claims["org_id"].as_str()?;
890    if org_id.is_empty() {
891        return None;
892    }
893    Some(org_id.to_string())
894}
895
896// ---------------------------------------------------------------------------
897// Route assembly
898// ---------------------------------------------------------------------------
899
900/// Build billing routes. When `state` is None, endpoints return 503.
901pub fn billing_routes(state: Option<SharedBillingState>) -> Router {
902    Router::new()
903        .route("/api/v1/billing/usage", get(handle_usage))
904        .route("/api/v1/billing/plan", get(handle_plan))
905        .route("/api/v1/billing/checkout", post(handle_checkout))
906        .route("/api/v1/billing/portal", post(handle_portal))
907        .route("/api/v1/billing/webhook", post(handle_webhook))
908        .with_state(state)
909}
910
911// ---------------------------------------------------------------------------
912// Tests
913// ---------------------------------------------------------------------------
914
915#[cfg(test)]
916mod tests {
917    use super::*;
918    use axum::body::Body;
919    use axum::http::Request;
920    use tower::ServiceExt;
921
922    fn get_req(uri: &str) -> Request<Body> {
923        Request::builder()
924            .method("GET")
925            .uri(uri)
926            .body(Body::empty())
927            .unwrap()
928    }
929
930    #[test]
931    fn test_tier_event_limits() {
932        assert_eq!(Tier::Free.event_limit(), Some(10_000));
933        assert_eq!(Tier::Pro.event_limit(), Some(10_000_000));
934        assert_eq!(Tier::Enterprise.event_limit(), None);
935    }
936
937    #[test]
938    fn test_tier_display_name() {
939        assert_eq!(Tier::Free.display_name(), "Free");
940        assert_eq!(Tier::Pro.display_name(), "Pro ($49/mo)");
941        assert_eq!(Tier::Enterprise.display_name(), "Enterprise");
942    }
943
944    #[test]
945    fn test_tier_from_str() {
946        assert_eq!("free".parse::<Tier>(), Ok(Tier::Free));
947        assert_eq!("pro".parse::<Tier>(), Ok(Tier::Pro));
948        assert_eq!("enterprise".parse::<Tier>(), Ok(Tier::Enterprise));
949        assert!("invalid".parse::<Tier>().is_err());
950    }
951
952    #[test]
953    fn test_tier_serialization() {
954        let json = serde_json::to_string(&Tier::Pro).unwrap();
955        assert_eq!(json, "\"pro\"");
956        let deserialized: Tier = serde_json::from_str(&json).unwrap();
957        assert_eq!(deserialized, Tier::Pro);
958    }
959
960    #[test]
961    fn test_usage_tracker_record_and_drain() {
962        let mut tracker = UsageTracker::new();
963        let org = Uuid::new_v4();
964
965        tracker.record_events(org, 100);
966        tracker.record_events(org, 50);
967        assert_eq!(tracker.get(&org), 150);
968
969        let drained = tracker.drain();
970        assert_eq!(drained.len(), 1);
971        assert_eq!(drained[0], (org, 150));
972
973        // Buffer should be empty after drain
974        assert_eq!(tracker.get(&org), 0);
975    }
976
977    #[test]
978    fn test_usage_tracker_multiple_orgs() {
979        let mut tracker = UsageTracker::new();
980        let org1 = Uuid::new_v4();
981        let org2 = Uuid::new_v4();
982
983        tracker.record_events(org1, 100);
984        tracker.record_events(org2, 200);
985        tracker.record_events(org1, 50);
986
987        assert_eq!(tracker.get(&org1), 150);
988        assert_eq!(tracker.get(&org2), 200);
989    }
990
991    #[test]
992    fn test_verify_stripe_signature() {
993        use hmac::{Hmac, Mac};
994        use sha2::Sha256;
995
996        let secret = "whsec_test123";
997        let payload = b"{\"type\":\"test\"}";
998        let timestamp = "1234567890";
999
1000        // Compute expected signature
1001        let signed = format!("{}.{}", timestamp, std::str::from_utf8(payload).unwrap());
1002        let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
1003        hmac::Mac::update(&mut mac, signed.as_bytes());
1004        let sig = hex::encode(mac.finalize().into_bytes());
1005
1006        let header = format!("t={timestamp},v1={sig}");
1007
1008        assert!(verify_stripe_signature(payload, &header, secret));
1009        assert!(!verify_stripe_signature(payload, &header, "wrong_secret"));
1010        assert!(!verify_stripe_signature(b"tampered", &header, secret));
1011    }
1012
1013    #[tokio::test]
1014    async fn test_billing_routes_not_configured() {
1015        let app = billing_routes(None);
1016
1017        let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1018
1019        assert_eq!(res.status(), 503);
1020    }
1021
1022    #[tokio::test]
1023    async fn test_billing_routes_usage() {
1024        let config = BillingConfig {
1025            stripe_secret_key: "sk_test_xxx".to_string(),
1026            stripe_webhook_secret: "whsec_xxx".to_string(),
1027            pro_price_id: "price_xxx".to_string(),
1028            frontend_url: "http://localhost:5173".to_string(),
1029        };
1030        let state = Arc::new(BillingState::new(config));
1031        let app = billing_routes(Some(state));
1032
1033        let res = app.oneshot(get_req("/api/v1/billing/usage")).await.unwrap();
1034
1035        assert_eq!(res.status(), 200);
1036    }
1037
1038    #[tokio::test]
1039    async fn test_billing_routes_plan() {
1040        let config = BillingConfig {
1041            stripe_secret_key: "sk_test_xxx".to_string(),
1042            stripe_webhook_secret: "whsec_xxx".to_string(),
1043            pro_price_id: "price_xxx".to_string(),
1044            frontend_url: "http://localhost:5173".to_string(),
1045        };
1046        let state = Arc::new(BillingState::new(config));
1047        let app = billing_routes(Some(state));
1048
1049        let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1050
1051        assert_eq!(res.status(), 200);
1052        let body = axum::body::to_bytes(res.into_body(), usize::MAX)
1053            .await
1054            .unwrap();
1055        let body: serde_json::Value = serde_json::from_slice(&body).unwrap();
1056        assert_eq!(body["tier"], "free");
1057    }
1058
1059    #[tokio::test]
1060    async fn test_webhook_invalid_signature() {
1061        let config = BillingConfig {
1062            stripe_secret_key: "sk_test_xxx".to_string(),
1063            stripe_webhook_secret: "whsec_real_secret".to_string(),
1064            pro_price_id: "price_xxx".to_string(),
1065            frontend_url: "http://localhost:5173".to_string(),
1066        };
1067        let state = Arc::new(BillingState::new(config));
1068        let app = billing_routes(Some(state));
1069
1070        let req: Request<Body> = Request::builder()
1071            .method("POST")
1072            .uri("/api/v1/billing/webhook")
1073            .header("stripe-signature", "t=123,v1=bad")
1074            .header("content-type", "application/octet-stream")
1075            .body(Body::from("{\"type\":\"test\"}"))
1076            .unwrap();
1077        let res = app.oneshot(req).await.unwrap();
1078
1079        assert_eq!(res.status(), 400);
1080    }
1081
1082    #[test]
1083    fn test_usage_limit_exceeded_serialization() {
1084        let err = UsageLimitExceeded {
1085            tier: Tier::Free,
1086            limit: 10_000,
1087            current_usage: 10_500,
1088            message: "Usage limit exceeded for Free tier".to_string(),
1089        };
1090        let json = serde_json::to_value(&err).unwrap();
1091        assert_eq!(json["tier"], "free");
1092        assert_eq!(json["limit"], 10_000);
1093        assert_eq!(json["current_usage"], 10_500);
1094    }
1095
1096    #[test]
1097    fn test_usage_limit_response_status() {
1098        let err = UsageLimitExceeded {
1099            tier: Tier::Free,
1100            limit: 10_000,
1101            current_usage: 11_000,
1102            message: "Limit exceeded".to_string(),
1103        };
1104        let resp = usage_limit_response(&err);
1105        assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
1106    }
1107
1108    #[test]
1109    fn test_tier_enterprise_no_limit() {
1110        // Enterprise has no event limit
1111        assert_eq!(Tier::Enterprise.event_limit(), None);
1112    }
1113
1114    #[test]
1115    fn test_usage_tracker_tracks_independently() {
1116        let mut tracker = UsageTracker::new();
1117        let org = Uuid::new_v4();
1118
1119        // Record events in multiple increments
1120        tracker.record_events(org, 5_000);
1121        tracker.record_events(org, 3_000);
1122        tracker.record_events(org, 2_001);
1123
1124        // Should be cumulative
1125        assert_eq!(tracker.get(&org), 10_001);
1126    }
1127}