1use std::collections::HashMap;
7use std::sync::Arc;
8
9use axum::extract::{Json, State};
10use axum::http::{HeaderMap, StatusCode};
11use axum::response::{IntoResponse, Response};
12use axum::routing::{get, post};
13use axum::Router;
14#[cfg(feature = "saas")]
15use chrono::Datelike;
16use serde::{Deserialize, Serialize};
17use tokio::sync::RwLock;
18use uuid::Uuid;
19
20use crate::audit::{AuditAction, AuditEntry, SharedAuditLogger};
21
22#[derive(Debug, Clone)]
28pub struct BillingConfig {
29 pub stripe_secret_key: String,
30 pub stripe_webhook_secret: String,
31 pub pro_price_id: String,
32 pub business_price_id: String,
33 pub frontend_url: String,
34}
35
36impl BillingConfig {
37 pub fn from_env() -> Option<Self> {
40 let secret_key = std::env::var("STRIPE_SECRET_KEY").ok()?;
41 let webhook_secret =
42 std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_else(|_| String::new());
43 let pro_price_id = std::env::var("STRIPE_PRO_PRICE_ID").unwrap_or_else(|_| String::new());
44 let business_price_id =
45 std::env::var("STRIPE_BUSINESS_PRICE_ID").unwrap_or_else(|_| String::new());
46 let frontend_url =
47 std::env::var("FRONTEND_URL").unwrap_or_else(|_| "http://localhost:5173".to_string());
48
49 Some(Self {
50 stripe_secret_key: secret_key,
51 stripe_webhook_secret: webhook_secret,
52 pro_price_id,
53 business_price_id,
54 frontend_url,
55 })
56 }
57
58 pub fn price_id_for_tier(&self, tier: &Tier) -> Option<&str> {
60 match tier {
61 Tier::Pro if !self.pro_price_id.is_empty() => Some(&self.pro_price_id),
62 Tier::Business if !self.business_price_id.is_empty() => Some(&self.business_price_id),
63 _ => None,
64 }
65 }
66
67 pub fn tier_for_price_id(&self, price_id: &str) -> Option<Tier> {
69 if !self.pro_price_id.is_empty() && price_id == self.pro_price_id {
70 Some(Tier::Pro)
71 } else if !self.business_price_id.is_empty() && price_id == self.business_price_id {
72 Some(Tier::Business)
73 } else {
74 None
75 }
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
84#[serde(rename_all = "lowercase")]
85pub enum Tier {
86 Free,
87 Pro,
88 Business,
89 Enterprise,
90}
91
92impl Tier {
93 pub const fn event_limit(&self) -> Option<i64> {
94 match self {
95 Self::Free => Some(100_000),
96 Self::Pro => Some(10_000_000),
97 Self::Business => Some(100_000_000),
98 Self::Enterprise => None,
99 }
100 }
101
102 pub const fn display_name(&self) -> &str {
103 match self {
104 Self::Free => "Free",
105 Self::Pro => "Pro ($49/mo)",
106 Self::Business => "Business ($199/mo)",
107 Self::Enterprise => "Enterprise",
108 }
109 }
110}
111
112impl std::fmt::Display for Tier {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 Self::Free => write!(f, "free"),
116 Self::Pro => write!(f, "pro"),
117 Self::Business => write!(f, "business"),
118 Self::Enterprise => write!(f, "enterprise"),
119 }
120 }
121}
122
123impl std::str::FromStr for Tier {
124 type Err = String;
125
126 fn from_str(s: &str) -> Result<Self, Self::Err> {
127 match s {
128 "free" => Ok(Self::Free),
129 "pro" => Ok(Self::Pro),
130 "business" => Ok(Self::Business),
131 "enterprise" => Ok(Self::Enterprise),
132 other => Err(format!("unknown tier: {other}")),
133 }
134 }
135}
136
137#[derive(Debug)]
143pub struct UsageTracker {
144 buffer: HashMap<Uuid, i64>,
145 monthly_totals: HashMap<Uuid, i64>,
147 monthly_limits: HashMap<Uuid, i64>,
149}
150
151impl Default for UsageTracker {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157impl UsageTracker {
158 pub fn new() -> Self {
159 Self {
160 buffer: HashMap::new(),
161 monthly_totals: HashMap::new(),
162 monthly_limits: HashMap::new(),
163 }
164 }
165
166 pub fn record_events(&mut self, org_id: Uuid, count: i64) {
167 *self.buffer.entry(org_id).or_insert(0) += count;
168 }
169
170 pub fn drain(&mut self) -> Vec<(Uuid, i64)> {
172 self.buffer.drain().collect()
173 }
174
175 pub fn get(&self, org_id: &Uuid) -> i64 {
176 self.buffer.get(org_id).copied().unwrap_or(0)
177 }
178
179 pub fn set_monthly_total(&mut self, org_id: Uuid, total: i64) {
181 self.monthly_totals.insert(org_id, total);
182 }
183
184 pub fn get_monthly_total(&self, org_id: &Uuid) -> Option<i64> {
186 self.monthly_totals.get(org_id).copied()
187 }
188
189 pub fn set_monthly_limit(&mut self, org_id: Uuid, limit: i64) {
191 self.monthly_limits.insert(org_id, limit);
192 }
193
194 pub fn get_monthly_limit(&self, org_id: &Uuid) -> Option<i64> {
196 self.monthly_limits.get(org_id).copied()
197 }
198
199 pub fn check_cached_limit(&self, org_id: &Uuid, additional: i64) -> Option<bool> {
202 let total = self.monthly_totals.get(org_id)?;
203 let limit = self.monthly_limits.get(org_id)?;
204 let buffered = self.buffer.get(org_id).copied().unwrap_or(0);
205 Some(total + buffered + additional <= *limit)
206 }
207}
208
209#[derive(Debug)]
214pub struct BillingState {
215 pub config: BillingConfig,
216 pub usage: RwLock<UsageTracker>,
217 pub http_client: reqwest::Client,
218 #[cfg(feature = "saas")]
219 pub db_pool: Option<varpulis_db::PgPool>,
220 pub audit_logger: Option<SharedAuditLogger>,
221}
222
223impl BillingState {
224 pub fn new(config: BillingConfig) -> Self {
225 Self {
226 config,
227 usage: RwLock::new(UsageTracker::new()),
228 http_client: reqwest::Client::new(),
229 #[cfg(feature = "saas")]
230 db_pool: None,
231 audit_logger: None,
232 }
233 }
234
235 pub fn with_audit_logger(mut self, logger: Option<SharedAuditLogger>) -> Self {
236 self.audit_logger = logger;
237 self
238 }
239
240 #[cfg(feature = "saas")]
241 pub fn with_db_pool(mut self, pool: varpulis_db::PgPool) -> Self {
242 self.db_pool = Some(pool);
243 self
244 }
245}
246
247pub type SharedBillingState = Arc<BillingState>;
248
249#[derive(Debug, Serialize)]
255pub struct UsageLimitExceeded {
256 pub tier: Tier,
257 pub limit: i64,
258 pub current_usage: i64,
259 pub message: String,
260}
261
262#[derive(Debug)]
264pub enum UsageCheckResult {
265 Ok,
267 ApproachingLimit { usage_percent: f64 },
269 Exceeded(UsageLimitExceeded),
271}
272
273impl BillingState {
274 #[cfg(feature = "saas")]
278 pub async fn check_usage_limit(
279 &self,
280 org_id: Uuid,
281 additional_events: i64,
282 ) -> UsageCheckResult {
283 let tier = if let Some(ref pool) = self.db_pool {
285 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
286 org.tier.parse().unwrap_or(Tier::Free)
287 } else {
288 Tier::Free
289 }
290 } else {
291 Tier::Free
292 };
293
294 let limit = match tier.event_limit() {
296 Some(l) => l,
297 None => return UsageCheckResult::Ok,
298 };
299
300 let db_usage = if let Some(ref pool) = self.db_pool {
302 let today = chrono::Utc::now().date_naive();
303 let start =
304 chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1).unwrap_or(today);
305 if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await {
306 rows.iter().map(|r| r.events_processed).sum::<i64>()
307 } else {
308 0
309 }
310 } else {
311 0
312 };
313
314 let buffered = self.usage.read().await.get(&org_id);
316 let total = db_usage + buffered + additional_events;
317
318 if total > limit {
319 UsageCheckResult::Exceeded(UsageLimitExceeded {
320 tier: tier.clone(),
321 limit,
322 current_usage: db_usage + buffered,
323 message: format!(
324 "Usage limit exceeded for {} tier ({}/{} events this month). Upgrade to increase your limit.",
325 tier.display_name(),
326 db_usage + buffered,
327 limit,
328 ),
329 })
330 } else {
331 let usage_percent = (total as f64 / limit as f64) * 100.0;
332 if usage_percent >= 80.0 {
333 UsageCheckResult::ApproachingLimit { usage_percent }
334 } else {
335 UsageCheckResult::Ok
336 }
337 }
338 }
339
340 #[cfg(feature = "saas")]
343 pub async fn org_id_for_api_key(&self, raw_key: &str) -> Option<Uuid> {
344 use sha2::Digest;
345 let pool = self.db_pool.as_ref()?;
346 let hash = hex::encode(sha2::Sha256::digest(raw_key.as_bytes()));
347 let api_key = varpulis_db::repo::get_api_key_by_hash(pool, &hash)
348 .await
349 .ok()??;
350 Some(api_key.org_id)
351 }
352}
353
354pub fn usage_limit_response(err: &UsageLimitExceeded) -> Response {
356 let body = Json(serde_json::json!({
357 "error": "usage_limit_exceeded",
358 "message": err.message,
359 "tier": err.tier,
360 "limit": err.limit,
361 "current_usage": err.current_usage,
362 "upgrade_url": "/billing",
363 }));
364
365 (
366 StatusCode::TOO_MANY_REQUESTS,
367 [("Retry-After", "3600")],
368 body,
369 )
370 .into_response()
371}
372
373#[cfg(feature = "saas")]
380pub fn spawn_usage_flush(state: SharedBillingState, pool: varpulis_db::PgPool) {
381 tokio::spawn(async move {
382 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
383 loop {
384 interval.tick().await;
385 let entries = state.usage.write().await.drain();
386 let today = chrono::Utc::now().date_naive();
387 for (org_id, count) in &entries {
388 if let Err(e) =
389 varpulis_db::repo::record_usage(&pool, *org_id, today, *count, 0).await
390 {
391 tracing::error!("Failed to flush usage for org {}: {}", org_id, e);
392 }
393 }
394
395 if let Ok(orgs) = varpulis_db::repo::list_all_organizations(&pool).await {
397 let mut tracker = state.usage.write().await;
398 for org in &orgs {
399 if let Ok(total) = varpulis_db::repo::get_org_usage_summary(&pool, org.id).await
400 {
401 tracker.set_monthly_total(org.id, total);
402 }
403 tracker.set_monthly_limit(org.id, org.monthly_event_limit);
404 }
405 }
406
407 if !entries.is_empty() {
408 tracing::debug!("Usage flush complete ({} orgs)", entries.len());
409 }
410 }
411 });
412}
413
414async fn stripe_post(
420 client: &reqwest::Client,
421 secret_key: &str,
422 endpoint: &str,
423 params: &[(&str, &str)],
424) -> Result<serde_json::Value, String> {
425 let resp = client
426 .post(format!("https://api.stripe.com/v1/{endpoint}"))
427 .basic_auth(secret_key, None::<&str>)
428 .form(params)
429 .send()
430 .await
431 .map_err(|e| format!("Stripe request failed: {e}"))?;
432
433 let status = resp.status();
434 let body: serde_json::Value = resp
435 .json()
436 .await
437 .map_err(|e| format!("Stripe response parse failed: {e}"))?;
438
439 if !status.is_success() {
440 let msg = body["error"]["message"]
441 .as_str()
442 .unwrap_or("Unknown Stripe error");
443 return Err(format!("Stripe API error ({status}): {msg}"));
444 }
445
446 Ok(body)
447}
448
449fn verify_stripe_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
451 use hmac::{Hmac, Mac};
452 use sha2::Sha256;
453
454 let mut timestamp = "";
456 let mut signature = "";
457 for part in sig_header.split(',') {
458 if let Some(t) = part.strip_prefix("t=") {
459 timestamp = t;
460 } else if let Some(s) = part.strip_prefix("v1=") {
461 signature = s;
462 }
463 }
464
465 if timestamp.is_empty() || signature.is_empty() {
466 return false;
467 }
468
469 let signed_payload = format!(
471 "{}.{}",
472 timestamp,
473 std::str::from_utf8(payload).unwrap_or("")
474 );
475 let mut mac =
476 Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key size");
477 hmac::Mac::update(&mut mac, signed_payload.as_bytes());
478 let expected = hex::encode(mac.finalize().into_bytes());
479
480 expected == signature
482}
483
484async fn handle_usage(
490 State(state): State<Option<SharedBillingState>>,
491 headers: HeaderMap,
492) -> Response {
493 let auth_header = headers
494 .get("authorization")
495 .and_then(|v| v.to_str().ok())
496 .map(|s| s.to_string());
497
498 match state {
499 Some(s) => {
500 #[cfg(feature = "saas")]
502 if let Some(ref pool) = s.db_pool {
503 if let Some(org_id) = extract_org_id_from_header(&auth_header, &s) {
504 let today = chrono::Utc::now().date_naive();
505 let start = chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1)
506 .unwrap_or(today);
507 if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await
508 {
509 let total: i64 = rows.iter().map(|r| r.events_processed).sum();
510 return (
511 StatusCode::OK,
512 Json(serde_json::json!({
513 "events_this_month": total,
514 "daily": rows.iter().map(|r| serde_json::json!({
515 "date": r.date.to_string(),
516 "events_processed": r.events_processed,
517 })).collect::<Vec<_>>(),
518 })),
519 )
520 .into_response();
521 }
522 }
523 }
524
525 let _ = auth_header;
527 let tracker = s.usage.read().await;
528 let orgs: Vec<serde_json::Value> = tracker
529 .buffer
530 .iter()
531 .map(|(org_id, count)| {
532 serde_json::json!({
533 "org_id": org_id.to_string(),
534 "events_today": count,
535 })
536 })
537 .collect();
538 (StatusCode::OK, Json(serde_json::json!({ "usage": orgs }))).into_response()
539 }
540 None => (
541 StatusCode::SERVICE_UNAVAILABLE,
542 Json(serde_json::json!({ "error": "Billing not configured" })),
543 )
544 .into_response(),
545 }
546}
547
548async fn handle_plan(
550 State(state): State<Option<SharedBillingState>>,
551 headers: HeaderMap,
552) -> Response {
553 let auth_header = headers
554 .get("authorization")
555 .and_then(|v| v.to_str().ok())
556 .map(|s| s.to_string());
557
558 match state {
559 Some(_s) => {
560 #[cfg(feature = "saas")]
562 if let Some(ref pool) = _s.db_pool {
563 if let Some(org_id) = extract_org_id_from_header(&auth_header, &_s) {
564 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
565 let tier: Tier = org.tier.parse().unwrap_or(Tier::Free);
566 return (
567 StatusCode::OK,
568 Json(serde_json::json!({
569 "tier": org.tier,
570 "event_limit": tier.event_limit(),
571 "display_name": tier.display_name(),
572 })),
573 )
574 .into_response();
575 }
576 }
577 }
578
579 let _ = auth_header;
581 (
582 StatusCode::OK,
583 Json(serde_json::json!({
584 "tier": "free",
585 "event_limit": 10_000,
586 "display_name": "Free",
587 })),
588 )
589 .into_response()
590 }
591 None => (
592 StatusCode::SERVICE_UNAVAILABLE,
593 Json(serde_json::json!({ "error": "Billing not configured" })),
594 )
595 .into_response(),
596 }
597}
598
599#[derive(Debug, Deserialize)]
600struct CheckoutRequest {
601 tier: Option<String>,
603 success_url: Option<String>,
604 cancel_url: Option<String>,
605}
606
607async fn handle_checkout(
609 State(state): State<Option<SharedBillingState>>,
610 headers: HeaderMap,
611 Json(body): Json<CheckoutRequest>,
612) -> Response {
613 let auth_header = headers
614 .get("authorization")
615 .and_then(|v| v.to_str().ok())
616 .map(|s| s.to_string());
617
618 match state {
619 Some(s) => {
620 let target_tier: Tier = body
622 .tier
623 .as_deref()
624 .unwrap_or("pro")
625 .parse()
626 .unwrap_or(Tier::Pro);
627
628 let price_id = match s.config.price_id_for_tier(&target_tier) {
629 Some(id) => id.to_string(),
630 None => {
631 return (
632 StatusCode::BAD_REQUEST,
633 Json(serde_json::json!({
634 "error": format!("Stripe Price ID not configured for {} tier", target_tier)
635 })),
636 )
637 .into_response();
638 }
639 };
640
641 let success_url = body
642 .success_url
643 .unwrap_or_else(|| format!("{}/billing?success=true", s.config.frontend_url));
644 let cancel_url = body
645 .cancel_url
646 .unwrap_or_else(|| format!("{}/billing", s.config.frontend_url));
647
648 let org_id_str = extract_org_id_str_from_header(&auth_header, &s).unwrap_or_default();
650
651 let mut params: Vec<(&str, &str)> = vec![
652 ("mode", "subscription"),
653 ("line_items[0][price]", &price_id),
654 ("line_items[0][quantity]", "1"),
655 ("success_url", &success_url),
656 ("cancel_url", &cancel_url),
657 ];
658
659 if !org_id_str.is_empty() {
660 params.push(("client_reference_id", &org_id_str));
661 }
662
663 #[allow(unused_mut)]
665 let mut customer_id = String::new();
666 #[cfg(feature = "saas")]
667 if let Some(ref pool) = s.db_pool {
668 if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
669 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
670 {
671 if let Some(cid) = org.stripe_customer_id {
672 customer_id = cid;
673 }
674 }
675 }
676 }
677
678 if !customer_id.is_empty() {
679 params.push(("customer", &customer_id));
680 }
681
682 match stripe_post(
683 &s.http_client,
684 &s.config.stripe_secret_key,
685 "checkout/sessions",
686 ¶ms,
687 )
688 .await
689 {
690 Ok(session) => {
691 let checkout_url = session["url"].as_str().unwrap_or("");
692 let session_id = session["id"].as_str().unwrap_or("");
693 if let Some(ref logger) = s.audit_logger {
695 logger
696 .log(
697 AuditEntry::new(
698 &org_id_str,
699 AuditAction::CheckoutStarted,
700 "/api/v1/billing/checkout",
701 )
702 .with_detail(format!("session: {session_id}")),
703 )
704 .await;
705 }
706 (
707 StatusCode::OK,
708 Json(serde_json::json!({
709 "checkout_url": checkout_url,
710 "session_id": session_id,
711 })),
712 )
713 .into_response()
714 }
715 Err(e) => {
716 tracing::error!("Stripe checkout failed: {}", e);
717 (
718 StatusCode::BAD_GATEWAY,
719 Json(serde_json::json!({"error": e})),
720 )
721 .into_response()
722 }
723 }
724 }
725 None => (
726 StatusCode::SERVICE_UNAVAILABLE,
727 Json(serde_json::json!({ "error": "Billing not configured" })),
728 )
729 .into_response(),
730 }
731}
732
733async fn handle_portal(
735 State(state): State<Option<SharedBillingState>>,
736 headers: HeaderMap,
737) -> Response {
738 let auth_header = headers
739 .get("authorization")
740 .and_then(|v| v.to_str().ok())
741 .map(|s| s.to_string());
742
743 match state {
744 Some(s) => {
745 #[allow(unused_mut)]
746 let mut customer_id = String::new();
747
748 #[cfg(feature = "saas")]
749 if let Some(ref pool) = s.db_pool {
750 if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
751 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
752 {
753 if let Some(cid) = org.stripe_customer_id {
754 customer_id = cid;
755 }
756 }
757 }
758 }
759 let _ = auth_header;
760
761 if customer_id.is_empty() {
762 return (
763 StatusCode::BAD_REQUEST,
764 Json(serde_json::json!({
765 "error": "No Stripe customer found. Upgrade first."
766 })),
767 )
768 .into_response();
769 }
770
771 let return_url = format!("{}/billing", s.config.frontend_url);
772 match stripe_post(
773 &s.http_client,
774 &s.config.stripe_secret_key,
775 "billing_portal/sessions",
776 &[("customer", &customer_id), ("return_url", &return_url)],
777 )
778 .await
779 {
780 Ok(session) => {
781 let portal_url = session["url"].as_str().unwrap_or("");
782 (
783 StatusCode::OK,
784 Json(serde_json::json!({
785 "portal_url": portal_url,
786 })),
787 )
788 .into_response()
789 }
790 Err(e) => {
791 tracing::error!("Stripe portal failed: {}", e);
792 (
793 StatusCode::BAD_GATEWAY,
794 Json(serde_json::json!({"error": e})),
795 )
796 .into_response()
797 }
798 }
799 }
800 None => (
801 StatusCode::SERVICE_UNAVAILABLE,
802 Json(serde_json::json!({ "error": "Billing not configured" })),
803 )
804 .into_response(),
805 }
806}
807
808async fn handle_webhook(
810 State(state): State<Option<SharedBillingState>>,
811 headers: HeaderMap,
812 body: bytes::Bytes,
813) -> Response {
814 let sig_header = headers
815 .get("stripe-signature")
816 .and_then(|v| v.to_str().ok())
817 .map(|s| s.to_string());
818
819 let s = match state {
820 Some(s) => s,
821 None => {
822 return (
823 StatusCode::SERVICE_UNAVAILABLE,
824 Json(serde_json::json!({"error": "Billing not configured"})),
825 )
826 .into_response();
827 }
828 };
829
830 if !s.config.stripe_webhook_secret.is_empty() {
832 let sig = sig_header.unwrap_or_default();
833 if !verify_stripe_signature(&body, &sig, &s.config.stripe_webhook_secret) {
834 return (
835 StatusCode::BAD_REQUEST,
836 Json(serde_json::json!({"error": "Invalid signature"})),
837 )
838 .into_response();
839 }
840 }
841
842 let event: serde_json::Value = match serde_json::from_slice(&body) {
844 Ok(v) => v,
845 Err(e) => {
846 tracing::error!("Webhook parse error: {}", e);
847 return (
848 StatusCode::BAD_REQUEST,
849 Json(serde_json::json!({"error": "Invalid JSON"})),
850 )
851 .into_response();
852 }
853 };
854
855 let event_type = event["type"].as_str().unwrap_or("");
856 tracing::info!("Stripe webhook: {}", event_type);
857
858 if let Some(ref logger) = s.audit_logger {
860 logger
861 .log(
862 AuditEntry::new(
863 "stripe",
864 AuditAction::WebhookReceived,
865 "/api/v1/billing/webhook",
866 )
867 .with_detail(event_type.to_string()),
868 )
869 .await;
870 }
871
872 #[cfg(feature = "saas")]
873 if let Some(ref pool) = s.db_pool {
874 match event_type {
875 "checkout.session.completed" => {
876 let obj = &event["data"]["object"];
877 let customer = obj["customer"].as_str().unwrap_or("");
878 let client_ref = obj["client_reference_id"].as_str().unwrap_or("");
879
880 if !client_ref.is_empty() && !customer.is_empty() {
881 if let Ok(org_id) = client_ref.parse::<uuid::Uuid>() {
882 if let Err(e) =
883 varpulis_db::repo::update_org_stripe_customer(pool, org_id, customer)
884 .await
885 {
886 tracing::error!("Failed to save Stripe customer: {}", e);
887 }
888
889 let subscription_id = obj["subscription"].as_str().unwrap_or("");
891 let new_tier = if !subscription_id.is_empty() {
892 if let Ok(sub) = stripe_post(
894 &s.http_client,
895 &s.config.stripe_secret_key,
896 &format!("subscriptions/{subscription_id}"),
897 &[],
898 )
899 .await
900 {
901 let price_id = sub["items"]["data"][0]["price"]["id"]
902 .as_str()
903 .unwrap_or("");
904 s.config.tier_for_price_id(price_id).unwrap_or(Tier::Pro)
905 } else {
906 Tier::Pro
907 }
908 } else {
909 Tier::Pro
910 };
911
912 let tier_str = new_tier.to_string();
913 if let Err(e) =
914 varpulis_db::repo::update_org_tier(pool, org_id, &tier_str).await
915 {
916 tracing::error!("Failed to update tier: {}", e);
917 }
918 if let Err(e) =
920 varpulis_db::repo::update_org_status(pool, org_id, "active").await
921 {
922 tracing::error!("Failed to update org status: {}", e);
923 }
924 tracing::info!(
925 "Org {} upgraded to {} (customer: {})",
926 org_id,
927 tier_str,
928 customer
929 );
930 if let Some(ref logger) = s.audit_logger {
932 logger
933 .log(
934 AuditEntry::new(
935 org_id.to_string(),
936 AuditAction::TierChange,
937 "/api/v1/billing/webhook",
938 )
939 .with_detail(format!("upgraded to {tier_str}")),
940 )
941 .await;
942 }
943 }
944 }
945 }
946 "customer.subscription.deleted" => {
947 let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
948 if !customer.is_empty() {
949 if let Ok(Some(org)) =
950 varpulis_db::repo::get_org_by_stripe_customer(pool, customer).await
951 {
952 if let Err(e) =
953 varpulis_db::repo::update_org_tier(pool, org.id, "free").await
954 {
955 tracing::error!("Failed to downgrade org: {}", e);
956 }
957 tracing::info!("Org {} downgraded to free", org.id);
958 if let Some(ref logger) = s.audit_logger {
960 logger
961 .log(
962 AuditEntry::new(
963 org.id.to_string(),
964 AuditAction::TierChange,
965 "/api/v1/billing/webhook",
966 )
967 .with_detail("downgraded to free"),
968 )
969 .await;
970 }
971 }
972 }
973 }
974 "customer.subscription.updated" => {
975 let obj = &event["data"]["object"];
976 let customer = obj["customer"].as_str().unwrap_or("");
977 let price_id = obj["items"]["data"][0]["price"]["id"]
978 .as_str()
979 .unwrap_or("");
980
981 if !customer.is_empty() && !price_id.is_empty() {
982 if let Some(new_tier) = s.config.tier_for_price_id(price_id) {
983 if let Ok(Some(org)) =
984 varpulis_db::repo::get_org_by_stripe_customer(pool, customer).await
985 {
986 let tier_str = new_tier.to_string();
987 if org.tier != tier_str {
988 if let Err(e) =
989 varpulis_db::repo::update_org_tier(pool, org.id, &tier_str)
990 .await
991 {
992 tracing::error!(
993 "Failed to update tier on subscription change: {}",
994 e
995 );
996 } else {
997 tracing::info!(
998 "Org {} tier changed to {} via subscription update",
999 org.id,
1000 tier_str
1001 );
1002 }
1003 }
1004 }
1005 }
1006 }
1007 }
1008 "invoice.payment_failed" => {
1009 let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
1010 tracing::warn!("Payment failed for customer {}", customer);
1011 }
1012 _ => {
1013 tracing::debug!("Unhandled webhook event: {}", event_type);
1014 }
1015 }
1016 }
1017
1018 #[cfg(not(feature = "saas"))]
1019 {
1020 let _ = event_type;
1021 tracing::debug!("Webhook received but saas feature not enabled");
1022 }
1023
1024 (StatusCode::OK, Json(serde_json::json!({"received": true}))).into_response()
1025}
1026
1027#[cfg_attr(not(feature = "saas"), allow(dead_code))]
1033fn extract_org_id_from_header(
1034 auth_header: &Option<String>,
1035 state: &BillingState,
1036) -> Option<uuid::Uuid> {
1037 extract_org_id_str_from_header(auth_header, state)?
1038 .parse()
1039 .ok()
1040}
1041
1042fn extract_org_id_str_from_header(
1044 auth_header: &Option<String>,
1045 _state: &BillingState,
1046) -> Option<String> {
1047 let header = auth_header.as_ref()?;
1048 let token = header.strip_prefix("Bearer ")?.trim();
1049 if token.is_empty() {
1050 return None;
1051 }
1052
1053 let token_data = jsonwebtoken::dangerous::insecure_decode::<serde_json::Value>(token).ok()?;
1056
1057 let org_id = token_data.claims["org_id"].as_str()?;
1058 if org_id.is_empty() {
1059 return None;
1060 }
1061 Some(org_id.to_string())
1062}
1063
1064pub fn billing_routes(state: Option<SharedBillingState>) -> Router {
1070 Router::new()
1071 .route("/api/v1/billing/usage", get(handle_usage))
1072 .route("/api/v1/billing/plan", get(handle_plan))
1073 .route("/api/v1/billing/checkout", post(handle_checkout))
1074 .route("/api/v1/billing/portal", post(handle_portal))
1075 .route("/api/v1/billing/webhook", post(handle_webhook))
1076 .with_state(state)
1077}
1078
1079#[cfg(test)]
1084mod tests {
1085 use axum::body::Body;
1086 use axum::http::Request;
1087 use tower::ServiceExt;
1088
1089 use super::*;
1090
1091 fn get_req(uri: &str) -> Request<Body> {
1092 Request::builder()
1093 .method("GET")
1094 .uri(uri)
1095 .body(Body::empty())
1096 .unwrap()
1097 }
1098
1099 #[test]
1100 fn test_tier_event_limits() {
1101 assert_eq!(Tier::Free.event_limit(), Some(100_000));
1102 assert_eq!(Tier::Pro.event_limit(), Some(10_000_000));
1103 assert_eq!(Tier::Business.event_limit(), Some(100_000_000));
1104 assert_eq!(Tier::Enterprise.event_limit(), None);
1105 }
1106
1107 #[test]
1108 fn test_tier_display_name() {
1109 assert_eq!(Tier::Free.display_name(), "Free");
1110 assert_eq!(Tier::Pro.display_name(), "Pro ($49/mo)");
1111 assert_eq!(Tier::Business.display_name(), "Business ($199/mo)");
1112 assert_eq!(Tier::Enterprise.display_name(), "Enterprise");
1113 }
1114
1115 #[test]
1116 fn test_tier_from_str() {
1117 assert_eq!("free".parse::<Tier>(), Ok(Tier::Free));
1118 assert_eq!("pro".parse::<Tier>(), Ok(Tier::Pro));
1119 assert_eq!("business".parse::<Tier>(), Ok(Tier::Business));
1120 assert_eq!("enterprise".parse::<Tier>(), Ok(Tier::Enterprise));
1121 assert!("invalid".parse::<Tier>().is_err());
1122 }
1123
1124 #[test]
1125 fn test_tier_serialization() {
1126 let json = serde_json::to_string(&Tier::Business).unwrap();
1127 assert_eq!(json, "\"business\"");
1128 let deserialized: Tier = serde_json::from_str(&json).unwrap();
1129 assert_eq!(deserialized, Tier::Business);
1130 }
1131
1132 #[test]
1133 fn test_usage_tracker_record_and_drain() {
1134 let mut tracker = UsageTracker::new();
1135 let org = Uuid::new_v4();
1136
1137 tracker.record_events(org, 100);
1138 tracker.record_events(org, 50);
1139 assert_eq!(tracker.get(&org), 150);
1140
1141 let drained = tracker.drain();
1142 assert_eq!(drained.len(), 1);
1143 assert_eq!(drained[0], (org, 150));
1144
1145 assert_eq!(tracker.get(&org), 0);
1147 }
1148
1149 #[test]
1150 fn test_usage_tracker_multiple_orgs() {
1151 let mut tracker = UsageTracker::new();
1152 let org1 = Uuid::new_v4();
1153 let org2 = Uuid::new_v4();
1154
1155 tracker.record_events(org1, 100);
1156 tracker.record_events(org2, 200);
1157 tracker.record_events(org1, 50);
1158
1159 assert_eq!(tracker.get(&org1), 150);
1160 assert_eq!(tracker.get(&org2), 200);
1161 }
1162
1163 #[test]
1164 fn test_verify_stripe_signature() {
1165 use hmac::{Hmac, Mac};
1166 use sha2::Sha256;
1167
1168 let secret = "whsec_test123";
1169 let payload = b"{\"type\":\"test\"}";
1170 let timestamp = "1234567890";
1171
1172 let signed = format!("{}.{}", timestamp, std::str::from_utf8(payload).unwrap());
1174 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
1175 hmac::Mac::update(&mut mac, signed.as_bytes());
1176 let sig = hex::encode(mac.finalize().into_bytes());
1177
1178 let header = format!("t={timestamp},v1={sig}");
1179
1180 assert!(verify_stripe_signature(payload, &header, secret));
1181 assert!(!verify_stripe_signature(payload, &header, "wrong_secret"));
1182 assert!(!verify_stripe_signature(b"tampered", &header, secret));
1183 }
1184
1185 #[tokio::test]
1186 async fn test_billing_routes_not_configured() {
1187 let app = billing_routes(None);
1188
1189 let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1190
1191 assert_eq!(res.status(), 503);
1192 }
1193
1194 #[tokio::test]
1195 async fn test_billing_routes_usage() {
1196 let config = BillingConfig {
1197 stripe_secret_key: "sk_test_xxx".to_string(),
1198 stripe_webhook_secret: "whsec_xxx".to_string(),
1199 pro_price_id: "price_xxx".to_string(),
1200 business_price_id: "price_biz_xxx".to_string(),
1201 frontend_url: "http://localhost:5173".to_string(),
1202 };
1203 let state = Arc::new(BillingState::new(config));
1204 let app = billing_routes(Some(state));
1205
1206 let res = app.oneshot(get_req("/api/v1/billing/usage")).await.unwrap();
1207
1208 assert_eq!(res.status(), 200);
1209 }
1210
1211 #[tokio::test]
1212 async fn test_billing_routes_plan() {
1213 let config = BillingConfig {
1214 stripe_secret_key: "sk_test_xxx".to_string(),
1215 stripe_webhook_secret: "whsec_xxx".to_string(),
1216 pro_price_id: "price_xxx".to_string(),
1217 business_price_id: "price_biz_xxx".to_string(),
1218 frontend_url: "http://localhost:5173".to_string(),
1219 };
1220 let state = Arc::new(BillingState::new(config));
1221 let app = billing_routes(Some(state));
1222
1223 let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1224
1225 assert_eq!(res.status(), 200);
1226 let body = axum::body::to_bytes(res.into_body(), usize::MAX)
1227 .await
1228 .unwrap();
1229 let body: serde_json::Value = serde_json::from_slice(&body).unwrap();
1230 assert_eq!(body["tier"], "free");
1231 }
1232
1233 #[tokio::test]
1234 async fn test_webhook_invalid_signature() {
1235 let config = BillingConfig {
1236 stripe_secret_key: "sk_test_xxx".to_string(),
1237 stripe_webhook_secret: "whsec_real_secret".to_string(),
1238 pro_price_id: "price_xxx".to_string(),
1239 business_price_id: "price_biz_xxx".to_string(),
1240 frontend_url: "http://localhost:5173".to_string(),
1241 };
1242 let state = Arc::new(BillingState::new(config));
1243 let app = billing_routes(Some(state));
1244
1245 let req: Request<Body> = Request::builder()
1246 .method("POST")
1247 .uri("/api/v1/billing/webhook")
1248 .header("stripe-signature", "t=123,v1=bad")
1249 .header("content-type", "application/octet-stream")
1250 .body(Body::from("{\"type\":\"test\"}"))
1251 .unwrap();
1252 let res = app.oneshot(req).await.unwrap();
1253
1254 assert_eq!(res.status(), 400);
1255 }
1256
1257 #[test]
1258 fn test_usage_limit_exceeded_serialization() {
1259 let err = UsageLimitExceeded {
1260 tier: Tier::Free,
1261 limit: 100_000,
1262 current_usage: 100_500,
1263 message: "Usage limit exceeded for Free tier".to_string(),
1264 };
1265 let json = serde_json::to_value(&err).unwrap();
1266 assert_eq!(json["tier"], "free");
1267 assert_eq!(json["limit"], 100_000);
1268 assert_eq!(json["current_usage"], 100_500);
1269 }
1270
1271 #[test]
1272 fn test_usage_limit_response_status() {
1273 let err = UsageLimitExceeded {
1274 tier: Tier::Free,
1275 limit: 100_000,
1276 current_usage: 110_000,
1277 message: "Limit exceeded".to_string(),
1278 };
1279 let resp = usage_limit_response(&err);
1280 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
1281 }
1282
1283 #[test]
1284 fn test_tier_enterprise_no_limit() {
1285 assert_eq!(Tier::Enterprise.event_limit(), None);
1287 }
1288
1289 #[test]
1290 fn test_usage_tracker_tracks_independently() {
1291 let mut tracker = UsageTracker::new();
1292 let org = Uuid::new_v4();
1293
1294 tracker.record_events(org, 5_000);
1296 tracker.record_events(org, 3_000);
1297 tracker.record_events(org, 2_001);
1298
1299 assert_eq!(tracker.get(&org), 10_001);
1301 }
1302}