1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
//! Miscellaneous helpers for working with WGSL syntax trees.

use crate::{decl::Decl, expr::Expr, stmt::Stmt};

mod parents;

use bitflags::bitflags;
use gramatika::{ArcStr, Span, Spanned, Substr};
use lsp_types::{Position, Range};
pub use parents::{find_parent, find_parents};

#[derive(DebugLisp, Clone)]
pub enum SyntaxNode {
	Decl(Decl),
	Stmt(Stmt),
	Expr(Expr),
}

bitflags! {
	#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
	pub struct SyntaxKind: u8 {
		const DECL = 0b001;
		const STMT = 0b010;
		const EXPR = 0b100;
	}
}

pub trait GetSyntaxKind {
	const KIND: SyntaxKind;
}
impl GetSyntaxKind for Decl {
	const KIND: SyntaxKind = SyntaxKind::DECL;
}
impl GetSyntaxKind for Stmt {
	const KIND: SyntaxKind = SyntaxKind::STMT;
}
impl GetSyntaxKind for Expr {
	const KIND: SyntaxKind = SyntaxKind::EXPR;
}
impl GetSyntaxKind for SyntaxNode {
	const KIND: SyntaxKind = SyntaxKind::all();
}

impl Spanned for SyntaxNode {
	fn span(&self) -> gramatika::Span {
		match self {
			SyntaxNode::Decl(decl) => decl.span(),
			SyntaxNode::Stmt(stmt) => stmt.span(),
			SyntaxNode::Expr(expr) => expr.span(),
		}
	}
}

impl From<SyntaxNode> for Decl {
	fn from(value: SyntaxNode) -> Self {
		match value {
			SyntaxNode::Decl(inner) => inner,
			_ => panic!("Expected a `SyntaxNode::Decl(...)`"),
		}
	}
}

impl From<SyntaxNode> for Stmt {
	fn from(value: SyntaxNode) -> Self {
		match value {
			SyntaxNode::Stmt(inner) => inner,
			_ => panic!("Expected a `SyntaxNode::Stmt(...)`"),
		}
	}
}

impl From<SyntaxNode> for Expr {
	fn from(value: SyntaxNode) -> Self {
		match value {
			SyntaxNode::Expr(inner) => inner,
			_ => panic!("Expected a `SyntaxNode::Expr(...)`"),
		}
	}
}

pub trait ToRange {
	fn to_range(self) -> Range;
}

impl ToRange for Span {
	fn to_range(self) -> Range {
		Range {
			start: Position {
				line: self.start.line as _,
				character: self.start.character as _,
			},
			end: Position {
				line: self.end.line as _,
				character: self.end.character as _,
			},
		}
	}
}

/// # Panics
///
/// Panics if either of the following are true:
///
/// - `start` and `end` point to different `ArcStr` allocations
/// - `end.range().end` is less than or equal to `start.range().start`
pub(crate) fn join_substrs(start: &Substr, end: &Substr) -> Substr {
	assert!(ArcStr::ptr_eq(start.parent(), end.parent()));
	assert!(end.range().end > start.range().start);

	let source = start.parent().clone();
	let range = start.range().start..end.range().end;

	unsafe { Substr::from_parts_unchecked(source, range) }
}