1use crate::domain::Domain;
13use chrono::{DateTime, Utc};
14use serde::{Deserialize, Serialize};
15use std::collections::HashSet;
16use std::fmt;
17use std::sync::Arc;
18use thiserror::Error;
19
20#[derive(Debug, Error, Clone)]
22pub enum LicenseError {
23 #[error("Domain '{0}' is not licensed")]
25 DomainNotLicensed(Domain),
26
27 #[error("Feature '{0}' is not licensed")]
29 FeatureNotLicensed(String),
30
31 #[error("GPU-native kernels require Enterprise license")]
33 GpuNativeNotLicensed,
34
35 #[error("License expired at {0}")]
37 Expired(DateTime<Utc>),
38
39 #[error("Maximum kernel count ({0}) exceeded")]
41 KernelLimitExceeded(usize),
42
43 #[error("License validation failed: {0}")]
45 ValidationFailed(String),
46
47 #[error("Invalid license key")]
49 InvalidKey,
50
51 #[error("No valid license found")]
53 NotFound,
54}
55
56pub type LicenseResult<T> = std::result::Result<T, LicenseError>;
58
59pub trait LicenseValidator: Send + Sync + fmt::Debug {
63 fn validate_domain(&self, domain: Domain) -> LicenseResult<()>;
65
66 fn validate_feature(&self, feature: &str) -> LicenseResult<()>;
68
69 fn gpu_native_enabled(&self) -> bool;
71
72 fn licensed_domains(&self) -> &[Domain];
74
75 fn expires_at(&self) -> Option<DateTime<Utc>>;
77
78 fn is_valid(&self) -> bool {
80 if let Some(expiry) = self.expires_at() {
81 Utc::now() < expiry
82 } else {
83 true }
85 }
86
87 fn tier(&self) -> LicenseTier;
89
90 fn max_kernels(&self) -> Option<usize>;
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub enum LicenseTier {
97 Development,
99 Community,
101 Professional,
103 Enterprise,
105}
106
107impl LicenseTier {
108 #[must_use]
110 pub const fn supports_gpu_native(&self) -> bool {
111 matches!(self, LicenseTier::Development | LicenseTier::Enterprise)
112 }
113
114 #[must_use]
116 pub const fn default_max_kernels(&self) -> Option<usize> {
117 match self {
118 LicenseTier::Development => None, LicenseTier::Community => Some(5),
120 LicenseTier::Professional => Some(50),
121 LicenseTier::Enterprise => None, }
123 }
124}
125
126#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
128pub struct LicenseId(pub String);
129
130impl LicenseId {
131 #[must_use]
133 pub fn new(id: impl Into<String>) -> Self {
134 Self(id.into())
135 }
136}
137
138impl fmt::Display for LicenseId {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(f, "{}", self.0)
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct License {
147 pub id: LicenseId,
149
150 pub tier: LicenseTier,
152
153 pub domains: HashSet<Domain>,
155
156 pub features: HashSet<String>,
158
159 pub gpu_native: bool,
161
162 pub expires_at: Option<DateTime<Utc>>,
164
165 pub max_kernels: Option<usize>,
167
168 pub holder: String,
170}
171
172impl License {
173 #[must_use]
175 pub fn development() -> Self {
176 Self {
177 id: LicenseId::new("dev-license"),
178 tier: LicenseTier::Development,
179 domains: Domain::ALL.iter().copied().collect(),
180 features: HashSet::new(), gpu_native: true,
182 expires_at: None,
183 max_kernels: None,
184 holder: "Development".to_string(),
185 }
186 }
187
188 #[must_use]
190 pub fn enterprise(holder: impl Into<String>, expires_at: Option<DateTime<Utc>>) -> Self {
191 Self {
192 id: LicenseId::new(format!("enterprise-{}", chrono::Utc::now().timestamp())),
193 tier: LicenseTier::Enterprise,
194 domains: Domain::ALL.iter().copied().collect(),
195 features: HashSet::new(), gpu_native: true,
197 expires_at,
198 max_kernels: None,
199 holder: holder.into(),
200 }
201 }
202
203 #[must_use]
205 pub fn professional(
206 holder: impl Into<String>,
207 domains: HashSet<Domain>,
208 expires_at: Option<DateTime<Utc>>,
209 ) -> Self {
210 Self {
211 id: LicenseId::new(format!("professional-{}", chrono::Utc::now().timestamp())),
212 tier: LicenseTier::Professional,
213 domains,
214 features: HashSet::new(),
215 gpu_native: false, expires_at,
217 max_kernels: Some(50),
218 holder: holder.into(),
219 }
220 }
221
222 #[must_use]
224 pub fn community(holder: impl Into<String>) -> Self {
225 let mut domains = HashSet::new();
226 domains.insert(Domain::Core);
227 domains.insert(Domain::GraphAnalytics);
228 domains.insert(Domain::StatisticalML);
229
230 Self {
231 id: LicenseId::new(format!("community-{}", chrono::Utc::now().timestamp())),
232 tier: LicenseTier::Community,
233 domains,
234 features: HashSet::new(),
235 gpu_native: false,
236 expires_at: None, max_kernels: Some(5),
238 holder: holder.into(),
239 }
240 }
241
242 #[must_use]
244 pub fn with_domain(mut self, domain: Domain) -> Self {
245 self.domains.insert(domain);
246 self
247 }
248
249 #[must_use]
251 pub fn with_feature(mut self, feature: impl Into<String>) -> Self {
252 self.features.insert(feature.into());
253 self
254 }
255}
256
257#[derive(Debug)]
259pub struct StandardLicenseValidator {
260 license: License,
261}
262
263impl StandardLicenseValidator {
264 #[must_use]
266 pub fn new(license: License) -> Self {
267 Self { license }
268 }
269
270 #[must_use]
272 pub fn license(&self) -> &License {
273 &self.license
274 }
275}
276
277impl LicenseValidator for StandardLicenseValidator {
278 fn validate_domain(&self, domain: Domain) -> LicenseResult<()> {
279 if !self.is_valid() {
281 return Err(LicenseError::Expired(
282 self.license.expires_at.unwrap_or_else(Utc::now),
283 ));
284 }
285
286 if self.license.tier == LicenseTier::Development {
288 return Ok(());
289 }
290
291 if self.license.tier == LicenseTier::Enterprise {
293 return Ok(());
294 }
295
296 if self.license.domains.contains(&domain) {
298 Ok(())
299 } else {
300 Err(LicenseError::DomainNotLicensed(domain))
301 }
302 }
303
304 fn validate_feature(&self, feature: &str) -> LicenseResult<()> {
305 if !self.is_valid() {
307 return Err(LicenseError::Expired(
308 self.license.expires_at.unwrap_or_else(Utc::now),
309 ));
310 }
311
312 if matches!(
314 self.license.tier,
315 LicenseTier::Development | LicenseTier::Enterprise
316 ) {
317 return Ok(());
318 }
319
320 if self.license.features.contains(feature) {
322 return Ok(());
323 }
324
325 if let Some((domain_str, _)) = feature.split_once('.') {
328 if let Some(domain) = Domain::parse(domain_str) {
329 if self.license.domains.contains(&domain) {
330 return Ok(());
331 }
332 }
333 }
334
335 Err(LicenseError::FeatureNotLicensed(feature.to_string()))
336 }
337
338 fn gpu_native_enabled(&self) -> bool {
339 self.license.gpu_native && self.license.tier.supports_gpu_native()
340 }
341
342 fn licensed_domains(&self) -> &[Domain] {
343 if matches!(
345 self.license.tier,
346 LicenseTier::Development | LicenseTier::Enterprise
347 ) {
348 Domain::ALL
349 } else {
350 Domain::ALL }
354 }
355
356 fn expires_at(&self) -> Option<DateTime<Utc>> {
357 self.license.expires_at
358 }
359
360 fn tier(&self) -> LicenseTier {
361 self.license.tier
362 }
363
364 fn max_kernels(&self) -> Option<usize> {
365 self.license.max_kernels
366 }
367}
368
369#[derive(Debug, Default, Clone)]
373pub struct DevelopmentLicense;
374
375impl LicenseValidator for DevelopmentLicense {
376 fn validate_domain(&self, _domain: Domain) -> LicenseResult<()> {
377 Ok(()) }
379
380 fn validate_feature(&self, _feature: &str) -> LicenseResult<()> {
381 Ok(()) }
383
384 fn gpu_native_enabled(&self) -> bool {
385 true
386 }
387
388 fn licensed_domains(&self) -> &[Domain] {
389 Domain::ALL
390 }
391
392 fn expires_at(&self) -> Option<DateTime<Utc>> {
393 None }
395
396 fn tier(&self) -> LicenseTier {
397 LicenseTier::Development
398 }
399
400 fn max_kernels(&self) -> Option<usize> {
401 None }
403}
404
405#[derive(Debug)]
409pub struct LicenseGuard<'a> {
410 validator: &'a dyn LicenseValidator,
411 domain: Domain,
412}
413
414impl<'a> LicenseGuard<'a> {
415 #[must_use]
417 pub fn new(validator: &'a dyn LicenseValidator, domain: Domain) -> Self {
418 Self { validator, domain }
419 }
420
421 pub fn check(&self) -> LicenseResult<()> {
423 self.validator.validate_domain(self.domain)
424 }
425
426 pub fn check_feature(&self, feature: &str) -> LicenseResult<()> {
428 self.validator.validate_feature(feature)
429 }
430
431 pub fn check_gpu_native(&self) -> LicenseResult<()> {
433 if self.validator.gpu_native_enabled() {
434 Ok(())
435 } else {
436 Err(LicenseError::GpuNativeNotLicensed)
437 }
438 }
439}
440
441pub type SharedLicenseValidator = Arc<dyn LicenseValidator>;
443
444#[must_use]
446pub fn dev_license() -> SharedLicenseValidator {
447 Arc::new(DevelopmentLicense)
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_development_license() {
456 let license = DevelopmentLicense;
457
458 assert!(license.validate_domain(Domain::GraphAnalytics).is_ok());
459 assert!(license.validate_domain(Domain::RiskAnalytics).is_ok());
460 assert!(license.validate_feature("GraphAnalytics.PageRank").is_ok());
461 assert!(license.gpu_native_enabled());
462 assert!(license.is_valid());
463 assert_eq!(license.tier(), LicenseTier::Development);
464 }
465
466 #[test]
467 fn test_community_license() {
468 let license = License::community("Test User");
469 let validator = StandardLicenseValidator::new(license);
470
471 assert!(validator.validate_domain(Domain::Core).is_ok());
473 assert!(validator.validate_domain(Domain::GraphAnalytics).is_ok());
474 assert!(validator.validate_domain(Domain::StatisticalML).is_ok());
475
476 assert!(validator.validate_domain(Domain::RiskAnalytics).is_err());
478
479 assert!(!validator.gpu_native_enabled());
481
482 assert_eq!(validator.max_kernels(), Some(5));
484 }
485
486 #[test]
487 fn test_enterprise_license() {
488 let license = License::enterprise("Enterprise User", None);
489 let validator = StandardLicenseValidator::new(license);
490
491 assert!(validator.validate_domain(Domain::GraphAnalytics).is_ok());
493 assert!(validator.validate_domain(Domain::RiskAnalytics).is_ok());
494 assert!(validator.validate_domain(Domain::Banking).is_ok());
495
496 assert!(validator.gpu_native_enabled());
498
499 assert_eq!(validator.max_kernels(), None);
501 }
502
503 #[test]
504 fn test_expired_license() {
505 let mut license = License::enterprise("Expired User", None);
506 license.expires_at = Some(Utc::now() - chrono::Duration::days(1));
507
508 let validator = StandardLicenseValidator::new(license);
509
510 assert!(!validator.is_valid());
511 assert!(validator.validate_domain(Domain::Core).is_err());
512 }
513
514 #[test]
515 fn test_license_guard() {
516 let validator = DevelopmentLicense;
517 let guard = LicenseGuard::new(&validator, Domain::GraphAnalytics);
518
519 assert!(guard.check().is_ok());
520 assert!(guard.check_feature("GraphAnalytics.PageRank").is_ok());
521 assert!(guard.check_gpu_native().is_ok());
522 }
523
524 #[test]
525 fn test_license_tier_properties() {
526 assert!(LicenseTier::Development.supports_gpu_native());
527 assert!(LicenseTier::Enterprise.supports_gpu_native());
528 assert!(!LicenseTier::Professional.supports_gpu_native());
529 assert!(!LicenseTier::Community.supports_gpu_native());
530 }
531}