1use crate::error::TorshError;
19
20#[cfg(feature = "std")]
21use std::collections::HashMap;
22#[cfg(feature = "std")]
23use std::sync::{Arc, Mutex, OnceLock};
24#[cfg(feature = "std")]
25use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
26
27#[cfg(not(feature = "std"))]
28use alloc::{
29 collections::BTreeMap as HashMap,
30 string::{String, ToString},
31 vec::Vec,
32};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
36#[repr(u8)]
37pub enum LogLevel {
38 Trace = 0,
40 Debug = 1,
42 Info = 2,
44 Warn = 3,
46 Error = 4,
48 Fatal = 5,
50}
51
52impl LogLevel {
53 pub fn as_str(&self) -> &'static str {
55 match self {
56 LogLevel::Trace => "TRACE",
57 LogLevel::Debug => "DEBUG",
58 LogLevel::Info => "INFO",
59 LogLevel::Warn => "WARN",
60 LogLevel::Error => "ERROR",
61 LogLevel::Fatal => "FATAL",
62 }
63 }
64
65 pub fn is_enabled(&self, min_level: LogLevel) -> bool {
67 *self >= min_level
68 }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
73#[repr(u32)]
74pub enum ErrorCode {
75 Success = 0,
77
78 ShapeMismatch = 1000,
80 InvalidShape = 1001,
81 BroadcastError = 1002,
82 DimensionMismatch = 1003,
83
84 IndexOutOfBounds = 2000,
86 InvalidDimension = 2001,
87 InvalidSlice = 2002,
88
89 TypeMismatch = 3000,
91 UnsupportedType = 3001,
92 ConversionError = 3002,
93
94 DeviceMismatch = 4000,
96 DeviceUnavailable = 4001,
97 DeviceError = 4002,
98
99 AllocationFailed = 5000,
101 OutOfMemory = 5001,
102 InvalidAlignment = 5002,
103
104 ComputeError = 6000,
106 NumericalError = 6001,
107 ConvergenceError = 6002,
108
109 IoError = 7000,
111 SerializationError = 7001,
112 DeserializationError = 7002,
113
114 InvalidOperation = 8000,
116 NotImplemented = 8001,
117 InvalidState = 8002,
118 SynchronizationError = 8003,
119
120 Unknown = 9999,
122}
123
124impl ErrorCode {
125 pub fn code(&self) -> u32 {
127 *self as u32
128 }
129
130 pub fn description(&self) -> &'static str {
132 match self {
133 ErrorCode::Success => "Success",
134 ErrorCode::ShapeMismatch => "Shape mismatch between tensors",
135 ErrorCode::InvalidShape => "Invalid tensor shape",
136 ErrorCode::BroadcastError => "Broadcasting error",
137 ErrorCode::DimensionMismatch => "Dimension mismatch",
138 ErrorCode::IndexOutOfBounds => "Index out of bounds",
139 ErrorCode::InvalidDimension => "Invalid dimension",
140 ErrorCode::InvalidSlice => "Invalid slice",
141 ErrorCode::TypeMismatch => "Type mismatch",
142 ErrorCode::UnsupportedType => "Unsupported type",
143 ErrorCode::ConversionError => "Type conversion error",
144 ErrorCode::DeviceMismatch => "Device mismatch",
145 ErrorCode::DeviceUnavailable => "Device unavailable",
146 ErrorCode::DeviceError => "Device error",
147 ErrorCode::AllocationFailed => "Memory allocation failed",
148 ErrorCode::OutOfMemory => "Out of memory",
149 ErrorCode::InvalidAlignment => "Invalid memory alignment",
150 ErrorCode::ComputeError => "Computation error",
151 ErrorCode::NumericalError => "Numerical error",
152 ErrorCode::ConvergenceError => "Convergence error",
153 ErrorCode::IoError => "I/O error",
154 ErrorCode::SerializationError => "Serialization error",
155 ErrorCode::DeserializationError => "Deserialization error",
156 ErrorCode::InvalidOperation => "Invalid operation",
157 ErrorCode::NotImplemented => "Not implemented",
158 ErrorCode::InvalidState => "Invalid state",
159 ErrorCode::SynchronizationError => "Synchronization error",
160 ErrorCode::Unknown => "Unknown error",
161 }
162 }
163
164 pub fn from_torsh_error(error: &TorshError) -> Self {
166 match error {
167 TorshError::ShapeMismatch { .. } => ErrorCode::ShapeMismatch,
168 TorshError::BroadcastError { .. } => ErrorCode::BroadcastError,
169 TorshError::InvalidShape(_) => ErrorCode::InvalidShape,
170 TorshError::IndexOutOfBounds { .. } => ErrorCode::IndexOutOfBounds,
171 TorshError::IndexError { .. } => ErrorCode::IndexOutOfBounds,
172 TorshError::InvalidDimension { .. } => ErrorCode::InvalidDimension,
173 TorshError::InvalidArgument(_) => ErrorCode::InvalidOperation,
174 TorshError::IoError(_) => ErrorCode::IoError,
175 TorshError::DeviceMismatch => ErrorCode::DeviceMismatch,
176 TorshError::NotImplemented(_) => ErrorCode::NotImplemented,
177 TorshError::AllocationError(_) => ErrorCode::AllocationFailed,
178 TorshError::InvalidOperation(_) => ErrorCode::InvalidOperation,
179 TorshError::ConversionError(_) => ErrorCode::ConversionError,
180 TorshError::InvalidState(_) => ErrorCode::InvalidState,
181 TorshError::UnsupportedOperation { .. } => ErrorCode::NotImplemented,
182 TorshError::ComputeError(_) => ErrorCode::ComputeError,
183 TorshError::SerializationError(_) => ErrorCode::SerializationError,
184 _ => ErrorCode::Unknown,
185 }
186 }
187}
188
189#[derive(Debug, Clone)]
191pub struct LogEvent {
192 pub timestamp: u64,
194 pub level: LogLevel,
196 pub message: String,
198 pub module_path: String,
200 pub file: String,
202 pub line: u32,
204 pub fields: HashMap<String, String>,
206 pub span_id: Option<u64>,
208 pub error_code: Option<ErrorCode>,
210}
211
212impl LogEvent {
213 pub fn new(
215 level: LogLevel,
216 message: String,
217 module_path: String,
218 file: String,
219 line: u32,
220 ) -> Self {
221 #[cfg(feature = "std")]
222 let timestamp = SystemTime::now()
223 .duration_since(UNIX_EPOCH)
224 .unwrap_or_default()
225 .as_secs();
226
227 #[cfg(not(feature = "std"))]
228 let timestamp = 0; Self {
231 timestamp,
232 level,
233 message,
234 module_path,
235 file,
236 line,
237 fields: HashMap::new(),
238 span_id: None,
239 error_code: None,
240 }
241 }
242
243 pub fn with_field(mut self, key: String, value: String) -> Self {
245 self.fields.insert(key, value);
246 self
247 }
248
249 pub fn with_span_id(mut self, span_id: u64) -> Self {
251 self.span_id = Some(span_id);
252 self
253 }
254
255 pub fn with_error_code(mut self, code: ErrorCode) -> Self {
257 self.error_code = Some(code);
258 self
259 }
260
261 pub fn format_structured(&self) -> String {
263 #[cfg(feature = "std")]
264 {
265 let mut parts = vec![
266 format!("timestamp={}", self.timestamp),
267 format!("level={}", self.level.as_str()),
268 format!("message=\"{}\"", self.message),
269 format!("module={}", self.module_path),
270 format!("file={}:{}", self.file, self.line),
271 ];
272
273 if let Some(span_id) = self.span_id {
274 parts.push(format!("span_id={}", span_id));
275 }
276
277 if let Some(error_code) = self.error_code {
278 parts.push(format!("error_code={}", error_code.code()));
279 }
280
281 for (key, value) in &self.fields {
282 parts.push(format!("{}=\"{}\"", key, value));
283 }
284
285 parts.join(" ")
286 }
287
288 #[cfg(not(feature = "std"))]
289 {
290 use alloc::vec;
291 let mut parts = vec![
292 format!("timestamp={}", self.timestamp),
293 format!("level={}", self.level.as_str()),
294 format!("message=\"{}\"", self.message),
295 format!("module={}", self.module_path),
296 format!("file={}:{}", self.file, self.line),
297 ];
298
299 if let Some(span_id) = self.span_id {
300 parts.push(format!("span_id={}", span_id));
301 }
302
303 if let Some(error_code) = self.error_code {
304 parts.push(format!("error_code={}", error_code.code()));
305 }
306
307 for (key, value) in &self.fields {
308 parts.push(format!("{}=\"{}\"", key, value));
309 }
310
311 parts.join(" ")
312 }
313 }
314}
315
316#[cfg(feature = "std")]
318#[derive(Debug, Clone)]
319pub struct Span {
320 pub span_id: u64,
322 pub parent_id: Option<u64>,
324 pub name: String,
326 pub start_time: Instant,
328 pub attributes: HashMap<String, String>,
330 pub events: Vec<SpanEvent>,
332}
333
334#[cfg(feature = "std")]
335impl Span {
336 pub fn new(span_id: u64, name: String, parent_id: Option<u64>) -> Self {
338 Self {
339 span_id,
340 parent_id,
341 name,
342 start_time: Instant::now(),
343 attributes: HashMap::new(),
344 events: Vec::new(),
345 }
346 }
347
348 pub fn add_attribute(&mut self, key: String, value: String) {
350 self.attributes.insert(key, value);
351 }
352
353 pub fn add_event(&mut self, event: SpanEvent) {
355 self.events.push(event);
356 }
357
358 pub fn duration(&self) -> Duration {
360 self.start_time.elapsed()
361 }
362
363 pub fn close(self) -> SpanMetrics {
365 SpanMetrics {
366 span_id: self.span_id,
367 name: self.name,
368 duration: self.start_time.elapsed(),
369 event_count: self.events.len(),
370 attributes: self.attributes,
371 }
372 }
373}
374
375#[cfg(feature = "std")]
377#[derive(Debug, Clone)]
378pub struct SpanEvent {
379 pub name: String,
381 pub timestamp: Duration,
383 pub attributes: HashMap<String, String>,
385}
386
387#[cfg(feature = "std")]
389#[derive(Debug, Clone)]
390pub struct SpanMetrics {
391 pub span_id: u64,
393 pub name: String,
395 pub duration: Duration,
397 pub event_count: usize,
399 pub attributes: HashMap<String, String>,
401}
402
403#[derive(Debug, Clone)]
405pub struct TelemetryConfig {
406 pub min_log_level: LogLevel,
408 pub enable_tracing: bool,
410 pub buffer_size: usize,
412 pub console_output: bool,
414 pub structured_logging: bool,
416}
417
418impl Default for TelemetryConfig {
419 fn default() -> Self {
420 Self {
421 min_log_level: LogLevel::Info,
422 enable_tracing: false,
423 buffer_size: 1000,
424 console_output: true,
425 structured_logging: true,
426 }
427 }
428}
429
430#[cfg(feature = "std")]
432pub struct TelemetrySystem {
433 config: TelemetryConfig,
434 event_buffer: Mutex<Vec<LogEvent>>,
435 active_spans: Mutex<HashMap<u64, Span>>,
436 next_span_id: Mutex<u64>,
437 closed_spans: Mutex<Vec<SpanMetrics>>,
438}
439
440#[cfg(feature = "std")]
441impl TelemetrySystem {
442 pub fn new(config: TelemetryConfig) -> Self {
444 let buffer_size = config.buffer_size;
445 Self {
446 config,
447 event_buffer: Mutex::new(Vec::with_capacity(buffer_size)),
448 active_spans: Mutex::new(HashMap::new()),
449 next_span_id: Mutex::new(1),
450 closed_spans: Mutex::new(Vec::new()),
451 }
452 }
453
454 pub fn log(&self, event: LogEvent) {
456 if !event.level.is_enabled(self.config.min_log_level) {
457 return;
458 }
459
460 if self.config.console_output {
462 if self.config.structured_logging {
463 eprintln!("{}", event.format_structured());
464 } else {
465 eprintln!("[{}] {}", event.level.as_str(), event.message);
466 }
467 }
468
469 let mut buffer = self
471 .event_buffer
472 .lock()
473 .expect("lock should not be poisoned");
474 buffer.push(event);
475
476 if buffer.len() >= self.config.buffer_size {
478 self.flush_events(&mut buffer);
479 }
480 }
481
482 pub fn start_span(&self, name: String, parent_id: Option<u64>) -> u64 {
484 let mut next_id = self
485 .next_span_id
486 .lock()
487 .expect("lock should not be poisoned");
488 let span_id = *next_id;
489 *next_id += 1;
490
491 let span = Span::new(span_id, name, parent_id);
492 let mut spans = self
493 .active_spans
494 .lock()
495 .expect("lock should not be poisoned");
496 spans.insert(span_id, span);
497
498 span_id
499 }
500
501 pub fn span_add_attribute(&self, span_id: u64, key: String, value: String) {
503 let mut spans = self
504 .active_spans
505 .lock()
506 .expect("lock should not be poisoned");
507 if let Some(span) = spans.get_mut(&span_id) {
508 span.add_attribute(key, value);
509 }
510 }
511
512 pub fn end_span(&self, span_id: u64) -> Option<SpanMetrics> {
514 let mut spans = self
515 .active_spans
516 .lock()
517 .expect("lock should not be poisoned");
518 if let Some(span) = spans.remove(&span_id) {
519 let metrics = span.close();
520 let mut closed = self
521 .closed_spans
522 .lock()
523 .expect("lock should not be poisoned");
524 closed.push(metrics.clone());
525 Some(metrics)
526 } else {
527 None
528 }
529 }
530
531 fn flush_events(&self, buffer: &mut Vec<LogEvent>) {
533 buffer.clear();
536 }
537
538 pub fn get_events(&self) -> Vec<LogEvent> {
540 let buffer = self
541 .event_buffer
542 .lock()
543 .expect("lock should not be poisoned");
544 buffer.clone()
545 }
546
547 pub fn get_span_metrics(&self) -> Vec<SpanMetrics> {
549 let closed = self
550 .closed_spans
551 .lock()
552 .expect("lock should not be poisoned");
553 closed.clone()
554 }
555
556 pub fn clear(&self) {
558 let mut buffer = self
559 .event_buffer
560 .lock()
561 .expect("lock should not be poisoned");
562 buffer.clear();
563 let mut closed = self
564 .closed_spans
565 .lock()
566 .expect("lock should not be poisoned");
567 closed.clear();
568 }
569}
570
571#[cfg(feature = "std")]
573static TELEMETRY: OnceLock<Arc<TelemetrySystem>> = OnceLock::new();
574
575#[cfg(feature = "std")]
577pub fn init_telemetry(config: TelemetryConfig) {
578 TELEMETRY.get_or_init(|| Arc::new(TelemetrySystem::new(config)));
579}
580
581#[cfg(feature = "std")]
583pub fn telemetry() -> Arc<TelemetrySystem> {
584 TELEMETRY
585 .get_or_init(|| Arc::new(TelemetrySystem::new(TelemetryConfig::default())))
586 .clone()
587}
588
589#[macro_export]
591macro_rules! log {
592 ($level:expr, $msg:expr $(, $key:expr => $value:expr)*) => {{
593 #[cfg(feature = "std")]
594 {
595 let mut event = $crate::telemetry::LogEvent::new(
596 $level,
597 $msg.to_string(),
598 module_path!().to_string(),
599 file!().to_string(),
600 line!(),
601 );
602 $(
603 event = event.with_field($key.to_string(), $value.to_string());
604 )*
605 $crate::telemetry::telemetry().log(event);
606 }
607 }};
608}
609
610#[macro_export]
612macro_rules! trace {
613 ($msg:expr $(, $key:expr => $value:expr)*) => {
614 $crate::log!($crate::telemetry::LogLevel::Trace, $msg $(, $key => $value)*)
615 };
616}
617
618#[macro_export]
619macro_rules! debug {
620 ($msg:expr $(, $key:expr => $value:expr)*) => {
621 $crate::log!($crate::telemetry::LogLevel::Debug, $msg $(, $key => $value)*)
622 };
623}
624
625#[macro_export]
626macro_rules! info {
627 ($msg:expr $(, $key:expr => $value:expr)*) => {
628 $crate::log!($crate::telemetry::LogLevel::Info, $msg $(, $key => $value)*)
629 };
630}
631
632#[macro_export]
633macro_rules! warn {
634 ($msg:expr $(, $key:expr => $value:expr)*) => {
635 $crate::log!($crate::telemetry::LogLevel::Warn, $msg $(, $key => $value)*)
636 };
637}
638
639#[macro_export]
640macro_rules! error {
641 ($msg:expr $(, $key:expr => $value:expr)*) => {
642 $crate::log!($crate::telemetry::LogLevel::Error, $msg $(, $key => $value)*)
643 };
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_log_level_ordering() {
652 assert!(LogLevel::Debug > LogLevel::Trace);
653 assert!(LogLevel::Info > LogLevel::Debug);
654 assert!(LogLevel::Warn > LogLevel::Info);
655 assert!(LogLevel::Error > LogLevel::Warn);
656 assert!(LogLevel::Fatal > LogLevel::Error);
657 }
658
659 #[test]
660 fn test_log_level_is_enabled() {
661 let min_level = LogLevel::Info;
662 assert!(!LogLevel::Trace.is_enabled(min_level));
663 assert!(!LogLevel::Debug.is_enabled(min_level));
664 assert!(LogLevel::Info.is_enabled(min_level));
665 assert!(LogLevel::Warn.is_enabled(min_level));
666 assert!(LogLevel::Error.is_enabled(min_level));
667 }
668
669 #[test]
670 fn test_error_code_mapping() {
671 assert_eq!(ErrorCode::ShapeMismatch.code(), 1000);
672 assert_eq!(ErrorCode::IndexOutOfBounds.code(), 2000);
673 assert_eq!(ErrorCode::TypeMismatch.code(), 3000);
674 assert_eq!(ErrorCode::DeviceMismatch.code(), 4000);
675 assert_eq!(ErrorCode::AllocationFailed.code(), 5000);
676 }
677
678 #[test]
679 fn test_error_code_from_torsh_error() {
680 let error = TorshError::InvalidShape("test".to_string());
681 assert_eq!(ErrorCode::from_torsh_error(&error), ErrorCode::InvalidShape);
682
683 let error = TorshError::DeviceMismatch;
684 assert_eq!(
685 ErrorCode::from_torsh_error(&error),
686 ErrorCode::DeviceMismatch
687 );
688 }
689
690 #[test]
691 fn test_log_event_creation() {
692 let event = LogEvent::new(
693 LogLevel::Info,
694 "test message".to_string(),
695 "test_module".to_string(),
696 "test.rs".to_string(),
697 42,
698 );
699
700 assert_eq!(event.level, LogLevel::Info);
701 assert_eq!(event.message, "test message");
702 assert_eq!(event.line, 42);
703 }
704
705 #[test]
706 fn test_log_event_with_metadata() {
707 let event = LogEvent::new(
708 LogLevel::Error,
709 "error occurred".to_string(),
710 "test_module".to_string(),
711 "test.rs".to_string(),
712 10,
713 )
714 .with_field("tensor_id".to_string(), "123".to_string())
715 .with_error_code(ErrorCode::ComputeError);
716
717 assert!(event.fields.contains_key("tensor_id"));
718 assert_eq!(event.error_code, Some(ErrorCode::ComputeError));
719 }
720
721 #[test]
722 #[cfg(feature = "std")]
723 fn test_telemetry_system() {
724 let config = TelemetryConfig {
725 min_log_level: LogLevel::Debug,
726 console_output: false,
727 ..Default::default()
728 };
729 let telemetry = TelemetrySystem::new(config);
730
731 let event = LogEvent::new(
732 LogLevel::Info,
733 "test".to_string(),
734 "test".to_string(),
735 "test.rs".to_string(),
736 1,
737 );
738 telemetry.log(event.clone());
739
740 let events = telemetry.get_events();
741 assert_eq!(events.len(), 1);
742 assert_eq!(events[0].message, "test");
743 }
744
745 #[test]
746 #[cfg(feature = "std")]
747 fn test_span_creation() {
748 let config = TelemetryConfig::default();
749 let telemetry = TelemetrySystem::new(config);
750
751 let span_id = telemetry.start_span("test_operation".to_string(), None);
752 telemetry.span_add_attribute(span_id, "key".to_string(), "value".to_string());
753
754 let metrics = telemetry
755 .end_span(span_id)
756 .expect("end_span should succeed");
757 assert_eq!(metrics.name, "test_operation");
758 assert!(metrics.attributes.contains_key("key"));
759 }
760
761 #[test]
762 #[cfg(feature = "std")]
763 fn test_log_filtering() {
764 let config = TelemetryConfig {
765 min_log_level: LogLevel::Warn,
766 console_output: false,
767 ..Default::default()
768 };
769 let telemetry = TelemetrySystem::new(config);
770
771 telemetry.log(LogEvent::new(
773 LogLevel::Info,
774 "info".to_string(),
775 "test".to_string(),
776 "test.rs".to_string(),
777 1,
778 ));
779
780 telemetry.log(LogEvent::new(
782 LogLevel::Error,
783 "error".to_string(),
784 "test".to_string(),
785 "test.rs".to_string(),
786 2,
787 ));
788
789 let events = telemetry.get_events();
790 assert_eq!(events.len(), 1);
791 assert_eq!(events[0].level, LogLevel::Error);
792 }
793}