1use std::sync::{Arc, LazyLock};
49
50use regex::Regex;
51use thiserror::Error;
52
53static API_KEY_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
56 Regex::new(r"\b[a-zA-Z0-9_-]{3,}\.[a-zA-Z0-9_-]{10,}\b").expect("invalid regex")
57});
58
59static SENSITIVE_PATTERNS: LazyLock<Vec<(Regex, &'static str)>> = LazyLock::new(|| {
60 vec![
61 (
62 Regex::new(r"(?i)(api[_-]?key\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
63 "$1[FILTERED]",
64 ),
65 (
66 Regex::new(r"(?i)(password\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
67 "$1[FILTERED]",
68 ),
69 (
70 Regex::new(r"(?i)(token\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
71 "$1[FILTERED]",
72 ),
73 (
74 Regex::new(r"(?i)(secret\s*[=:]\s*)[^\s,]+").expect("invalid regex"),
75 "$1[FILTERED]",
76 ),
77 (
78 Regex::new(r"(?i)(bearer\s+[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)").expect("invalid regex"),
79 "bearer [FILTERED]",
80 ),
81 (
82 Regex::new(r"(?i)(authorization\s*:\s*Bearer\s+)[^\s,]+").expect("invalid regex"),
83 "$1[FILTERED]",
84 ),
85 ]
86});
87
88static CONTAINS_SENSITIVE_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
89 vec![
90 Regex::new(r"(?i)api[_-]?key\s*[=:]").expect("invalid regex"),
91 Regex::new(r"(?i)password\s*[=:]").expect("invalid regex"),
92 Regex::new(r"(?i)token\s*[=:]").expect("invalid regex"),
93 Regex::new(r"(?i)secret\s*[=:]").expect("invalid regex"),
94 Regex::new(r"(?i)authorization\s*:\s*Bearer").expect("invalid regex"),
95 ]
96});
97
98pub fn mask_sensitive_info(text: &str) -> String {
132 let mut result = API_KEY_PATTERN.replace_all(text, "[FILTERED]").to_string();
133
134 for (re, replacement) in SENSITIVE_PATTERNS.iter() {
135 result = re.replace_all(&result, *replacement).to_string();
136 }
137
138 result
139}
140
141pub fn mask_api_key(text: &str) -> String {
146 API_KEY_PATTERN.replace_all(text, "[FILTERED]").to_string()
147}
148
149pub fn contains_sensitive_info(text: &str) -> bool {
151 if API_KEY_PATTERN.is_match(text) {
152 return true;
153 }
154
155 CONTAINS_SENSITIVE_PATTERNS
156 .iter()
157 .any(|re| re.is_match(text))
158}
159
160pub fn validate_api_key(api_key: &str) -> ZaiResult<()> {
185 if api_key.is_empty() {
186 return Err(ZaiError::ApiError {
187 code: 1200,
188 message: "API key cannot be empty".to_string(),
189 });
190 }
191
192 let parts: Vec<&str> = api_key.split('.').collect();
193 if parts.len() != 2 {
194 return Err(ZaiError::ApiError {
195 code: 1001,
196 message: "API key must be in format '<id>.<secret>'".to_string(),
197 });
198 }
199
200 let (id, secret) = (parts[0], parts[1]);
201
202 if id.is_empty() || secret.is_empty() {
203 return Err(ZaiError::ApiError {
204 code: 1200,
205 message: "API key id and secret must not be empty".to_string(),
206 });
207 }
208
209 let valid_chars = |s: &str| -> bool {
212 s.chars()
213 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
214 };
215
216 if !valid_chars(id) || !valid_chars(secret) {
217 return Err(ZaiError::ApiError {
218 code: 1200,
219 message: "API key contains invalid characters".to_string(),
220 });
221 }
222
223 if id.len() < 3 {
226 return Err(ZaiError::ApiError {
227 code: 1200,
228 message: "API key id is too short".to_string(),
229 });
230 }
231
232 if secret.len() < 10 {
233 return Err(ZaiError::ApiError {
234 code: 1200,
235 message: "API key secret is too short".to_string(),
236 });
237 }
238
239 Ok(())
240}
241
242#[derive(Error, Debug)]
244pub enum ZaiError {
245 #[error("HTTP error [{status}]: {message}")]
247 HttpError { status: u16, message: String },
248
249 #[error("Authentication error [{code}]: {message}")]
251 AuthError { code: u16, message: String },
252
253 #[error("Account error [{code}]: {message}")]
255 AccountError { code: u16, message: String },
256
257 #[error("API error [{code}]: {message}")]
259 ApiError { code: u16, message: String },
260
261 #[error("Rate limit error [{code}]: {message}")]
263 RateLimitError { code: u16, message: String },
264
265 #[error("Content policy error [{code}]: {message}")]
267 ContentPolicyError { code: u16, message: String },
268
269 #[error("File error [{code}]: {message}")]
271 FileError { code: u16, message: String },
272
273 #[error("Network error: {0}")]
275 NetworkError(Arc<reqwest::Error>),
276
277 #[error("JSON error: {0}")]
279 JsonError(Arc<serde_json::Error>),
280
281 #[error("Unknown error [{code}]: {message}")]
283 Unknown { code: u16, message: String },
284}
285
286impl ZaiError {
287 pub fn from_api_response(status: u16, api_code: u16, api_message: String) -> Self {
289 match status {
291 400 => ZaiError::HttpError {
292 status,
293 message: if api_message.is_empty() {
294 "Bad request - check your parameters".to_string()
295 } else {
296 api_message
297 },
298 },
299 401 => ZaiError::HttpError {
300 status,
301 message: "Unauthorized - check your API key".to_string(),
302 },
303 404 => ZaiError::HttpError {
304 status,
305 message: "Not found - requested resource doesn't exist".to_string(),
306 },
307 429 => ZaiError::HttpError {
308 status,
309 message: if api_message.is_empty() {
310 "Too many requests - rate limit exceeded".to_string()
311 } else {
312 api_message
313 },
314 },
315 434 => ZaiError::HttpError {
316 status,
317 message: "No API permission - feature not available".to_string(),
318 },
319 435 => ZaiError::HttpError {
320 status,
321 message: "File size exceeds 100MB limit".to_string(),
322 },
323 500 => ZaiError::HttpError {
324 status,
325 message: "Internal server error - try again later".to_string(),
326 },
327 _ => {
328 match api_code {
330 1000..=1004 | 1100 => ZaiError::AuthError {
332 code: api_code,
333 message: api_message,
334 },
335 1110..=1121 => ZaiError::AccountError {
337 code: api_code,
338 message: api_message,
339 },
340 1200..=1234 => ZaiError::ApiError {
342 code: api_code,
343 message: api_message,
344 },
345 1300..=1309 => ZaiError::RateLimitError {
347 code: api_code,
348 message: api_message,
349 },
350 _ => ZaiError::Unknown {
352 code: api_code,
353 message: if api_message.is_empty() {
354 "Unknown error".to_string()
355 } else {
356 api_message
357 },
358 },
359 }
360 },
361 }
362 }
363
364 pub fn is_rate_limit(&self) -> bool {
366 matches!(self, ZaiError::RateLimitError { .. })
367 }
368
369 pub fn is_auth_error(&self) -> bool {
371 matches!(self, ZaiError::AuthError { .. })
372 }
373
374 pub fn is_client_error(&self) -> bool {
376 match self {
377 ZaiError::HttpError { status, .. } => *status >= 400 && *status < 500,
378 ZaiError::AuthError { .. }
379 | ZaiError::AccountError { .. }
380 | ZaiError::ApiError { .. }
381 | ZaiError::RateLimitError { .. }
382 | ZaiError::ContentPolicyError { .. }
383 | ZaiError::FileError { .. } => true,
384 _ => false,
385 }
386 }
387
388 pub fn is_server_error(&self) -> bool {
390 match self {
391 ZaiError::HttpError { status, .. } => *status >= 500,
392 ZaiError::Unknown { code, .. } => *code >= 500,
393 _ => false,
394 }
395 }
396
397 pub fn compact(&self) -> String {
399 match self {
400 ZaiError::HttpError { status, message } => {
401 format!("HTTP[{}]: {}", status, message)
402 },
403 ZaiError::AuthError { code, message } => {
404 format!("AUTH[{}]: {}", code, message)
405 },
406 ZaiError::AccountError { code, message } => {
407 format!("ACCOUNT[{}]: {}", code, message)
408 },
409 ZaiError::ApiError { code, message } => {
410 format!("API[{}]: {}", code, message)
411 },
412 ZaiError::RateLimitError { code, message } => {
413 format!("RATE_LIMIT[{}]: {}", code, message)
414 },
415 ZaiError::ContentPolicyError { code, message } => {
416 format!("POLICY[{}]: {}", code, message)
417 },
418 ZaiError::FileError { code, message } => {
419 format!("FILE[{}]: {}", code, message)
420 },
421 ZaiError::NetworkError(err) => {
422 format!("NETWORK: {}", err)
423 },
424 ZaiError::JsonError(err) => {
425 format!("JSON: {}", err)
426 },
427 ZaiError::Unknown { code, message } => {
428 format!("UNKNOWN[{}]: {}", code, message)
429 },
430 }
431 }
432
433 pub fn code(&self) -> Option<u16> {
435 match self {
436 ZaiError::HttpError { status, .. } => Some(*status),
437 ZaiError::AuthError { code, .. } => Some(*code),
438 ZaiError::AccountError { code, .. } => Some(*code),
439 ZaiError::ApiError { code, .. } => Some(*code),
440 ZaiError::RateLimitError { code, .. } => Some(*code),
441 ZaiError::ContentPolicyError { code, .. } => Some(*code),
442 ZaiError::FileError { code, .. } => Some(*code),
443 ZaiError::NetworkError(_) => None,
444 ZaiError::JsonError(_) => None,
445 ZaiError::Unknown { code, .. } => Some(*code),
446 }
447 }
448
449 pub fn message(&self) -> String {
451 match self {
452 ZaiError::HttpError { message, .. } => message.clone(),
453 ZaiError::AuthError { message, .. } => message.clone(),
454 ZaiError::AccountError { message, .. } => message.clone(),
455 ZaiError::ApiError { message, .. } => message.clone(),
456 ZaiError::RateLimitError { message, .. } => message.clone(),
457 ZaiError::ContentPolicyError { message, .. } => message.clone(),
458 ZaiError::FileError { message, .. } => message.clone(),
459 ZaiError::NetworkError(err) => err.to_string(),
460 ZaiError::JsonError(err) => err.to_string(),
461 ZaiError::Unknown { message, .. } => message.clone(),
462 }
463 }
464}
465
466impl Clone for ZaiError {
467 fn clone(&self) -> Self {
468 match self {
469 ZaiError::HttpError { status, message } => ZaiError::HttpError {
470 status: *status,
471 message: message.clone(),
472 },
473 ZaiError::AuthError { code, message } => ZaiError::AuthError {
474 code: *code,
475 message: message.clone(),
476 },
477 ZaiError::AccountError { code, message } => ZaiError::AccountError {
478 code: *code,
479 message: message.clone(),
480 },
481 ZaiError::ApiError { code, message } => ZaiError::ApiError {
482 code: *code,
483 message: message.clone(),
484 },
485 ZaiError::RateLimitError { code, message } => ZaiError::RateLimitError {
486 code: *code,
487 message: message.clone(),
488 },
489 ZaiError::ContentPolicyError { code, message } => ZaiError::ContentPolicyError {
490 code: *code,
491 message: message.clone(),
492 },
493 ZaiError::FileError { code, message } => ZaiError::FileError {
494 code: *code,
495 message: message.clone(),
496 },
497 ZaiError::NetworkError(err) => ZaiError::NetworkError(Arc::clone(err)),
499 ZaiError::JsonError(err) => ZaiError::JsonError(Arc::clone(err)),
500 ZaiError::Unknown { code, message } => ZaiError::Unknown {
501 code: *code,
502 message: message.clone(),
503 },
504 }
505 }
506}
507
508pub type ZaiResult<T> = Result<T, ZaiError>;
510
511impl From<reqwest::Error> for ZaiError {
513 fn from(err: reqwest::Error) -> Self {
514 if let Some(status) = err.status() {
515 ZaiError::from_api_response(status.as_u16(), 0, err.to_string())
516 } else {
517 ZaiError::NetworkError(Arc::new(err))
518 }
519 }
520}
521
522impl From<serde_json::Error> for ZaiError {
524 fn from(err: serde_json::Error) -> Self {
525 ZaiError::JsonError(Arc::new(err))
526 }
527}
528
529impl From<validator::ValidationErrors> for ZaiError {
531 fn from(err: validator::ValidationErrors) -> Self {
532 ZaiError::ApiError {
533 code: 1200,
534 message: format!("Validation error: {:?}", err),
535 }
536 }
537}
538
539impl From<std::io::Error> for ZaiError {
541 fn from(err: std::io::Error) -> Self {
542 ZaiError::Unknown {
543 code: 0,
544 message: err.to_string(),
545 }
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_from_api_response_bad_request() {
555 let err = ZaiError::from_api_response(400, 0, "Invalid input".to_string());
556 assert!(err.is_client_error());
557 assert!(!err.is_server_error());
558 assert_eq!(err.code(), Some(400));
559 }
560
561 #[test]
562 fn test_from_api_response_unauthorized() {
563 let err = ZaiError::from_api_response(401, 0, "".to_string());
564 assert!(err.is_client_error());
565 assert_eq!(err.message(), "Unauthorized - check your API key");
566 }
567
568 #[test]
569 fn test_from_api_response_rate_limit() {
570 let err = ZaiError::from_api_response(429, 1301, "Too many requests".to_string());
572 assert!(err.is_client_error());
573 assert!(!err.is_rate_limit()); assert_eq!(err.code(), Some(429));
575
576 let err = ZaiError::from_api_response(200, 1301, "Too many requests".to_string());
578 assert!(err.is_client_error());
579 assert!(err.is_rate_limit());
580 assert_eq!(err.code(), Some(1301));
581 }
582
583 #[test]
584 fn test_from_api_response_server_error() {
585 let err = ZaiError::from_api_response(500, 0, "".to_string());
586 assert!(!err.is_client_error());
587 assert!(err.is_server_error());
588 }
589
590 #[test]
591 fn test_from_api_response_auth_error_code() {
592 let err = ZaiError::from_api_response(200, 1001, "Invalid API key".to_string());
593 assert!(err.is_auth_error());
594 assert_eq!(err.code(), Some(1001));
595 assert_eq!(err.message(), "Invalid API key");
596 }
597
598 #[test]
599 fn test_from_api_response_account_error() {
600 let err = ZaiError::from_api_response(200, 1110, "Account expired".to_string());
601 assert!(err.is_client_error());
602 assert_eq!(err.code(), Some(1110));
603 }
604
605 #[test]
606 fn test_from_api_response_api_error() {
607 let err = ZaiError::from_api_response(200, 1200, "Invalid parameters".to_string());
608 assert!(err.is_client_error());
609 assert_eq!(err.code(), Some(1200));
610 }
611
612 #[test]
613 fn test_from_api_response_unknown_code() {
614 let err = ZaiError::from_api_response(200, 9999, "Unknown error".to_string());
615 assert!(!err.is_client_error()); assert_eq!(err.code(), Some(9999));
617 }
618
619 #[test]
620 fn test_compact() {
621 let err = ZaiError::HttpError {
622 status: 404,
623 message: "Not found".to_string(),
624 };
625 assert_eq!(err.compact(), "HTTP[404]: Not found");
626
627 let err = ZaiError::AuthError {
628 code: 1001,
629 message: "Invalid key".to_string(),
630 };
631 assert_eq!(err.compact(), "AUTH[1001]: Invalid key");
632 }
633
634 #[test]
635 fn test_code() {
636 let io_err =
638 std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "connection refused");
639 let err = ZaiError::from(io_err);
640 assert_eq!(err.code(), Some(0)); let err = ZaiError::JsonError(std::sync::Arc::new(serde_json::Error::io(
644 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid JSON"),
645 )));
646 assert!(err.code().is_none());
647
648 let err = ZaiError::HttpError {
650 status: 500,
651 message: "Server error".to_string(),
652 };
653 assert_eq!(err.code(), Some(500));
654 }
655
656 #[test]
657 fn test_message() {
658 let err = ZaiError::RateLimitError {
659 code: 1300,
660 message: "Too many requests".to_string(),
661 };
662 assert_eq!(err.message(), "Too many requests");
663 }
664
665 #[test]
666 fn test_from_reqwest_error_with_status() {
667 let io_err = std::io::Error::other("test error");
668 let zai_err = ZaiError::from(io_err);
669 match zai_err {
670 ZaiError::Unknown { .. } => {},
671 _ => panic!("Expected Unknown error for io::Error"),
672 }
673 }
674
675 #[test]
676 fn test_validate_api_key_valid() {
677 assert!(validate_api_key("abc123.abcdefghijklmnopqrstuvwxyz").is_ok());
678 }
683
684 #[test]
685 fn test_validate_api_key_empty() {
686 let result = validate_api_key("");
687 assert!(result.is_err());
688 match result {
689 Err(ZaiError::ApiError { code, .. }) => {
690 assert_eq!(code, 1200);
691 },
692 _ => panic!("Expected ApiError"),
693 }
694 }
695
696 #[test]
697 fn test_validate_api_key_no_dot() {
698 let result = validate_api_key("invalid");
699 assert!(result.is_err());
700 match result {
701 Err(ZaiError::ApiError { code, message }) => {
702 assert_eq!(code, 1001);
703 assert!(message.contains("format"));
704 },
705 _ => panic!("Expected ApiError"),
706 }
707 }
708
709 #[test]
710 fn test_validate_api_key_multiple_dots() {
711 let result = validate_api_key("id.secret.extra");
712 assert!(result.is_err());
713 assert_eq!(result.unwrap_err().code(), Some(1001));
714 }
715
716 #[test]
717 fn test_validate_api_key_empty_id() {
718 let result = validate_api_key(".secret123456789");
719 assert!(result.is_err());
720 assert_eq!(result.unwrap_err().code(), Some(1200));
721 }
722
723 #[test]
724 fn test_validate_api_key_empty_secret() {
725 let result = validate_api_key("id123.");
726 assert!(result.is_err());
727 assert_eq!(result.unwrap_err().code(), Some(1200));
728 }
729
730 #[test]
731 fn test_validate_api_key_invalid_chars() {
732 let result = validate_api_key("id$123.secret@456");
733 assert!(result.is_err());
734 assert_eq!(result.unwrap_err().code(), Some(1200));
735 }
736
737 #[test]
738 fn test_validate_api_key_id_too_short() {
739 let result = validate_api_key("ab.abcdefghijklmn");
740 assert!(result.is_err());
741 assert!(result.unwrap_err().message().contains("id is too short"));
742 }
743
744 #[test]
745 fn test_validate_api_key_secret_too_short() {
746 let result = validate_api_key("id123.short");
747 assert!(result.is_err());
748 assert!(
749 result
750 .unwrap_err()
751 .message()
752 .contains("secret is too short")
753 );
754 }
755
756 #[test]
757 fn test_mask_sensitive_info_api_key() {
758 let text = "API key: abc123.abcdefghijklmnopqrstuvwxyz12345";
759 let filtered = mask_sensitive_info(text);
760 assert!(filtered.contains("[FILTERED]"));
761 assert!(!filtered.contains("abc123"));
762 assert!(!filtered.contains("abcdefghijklmnopqrstuvwxyz"));
763 }
764
765 #[test]
766 fn test_mask_sensitive_info_password() {
767 let text = "password: secret123, other text";
768 let filtered = mask_sensitive_info(text);
769 assert!(filtered.contains("[FILTERED]"));
770 assert!(!filtered.contains("secret123"));
771 }
772
773 #[test]
774 fn test_mask_sensitive_info_token() {
775 let text = "token=abc123xyz, other content";
776 let filtered = mask_sensitive_info(text);
777 assert!(filtered.contains("[FILTERED]"));
778 assert!(!filtered.contains("abc123xyz"));
779 }
780
781 #[test]
782 fn test_mask_sensitive_info_bearer() {
783 let text = "Authorization: Bearer abc123.abc1234567890";
784 let filtered = mask_sensitive_info(text);
785 assert!(filtered.contains("[FILTERED]"));
786 assert!(!filtered.contains("abc123"));
787 }
788
789 #[test]
790 fn test_mask_sensitive_info_multiple() {
791 let text = "api_key=abc123.xyz456, password=secret123";
792 let filtered = mask_sensitive_info(text);
793 let filtered_count = filtered.matches("[FILTERED]").count();
794 assert_eq!(filtered_count, 2);
795 }
796
797 #[test]
798 fn test_mask_sensitive_info_no_sensitive() {
799 let text = "Regular text without sensitive information";
800 let filtered = mask_sensitive_info(text);
801 assert_eq!(filtered, text);
802 }
803
804 #[test]
805 fn test_mask_api_key() {
806 let text = "API key: abc123.abcdefghijklmnopqrstuvwxyz12345";
807 let filtered = mask_api_key(text);
808 assert!(filtered.contains("[FILTERED]"));
809 assert!(!filtered.contains("abc123"));
810 }
811
812 #[test]
813 fn test_contains_sensitive_info_api_key() {
814 assert!(contains_sensitive_info("api_key: abc123.abc1234567890"));
815 assert!(!contains_sensitive_info("regular text"));
816 }
817
818 #[test]
819 fn test_contains_sensitive_info_password() {
820 assert!(contains_sensitive_info("password: secret"));
821 assert!(contains_sensitive_info("password=123"));
822 assert!(!contains_sensitive_info("password"));
823 assert!(!contains_sensitive_info("word:password"));
824 }
825
826 #[test]
827 fn test_contains_sensitive_info_token() {
828 assert!(contains_sensitive_info("token=abc123"));
829 assert!(contains_sensitive_info("token: xyz123"));
830 assert!(!contains_sensitive_info("token"));
831 assert!(!contains_sensitive_info("tokenize this"));
832 }
833}