teloxide_ng/dispatching/
handler_description.rs1use std::collections::HashSet;
2
3use dptree::{
4 HandlerDescription,
5 description::{EventKind, InterestSet},
6};
7use teloxide_core_ng::types::AllowedUpdate;
8
9#[derive(Debug, Clone)]
13pub struct DpHandlerDescription {
14 allowed: InterestSet<Kind>,
15}
16
17impl DpHandlerDescription {
18 pub(crate) fn of(allowed: AllowedUpdate) -> Self {
19 let mut set = HashSet::with_capacity(1);
20 set.insert(Kind(allowed));
21 Self { allowed: InterestSet::new_filter(set) }
22 }
23
24 pub(crate) fn allowed_updates(&self) -> Vec<AllowedUpdate> {
25 self.allowed.observed.iter().map(|&Kind(x)| x).collect()
26 }
27}
28
29impl HandlerDescription for DpHandlerDescription {
30 fn entry() -> Self {
31 Self { allowed: HandlerDescription::entry() }
32 }
33
34 fn user_defined() -> Self {
35 Self { allowed: HandlerDescription::user_defined() }
36 }
37
38 fn merge_chain(&self, other: &Self) -> Self {
39 Self { allowed: self.allowed.merge_chain(&other.allowed) }
40 }
41
42 fn merge_branch(&self, other: &Self) -> Self {
43 Self { allowed: self.allowed.merge_branch(&other.allowed) }
44 }
45}
46
47#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
48struct Kind(AllowedUpdate);
49
50impl EventKind for Kind {
51 fn full_set() -> HashSet<Self> {
52 use AllowedUpdate::*;
53
54 [
67 Message,
68 EditedMessage,
69 ChannelPost,
70 EditedChannelPost,
71 BusinessConnection,
72 BusinessMessage,
73 EditedBusinessMessage,
74 DeletedBusinessMessages,
75 MessageReaction,
76 MessageReactionCount,
77 InlineQuery,
78 ChosenInlineResult,
79 CallbackQuery,
80 ShippingQuery,
81 PreCheckoutQuery,
82 PurchasedPaidMedia,
83 Poll,
84 PollAnswer,
85 MyChatMember,
86 ChatMember,
87 ChatJoinRequest,
88 ChatBoost,
89 RemovedChatBoost,
90 ]
91 .into_iter()
92 .map(Kind)
93 .collect()
94 }
95
96 fn empty_set() -> HashSet<Self> {
97 HashSet::new()
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 #[cfg(feature = "macros")]
104 use crate::{
105 self as teloxide_ng, dispatching::{HandlerExt, UpdateFilterExt, handler_description::Kind},
107 types::{AllowedUpdate::*, Update},
108 utils::command::BotCommands,
109 };
110 #[cfg(feature = "macros")]
111 use dptree::description::EventKind;
112
113 #[cfg(feature = "macros")]
114 #[derive(BotCommands, Clone)]
115 #[command(rename_rule = "lowercase")]
116 enum Cmd {
117 B,
118 }
119
120 #[test]
122 #[cfg(feature = "macros")]
123 fn discussion_648() {
124 let h =
125 dptree::entry().branch(Update::filter_my_chat_member().endpoint(|| async {})).branch(
126 Update::filter_message()
127 .branch(dptree::entry().filter_command::<Cmd>().endpoint(|| async {}))
128 .endpoint(|| async {}),
129 );
130
131 let mut v = h.description().allowed_updates();
132
133 v.sort_by_key(|&a| a as u8);
135
136 assert_eq!(v, [Message, MyChatMember])
137 }
138
139 #[test]
140 #[ignore = "this test requires `macros` feature"]
141 #[cfg(not(feature = "macros"))]
142 fn discussion_648() {
143 panic!("this test requires `macros` feature")
144 }
145
146 #[test]
148 #[cfg(feature = "macros")]
149 fn allowed_updates_full_set() {
150 let full_set = Kind::full_set();
151 let allowed_updates_reference = vec![
152 Message,
153 EditedMessage,
154 ChannelPost,
155 EditedChannelPost,
156 MessageReaction,
157 MessageReactionCount,
158 InlineQuery,
159 ChosenInlineResult,
160 CallbackQuery,
161 ShippingQuery,
162 PreCheckoutQuery,
163 PurchasedPaidMedia,
164 Poll,
165 PollAnswer,
166 MyChatMember,
167 ChatMember,
168 ChatJoinRequest,
169 ChatBoost,
170 RemovedChatBoost,
171 ];
172
173 for update in allowed_updates_reference {
174 match update {
175 Message
177 | EditedMessage
178 | ChannelPost
179 | EditedChannelPost
180 | MessageReaction
181 | MessageReactionCount
182 | InlineQuery
183 | ChosenInlineResult
184 | CallbackQuery
185 | ShippingQuery
186 | PreCheckoutQuery
187 | PurchasedPaidMedia
188 | Poll
189 | PollAnswer
190 | MyChatMember
191 | ChatMember
192 | ChatJoinRequest
193 | ChatBoost
194 | RemovedChatBoost
195 | BusinessMessage
196 | BusinessConnection
197 | EditedBusinessMessage
198 | DeletedBusinessMessages => {
199 assert!(full_set.contains(&Kind(update)))
200 }
201 }
202 }
203 }
204}