1use std::collections::{HashMap, HashSet};
8use std::net::SocketAddr;
9use std::sync::{Arc, RwLock};
10
11use axum::extract::connect_info::ConnectInfo;
12use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
13use axum::extract::{Query, State};
14use axum::http::{HeaderMap, StatusCode};
15use axum::response::IntoResponse;
16use chrono::Utc;
17use futures::{SinkExt, StreamExt};
18use tokio::sync::broadcast;
19
20use futu_auth::{KeyRecord, KeyStore, Scope};
21use futu_server::push::ExternalPushSink;
22
23use crate::adapter::RestState;
24
25pub const REST_WS_MAX_CONTROL_MESSAGE_SIZE_BYTES: usize = 64 * 1024;
29
30#[derive(Clone, Debug, serde::Serialize)]
32pub struct WsPushEvent {
33 #[serde(rename = "type")]
35 pub event_type: String,
36 #[serde(skip)]
38 pub required_scope: WsPushScope,
39 pub proto_id: u32,
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub sec_key: Option<String>,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub sub_type: Option<i32>,
47 #[serde(skip_serializing_if = "Option::is_none")]
51 pub rehab_type: Option<i32>,
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub acc_id: Option<u64>,
55 pub body_b64: String,
57 #[serde(skip_serializing_if = "Option::is_none")]
66 pub trd_market: Option<String>,
67}
68
69#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
75#[non_exhaustive]
76pub enum WsPushScope {
77 #[default]
80 Quote,
81 Notify,
83 Trade,
85}
86
87impl WsPushScope {
88 pub fn required_scope(&self) -> Scope {
90 match self {
91 WsPushScope::Quote => Scope::QotRead,
92 WsPushScope::Notify => Scope::QotRead,
93 WsPushScope::Trade => Scope::AccRead,
94 }
95 }
96}
97
98#[derive(Clone)]
104pub struct WsBroadcaster {
105 tx: broadcast::Sender<WsPushEvent>,
106}
107
108impl WsBroadcaster {
109 pub fn new(capacity: usize) -> Self {
110 let (tx, _) = broadcast::channel(capacity);
111 Self { tx }
112 }
113
114 fn has_receivers(&self) -> bool {
115 self.tx.receiver_count() > 0
116 }
117
118 pub fn send(&self, event: WsPushEvent) {
120 if !self.has_receivers() {
121 return;
122 }
123 let proto_id = event.proto_id;
124 let event_type = event.event_type.clone();
125 if self.tx.send(event).is_err() {
126 tracing::debug!(
127 proto_id,
128 event_type,
129 receiver_count = self.tx.receiver_count(),
130 "rest ws broadcast send skipped"
131 );
132 }
133 }
134
135 pub fn subscribe(&self) -> broadcast::Receiver<WsPushEvent> {
137 self.tx.subscribe()
138 }
139
140 fn encode_body(body: &[u8]) -> String {
141 use base64::Engine;
142 base64::engine::general_purpose::STANDARD.encode(body)
143 }
144
145 pub fn push_quote(
153 &self,
154 sec_key: &str,
155 sub_type: i32,
156 rehab_type: i32,
157 proto_id: u32,
158 body: &[u8],
159 ) {
160 if !self.has_receivers() {
161 return;
162 }
163 self.send(WsPushEvent {
164 event_type: "quote".to_string(),
165 required_scope: WsPushScope::Quote,
166 proto_id,
167 sec_key: Some(sec_key.to_string()),
168 sub_type: Some(sub_type),
169 rehab_type: Some(rehab_type),
170 acc_id: None,
171 body_b64: Self::encode_body(body),
172 trd_market: None,
173 });
174 }
175
176 pub fn push_broadcast(&self, proto_id: u32, body: &[u8]) {
178 if !self.has_receivers() {
179 return;
180 }
181 self.send(WsPushEvent {
182 event_type: "notify".to_string(),
183 required_scope: WsPushScope::Notify,
184 proto_id,
185 sec_key: None,
186 sub_type: None,
187 rehab_type: None,
188 acc_id: None,
189 body_b64: Self::encode_body(body),
190 trd_market: None,
191 });
192 }
193
194 pub fn push_trade(&self, acc_id: u64, proto_id: u32, body: &[u8], trd_market: Option<&str>) {
200 if !self.has_receivers() {
201 return;
202 }
203 self.send(WsPushEvent {
204 event_type: "trade".to_string(),
205 required_scope: WsPushScope::Trade,
206 proto_id,
207 sec_key: None,
208 sub_type: None,
209 rehab_type: None,
210 acc_id: Some(acc_id),
211 body_b64: Self::encode_body(body),
212 trd_market: trd_market.map(|s| s.to_string()),
213 });
214 }
215}
216
217impl ExternalPushSink for WsBroadcaster {
219 fn on_quote_push(
220 &self,
221 sec_key: &str,
222 sub_type: i32,
223 rehab_type: i32,
224 proto_id: u32,
225 body: &[u8],
226 ) {
227 self.push_quote(sec_key, sub_type, rehab_type, proto_id, body);
228 }
229
230 fn on_broadcast_push(&self, proto_id: u32, body: &[u8]) {
231 self.push_broadcast(proto_id, body);
232 }
233
234 fn on_trade_push(&self, acc_id: u64, proto_id: u32, body: &[u8], trd_market: Option<&str>) {
235 self.push_trade(acc_id, proto_id, body, trd_market);
236 }
237}
238
239fn extract_ws_token(headers: &HeaderMap, query: &HashMap<String, String>) -> Option<String> {
244 if let Some(t) = query.get("token") {
245 return Some(t.clone());
246 }
247 headers
248 .get("authorization")
249 .and_then(|v| v.to_str().ok())
250 .and_then(|v| futu_auth_pipeline::parse_bearer_scheme(v).map(|s| s.to_string()))
251}
252
253fn authenticate_ws(
260 key_store: &KeyStore,
261 headers: &HeaderMap,
262 query: &HashMap<String, String>,
263) -> Result<Option<Arc<KeyRecord>>, (StatusCode, &'static str)> {
264 if !key_store.is_configured() {
265 return Ok(None);
266 }
267
268 let Some(token) = extract_ws_token(headers, query) else {
269 futu_auth::audit::reject(
270 "ws",
271 "/ws",
272 "<missing>",
273 "missing token (query or Authorization)",
274 );
275 return Err((StatusCode::UNAUTHORIZED, "missing api key"));
276 };
277
278 let Some(rec) = key_store.verify(&token) else {
279 futu_auth::audit::reject("ws", "/ws", "<invalid>", "invalid api key");
280 return Err((StatusCode::UNAUTHORIZED, "invalid api key"));
281 };
282
283 if rec.is_expired(Utc::now()) {
284 futu_auth::audit::reject("ws", "/ws", &rec.id, "key expired");
285 return Err((StatusCode::UNAUTHORIZED, "key expired"));
286 }
287
288 if !rec.scopes.contains(&Scope::QotRead) {
289 futu_auth::audit::reject("ws", "/ws", &rec.id, "missing qot:read scope");
292 return Err((StatusCode::FORBIDDEN, "forbidden"));
293 }
294
295 futu_auth::audit::allow("ws", "/ws", &rec.id, Some("qot:read"));
296 Ok(Some(rec))
297}
298
299pub async fn ws_handler(
301 ws: WebSocketUpgrade,
302 ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
303 headers: HeaderMap,
304 Query(query): Query<HashMap<String, String>>,
305 State(state): State<RestState>,
306) -> impl IntoResponse {
307 let peer_addr_string = peer_addr.to_string();
308 let session_id = headers
309 .get("x-request-id")
310 .or_else(|| headers.get("x-futu-session-id"))
311 .and_then(|v| v.to_str().ok())
312 .map(str::trim)
313 .filter(|v| !v.is_empty());
314 let audit_ctx =
315 futu_auth::audit::AuditContext::new(Some(peer_addr_string.as_str()), session_id);
316 let rec = match futu_auth::audit::with_context(audit_ctx.clone(), || {
317 authenticate_ws(&state.key_store, &headers, &query)
318 }) {
319 Ok(rec) => rec,
320 Err((code, msg)) => return (code, msg).into_response(),
321 };
322 let scopes: HashSet<Scope> = match &rec {
324 Some(r) => r.scopes.clone(),
325 None => all_scopes(),
326 };
327 let key_id = rec.as_ref().map(|r| r.id.clone());
328 let allowed_acc_ids = rec.as_ref().and_then(|r| r.allowed_acc_ids.clone());
332 let allowed_markets = rec.as_ref().and_then(|r| r.allowed_markets.clone());
336 let broadcaster = Arc::clone(&state.ws_broadcaster);
337 let rest_acc_subs = Arc::clone(&state.rest_acc_subscriptions);
340 let filter_registry = Arc::clone(&state.filter_registry);
343 let ctx = WsConnectionContext {
344 broadcaster,
345 scopes,
346 key_id,
347 allowed_acc_ids,
348 allowed_markets,
349 rest_acc_subscriptions: rest_acc_subs,
350 filter_registry,
351 };
352 ws.max_message_size(REST_WS_MAX_CONTROL_MESSAGE_SIZE_BYTES)
353 .max_frame_size(REST_WS_MAX_CONTROL_MESSAGE_SIZE_BYTES)
354 .on_upgrade(move |socket| handle_ws_connection(socket, ctx))
355 .into_response()
356}
357
358fn all_scopes() -> HashSet<Scope> {
360 [
361 Scope::QotRead,
362 Scope::AccRead,
363 Scope::TradeSimulate,
364 Scope::TradeReal,
365 ]
366 .into_iter()
367 .collect()
368}
369
370struct WsConnectionContext {
377 broadcaster: Arc<WsBroadcaster>,
378 scopes: HashSet<Scope>,
379 key_id: Option<String>,
380 allowed_acc_ids: Option<HashSet<u64>>,
381 allowed_markets: Option<HashSet<String>>,
384 rest_acc_subscriptions: Arc<RwLock<HashMap<String, HashSet<u64>>>>,
385 filter_registry: Arc<futu_auth_pipeline::FilterRegistry>,
388}
389
390async fn handle_ws_connection(socket: WebSocket, ctx: WsConnectionContext) {
391 let WsConnectionContext {
392 broadcaster,
393 scopes,
394 key_id,
395 allowed_acc_ids,
396 allowed_markets,
397 rest_acc_subscriptions,
398 filter_registry,
399 } = ctx;
400
401 let (mut ws_tx, mut ws_rx) = socket.split();
402 let mut push_rx = broadcaster.subscribe();
403
404 tracing::info!(
405 key_id = ?key_id,
406 scopes = ?scopes,
407 "WebSocket push client connected"
408 );
409
410 let notify_subscribed = Arc::new(std::sync::atomic::AtomicBool::new(false));
421 let notify_subscribed_for_send = Arc::clone(¬ify_subscribed);
422 let notify_subscribed_for_recv = Arc::clone(¬ify_subscribed);
423
424 let send_scopes = scopes.clone();
426 let send_key_id_str = key_id.clone().unwrap_or_else(|| "<none>".to_string());
427 let send_key_id_for_filter = key_id.clone();
428 let rest_subs_for_filter = Arc::clone(&rest_acc_subscriptions);
429 let mut send_task = tokio::spawn(async move {
430 loop {
431 let event = match push_rx.recv().await {
432 Ok(event) => event,
433 Err(broadcast::error::RecvError::Lagged(n)) => {
434 tracing::warn!(
435 skipped = n,
436 "REST WebSocket push client lagged, skipped events"
437 );
438 continue;
439 }
440 Err(broadcast::error::RecvError::Closed) => break,
441 };
442 if !send_scopes.contains(&event.required_scope.required_scope()) {
444 futu_auth::metrics::bump_ws_filtered(&event.event_type, &send_key_id_str);
446 continue;
447 }
448 if matches!(event.required_scope, WsPushScope::Notify)
451 && !notify_subscribed_for_send.load(std::sync::atomic::Ordering::Relaxed)
452 {
453 futu_auth::metrics::bump_ws_filtered("notify_unsub", &send_key_id_str);
454 continue;
455 }
456 if matches!(event.required_scope, WsPushScope::Trade)
477 && let Some(event_acc) = event.acc_id
478 {
479 let sub_state_owned: Option<HashSet<u64>> =
480 send_key_id_for_filter.as_ref().and_then(|kid| {
481 crate::adapter::with_rest_acc_subscriptions_read(
482 &rest_subs_for_filter,
483 |subs| subs.get(kid).cloned(),
484 )
485 });
486 let ctx = futu_auth_pipeline::PushEventCtx {
487 event_type: &event.event_type,
488 event_acc: Some(event_acc),
489 allowed_acc_ids: allowed_acc_ids.as_ref(),
490 sub_state: sub_state_owned.as_ref(),
491 event_trd_market: event.trd_market.as_deref(),
496 allowed_markets: allowed_markets.as_ref(),
497 };
498 if filter_registry.should_drop_event(&ctx) {
499 futu_auth::metrics::bump_ws_filtered("trade_market", &send_key_id_str);
504 continue;
505 }
506 }
507 let json = match serde_json::to_string(&event) {
508 Ok(j) => j,
509 Err(_) => continue,
510 };
511 if ws_tx.send(Message::Text(json.into())).await.is_err() {
512 break; }
514 }
515 });
516
517 let mut recv_task = tokio::spawn(async move {
519 while let Some(msg) = ws_rx.next().await {
520 match msg {
521 Ok(Message::Close(_)) | Err(_) => break,
522 Ok(Message::Ping(_data)) => {
523 }
525 Ok(Message::Text(text)) => {
529 if let Ok(val) = serde_json::from_str::<serde_json::Value>(&text)
530 && let Some(action) = val.get("action").and_then(|v| v.as_str())
531 {
532 match action {
533 "subscribe-notify" => {
534 notify_subscribed_for_recv
535 .store(true, std::sync::atomic::Ordering::Relaxed);
536 tracing::info!("WS client subscribed notify push");
537 }
538 "unsubscribe-notify" => {
539 notify_subscribed_for_recv
540 .store(false, std::sync::atomic::Ordering::Relaxed);
541 tracing::info!("WS client unsubscribed notify push");
542 }
543 other => {
544 tracing::debug!(action = %other, "WS client unknown action");
545 }
546 }
547 }
548 }
549 _ => {} }
551 }
552 });
553
554 tokio::select! {
558 _ = &mut send_task => {
559 recv_task.abort();
560 }
561 _ = &mut recv_task => {
562 send_task.abort();
563 }
564 }
565
566 tracing::info!("WebSocket push client disconnected");
567}
568
569#[cfg(test)]
586pub(crate) fn should_drop_trade_event_for_caller(
587 allowed_acc_ids: Option<&HashSet<u64>>,
588 sub_state: Option<&HashSet<u64>>,
589 event_acc: u64,
590) -> bool {
591 if let Some(allowed) = allowed_acc_ids
593 && !allowed.is_empty()
594 && !allowed.contains(&event_acc)
595 {
596 return true;
597 }
598 if let Some(set) = sub_state
603 && !set.contains(&event_acc)
604 {
605 return true;
606 }
607 false
608}
609
610#[cfg(test)]
611mod tests;