1use std::sync::Arc;
4use std::sync::atomic::{AtomicU32, Ordering};
5
6use bytes::Bytes;
7use dashmap::DashMap;
8use futu_auth::Scope;
9use futu_codec::frame::FutuFrame;
10use tokio::sync::mpsc::error::TrySendError;
11
12use crate::conn::ClientConn;
13use crate::metrics::GatewayMetrics;
14use crate::subscription::SubscriptionManager;
15
16fn should_push_to(conn: &ClientConn, needed: Scope, event_label: &str) -> bool {
23 if conn.scopes.is_empty() {
24 return true; }
26 if conn.scopes.contains(&needed) {
27 return true;
28 }
29 let key_id = conn.key_id.as_deref().unwrap_or("<none>");
31 futu_auth::metrics::bump_ws_filtered(event_label, key_id);
32 false
33}
34
35pub trait ExternalPushSink: Send + Sync {
46 fn on_quote_push(
48 &self,
49 sec_key: &str,
50 sub_type: i32,
51 rehab_type: i32,
52 proto_id: u32,
53 body: &[u8],
54 );
55 fn on_broadcast_push(&self, proto_id: u32, body: &[u8]);
57 fn on_trade_push(&self, acc_id: u64, proto_id: u32, body: &[u8], trd_market: Option<&str>);
69}
70
71#[must_use]
87pub fn extract_trd_market_from_trade_body(proto_id: u32, body: &[u8]) -> Option<&'static str> {
88 use prost::Message;
89 let market_int = match proto_id {
90 2208 => {
92 let resp = match futu_proto::trd_update_order::Response::decode(body) {
93 Ok(resp) => resp,
94 Err(e) => {
95 tracing::debug!(
96 proto_id,
97 body_len = body.len(),
98 error = %e,
99 "trade push body decode failed while extracting trd_market"
100 );
101 return None;
102 }
103 };
104 resp.s2c?.header.trd_market
105 }
106 2218 => {
108 let resp = match futu_proto::trd_update_order_fill::Response::decode(body) {
109 Ok(resp) => resp,
110 Err(e) => {
111 tracing::debug!(
112 proto_id,
113 body_len = body.len(),
114 error = %e,
115 "trade push body decode failed while extracting trd_market"
116 );
117 return None;
118 }
119 };
120 resp.s2c?.header.trd_market
121 }
122 _ => return None,
124 };
125 match market_int {
128 1 => Some("HK"),
129 2 => Some("US"),
130 3 => Some("CN"),
131 4 => Some("HKCC"),
132 5 => Some("FUTURES"),
133 6 => Some("SG"),
134 7 => Some("CRYPTO"),
135 8 => Some("AU"),
136 10 => Some("FUTURES_SIMULATE_HK"),
137 11 => Some("FUTURES_SIMULATE_US"),
138 12 => Some("FUTURES_SIMULATE_SG"),
139 13 => Some("FUTURES_SIMULATE_JP"),
140 15 => Some("JP"),
141 111 => Some("MY"),
142 112 => Some("CA"),
143 113 => Some("HKFUND"),
144 123 => Some("USFUND"),
145 124 => Some("SGFUND"),
146 125 => Some("MYFUND"),
147 126 => Some("JPFUND"),
148 _ => None,
149 }
150}
151
152pub struct PushDispatcher {
154 connections: Arc<DashMap<u64, ClientConn>>,
155 subscriptions: Arc<SubscriptionManager>,
156 metrics: Option<Arc<GatewayMetrics>>,
157 push_serial_no: AtomicU32,
162 external_sinks: Vec<Arc<dyn ExternalPushSink>>,
164}
165
166impl PushDispatcher {
167 pub fn new(
171 connections: Arc<DashMap<u64, ClientConn>>,
172 subscriptions: Arc<SubscriptionManager>,
173 ) -> Self {
174 Self {
175 connections,
176 subscriptions,
177 metrics: None,
178 push_serial_no: AtomicU32::new(0),
179 external_sinks: Vec::new(),
180 }
181 }
182
183 pub fn with_metrics(mut self, metrics: Arc<GatewayMetrics>) -> Self {
185 self.metrics = Some(metrics);
186 self
187 }
188
189 pub fn with_external_sink(mut self, sink: Arc<dyn ExternalPushSink>) -> Self {
191 self.external_sinks.push(sink);
192 self
193 }
194
195 fn record_push(&self) {
196 if let Some(ref m) = self.metrics {
197 m.client_pushes_sent
198 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
199 }
200 }
201
202 fn record_push_send_failure(&self) {
203 if let Some(ref m) = self.metrics {
204 m.client_push_send_failures
205 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
206 }
207 }
208
209 fn record_qot_client_backpressure_drop(&self, sub_type: i32) {
210 if let Some(ref m) = self.metrics {
211 m.record_qot_client_push_backpressure_drop(sub_type);
212 }
213 }
214
215 fn next_push_serial_no(&self) -> u32 {
216 self.push_serial_no
217 .fetch_add(1, Ordering::Relaxed)
218 .wrapping_add(1)
219 }
220
221 async fn send_client_frame(
222 &self,
223 tx: tokio::sync::mpsc::Sender<FutuFrame>,
224 frame: FutuFrame,
225 push_path: &'static str,
226 ) {
227 if let Err(e) = tx.send(frame).await {
228 self.record_push_send_failure();
229 tracing::warn!(push_path, error = %e, "client push send failed");
230 return;
231 }
232 self.record_push();
233 }
234
235 fn try_send_qot_client_frame(
236 &self,
237 tx: tokio::sync::mpsc::Sender<FutuFrame>,
238 frame: FutuFrame,
239 sub_type: i32,
240 push_path: &'static str,
241 ) {
242 match tx.try_send(frame) {
243 Ok(()) => self.record_push(),
244 Err(TrySendError::Full(_frame)) => {
245 self.record_qot_client_backpressure_drop(sub_type);
246 tracing::warn!(
247 push_path,
248 sub_type,
249 "client quote push dropped because downstream channel is full"
250 );
251 }
252 Err(TrySendError::Closed(_frame)) => {
253 self.record_push_send_failure();
254 tracing::warn!(
255 push_path,
256 "client quote push send failed because downstream channel is closed"
257 );
258 }
259 }
260 }
261
262 pub async fn push_to_conn(&self, conn_id: u64, proto_id: u32, body: Vec<u8>) {
264 let push = self.connections.get(&conn_id).map(|conn| {
265 let frame = conn.make_frame(proto_id, self.next_push_serial_no(), Bytes::from(body));
266 (conn.tx.clone(), frame)
267 });
268 if let Some((tx, frame)) = push {
269 self.send_client_frame(tx, frame, "push_to_conn").await;
270 }
271 }
272
273 pub async fn push_qot_to_conn(&self, conn_id: u64, proto_id: u32, body: Vec<u8>) {
275 let push = self.connections.get(&conn_id).and_then(|conn| {
276 if !should_push_to(&conn, Scope::QotRead, "quote_first") {
277 return None;
278 }
279 let frame = conn.make_frame(proto_id, self.next_push_serial_no(), Bytes::from(body));
280 Some((conn.tx.clone(), frame))
281 });
282 if let Some((tx, frame)) = push {
283 self.try_send_qot_client_frame(tx, frame, 0, "push_qot_to_conn");
284 }
285 }
286
287 pub async fn push_notify(&self, proto_id: u32, body: Vec<u8>) {
289 let body = Bytes::from(body);
290 let body_sha1 = FutuFrame::body_sha1(&body);
291 let pushes: Vec<_> = self
292 .connections
293 .iter()
294 .filter_map(|entry| {
295 let conn = entry.value();
296 if !conn.recv_notify {
297 return None;
298 }
299 if !should_push_to(conn, Scope::QotRead, "notify") {
301 return None;
302 }
303 let serial_no = self.next_push_serial_no();
304 let frame = conn.make_frame_with_sha1(proto_id, serial_no, body.clone(), body_sha1);
305 Some((conn.tx.clone(), frame))
306 })
307 .collect();
308 for (tx, frame) in pushes {
309 self.send_client_frame(tx, frame, "push_notify").await;
310 }
311 }
312
313 pub async fn push_trd_acc(&self, acc_id: u64, proto_id: u32, body: Vec<u8>) {
315 let trd_market = extract_trd_market_from_trade_body(proto_id, &body);
318 for sink in &self.external_sinks {
320 sink.on_trade_push(acc_id, proto_id, &body, trd_market);
321 }
322 let body = Bytes::from(body);
323 let body_sha1 = FutuFrame::body_sha1(&body);
324 let subscribers = self.subscriptions.get_acc_subscribers(acc_id);
325 let pushes: Vec<_> = subscribers
326 .into_iter()
327 .filter_map(|conn_id| {
328 let conn = self.connections.get(&conn_id)?;
329 if !should_push_to(&conn, Scope::AccRead, "trade") {
331 return None;
332 }
333 if let Some(allowed_accs) = conn.allowed_acc_ids.as_ref()
343 && !allowed_accs.is_empty()
344 && !allowed_accs.contains(&acc_id)
345 {
346 let key_id = conn.key_id.as_deref().unwrap_or("<none>");
347 futu_auth::metrics::bump_ws_filtered("trade_acc_id", key_id);
348 return None;
349 }
350 if let (Some(market), Some(allowed_mkts)) =
356 (trd_market, conn.allowed_markets.as_ref())
357 && !allowed_mkts.is_empty()
358 && !allowed_mkts.contains(market)
359 {
360 let key_id = conn.key_id.as_deref().unwrap_or("<none>");
361 futu_auth::metrics::bump_ws_filtered("trade_market", key_id);
362 return None;
363 }
364 let serial_no = self.next_push_serial_no();
365 let frame = conn.make_frame_with_sha1(proto_id, serial_no, body.clone(), body_sha1);
366 Some((conn.tx.clone(), frame))
367 })
368 .collect();
369 for (tx, frame) in pushes {
370 self.send_client_frame(tx, frame, "push_trd_acc").await;
371 }
372 }
373
374 pub async fn push_broadcast(&self, proto_id: u32, body: Vec<u8>) {
377 for sink in &self.external_sinks {
379 sink.on_broadcast_push(proto_id, &body);
380 }
381 let body = Bytes::from(body);
382 let body_sha1 = FutuFrame::body_sha1(&body);
383 let pushes: Vec<_> = self
384 .connections
385 .iter()
386 .filter_map(|entry| {
387 let conn = entry.value();
388 if !conn.recv_notify {
389 return None;
390 }
391 if !should_push_to(conn, Scope::QotRead, "broadcast") {
392 return None;
393 }
394 let serial_no = self.next_push_serial_no();
395 let frame = conn.make_frame_with_sha1(proto_id, serial_no, body.clone(), body_sha1);
396 Some((conn.tx.clone(), frame))
397 })
398 .collect();
399 for (tx, frame) in pushes {
400 self.send_client_frame(tx, frame, "push_broadcast").await;
401 }
402 }
403
404 pub async fn push_qot(
420 &self,
421 security_key: &str,
422 sub_type: i32,
423 rehab_type: i32,
424 proto_id: u32,
425 body: Vec<u8>,
426 ) {
427 for sink in &self.external_sinks {
430 sink.on_quote_push(security_key, sub_type, rehab_type, proto_id, &body);
431 }
432 let body = Bytes::from(body);
433 let subscribers = self.subscriptions.get_qot_push_subscribers_by_cache_key(
435 security_key,
436 sub_type,
437 rehab_type,
438 );
439 let body_sha1 = FutuFrame::body_sha1(&body);
440 let pushes: Vec<_> = subscribers
441 .into_iter()
442 .filter_map(|conn_id| {
443 let conn = self.connections.get(&conn_id)?;
444 if !should_push_to(&conn, Scope::QotRead, "quote") {
445 return None;
446 }
447 let serial_no = self.next_push_serial_no();
448 let frame = conn.make_frame_with_sha1(proto_id, serial_no, body.clone(), body_sha1);
449 Some((conn.tx.clone(), frame))
450 })
451 .collect();
452 for (tx, frame) in pushes {
453 self.try_send_qot_client_frame(tx, frame, sub_type, "push_qot");
454 }
455 }
456}
457
458#[cfg(test)]
459mod tests;