1use std::sync::Arc;
12
13use regex;
14use thiserror::Error;
15
16pub fn mask_sensitive_info(text: &str) -> String {
50 let mut result = text.to_string();
51
52 let api_key_pattern = regex::Regex::new(r"\b[a-zA-Z0-9_-]{3,}\.[a-zA-Z0-9_-]{10,}\b").unwrap();
54 result = api_key_pattern
55 .replace_all(&result, "[FILTERED]")
56 .to_string();
57
58 let patterns = vec![
60 (r"(?i)(api[_-]?key\s*[=:]\s*)[^\s,]+", "$1[FILTERED]"),
61 (r"(?i)(password\s*[=:]\s*)[^\s,]+", "$1[FILTERED]"),
62 (r"(?i)(token\s*[=:]\s*)[^\s,]+", "$1[FILTERED]"),
63 (r"(?i)(secret\s*[=:]\s*)[^\s,]+", "$1[FILTERED]"),
64 (
65 r"(?i)(bearer\s+[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+)",
66 "bearer [FILTERED]",
67 ),
68 (
69 r"(?i)(authorization\s*:\s*Bearer\s+)[^\s,]+",
70 "$1[FILTERED]",
71 ),
72 ];
73
74 for (pattern, replacement) in patterns {
75 let re = regex::Regex::new(pattern).unwrap();
76 result = re.replace_all(&result, replacement).to_string();
77 }
78
79 result
80}
81
82pub fn mask_api_key(text: &str) -> String {
87 let pattern = regex::Regex::new(r"\b[a-zA-Z0-9_-]{3,}\.[a-zA-Z0-9_-]{10,}\b").unwrap();
88 pattern.replace_all(text, "[FILTERED]").to_string()
89}
90
91pub fn contains_sensitive_info(text: &str) -> bool {
93 let api_key_pattern = regex::Regex::new(r"\b[a-zA-Z0-9_-]{3,}\.[a-zA-Z0-9_-]{10,}\b").unwrap();
94
95 if api_key_pattern.is_match(text) {
96 return true;
97 }
98
99 let patterns = vec![
100 r"(?i)api[_-]?key\s*[=:]",
101 r"(?i)password\s*[=:]",
102 r"(?i)token\s*[=:]",
103 r"(?i)secret\s*[=:]",
104 r"(?i)authorization\s*:\s*Bearer",
105 ];
106
107 for pattern in patterns {
108 if regex::Regex::new(pattern).unwrap().is_match(text) {
109 return true;
110 }
111 }
112
113 false
114}
115
116pub fn validate_api_key(api_key: &str) -> ZaiResult<()> {
141 if api_key.is_empty() {
142 return Err(ZaiError::ApiError {
143 code: 1200,
144 message: "API key cannot be empty".to_string(),
145 });
146 }
147
148 let parts: Vec<&str> = api_key.split('.').collect();
149 if parts.len() != 2 {
150 return Err(ZaiError::ApiError {
151 code: 1001,
152 message: "API key must be in format '<id>.<secret>'".to_string(),
153 });
154 }
155
156 let (id, secret) = (parts[0], parts[1]);
157
158 if id.is_empty() || secret.is_empty() {
159 return Err(ZaiError::ApiError {
160 code: 1200,
161 message: "API key id and secret must not be empty".to_string(),
162 });
163 }
164
165 let valid_chars = |s: &str| -> bool {
168 s.chars()
169 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
170 };
171
172 if !valid_chars(id) || !valid_chars(secret) {
173 return Err(ZaiError::ApiError {
174 code: 1200,
175 message: "API key contains invalid characters".to_string(),
176 });
177 }
178
179 if id.len() < 3 {
182 return Err(ZaiError::ApiError {
183 code: 1200,
184 message: "API key id is too short".to_string(),
185 });
186 }
187
188 if secret.len() < 10 {
189 return Err(ZaiError::ApiError {
190 code: 1200,
191 message: "API key secret is too short".to_string(),
192 });
193 }
194
195 Ok(())
196}
197
198#[derive(Error, Debug)]
200pub enum ZaiError {
201 #[error("HTTP error [{status}]: {message}")]
203 HttpError { status: u16, message: String },
204
205 #[error("Authentication error [{code}]: {message}")]
207 AuthError { code: u16, message: String },
208
209 #[error("Account error [{code}]: {message}")]
211 AccountError { code: u16, message: String },
212
213 #[error("API error [{code}]: {message}")]
215 ApiError { code: u16, message: String },
216
217 #[error("Rate limit error [{code}]: {message}")]
219 RateLimitError { code: u16, message: String },
220
221 #[error("Content policy error [{code}]: {message}")]
223 ContentPolicyError { code: u16, message: String },
224
225 #[error("File error [{code}]: {message}")]
227 FileError { code: u16, message: String },
228
229 #[error("Network error: {0}")]
231 NetworkError(Arc<reqwest::Error>),
232
233 #[error("JSON error: {0}")]
235 JsonError(Arc<serde_json::Error>),
236
237 #[error("Unknown error [{code}]: {message}")]
239 Unknown { code: u16, message: String },
240}
241
242impl ZaiError {
243 pub fn from_api_response(status: u16, api_code: u16, api_message: String) -> Self {
245 match status {
247 400 => ZaiError::HttpError {
248 status,
249 message: if api_message.is_empty() {
250 "Bad request - check your parameters".to_string()
251 } else {
252 api_message
253 },
254 },
255 401 => ZaiError::HttpError {
256 status,
257 message: "Unauthorized - check your API key".to_string(),
258 },
259 404 => ZaiError::HttpError {
260 status,
261 message: "Not found - requested resource doesn't exist".to_string(),
262 },
263 429 => ZaiError::HttpError {
264 status,
265 message: if api_message.is_empty() {
266 "Too many requests - rate limit exceeded".to_string()
267 } else {
268 api_message
269 },
270 },
271 434 => ZaiError::HttpError {
272 status,
273 message: "No API permission - feature not available".to_string(),
274 },
275 435 => ZaiError::HttpError {
276 status,
277 message: "File size exceeds 100MB limit".to_string(),
278 },
279 500 => ZaiError::HttpError {
280 status,
281 message: "Internal server error - try again later".to_string(),
282 },
283 _ => {
284 match api_code {
286 1000..=1004 | 1100 => ZaiError::AuthError {
288 code: api_code,
289 message: api_message,
290 },
291 1110..=1121 => ZaiError::AccountError {
293 code: api_code,
294 message: api_message,
295 },
296 1200..=1234 => ZaiError::ApiError {
298 code: api_code,
299 message: api_message,
300 },
301 1300..=1309 => ZaiError::RateLimitError {
303 code: api_code,
304 message: api_message,
305 },
306 _ => ZaiError::Unknown {
308 code: api_code,
309 message: if api_message.is_empty() {
310 "Unknown error".to_string()
311 } else {
312 api_message
313 },
314 },
315 }
316 },
317 }
318 }
319
320 pub fn is_rate_limit(&self) -> bool {
322 matches!(self, ZaiError::RateLimitError { .. })
323 }
324
325 pub fn is_auth_error(&self) -> bool {
327 matches!(self, ZaiError::AuthError { .. })
328 }
329
330 pub fn is_client_error(&self) -> bool {
332 match self {
333 ZaiError::HttpError { status, .. } => *status >= 400 && *status < 500,
334 ZaiError::AuthError { .. }
335 | ZaiError::AccountError { .. }
336 | ZaiError::ApiError { .. }
337 | ZaiError::RateLimitError { .. }
338 | ZaiError::ContentPolicyError { .. }
339 | ZaiError::FileError { .. } => true,
340 _ => false,
341 }
342 }
343
344 pub fn is_server_error(&self) -> bool {
346 match self {
347 ZaiError::HttpError { status, .. } => *status >= 500,
348 ZaiError::Unknown { code, .. } => *code >= 500,
349 _ => false,
350 }
351 }
352
353 pub fn compact(&self) -> String {
355 match self {
356 ZaiError::HttpError { status, message } => {
357 format!("HTTP[{}]: {}", status, message)
358 },
359 ZaiError::AuthError { code, message } => {
360 format!("AUTH[{}]: {}", code, message)
361 },
362 ZaiError::AccountError { code, message } => {
363 format!("ACCOUNT[{}]: {}", code, message)
364 },
365 ZaiError::ApiError { code, message } => {
366 format!("API[{}]: {}", code, message)
367 },
368 ZaiError::RateLimitError { code, message } => {
369 format!("RATE_LIMIT[{}]: {}", code, message)
370 },
371 ZaiError::ContentPolicyError { code, message } => {
372 format!("POLICY[{}]: {}", code, message)
373 },
374 ZaiError::FileError { code, message } => {
375 format!("FILE[{}]: {}", code, message)
376 },
377 ZaiError::NetworkError(err) => {
378 format!("NETWORK: {}", err)
379 },
380 ZaiError::JsonError(err) => {
381 format!("JSON: {}", err)
382 },
383 ZaiError::Unknown { code, message } => {
384 format!("UNKNOWN[{}]: {}", code, message)
385 },
386 }
387 }
388
389 pub fn code(&self) -> Option<u16> {
391 match self {
392 ZaiError::HttpError { status, .. } => Some(*status),
393 ZaiError::AuthError { code, .. } => Some(*code),
394 ZaiError::AccountError { code, .. } => Some(*code),
395 ZaiError::ApiError { code, .. } => Some(*code),
396 ZaiError::RateLimitError { code, .. } => Some(*code),
397 ZaiError::ContentPolicyError { code, .. } => Some(*code),
398 ZaiError::FileError { code, .. } => Some(*code),
399 ZaiError::NetworkError(_) => None,
400 ZaiError::JsonError(_) => None,
401 ZaiError::Unknown { code, .. } => Some(*code),
402 }
403 }
404
405 pub fn message(&self) -> String {
407 match self {
408 ZaiError::HttpError { message, .. } => message.clone(),
409 ZaiError::AuthError { message, .. } => message.clone(),
410 ZaiError::AccountError { message, .. } => message.clone(),
411 ZaiError::ApiError { message, .. } => message.clone(),
412 ZaiError::RateLimitError { message, .. } => message.clone(),
413 ZaiError::ContentPolicyError { message, .. } => message.clone(),
414 ZaiError::FileError { message, .. } => message.clone(),
415 ZaiError::NetworkError(err) => err.to_string(),
416 ZaiError::JsonError(err) => err.to_string(),
417 ZaiError::Unknown { message, .. } => message.clone(),
418 }
419 }
420}
421
422impl Clone for ZaiError {
423 fn clone(&self) -> Self {
424 match self {
425 ZaiError::HttpError { status, message } => ZaiError::HttpError {
426 status: *status,
427 message: message.clone(),
428 },
429 ZaiError::AuthError { code, message } => ZaiError::AuthError {
430 code: *code,
431 message: message.clone(),
432 },
433 ZaiError::AccountError { code, message } => ZaiError::AccountError {
434 code: *code,
435 message: message.clone(),
436 },
437 ZaiError::ApiError { code, message } => ZaiError::ApiError {
438 code: *code,
439 message: message.clone(),
440 },
441 ZaiError::RateLimitError { code, message } => ZaiError::RateLimitError {
442 code: *code,
443 message: message.clone(),
444 },
445 ZaiError::ContentPolicyError { code, message } => ZaiError::ContentPolicyError {
446 code: *code,
447 message: message.clone(),
448 },
449 ZaiError::FileError { code, message } => ZaiError::FileError {
450 code: *code,
451 message: message.clone(),
452 },
453 ZaiError::NetworkError(err) => ZaiError::NetworkError(Arc::clone(err)),
455 ZaiError::JsonError(err) => ZaiError::JsonError(Arc::clone(err)),
456 ZaiError::Unknown { code, message } => ZaiError::Unknown {
457 code: *code,
458 message: message.clone(),
459 },
460 }
461 }
462}
463
464pub type ZaiResult<T> = Result<T, ZaiError>;
466
467impl From<reqwest::Error> for ZaiError {
469 fn from(err: reqwest::Error) -> Self {
470 if let Some(status) = err.status() {
471 ZaiError::from_api_response(status.as_u16(), 0, err.to_string())
472 } else {
473 ZaiError::NetworkError(Arc::new(err))
474 }
475 }
476}
477
478impl From<serde_json::Error> for ZaiError {
480 fn from(err: serde_json::Error) -> Self {
481 ZaiError::JsonError(Arc::new(err))
482 }
483}
484
485impl From<validator::ValidationErrors> for ZaiError {
487 fn from(err: validator::ValidationErrors) -> Self {
488 ZaiError::ApiError {
489 code: 1200,
490 message: format!("Validation error: {:?}", err),
491 }
492 }
493}
494
495impl From<std::io::Error> for ZaiError {
497 fn from(err: std::io::Error) -> Self {
498 ZaiError::Unknown {
499 code: 0,
500 message: err.to_string(),
501 }
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_from_api_response_bad_request() {
511 let err = ZaiError::from_api_response(400, 0, "Invalid input".to_string());
512 assert!(err.is_client_error());
513 assert!(!err.is_server_error());
514 assert_eq!(err.code(), Some(400));
515 }
516
517 #[test]
518 fn test_from_api_response_unauthorized() {
519 let err = ZaiError::from_api_response(401, 0, "".to_string());
520 assert!(err.is_client_error());
521 assert_eq!(err.message(), "Unauthorized - check your API key");
522 }
523
524 #[test]
525 fn test_from_api_response_rate_limit() {
526 let err = ZaiError::from_api_response(429, 1301, "Too many requests".to_string());
528 assert!(err.is_client_error());
529 assert!(!err.is_rate_limit()); assert_eq!(err.code(), Some(429));
531
532 let err = ZaiError::from_api_response(200, 1301, "Too many requests".to_string());
534 assert!(err.is_client_error());
535 assert!(err.is_rate_limit());
536 assert_eq!(err.code(), Some(1301));
537 }
538
539 #[test]
540 fn test_from_api_response_server_error() {
541 let err = ZaiError::from_api_response(500, 0, "".to_string());
542 assert!(!err.is_client_error());
543 assert!(err.is_server_error());
544 }
545
546 #[test]
547 fn test_from_api_response_auth_error_code() {
548 let err = ZaiError::from_api_response(200, 1001, "Invalid API key".to_string());
549 assert!(err.is_auth_error());
550 assert_eq!(err.code(), Some(1001));
551 assert_eq!(err.message(), "Invalid API key");
552 }
553
554 #[test]
555 fn test_from_api_response_account_error() {
556 let err = ZaiError::from_api_response(200, 1110, "Account expired".to_string());
557 assert!(err.is_client_error());
558 assert_eq!(err.code(), Some(1110));
559 }
560
561 #[test]
562 fn test_from_api_response_api_error() {
563 let err = ZaiError::from_api_response(200, 1200, "Invalid parameters".to_string());
564 assert!(err.is_client_error());
565 assert_eq!(err.code(), Some(1200));
566 }
567
568 #[test]
569 fn test_from_api_response_unknown_code() {
570 let err = ZaiError::from_api_response(200, 9999, "Unknown error".to_string());
571 assert!(!err.is_client_error()); assert_eq!(err.code(), Some(9999));
573 }
574
575 #[test]
576 fn test_compact() {
577 let err = ZaiError::HttpError {
578 status: 404,
579 message: "Not found".to_string(),
580 };
581 assert_eq!(err.compact(), "HTTP[404]: Not found");
582
583 let err = ZaiError::AuthError {
584 code: 1001,
585 message: "Invalid key".to_string(),
586 };
587 assert_eq!(err.compact(), "AUTH[1001]: Invalid key");
588 }
589
590 #[test]
591 fn test_code() {
592 let io_err =
594 std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "connection refused");
595 let err = ZaiError::from(io_err);
596 assert_eq!(err.code(), Some(0)); let err = ZaiError::JsonError(std::sync::Arc::new(serde_json::Error::io(
600 std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid JSON"),
601 )));
602 assert!(err.code().is_none());
603
604 let err = ZaiError::HttpError {
606 status: 500,
607 message: "Server error".to_string(),
608 };
609 assert_eq!(err.code(), Some(500));
610 }
611
612 #[test]
613 fn test_message() {
614 let err = ZaiError::RateLimitError {
615 code: 1300,
616 message: "Too many requests".to_string(),
617 };
618 assert_eq!(err.message(), "Too many requests");
619 }
620
621 #[test]
622 fn test_from_reqwest_error_with_status() {
623 let io_err = std::io::Error::other("test error");
624 let zai_err = ZaiError::from(io_err);
625 match zai_err {
626 ZaiError::Unknown { .. } => {},
627 _ => panic!("Expected Unknown error for io::Error"),
628 }
629 }
630
631 #[test]
632 fn test_validate_api_key_valid() {
633 assert!(validate_api_key("abc123.abcdefghijklmnopqrstuvwxyz").is_ok());
634 }
639
640 #[test]
641 fn test_validate_api_key_empty() {
642 let result = validate_api_key("");
643 assert!(result.is_err());
644 match result {
645 Err(ZaiError::ApiError { code, .. }) => {
646 assert_eq!(code, 1200);
647 },
648 _ => panic!("Expected ApiError"),
649 }
650 }
651
652 #[test]
653 fn test_validate_api_key_no_dot() {
654 let result = validate_api_key("invalid");
655 assert!(result.is_err());
656 match result {
657 Err(ZaiError::ApiError { code, message }) => {
658 assert_eq!(code, 1001);
659 assert!(message.contains("format"));
660 },
661 _ => panic!("Expected ApiError"),
662 }
663 }
664
665 #[test]
666 fn test_validate_api_key_multiple_dots() {
667 let result = validate_api_key("id.secret.extra");
668 assert!(result.is_err());
669 assert_eq!(result.unwrap_err().code(), Some(1001));
670 }
671
672 #[test]
673 fn test_validate_api_key_empty_id() {
674 let result = validate_api_key(".secret123456789");
675 assert!(result.is_err());
676 assert_eq!(result.unwrap_err().code(), Some(1200));
677 }
678
679 #[test]
680 fn test_validate_api_key_empty_secret() {
681 let result = validate_api_key("id123.");
682 assert!(result.is_err());
683 assert_eq!(result.unwrap_err().code(), Some(1200));
684 }
685
686 #[test]
687 fn test_validate_api_key_invalid_chars() {
688 let result = validate_api_key("id$123.secret@456");
689 assert!(result.is_err());
690 assert_eq!(result.unwrap_err().code(), Some(1200));
691 }
692
693 #[test]
694 fn test_validate_api_key_id_too_short() {
695 let result = validate_api_key("ab.abcdefghijklmn");
696 assert!(result.is_err());
697 assert!(result.unwrap_err().message().contains("id is too short"));
698 }
699
700 #[test]
701 fn test_validate_api_key_secret_too_short() {
702 let result = validate_api_key("id123.short");
703 assert!(result.is_err());
704 assert!(
705 result
706 .unwrap_err()
707 .message()
708 .contains("secret is too short")
709 );
710 }
711
712 #[test]
713 fn test_mask_sensitive_info_api_key() {
714 let text = "API key: abc123.abcdefghijklmnopqrstuvwxyz12345";
715 let filtered = mask_sensitive_info(text);
716 assert!(filtered.contains("[FILTERED]"));
717 assert!(!filtered.contains("abc123"));
718 assert!(!filtered.contains("abcdefghijklmnopqrstuvwxyz"));
719 }
720
721 #[test]
722 fn test_mask_sensitive_info_password() {
723 let text = "password: secret123, other text";
724 let filtered = mask_sensitive_info(text);
725 assert!(filtered.contains("[FILTERED]"));
726 assert!(!filtered.contains("secret123"));
727 }
728
729 #[test]
730 fn test_mask_sensitive_info_token() {
731 let text = "token=abc123xyz, other content";
732 let filtered = mask_sensitive_info(text);
733 assert!(filtered.contains("[FILTERED]"));
734 assert!(!filtered.contains("abc123xyz"));
735 }
736
737 #[test]
738 fn test_mask_sensitive_info_bearer() {
739 let text = "Authorization: Bearer abc123.abc1234567890";
740 let filtered = mask_sensitive_info(text);
741 assert!(filtered.contains("[FILTERED]"));
742 assert!(!filtered.contains("abc123"));
743 }
744
745 #[test]
746 fn test_mask_sensitive_info_multiple() {
747 let text = "api_key=abc123.xyz456, password=secret123";
748 let filtered = mask_sensitive_info(text);
749 let filtered_count = filtered.matches("[FILTERED]").count();
750 assert_eq!(filtered_count, 2);
751 }
752
753 #[test]
754 fn test_mask_sensitive_info_no_sensitive() {
755 let text = "Regular text without sensitive information";
756 let filtered = mask_sensitive_info(text);
757 assert_eq!(filtered, text);
758 }
759
760 #[test]
761 fn test_mask_api_key() {
762 let text = "API key: abc123.abcdefghijklmnopqrstuvwxyz12345";
763 let filtered = mask_api_key(text);
764 assert!(filtered.contains("[FILTERED]"));
765 assert!(!filtered.contains("abc123"));
766 }
767
768 #[test]
769 fn test_contains_sensitive_info_api_key() {
770 assert!(contains_sensitive_info("api_key: abc123.abc1234567890"));
771 assert!(!contains_sensitive_info("regular text"));
772 }
773
774 #[test]
775 fn test_contains_sensitive_info_password() {
776 assert!(contains_sensitive_info("password: secret"));
777 assert!(contains_sensitive_info("password=123"));
778 assert!(!contains_sensitive_info("password"));
779 assert!(!contains_sensitive_info("word:password"));
780 }
781
782 #[test]
783 fn test_contains_sensitive_info_token() {
784 assert!(contains_sensitive_info("token=abc123"));
785 assert!(contains_sensitive_info("token: xyz123"));
786 assert!(!contains_sensitive_info("token"));
787 assert!(!contains_sensitive_info("tokenize this"));
788 }
789}