1use axum::extract::{Json, State};
7use axum::http::{HeaderMap, StatusCode};
8use axum::response::{IntoResponse, Response};
9use axum::routing::{get, post};
10use axum::Router;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use uuid::Uuid;
16
17use crate::audit::{AuditAction, AuditEntry, SharedAuditLogger};
18
19#[cfg(feature = "saas")]
20use chrono::Datelike;
21
22#[derive(Debug, Clone)]
28pub struct BillingConfig {
29 pub stripe_secret_key: String,
30 pub stripe_webhook_secret: String,
31 pub pro_price_id: String,
32 pub frontend_url: String,
33}
34
35impl BillingConfig {
36 pub fn from_env() -> Option<Self> {
39 let secret_key = std::env::var("STRIPE_SECRET_KEY").ok()?;
40 let webhook_secret =
41 std::env::var("STRIPE_WEBHOOK_SECRET").unwrap_or_else(|_| String::new());
42 let pro_price_id = std::env::var("STRIPE_PRO_PRICE_ID").unwrap_or_else(|_| String::new());
43 let frontend_url =
44 std::env::var("FRONTEND_URL").unwrap_or_else(|_| "http://localhost:5173".to_string());
45
46 Some(Self {
47 stripe_secret_key: secret_key,
48 stripe_webhook_secret: webhook_secret,
49 pro_price_id,
50 frontend_url,
51 })
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
60#[serde(rename_all = "lowercase")]
61pub enum Tier {
62 Free,
63 Pro,
64 Enterprise,
65}
66
67impl Tier {
68 pub const fn event_limit(&self) -> Option<i64> {
69 match self {
70 Self::Free => Some(10_000),
71 Self::Pro => Some(10_000_000),
72 Self::Enterprise => None,
73 }
74 }
75
76 pub const fn display_name(&self) -> &str {
77 match self {
78 Self::Free => "Free",
79 Self::Pro => "Pro ($49/mo)",
80 Self::Enterprise => "Enterprise",
81 }
82 }
83}
84
85impl std::fmt::Display for Tier {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 Self::Free => write!(f, "free"),
89 Self::Pro => write!(f, "pro"),
90 Self::Enterprise => write!(f, "enterprise"),
91 }
92 }
93}
94
95impl std::str::FromStr for Tier {
96 type Err = String;
97
98 fn from_str(s: &str) -> Result<Self, Self::Err> {
99 match s {
100 "free" => Ok(Self::Free),
101 "pro" => Ok(Self::Pro),
102 "enterprise" => Ok(Self::Enterprise),
103 other => Err(format!("unknown tier: {other}")),
104 }
105 }
106}
107
108#[derive(Debug)]
114pub struct UsageTracker {
115 buffer: HashMap<Uuid, i64>,
116}
117
118impl Default for UsageTracker {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl UsageTracker {
125 pub fn new() -> Self {
126 Self {
127 buffer: HashMap::new(),
128 }
129 }
130
131 pub fn record_events(&mut self, org_id: Uuid, count: i64) {
132 *self.buffer.entry(org_id).or_insert(0) += count;
133 }
134
135 pub fn drain(&mut self) -> Vec<(Uuid, i64)> {
137 self.buffer.drain().collect()
138 }
139
140 pub fn get(&self, org_id: &Uuid) -> i64 {
141 self.buffer.get(org_id).copied().unwrap_or(0)
142 }
143}
144
145#[derive(Debug)]
150pub struct BillingState {
151 pub config: BillingConfig,
152 pub usage: RwLock<UsageTracker>,
153 pub http_client: reqwest::Client,
154 #[cfg(feature = "saas")]
155 pub db_pool: Option<varpulis_db::PgPool>,
156 pub audit_logger: Option<SharedAuditLogger>,
157}
158
159impl BillingState {
160 pub fn new(config: BillingConfig) -> Self {
161 Self {
162 config,
163 usage: RwLock::new(UsageTracker::new()),
164 http_client: reqwest::Client::new(),
165 #[cfg(feature = "saas")]
166 db_pool: None,
167 audit_logger: None,
168 }
169 }
170
171 pub fn with_audit_logger(mut self, logger: Option<SharedAuditLogger>) -> Self {
172 self.audit_logger = logger;
173 self
174 }
175
176 #[cfg(feature = "saas")]
177 pub fn with_db_pool(mut self, pool: varpulis_db::PgPool) -> Self {
178 self.db_pool = Some(pool);
179 self
180 }
181}
182
183pub type SharedBillingState = Arc<BillingState>;
184
185#[derive(Debug, Serialize)]
191pub struct UsageLimitExceeded {
192 pub tier: Tier,
193 pub limit: i64,
194 pub current_usage: i64,
195 pub message: String,
196}
197
198impl BillingState {
199 #[cfg(feature = "saas")]
202 pub async fn check_usage_limit(
203 &self,
204 org_id: Uuid,
205 additional_events: i64,
206 ) -> Result<(), UsageLimitExceeded> {
207 let tier = if let Some(ref pool) = self.db_pool {
209 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
210 org.tier.parse().unwrap_or(Tier::Free)
211 } else {
212 Tier::Free
213 }
214 } else {
215 Tier::Free
216 };
217
218 let limit = match tier.event_limit() {
220 Some(l) => l,
221 None => return Ok(()),
222 };
223
224 let db_usage = if let Some(ref pool) = self.db_pool {
226 let today = chrono::Utc::now().date_naive();
227 let start =
228 chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1).unwrap_or(today);
229 if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await {
230 rows.iter().map(|r| r.events_processed).sum::<i64>()
231 } else {
232 0
233 }
234 } else {
235 0
236 };
237
238 let buffered = self.usage.read().await.get(&org_id);
240 let total = db_usage + buffered + additional_events;
241
242 if total > limit {
243 Err(UsageLimitExceeded {
244 tier: tier.clone(),
245 limit,
246 current_usage: db_usage + buffered,
247 message: format!(
248 "Usage limit exceeded for {} tier ({}/{} events this month). Upgrade to increase your limit.",
249 tier.display_name(),
250 db_usage + buffered,
251 limit,
252 ),
253 })
254 } else {
255 Ok(())
256 }
257 }
258
259 #[cfg(feature = "saas")]
262 pub async fn org_id_for_api_key(&self, raw_key: &str) -> Option<Uuid> {
263 use sha2::Digest;
264 let pool = self.db_pool.as_ref()?;
265 let hash = hex::encode(sha2::Sha256::digest(raw_key.as_bytes()));
266 let api_key = varpulis_db::repo::get_api_key_by_hash(pool, &hash)
267 .await
268 .ok()??;
269 Some(api_key.org_id)
270 }
271}
272
273pub fn usage_limit_response(err: &UsageLimitExceeded) -> Response {
275 let body = Json(serde_json::json!({
276 "error": "usage_limit_exceeded",
277 "message": err.message,
278 "tier": err.tier,
279 "limit": err.limit,
280 "current_usage": err.current_usage,
281 "upgrade_url": "/billing",
282 }));
283
284 (
285 StatusCode::TOO_MANY_REQUESTS,
286 [("Retry-After", "3600")],
287 body,
288 )
289 .into_response()
290}
291
292#[cfg(feature = "saas")]
298pub fn spawn_usage_flush(state: SharedBillingState, pool: varpulis_db::PgPool) {
299 tokio::spawn(async move {
300 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
301 loop {
302 interval.tick().await;
303 let entries = state.usage.write().await.drain();
304 if entries.is_empty() {
305 continue;
306 }
307 let today = chrono::Utc::now().date_naive();
308 for (org_id, count) in entries {
309 if let Err(e) =
310 varpulis_db::repo::record_usage(&pool, org_id, today, count, 0).await
311 {
312 tracing::error!("Failed to flush usage for org {}: {}", org_id, e);
313 }
314 }
315 tracing::debug!("Usage flush complete");
316 }
317 });
318}
319
320async fn stripe_post(
326 client: &reqwest::Client,
327 secret_key: &str,
328 endpoint: &str,
329 params: &[(&str, &str)],
330) -> Result<serde_json::Value, String> {
331 let resp = client
332 .post(format!("https://api.stripe.com/v1/{endpoint}"))
333 .basic_auth(secret_key, None::<&str>)
334 .form(params)
335 .send()
336 .await
337 .map_err(|e| format!("Stripe request failed: {e}"))?;
338
339 let status = resp.status();
340 let body: serde_json::Value = resp
341 .json()
342 .await
343 .map_err(|e| format!("Stripe response parse failed: {e}"))?;
344
345 if !status.is_success() {
346 let msg = body["error"]["message"]
347 .as_str()
348 .unwrap_or("Unknown Stripe error");
349 return Err(format!("Stripe API error ({status}): {msg}"));
350 }
351
352 Ok(body)
353}
354
355fn verify_stripe_signature(payload: &[u8], sig_header: &str, secret: &str) -> bool {
357 use hmac::{Hmac, Mac};
358 use sha2::Sha256;
359
360 let mut timestamp = "";
362 let mut signature = "";
363 for part in sig_header.split(',') {
364 if let Some(t) = part.strip_prefix("t=") {
365 timestamp = t;
366 } else if let Some(s) = part.strip_prefix("v1=") {
367 signature = s;
368 }
369 }
370
371 if timestamp.is_empty() || signature.is_empty() {
372 return false;
373 }
374
375 let signed_payload = format!(
377 "{}.{}",
378 timestamp,
379 std::str::from_utf8(payload).unwrap_or("")
380 );
381 let mut mac =
382 Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key size");
383 hmac::Mac::update(&mut mac, signed_payload.as_bytes());
384 let expected = hex::encode(mac.finalize().into_bytes());
385
386 expected == signature
388}
389
390async fn handle_usage(
396 State(state): State<Option<SharedBillingState>>,
397 headers: HeaderMap,
398) -> Response {
399 let auth_header = headers
400 .get("authorization")
401 .and_then(|v| v.to_str().ok())
402 .map(|s| s.to_string());
403
404 match state {
405 Some(s) => {
406 #[cfg(feature = "saas")]
408 if let Some(ref pool) = s.db_pool {
409 if let Some(org_id) = extract_org_id_from_header(&auth_header, &s) {
410 let today = chrono::Utc::now().date_naive();
411 let start = chrono::NaiveDate::from_ymd_opt(today.year(), today.month(), 1)
412 .unwrap_or(today);
413 if let Ok(rows) = varpulis_db::repo::get_usage(pool, org_id, start, today).await
414 {
415 let total: i64 = rows.iter().map(|r| r.events_processed).sum();
416 return (
417 StatusCode::OK,
418 Json(serde_json::json!({
419 "events_this_month": total,
420 "daily": rows.iter().map(|r| serde_json::json!({
421 "date": r.date.to_string(),
422 "events_processed": r.events_processed,
423 })).collect::<Vec<_>>(),
424 })),
425 )
426 .into_response();
427 }
428 }
429 }
430
431 let _ = auth_header;
433 let tracker = s.usage.read().await;
434 let orgs: Vec<serde_json::Value> = tracker
435 .buffer
436 .iter()
437 .map(|(org_id, count)| {
438 serde_json::json!({
439 "org_id": org_id.to_string(),
440 "events_today": count,
441 })
442 })
443 .collect();
444 (StatusCode::OK, Json(serde_json::json!({ "usage": orgs }))).into_response()
445 }
446 None => (
447 StatusCode::SERVICE_UNAVAILABLE,
448 Json(serde_json::json!({ "error": "Billing not configured" })),
449 )
450 .into_response(),
451 }
452}
453
454async fn handle_plan(
456 State(state): State<Option<SharedBillingState>>,
457 headers: HeaderMap,
458) -> Response {
459 let auth_header = headers
460 .get("authorization")
461 .and_then(|v| v.to_str().ok())
462 .map(|s| s.to_string());
463
464 match state {
465 Some(_s) => {
466 #[cfg(feature = "saas")]
468 if let Some(ref pool) = _s.db_pool {
469 if let Some(org_id) = extract_org_id_from_header(&auth_header, &_s) {
470 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_id).await {
471 let tier: Tier = org.tier.parse().unwrap_or(Tier::Free);
472 return (
473 StatusCode::OK,
474 Json(serde_json::json!({
475 "tier": org.tier,
476 "event_limit": tier.event_limit(),
477 "display_name": tier.display_name(),
478 })),
479 )
480 .into_response();
481 }
482 }
483 }
484
485 let _ = auth_header;
487 (
488 StatusCode::OK,
489 Json(serde_json::json!({
490 "tier": "free",
491 "event_limit": 10_000,
492 "display_name": "Free",
493 })),
494 )
495 .into_response()
496 }
497 None => (
498 StatusCode::SERVICE_UNAVAILABLE,
499 Json(serde_json::json!({ "error": "Billing not configured" })),
500 )
501 .into_response(),
502 }
503}
504
505#[derive(Debug, Deserialize)]
506struct CheckoutRequest {
507 success_url: Option<String>,
508 cancel_url: Option<String>,
509}
510
511async fn handle_checkout(
513 State(state): State<Option<SharedBillingState>>,
514 headers: HeaderMap,
515 Json(body): Json<CheckoutRequest>,
516) -> Response {
517 let auth_header = headers
518 .get("authorization")
519 .and_then(|v| v.to_str().ok())
520 .map(|s| s.to_string());
521
522 match state {
523 Some(s) => {
524 if s.config.pro_price_id.is_empty() {
525 return (
526 StatusCode::BAD_REQUEST,
527 Json(serde_json::json!({
528 "error": "Stripe Price ID not configured"
529 })),
530 )
531 .into_response();
532 }
533
534 let success_url = body
535 .success_url
536 .unwrap_or_else(|| format!("{}/billing?success=true", s.config.frontend_url));
537 let cancel_url = body
538 .cancel_url
539 .unwrap_or_else(|| format!("{}/billing", s.config.frontend_url));
540
541 let org_id_str = extract_org_id_str_from_header(&auth_header, &s).unwrap_or_default();
543
544 let mut params: Vec<(&str, &str)> = vec![
545 ("mode", "subscription"),
546 ("line_items[0][price]", &s.config.pro_price_id),
547 ("line_items[0][quantity]", "1"),
548 ("success_url", &success_url),
549 ("cancel_url", &cancel_url),
550 ];
551
552 if !org_id_str.is_empty() {
553 params.push(("client_reference_id", &org_id_str));
554 }
555
556 #[allow(unused_mut)]
558 let mut customer_id = String::new();
559 #[cfg(feature = "saas")]
560 if let Some(ref pool) = s.db_pool {
561 if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
562 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
563 {
564 if let Some(cid) = org.stripe_customer_id {
565 customer_id = cid;
566 }
567 }
568 }
569 }
570
571 if !customer_id.is_empty() {
572 params.push(("customer", &customer_id));
573 }
574
575 match stripe_post(
576 &s.http_client,
577 &s.config.stripe_secret_key,
578 "checkout/sessions",
579 ¶ms,
580 )
581 .await
582 {
583 Ok(session) => {
584 let checkout_url = session["url"].as_str().unwrap_or("");
585 let session_id = session["id"].as_str().unwrap_or("");
586 if let Some(ref logger) = s.audit_logger {
588 logger
589 .log(
590 AuditEntry::new(
591 &org_id_str,
592 AuditAction::CheckoutStarted,
593 "/api/v1/billing/checkout",
594 )
595 .with_detail(format!("session: {session_id}")),
596 )
597 .await;
598 }
599 (
600 StatusCode::OK,
601 Json(serde_json::json!({
602 "checkout_url": checkout_url,
603 "session_id": session_id,
604 })),
605 )
606 .into_response()
607 }
608 Err(e) => {
609 tracing::error!("Stripe checkout failed: {}", e);
610 (
611 StatusCode::BAD_GATEWAY,
612 Json(serde_json::json!({"error": e})),
613 )
614 .into_response()
615 }
616 }
617 }
618 None => (
619 StatusCode::SERVICE_UNAVAILABLE,
620 Json(serde_json::json!({ "error": "Billing not configured" })),
621 )
622 .into_response(),
623 }
624}
625
626async fn handle_portal(
628 State(state): State<Option<SharedBillingState>>,
629 headers: HeaderMap,
630) -> Response {
631 let auth_header = headers
632 .get("authorization")
633 .and_then(|v| v.to_str().ok())
634 .map(|s| s.to_string());
635
636 match state {
637 Some(s) => {
638 #[allow(unused_mut)]
639 let mut customer_id = String::new();
640
641 #[cfg(feature = "saas")]
642 if let Some(ref pool) = s.db_pool {
643 if let Some(org_uuid) = extract_org_id_from_header(&auth_header, &s) {
644 if let Ok(Some(org)) = varpulis_db::repo::get_organization(pool, org_uuid).await
645 {
646 if let Some(cid) = org.stripe_customer_id {
647 customer_id = cid;
648 }
649 }
650 }
651 }
652 let _ = auth_header;
653
654 if customer_id.is_empty() {
655 return (
656 StatusCode::BAD_REQUEST,
657 Json(serde_json::json!({
658 "error": "No Stripe customer found. Upgrade first."
659 })),
660 )
661 .into_response();
662 }
663
664 let return_url = format!("{}/billing", s.config.frontend_url);
665 match stripe_post(
666 &s.http_client,
667 &s.config.stripe_secret_key,
668 "billing_portal/sessions",
669 &[("customer", &customer_id), ("return_url", &return_url)],
670 )
671 .await
672 {
673 Ok(session) => {
674 let portal_url = session["url"].as_str().unwrap_or("");
675 (
676 StatusCode::OK,
677 Json(serde_json::json!({
678 "portal_url": portal_url,
679 })),
680 )
681 .into_response()
682 }
683 Err(e) => {
684 tracing::error!("Stripe portal failed: {}", e);
685 (
686 StatusCode::BAD_GATEWAY,
687 Json(serde_json::json!({"error": e})),
688 )
689 .into_response()
690 }
691 }
692 }
693 None => (
694 StatusCode::SERVICE_UNAVAILABLE,
695 Json(serde_json::json!({ "error": "Billing not configured" })),
696 )
697 .into_response(),
698 }
699}
700
701async fn handle_webhook(
703 State(state): State<Option<SharedBillingState>>,
704 headers: HeaderMap,
705 body: bytes::Bytes,
706) -> Response {
707 let sig_header = headers
708 .get("stripe-signature")
709 .and_then(|v| v.to_str().ok())
710 .map(|s| s.to_string());
711
712 let s = match state {
713 Some(s) => s,
714 None => {
715 return (
716 StatusCode::SERVICE_UNAVAILABLE,
717 Json(serde_json::json!({"error": "Billing not configured"})),
718 )
719 .into_response();
720 }
721 };
722
723 if !s.config.stripe_webhook_secret.is_empty() {
725 let sig = sig_header.unwrap_or_default();
726 if !verify_stripe_signature(&body, &sig, &s.config.stripe_webhook_secret) {
727 return (
728 StatusCode::BAD_REQUEST,
729 Json(serde_json::json!({"error": "Invalid signature"})),
730 )
731 .into_response();
732 }
733 }
734
735 let event: serde_json::Value = match serde_json::from_slice(&body) {
737 Ok(v) => v,
738 Err(e) => {
739 tracing::error!("Webhook parse error: {}", e);
740 return (
741 StatusCode::BAD_REQUEST,
742 Json(serde_json::json!({"error": "Invalid JSON"})),
743 )
744 .into_response();
745 }
746 };
747
748 let event_type = event["type"].as_str().unwrap_or("");
749 tracing::info!("Stripe webhook: {}", event_type);
750
751 if let Some(ref logger) = s.audit_logger {
753 logger
754 .log(
755 AuditEntry::new(
756 "stripe",
757 AuditAction::WebhookReceived,
758 "/api/v1/billing/webhook",
759 )
760 .with_detail(event_type.to_string()),
761 )
762 .await;
763 }
764
765 #[cfg(feature = "saas")]
766 if let Some(ref pool) = s.db_pool {
767 match event_type {
768 "checkout.session.completed" => {
769 let obj = &event["data"]["object"];
770 let customer = obj["customer"].as_str().unwrap_or("");
771 let client_ref = obj["client_reference_id"].as_str().unwrap_or("");
772
773 if !client_ref.is_empty() && !customer.is_empty() {
774 if let Ok(org_id) = client_ref.parse::<uuid::Uuid>() {
775 if let Err(e) =
776 varpulis_db::repo::update_org_stripe_customer(pool, org_id, customer)
777 .await
778 {
779 tracing::error!("Failed to save Stripe customer: {}", e);
780 }
781 if let Err(e) =
782 varpulis_db::repo::update_org_tier(pool, org_id, "pro").await
783 {
784 tracing::error!("Failed to update tier: {}", e);
785 }
786 tracing::info!("Org {} upgraded to pro (customer: {})", org_id, customer);
787 if let Some(ref logger) = s.audit_logger {
789 logger
790 .log(
791 AuditEntry::new(
792 org_id.to_string(),
793 AuditAction::TierChange,
794 "/api/v1/billing/webhook",
795 )
796 .with_detail("upgraded to pro"),
797 )
798 .await;
799 }
800 }
801 }
802 }
803 "customer.subscription.deleted" => {
804 let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
805 if !customer.is_empty() {
806 if let Ok(Some(org)) =
807 varpulis_db::repo::get_org_by_stripe_customer(pool, customer).await
808 {
809 if let Err(e) =
810 varpulis_db::repo::update_org_tier(pool, org.id, "free").await
811 {
812 tracing::error!("Failed to downgrade org: {}", e);
813 }
814 tracing::info!("Org {} downgraded to free", org.id);
815 if let Some(ref logger) = s.audit_logger {
817 logger
818 .log(
819 AuditEntry::new(
820 org.id.to_string(),
821 AuditAction::TierChange,
822 "/api/v1/billing/webhook",
823 )
824 .with_detail("downgraded to free"),
825 )
826 .await;
827 }
828 }
829 }
830 }
831 "customer.subscription.updated" => {
832 tracing::info!("Subscription updated (no-op)");
834 }
835 "invoice.payment_failed" => {
836 let customer = event["data"]["object"]["customer"].as_str().unwrap_or("");
837 tracing::warn!("Payment failed for customer {}", customer);
838 }
839 _ => {
840 tracing::debug!("Unhandled webhook event: {}", event_type);
841 }
842 }
843 }
844
845 #[cfg(not(feature = "saas"))]
846 {
847 let _ = event_type;
848 tracing::debug!("Webhook received but saas feature not enabled");
849 }
850
851 (StatusCode::OK, Json(serde_json::json!({"received": true}))).into_response()
852}
853
854#[cfg_attr(not(feature = "saas"), allow(dead_code))]
860fn extract_org_id_from_header(
861 auth_header: &Option<String>,
862 state: &BillingState,
863) -> Option<uuid::Uuid> {
864 extract_org_id_str_from_header(auth_header, state)?
865 .parse()
866 .ok()
867}
868
869fn extract_org_id_str_from_header(
871 auth_header: &Option<String>,
872 _state: &BillingState,
873) -> Option<String> {
874 let header = auth_header.as_ref()?;
875 let token = header.strip_prefix("Bearer ")?.trim();
876 if token.is_empty() {
877 return None;
878 }
879
880 use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
883 let mut validation = Validation::new(Algorithm::HS256);
884 validation.insecure_disable_signature_validation();
885 validation.validate_exp = false;
886 let token_data =
887 decode::<serde_json::Value>(token, &DecodingKey::from_secret(b""), &validation).ok()?;
888
889 let org_id = token_data.claims["org_id"].as_str()?;
890 if org_id.is_empty() {
891 return None;
892 }
893 Some(org_id.to_string())
894}
895
896pub fn billing_routes(state: Option<SharedBillingState>) -> Router {
902 Router::new()
903 .route("/api/v1/billing/usage", get(handle_usage))
904 .route("/api/v1/billing/plan", get(handle_plan))
905 .route("/api/v1/billing/checkout", post(handle_checkout))
906 .route("/api/v1/billing/portal", post(handle_portal))
907 .route("/api/v1/billing/webhook", post(handle_webhook))
908 .with_state(state)
909}
910
911#[cfg(test)]
916mod tests {
917 use super::*;
918 use axum::body::Body;
919 use axum::http::Request;
920 use tower::ServiceExt;
921
922 fn get_req(uri: &str) -> Request<Body> {
923 Request::builder()
924 .method("GET")
925 .uri(uri)
926 .body(Body::empty())
927 .unwrap()
928 }
929
930 #[test]
931 fn test_tier_event_limits() {
932 assert_eq!(Tier::Free.event_limit(), Some(10_000));
933 assert_eq!(Tier::Pro.event_limit(), Some(10_000_000));
934 assert_eq!(Tier::Enterprise.event_limit(), None);
935 }
936
937 #[test]
938 fn test_tier_display_name() {
939 assert_eq!(Tier::Free.display_name(), "Free");
940 assert_eq!(Tier::Pro.display_name(), "Pro ($49/mo)");
941 assert_eq!(Tier::Enterprise.display_name(), "Enterprise");
942 }
943
944 #[test]
945 fn test_tier_from_str() {
946 assert_eq!("free".parse::<Tier>(), Ok(Tier::Free));
947 assert_eq!("pro".parse::<Tier>(), Ok(Tier::Pro));
948 assert_eq!("enterprise".parse::<Tier>(), Ok(Tier::Enterprise));
949 assert!("invalid".parse::<Tier>().is_err());
950 }
951
952 #[test]
953 fn test_tier_serialization() {
954 let json = serde_json::to_string(&Tier::Pro).unwrap();
955 assert_eq!(json, "\"pro\"");
956 let deserialized: Tier = serde_json::from_str(&json).unwrap();
957 assert_eq!(deserialized, Tier::Pro);
958 }
959
960 #[test]
961 fn test_usage_tracker_record_and_drain() {
962 let mut tracker = UsageTracker::new();
963 let org = Uuid::new_v4();
964
965 tracker.record_events(org, 100);
966 tracker.record_events(org, 50);
967 assert_eq!(tracker.get(&org), 150);
968
969 let drained = tracker.drain();
970 assert_eq!(drained.len(), 1);
971 assert_eq!(drained[0], (org, 150));
972
973 assert_eq!(tracker.get(&org), 0);
975 }
976
977 #[test]
978 fn test_usage_tracker_multiple_orgs() {
979 let mut tracker = UsageTracker::new();
980 let org1 = Uuid::new_v4();
981 let org2 = Uuid::new_v4();
982
983 tracker.record_events(org1, 100);
984 tracker.record_events(org2, 200);
985 tracker.record_events(org1, 50);
986
987 assert_eq!(tracker.get(&org1), 150);
988 assert_eq!(tracker.get(&org2), 200);
989 }
990
991 #[test]
992 fn test_verify_stripe_signature() {
993 use hmac::{Hmac, Mac};
994 use sha2::Sha256;
995
996 let secret = "whsec_test123";
997 let payload = b"{\"type\":\"test\"}";
998 let timestamp = "1234567890";
999
1000 let signed = format!("{}.{}", timestamp, std::str::from_utf8(payload).unwrap());
1002 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).unwrap();
1003 hmac::Mac::update(&mut mac, signed.as_bytes());
1004 let sig = hex::encode(mac.finalize().into_bytes());
1005
1006 let header = format!("t={timestamp},v1={sig}");
1007
1008 assert!(verify_stripe_signature(payload, &header, secret));
1009 assert!(!verify_stripe_signature(payload, &header, "wrong_secret"));
1010 assert!(!verify_stripe_signature(b"tampered", &header, secret));
1011 }
1012
1013 #[tokio::test]
1014 async fn test_billing_routes_not_configured() {
1015 let app = billing_routes(None);
1016
1017 let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1018
1019 assert_eq!(res.status(), 503);
1020 }
1021
1022 #[tokio::test]
1023 async fn test_billing_routes_usage() {
1024 let config = BillingConfig {
1025 stripe_secret_key: "sk_test_xxx".to_string(),
1026 stripe_webhook_secret: "whsec_xxx".to_string(),
1027 pro_price_id: "price_xxx".to_string(),
1028 frontend_url: "http://localhost:5173".to_string(),
1029 };
1030 let state = Arc::new(BillingState::new(config));
1031 let app = billing_routes(Some(state));
1032
1033 let res = app.oneshot(get_req("/api/v1/billing/usage")).await.unwrap();
1034
1035 assert_eq!(res.status(), 200);
1036 }
1037
1038 #[tokio::test]
1039 async fn test_billing_routes_plan() {
1040 let config = BillingConfig {
1041 stripe_secret_key: "sk_test_xxx".to_string(),
1042 stripe_webhook_secret: "whsec_xxx".to_string(),
1043 pro_price_id: "price_xxx".to_string(),
1044 frontend_url: "http://localhost:5173".to_string(),
1045 };
1046 let state = Arc::new(BillingState::new(config));
1047 let app = billing_routes(Some(state));
1048
1049 let res = app.oneshot(get_req("/api/v1/billing/plan")).await.unwrap();
1050
1051 assert_eq!(res.status(), 200);
1052 let body = axum::body::to_bytes(res.into_body(), usize::MAX)
1053 .await
1054 .unwrap();
1055 let body: serde_json::Value = serde_json::from_slice(&body).unwrap();
1056 assert_eq!(body["tier"], "free");
1057 }
1058
1059 #[tokio::test]
1060 async fn test_webhook_invalid_signature() {
1061 let config = BillingConfig {
1062 stripe_secret_key: "sk_test_xxx".to_string(),
1063 stripe_webhook_secret: "whsec_real_secret".to_string(),
1064 pro_price_id: "price_xxx".to_string(),
1065 frontend_url: "http://localhost:5173".to_string(),
1066 };
1067 let state = Arc::new(BillingState::new(config));
1068 let app = billing_routes(Some(state));
1069
1070 let req: Request<Body> = Request::builder()
1071 .method("POST")
1072 .uri("/api/v1/billing/webhook")
1073 .header("stripe-signature", "t=123,v1=bad")
1074 .header("content-type", "application/octet-stream")
1075 .body(Body::from("{\"type\":\"test\"}"))
1076 .unwrap();
1077 let res = app.oneshot(req).await.unwrap();
1078
1079 assert_eq!(res.status(), 400);
1080 }
1081
1082 #[test]
1083 fn test_usage_limit_exceeded_serialization() {
1084 let err = UsageLimitExceeded {
1085 tier: Tier::Free,
1086 limit: 10_000,
1087 current_usage: 10_500,
1088 message: "Usage limit exceeded for Free tier".to_string(),
1089 };
1090 let json = serde_json::to_value(&err).unwrap();
1091 assert_eq!(json["tier"], "free");
1092 assert_eq!(json["limit"], 10_000);
1093 assert_eq!(json["current_usage"], 10_500);
1094 }
1095
1096 #[test]
1097 fn test_usage_limit_response_status() {
1098 let err = UsageLimitExceeded {
1099 tier: Tier::Free,
1100 limit: 10_000,
1101 current_usage: 11_000,
1102 message: "Limit exceeded".to_string(),
1103 };
1104 let resp = usage_limit_response(&err);
1105 assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
1106 }
1107
1108 #[test]
1109 fn test_tier_enterprise_no_limit() {
1110 assert_eq!(Tier::Enterprise.event_limit(), None);
1112 }
1113
1114 #[test]
1115 fn test_usage_tracker_tracks_independently() {
1116 let mut tracker = UsageTracker::new();
1117 let org = Uuid::new_v4();
1118
1119 tracker.record_events(org, 5_000);
1121 tracker.record_events(org, 3_000);
1122 tracker.record_events(org, 2_001);
1123
1124 assert_eq!(tracker.get(&org), 10_001);
1126 }
1127}