futu_rest/
ws.rs

1//! WebSocket 推送模块
2//!
3//! 在 REST API 端口上提供 /ws 路由,客户端通过 WebSocket 接收实时推送。
4//!
5//! 推送事件通过 broadcast channel 从 OpenD 核心分发到所有 WebSocket 客户端。
6
7use std::collections::{HashMap, HashSet};
8use std::sync::Arc;
9
10use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
11use axum::extract::{Query, State};
12use axum::http::{HeaderMap, StatusCode};
13use axum::response::IntoResponse;
14use chrono::Utc;
15use futures::{SinkExt, StreamExt};
16use tokio::sync::broadcast;
17
18use futu_auth::{KeyRecord, KeyStore, Scope};
19use futu_server::push::ExternalPushSink;
20
21use crate::adapter::RestState;
22
23/// WebSocket 推送事件
24#[derive(Clone, Debug, serde::Serialize)]
25pub struct WsPushEvent {
26    /// 推送类型: "quote", "trade", "notify"
27    #[serde(rename = "type")]
28    pub event_type: String,
29    /// 该事件需要哪个 scope 才能被某个 client 接收(filter 用,不发到客户端)
30    #[serde(skip)]
31    pub required_scope: WsPushScope,
32    /// 协议 ID
33    pub proto_id: u32,
34    /// 证券标识 (行情推送)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub sec_key: Option<String>,
37    /// 订阅类型 (行情推送)
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub sub_type: Option<i32>,
40    /// 交易账户 ID (交易推送)
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub acc_id: Option<u64>,
43    /// protobuf body 的 base64 编码
44    pub body_b64: String,
45}
46
47/// WS 推送事件需要的最低 scope(client 没这个 scope 就收不到)
48///
49/// - `Quote` → `qot:read`:行情类
50/// - `Notify` → `qot:read`:通用通知(如订阅状态、网关心跳)
51/// - `Trade` → `acc:read`:交易回报涉及账户隐私,必须有账户读权限
52#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
53pub enum WsPushScope {
54    #[default]
55    Quote,
56    Notify,
57    Trade,
58}
59
60impl WsPushScope {
61    /// 该事件类型需要的 Scope;client 必须持有这个 scope 才能收到
62    pub fn required_scope(&self) -> Scope {
63        match self {
64            WsPushScope::Quote => Scope::QotRead,
65            WsPushScope::Notify => Scope::QotRead,
66            WsPushScope::Trade => Scope::AccRead,
67        }
68    }
69}
70
71/// WebSocket 推送广播器
72///
73/// OpenD 核心推送事件 → broadcast channel → 所有 WebSocket 客户端
74///
75/// 实现 `ExternalPushSink` trait,可直接嵌入 PushDispatcher。
76#[derive(Clone)]
77pub struct WsBroadcaster {
78    tx: broadcast::Sender<WsPushEvent>,
79}
80
81impl WsBroadcaster {
82    pub fn new(capacity: usize) -> Self {
83        let (tx, _) = broadcast::channel(capacity);
84        Self { tx }
85    }
86
87    /// 发送推送事件到所有 WebSocket 客户端
88    pub fn send(&self, event: WsPushEvent) {
89        // 忽略没有接收者的情况
90        let _ = self.tx.send(event);
91    }
92
93    /// 创建接收端
94    pub fn subscribe(&self) -> broadcast::Receiver<WsPushEvent> {
95        self.tx.subscribe()
96    }
97
98    fn encode_body(body: &[u8]) -> String {
99        use base64::Engine;
100        base64::engine::general_purpose::STANDARD.encode(body)
101    }
102
103    /// 发送行情推送
104    pub fn push_quote(&self, sec_key: &str, sub_type: i32, proto_id: u32, body: &[u8]) {
105        self.send(WsPushEvent {
106            event_type: "quote".to_string(),
107            required_scope: WsPushScope::Quote,
108            proto_id,
109            sec_key: Some(sec_key.to_string()),
110            sub_type: Some(sub_type),
111            acc_id: None,
112            body_b64: Self::encode_body(body),
113        });
114    }
115
116    /// 发送广播推送
117    pub fn push_broadcast(&self, proto_id: u32, body: &[u8]) {
118        self.send(WsPushEvent {
119            event_type: "notify".to_string(),
120            required_scope: WsPushScope::Notify,
121            proto_id,
122            sec_key: None,
123            sub_type: None,
124            acc_id: None,
125            body_b64: Self::encode_body(body),
126        });
127    }
128
129    /// 发送交易推送
130    pub fn push_trade(&self, acc_id: u64, proto_id: u32, body: &[u8]) {
131        self.send(WsPushEvent {
132            event_type: "trade".to_string(),
133            required_scope: WsPushScope::Trade,
134            proto_id,
135            sec_key: None,
136            sub_type: None,
137            acc_id: Some(acc_id),
138            body_b64: Self::encode_body(body),
139        });
140    }
141}
142
143/// 实现 ExternalPushSink,使 WsBroadcaster 可嵌入 PushDispatcher
144impl ExternalPushSink for WsBroadcaster {
145    fn on_quote_push(&self, sec_key: &str, sub_type: i32, proto_id: u32, body: &[u8]) {
146        self.push_quote(sec_key, sub_type, proto_id, body);
147    }
148
149    fn on_broadcast_push(&self, proto_id: u32, body: &[u8]) {
150        self.push_broadcast(proto_id, body);
151    }
152
153    fn on_trade_push(&self, acc_id: u64, proto_id: u32, body: &[u8]) {
154        self.push_trade(acc_id, proto_id, body);
155    }
156}
157
158/// WebSocket 握手鉴权:从 `?token=xxx` 查询参数或 `Authorization: Bearer` header 提取 token
159///
160/// 浏览器 WebSocket API 不允许设置自定义 header,所以优先支持 `?token=`;
161/// 原生客户端(curl / websocat / tokio-tungstenite)可以用任一方式。
162fn extract_ws_token(headers: &HeaderMap, query: &HashMap<String, String>) -> Option<String> {
163    if let Some(t) = query.get("token") {
164        return Some(t.clone());
165    }
166    headers
167        .get("authorization")
168        .and_then(|v| v.to_str().ok())
169        .and_then(|v| v.strip_prefix("Bearer ").map(|s| s.trim().to_string()))
170}
171
172/// 校验 WebSocket 握手的 token;返回 `Ok(Some(rec))` 表示 scope 模式 + 通过;
173/// `Ok(None)` 表示 legacy 模式(未配 KeyStore),所有事件无条件放行。
174///
175/// - `key_store.is_configured() == false` → 无条件放行(legacy 模式)
176/// - 配置了 KeyStore:必须有 token,且 key 有 `qot:read` scope(最低门槛,
177///   实际收哪些事件由后续 push filter 按 scope 决定)
178fn authenticate_ws(
179    key_store: &KeyStore,
180    headers: &HeaderMap,
181    query: &HashMap<String, String>,
182) -> Result<Option<Arc<KeyRecord>>, (StatusCode, &'static str)> {
183    if !key_store.is_configured() {
184        return Ok(None);
185    }
186
187    let Some(token) = extract_ws_token(headers, query) else {
188        futu_auth::audit::reject(
189            "ws",
190            "/ws",
191            "<missing>",
192            "missing token (query or Authorization)",
193        );
194        return Err((StatusCode::UNAUTHORIZED, "missing api key"));
195    };
196
197    let Some(rec) = key_store.verify(&token) else {
198        futu_auth::audit::reject("ws", "/ws", "<invalid>", "invalid api key");
199        return Err((StatusCode::UNAUTHORIZED, "invalid api key"));
200    };
201
202    if rec.is_expired(Utc::now()) {
203        futu_auth::audit::reject("ws", "/ws", &rec.id, "key expired");
204        return Err((StatusCode::UNAUTHORIZED, "key expired"));
205    }
206
207    if !rec.scopes.contains(&Scope::QotRead) {
208        futu_auth::audit::reject("ws", "/ws", &rec.id, "missing qot:read scope");
209        return Err((StatusCode::FORBIDDEN, "missing qot:read scope"));
210    }
211
212    futu_auth::audit::allow("ws", "/ws", &rec.id, Some("qot:read"));
213    Ok(Some(rec))
214}
215
216/// WebSocket 升级处理
217pub async fn ws_handler(
218    ws: WebSocketUpgrade,
219    headers: HeaderMap,
220    Query(query): Query<HashMap<String, String>>,
221    State(state): State<RestState>,
222) -> impl IntoResponse {
223    let rec = match authenticate_ws(&state.key_store, &headers, &query) {
224        Ok(rec) => rec,
225        Err((code, msg)) => return (code, msg).into_response(),
226    };
227    // legacy(rec=None)时给个"全 scope"快照让 filter 全放行;scope 模式用 rec.scopes
228    let scopes: HashSet<Scope> = match &rec {
229        Some(r) => r.scopes.clone(),
230        None => all_scopes(),
231    };
232    let key_id = rec.as_ref().map(|r| r.id.clone());
233    let broadcaster = Arc::clone(&state.ws_broadcaster);
234    ws.on_upgrade(move |socket| handle_ws_connection(socket, broadcaster, scopes, key_id))
235        .into_response()
236}
237
238/// 全 scope 集合(legacy 模式用)
239fn all_scopes() -> HashSet<Scope> {
240    [
241        Scope::QotRead,
242        Scope::AccRead,
243        Scope::TradeSimulate,
244        Scope::TradeReal,
245    ]
246    .into_iter()
247    .collect()
248}
249
250/// 处理单个 WebSocket 连接
251///
252/// `scopes` 是该连接 key 持有的 scope 集合,用于按 `WsPushScope::required_scope()`
253/// 过滤推送事件。例如只有 `qot:read` 的 key 不会收到 `trade` 类推送。
254async fn handle_ws_connection(
255    socket: WebSocket,
256    broadcaster: Arc<WsBroadcaster>,
257    scopes: HashSet<Scope>,
258    key_id: Option<String>,
259) {
260    let (mut ws_tx, mut ws_rx) = socket.split();
261    let mut push_rx = broadcaster.subscribe();
262
263    tracing::info!(
264        key_id = ?key_id,
265        scopes = ?scopes,
266        "WebSocket push client connected"
267    );
268
269    // 推送任务:从 broadcast channel 读取事件 → 按 scope 过滤 → 发送给客户端
270    let send_scopes = scopes.clone();
271    let send_key_id = key_id.clone().unwrap_or_else(|| "<none>".to_string());
272    let send_task = tokio::spawn(async move {
273        while let Ok(event) = push_rx.recv().await {
274            // 按 client scope 过滤:key 没这个 scope 就不发
275            if !send_scopes.contains(&event.required_scope.required_scope()) {
276                // 记一次"被挡住的推送",供 Prometheus `/metrics` 观察
277                futu_auth::metrics::bump_ws_filtered(&event.event_type, &send_key_id);
278                continue;
279            }
280            let json = match serde_json::to_string(&event) {
281                Ok(j) => j,
282                Err(_) => continue,
283            };
284            if ws_tx.send(Message::Text(json.into())).await.is_err() {
285                break; // 客户端断开
286            }
287        }
288    });
289
290    // 接收任务:处理客户端消息(ping/pong/close)
291    let recv_task = tokio::spawn(async move {
292        while let Some(msg) = ws_rx.next().await {
293            match msg {
294                Ok(Message::Close(_)) | Err(_) => break,
295                Ok(Message::Ping(data)) => {
296                    // axum 自动回复 pong,不需要手动处理
297                    let _ = data;
298                }
299                _ => {} // 忽略其他消息
300            }
301        }
302    });
303
304    // 任一任务结束则关闭连接
305    tokio::select! {
306        _ = send_task => {}
307        _ = recv_task => {}
308    }
309
310    tracing::info!("WebSocket push client disconnected");
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    /// `qot:read` only 的 client 只应收到 Quote/Notify(两者都 required_scope=QotRead),
318    /// Trade 类推送应被挡住
319    #[test]
320    fn scope_filter_blocks_trade_for_qot_only_client() {
321        let scopes: HashSet<Scope> = [Scope::QotRead].into_iter().collect();
322
323        // Quote → QotRead → 放行
324        assert!(scopes.contains(&WsPushScope::Quote.required_scope()));
325        // Notify → QotRead → 放行
326        assert!(scopes.contains(&WsPushScope::Notify.required_scope()));
327        // Trade → AccRead → 挡住
328        assert!(!scopes.contains(&WsPushScope::Trade.required_scope()));
329    }
330
331    /// `qot:read + acc:read` 的 client 三类推送都能收
332    #[test]
333    fn scope_filter_allows_all_for_qot_plus_acc() {
334        let scopes: HashSet<Scope> = [Scope::QotRead, Scope::AccRead].into_iter().collect();
335        for s in [WsPushScope::Quote, WsPushScope::Notify, WsPushScope::Trade] {
336            assert!(
337                scopes.contains(&s.required_scope()),
338                "{:?} should be allowed",
339                s
340            );
341        }
342    }
343
344    /// legacy 模式(全 scope)三类都收
345    #[test]
346    fn legacy_all_scopes_allows_everything() {
347        let scopes = all_scopes();
348        for s in [WsPushScope::Quote, WsPushScope::Notify, WsPushScope::Trade] {
349            assert!(scopes.contains(&s.required_scope()));
350        }
351    }
352
353    /// 字段映射:required_scope 的文字命名和 event_type 一致不跑偏
354    #[test]
355    fn event_type_matches_scope_category() {
356        let b = WsBroadcaster::new(4);
357        let mut rx = b.subscribe();
358        b.push_quote("HK.00700", 1, 0, b"x");
359        b.push_broadcast(0, b"x");
360        b.push_trade(42, 0, b"x");
361
362        let e1 = rx.try_recv().unwrap();
363        assert_eq!(e1.event_type, "quote");
364        assert_eq!(e1.required_scope, WsPushScope::Quote);
365
366        let e2 = rx.try_recv().unwrap();
367        assert_eq!(e2.event_type, "notify");
368        assert_eq!(e2.required_scope, WsPushScope::Notify);
369
370        let e3 = rx.try_recv().unwrap();
371        assert_eq!(e3.event_type, "trade");
372        assert_eq!(e3.required_scope, WsPushScope::Trade);
373    }
374}