futu_server/
ws_listener.rs

1// WebSocket 监听器:接受 WebSocket 连接,复用 TCP 的请求路由和连接池
2//
3// 每个 WebSocket 二进制消息 = 一个完整的 FutuAPI 帧(44 字节帧头 + body)。
4// 与 TCP 共享同一个 connections DashMap、RequestRouter、SubscriptionManager。
5//
6// ## v1.0 鉴权
7//
8// 握手阶段(accept_hdr_async)校验 HTTP `Authorization: Bearer <token>` 或
9// `?token=<plaintext>` query —— 通过 `KeyStore::verify` 得到 `KeyRecord`,
10// 把 scope 集合和 key_id 存到 `ClientConn`。每条消息进 `ws_process_requests`
11// 时按 `futu_auth::scope_for_proto_id(proto_id)` 查所需 scope,不匹配 → 不 dispatch、
12// 记 audit reject。`trade:real` 额外跑 `check_and_commit` 过一道 rate + hours
13// 全局闸门。未注入 KeyStore(TCP listener 或 legacy 模式)→ scopes 空集被
14// 解释为"全放行",保持向后兼容。
15
16use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18use std::time::Instant;
19
20use bytes::BytesMut;
21use chrono::Utc;
22use dashmap::DashMap;
23use futures::{SinkExt, StreamExt};
24use tokio::net::TcpListener;
25use tokio::sync::mpsc;
26use tokio_tungstenite::tungstenite::handshake::server::{ErrorResponse, Request, Response};
27use tokio_tungstenite::tungstenite::http::StatusCode;
28use tokio_tungstenite::tungstenite::protocol::Message;
29
30use futu_auth::{CheckCtx, KeyRecord, KeyStore, LimitOutcome, RuntimeCounters, Scope};
31use futu_codec::frame::FutuFrame;
32use futu_codec::header::{FutuHeader, ProtoFmtType, HEADER_SIZE};
33use futu_core::proto_id;
34
35use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
36use crate::listener::{ServerConfig, MAX_CONNECTIONS};
37use crate::router::RequestRouter;
38
39/// WebSocket 服务端
40pub struct WsServer {
41    listen_addr: String,
42    config: ServerConfig,
43    connections: Arc<DashMap<u64, ClientConn>>,
44    router: Arc<RequestRouter>,
45    subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
46    /// v1.0:握手时做 Bearer token 鉴权。None 或 `!is_configured()` → legacy 模式放行
47    key_store: Option<Arc<KeyStore>>,
48    /// v1.0:跨 REST / gRPC / WS 共享的限额 counters
49    counters: Option<Arc<RuntimeCounters>>,
50}
51
52impl WsServer {
53    /// 创建 WsServer,共享 TCP 的连接池、路由器、订阅管理器(无鉴权,向后兼容)
54    pub fn new(
55        listen_addr: String,
56        config: ServerConfig,
57        connections: Arc<DashMap<u64, ClientConn>>,
58        router: Arc<RequestRouter>,
59        subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
60    ) -> Self {
61        Self::with_auth(
62            listen_addr,
63            config,
64            connections,
65            router,
66            subscriptions,
67            None,
68            None,
69        )
70    }
71
72    /// v1.0 入口:同时接入 KeyStore + 共享 RuntimeCounters 做握手鉴权和 per-message
73    /// scope / 限额检查。`key_store = None` 或未配置时保持 legacy(全放行)。
74    #[allow(clippy::too_many_arguments)]
75    pub fn with_auth(
76        listen_addr: String,
77        config: ServerConfig,
78        connections: Arc<DashMap<u64, ClientConn>>,
79        router: Arc<RequestRouter>,
80        subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
81        key_store: Option<Arc<KeyStore>>,
82        counters: Option<Arc<RuntimeCounters>>,
83    ) -> Self {
84        Self {
85            listen_addr,
86            config,
87            connections,
88            router,
89            subscriptions,
90            key_store,
91            counters,
92        }
93    }
94
95    /// 启动 WebSocket 服务端监听
96    pub async fn run(&self) -> anyhow::Result<()> {
97        let listener = TcpListener::bind(&self.listen_addr).await?;
98        tracing::info!(addr = %self.listen_addr, "WebSocket server listening");
99
100        let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
101        let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
102
103        // 启动请求处理任务(与 TCP 共享同一逻辑)
104        let connections = Arc::clone(&self.connections);
105        let router = Arc::clone(&self.router);
106        let config = self.config.clone();
107        let counters_for_process = self.counters.clone();
108        let key_store_for_process = self.key_store.clone();
109        let scope_mode = self.key_store.as_ref().is_some_and(|ks| ks.is_configured());
110        tokio::spawn(async move {
111            ws_process_requests(
112                req_rx,
113                connections,
114                router,
115                config,
116                counters_for_process,
117                key_store_for_process,
118                scope_mode,
119            )
120            .await;
121        });
122
123        // 启动连接清理任务
124        let cleanup_connections = Arc::clone(&self.connections);
125        let cleanup_subs = self.subscriptions.clone();
126        tokio::spawn(async move {
127            while let Some(notify) = disconnect_rx.recv().await {
128                let removed = cleanup_connections.remove(&notify.conn_id);
129                if removed.is_some() {
130                    if let Some(ref subs) = cleanup_subs {
131                        subs.on_disconnect(notify.conn_id);
132                    }
133                    tracing::info!(
134                        conn_id = notify.conn_id,
135                        remaining = cleanup_connections.len(),
136                        "ws connection removed from pool"
137                    );
138                }
139            }
140        });
141
142        // 接受连接循环
143        let connections = Arc::clone(&self.connections);
144        let key_store_accept = self.key_store.clone();
145        if !scope_mode {
146            tracing::warn!(
147                "WS server running WITHOUT API key auth (legacy mode); \
148                 all WS clients are open. Pass KeyStore via with_auth() to enable."
149            );
150        }
151        loop {
152            let (stream, peer_addr) = listener.accept().await?;
153
154            if connections.len() >= MAX_CONNECTIONS {
155                tracing::warn!(
156                    peer = %peer_addr,
157                    "max connections reached ({}), rejecting ws client",
158                    MAX_CONNECTIONS,
159                );
160                drop(stream);
161                continue;
162            }
163
164            let conn_id = ClientConn::generate_conn_id();
165            let aes_key = ClientConn::generate_aes_key();
166            stream.set_nodelay(true).ok();
167
168            tracing::info!(
169                conn_id = conn_id,
170                peer = %peer_addr,
171                total = connections.len() + 1,
172                "ws client connected"
173            );
174
175            let (tx, authed) = run_ws_connection(
176                stream,
177                conn_id,
178                aes_key,
179                req_tx.clone(),
180                disconnect_tx.clone(),
181                key_store_accept.clone(),
182            )
183            .await;
184
185            // 握手鉴权失败 → run_ws_connection 已经 drop 连接;这里什么都不做
186            let Some(authed) = authed else {
187                continue;
188            };
189
190            let (key_id, scopes) = match authed {
191                AuthResult::Authenticated(rec) => (Some(rec.id.clone()), rec.scopes.clone()),
192                AuthResult::Legacy => (None, HashSet::new()),
193            };
194
195            let conn = ClientConn {
196                conn_id,
197                state: ConnState::Connected,
198                aes_key,
199                aes_encrypt_enabled: false,
200                proto_fmt_type: ProtoFmtType::Protobuf,
201                last_keepalive: Instant::now(),
202                keepalive_count: std::sync::atomic::AtomicU32::new(0),
203                tx,
204                key_id,
205                scopes,
206            };
207
208            connections.insert(conn_id, conn);
209        }
210    }
211}
212
213/// 握手鉴权结果:scope 模式下的 KeyRecord 或 legacy 全放行
214enum AuthResult {
215    Authenticated(Arc<KeyRecord>),
216    Legacy,
217}
218
219/// 运行单个 WebSocket 连接的收发循环
220///
221/// 接收端:WS Binary Message → 解析为 FutuFrame → 发到 req_tx
222/// 发送端:frame_rx 收到 FutuFrame → 编码为字节 → 发送 WS Binary Message
223///
224/// 返回 `(frame_tx, Option<AuthResult>)` —— `None` 表示握手 / 鉴权失败,调用方
225/// 不应把这个 conn_id 加入连接池。
226async fn run_ws_connection(
227    stream: tokio::net::TcpStream,
228    conn_id: u64,
229    _aes_key: [u8; 16],
230    req_tx: mpsc::UnboundedSender<IncomingRequest>,
231    disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
232    key_store: Option<Arc<KeyStore>>,
233) -> (mpsc::Sender<FutuFrame>, Option<AuthResult>) {
234    let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
235
236    // 握手回调里验证 token + scope,结果存到 authed_slot;accept_hdr_async 内部
237    // 只读 callback 的 Ok/Err 决定要不要升级,所以 KeyRecord 要借 Mutex 传出来
238    let authed_slot: Arc<std::sync::Mutex<Option<AuthResult>>> =
239        Arc::new(std::sync::Mutex::new(None));
240    let slot_cb = Arc::clone(&authed_slot);
241    let store_cb = key_store.clone();
242
243    #[allow(clippy::result_large_err)] // ErrorResponse 是 tungstenite 的类型,我们无法改其大小
244    let callback = move |req: &Request, resp: Response| -> Result<Response, ErrorResponse> {
245        // legacy:未注入 KeyStore 或未配置 keys.json → 全放行
246        let Some(store) = store_cb.as_ref() else {
247            *slot_cb.lock().unwrap() = Some(AuthResult::Legacy);
248            return Ok(resp);
249        };
250        if !store.is_configured() {
251            *slot_cb.lock().unwrap() = Some(AuthResult::Legacy);
252            return Ok(resp);
253        }
254
255        // 提取 token:?token=... 或 Authorization: Bearer ...
256        let token = extract_ws_token(req);
257        let Some(token) = token else {
258            futu_auth::audit::reject("ws", "/ws", "<missing>", "missing token");
259            return Err(make_err_response(
260                StatusCode::UNAUTHORIZED,
261                "missing api key (use ?token=... or Authorization: Bearer ...)",
262            ));
263        };
264
265        let Some(rec) = store.verify(&token) else {
266            futu_auth::audit::reject("ws", "/ws", "<invalid>", "invalid api key");
267            return Err(make_err_response(
268                StatusCode::UNAUTHORIZED,
269                "invalid api key",
270            ));
271        };
272
273        if rec.is_expired(Utc::now()) {
274            futu_auth::audit::reject("ws", "/ws", &rec.id, "key expired");
275            return Err(make_err_response(StatusCode::UNAUTHORIZED, "key expired"));
276        }
277
278        // 最低门槛:qot:read。真正按 proto_id 的细粒度检查留给 process_requests
279        if !rec.scopes.contains(&Scope::QotRead) {
280            futu_auth::audit::reject("ws", "/ws", &rec.id, "missing qot:read");
281            return Err(make_err_response(
282                StatusCode::FORBIDDEN,
283                "missing qot:read scope",
284            ));
285        }
286
287        futu_auth::audit::allow("ws", "/ws", &rec.id, Some("qot:read"));
288        *slot_cb.lock().unwrap() = Some(AuthResult::Authenticated(rec));
289        Ok(resp)
290    };
291
292    let ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
293        Ok(ws) => ws,
294        Err(e) => {
295            tracing::warn!(conn_id = conn_id, error = %e, "ws handshake failed");
296            let _ = disconnect_tx.send(DisconnectNotify { conn_id });
297            return (frame_tx, None);
298        }
299    };
300
301    // 握手成功 → slot_cb 已填,否则是严重 bug
302    let authed = authed_slot
303        .lock()
304        .unwrap()
305        .take()
306        .expect("authed_slot must be filled after successful handshake");
307
308    let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
309
310    // 发送任务:FutuFrame → 编码为字节 → WS Binary
311    tokio::spawn(async move {
312        while let Some(frame) = frame_rx.recv().await {
313            let mut buf = BytesMut::new();
314            frame.header.encode(&mut buf);
315            buf.extend_from_slice(&frame.body);
316            let msg = Message::Binary(buf.freeze().into());
317            if let Err(e) = ws_sink.send(msg).await {
318                tracing::warn!(conn_id = conn_id, error = %e, "ws send failed");
319                break;
320            }
321        }
322    });
323
324    // 接收任务:WS Binary → 解析 FutuFrame → req_tx
325    tokio::spawn(async move {
326        while let Some(result) = ws_stream_rx.next().await {
327            match result {
328                Ok(msg) => {
329                    let data = match msg {
330                        Message::Binary(data) => data,
331                        Message::Close(_) => {
332                            tracing::info!(conn_id = conn_id, "ws client sent close");
333                            break;
334                        }
335                        Message::Ping(_) | Message::Pong(_) => {
336                            // tokio-tungstenite 自动处理 ping/pong
337                            continue;
338                        }
339                        _ => {
340                            // 忽略 Text 等其他类型
341                            continue;
342                        }
343                    };
344
345                    // 解析 FutuAPI 帧(44 字节帧头 + body)
346                    if data.len() < HEADER_SIZE {
347                        tracing::warn!(
348                            conn_id = conn_id,
349                            len = data.len(),
350                            "ws message too short for futu header"
351                        );
352                        continue;
353                    }
354
355                    let header_buf = BytesMut::from(&data[..]);
356                    let header = match FutuHeader::peek(&header_buf) {
357                        Ok(Some(h)) => h,
358                        Ok(None) => {
359                            tracing::warn!(conn_id = conn_id, "ws header peek returned None");
360                            continue;
361                        }
362                        Err(e) => {
363                            tracing::warn!(conn_id = conn_id, error = %e, "ws invalid futu header");
364                            continue;
365                        }
366                    };
367
368                    let expected_len = HEADER_SIZE + header.body_len as usize;
369                    if data.len() < expected_len {
370                        tracing::warn!(
371                            conn_id = conn_id,
372                            expected = expected_len,
373                            actual = data.len(),
374                            "ws message shorter than expected frame size"
375                        );
376                        continue;
377                    }
378
379                    let body = bytes::Bytes::copy_from_slice(&data[HEADER_SIZE..expected_len]);
380
381                    let req = IncomingRequest {
382                        conn_id,
383                        proto_id: header.proto_id,
384                        serial_no: header.serial_no,
385                        proto_fmt_type: header.proto_fmt_type,
386                        body,
387                    };
388
389                    if req_tx.send(req).is_err() {
390                        break;
391                    }
392                }
393                Err(e) => {
394                    tracing::warn!(conn_id = conn_id, error = %e, "ws recv error");
395                    break;
396                }
397            }
398        }
399        tracing::info!(conn_id = conn_id, "ws connection closed");
400        let _ = disconnect_tx.send(DisconnectNotify { conn_id });
401    });
402
403    (frame_tx, Some(authed))
404}
405
406/// 从握手 Request 里提取 token:优先 ?token=...,再 Authorization: Bearer ...
407///
408/// 浏览器 WS API 不允许设置自定义 header,所以优先支持 query;原生客户端
409/// (curl / Futu SDK 之类)两种都可以
410fn extract_ws_token(req: &Request) -> Option<String> {
411    if let Some(q) = req.uri().query() {
412        // 手写 query 解析,避免引 url / percent-encoding 新依赖
413        let params: HashMap<&str, &str> =
414            q.split('&').filter_map(|kv| kv.split_once('=')).collect();
415        if let Some(v) = params.get("token") {
416            if !v.is_empty() {
417                return Some((*v).to_string());
418            }
419        }
420    }
421    req.headers()
422        .get("authorization")
423        .and_then(|v| v.to_str().ok())
424        .and_then(|v| v.strip_prefix("Bearer ").map(|s| s.trim().to_string()))
425        .filter(|s| !s.is_empty())
426}
427
428/// 构造 tungstenite 握手失败的 HTTP 响应
429fn make_err_response(code: StatusCode, msg: &str) -> ErrorResponse {
430    let body = Some(format!(r#"{{"error":"{msg}"}}"#));
431    let mut resp = tokio_tungstenite::tungstenite::http::Response::new(body);
432    *resp.status_mut() = code;
433    resp.headers_mut().insert(
434        "content-type",
435        tokio_tungstenite::tungstenite::http::HeaderValue::from_static("application/json"),
436    );
437    resp
438}
439
440/// 处理 WebSocket 连接的请求(逻辑与 TCP 的 process_requests 相同,额外做 scope / 限额)
441///
442/// `scope_mode`:KeyStore 配置了时为 true,此时按 proto_id 查 scope 并与
443/// 连接里存的 scopes 比对;未匹配直接丢弃请求 + 写 audit reject,不发响应。
444/// `counters`:trade:real 请求通过 scope 后再跑一次 check_and_commit。
445async fn ws_process_requests(
446    mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
447    connections: Arc<DashMap<u64, ClientConn>>,
448    router: Arc<RequestRouter>,
449    config: ServerConfig,
450    counters: Option<Arc<RuntimeCounters>>,
451    key_store: Option<Arc<KeyStore>>,
452    scope_mode: bool,
453) {
454    use crate::listener::ApiServer;
455
456    while let Some(mut req) = req_rx.recv().await {
457        let conn_id = req.conn_id;
458        let proto_id_val = req.proto_id;
459        let serial_no = req.serial_no;
460
461        // 更新 last_keepalive(任何包都算活跃)
462        if let Some(mut conn) = connections.get_mut(&conn_id) {
463            conn.last_keepalive = Instant::now();
464        }
465
466        // v1.0 per-message scope gate —— 只在 scope_mode 启用
467        if scope_mode {
468            if let Some(needed) = futu_auth::scope_for_proto_id(proto_id_val) {
469                // 取该连接的 scope 集合和 key_id 快照(拿完锁就释放)
470                let (scopes, key_id_snap) = match connections.get(&conn_id) {
471                    Some(conn) => (conn.scopes.clone(), conn.key_id.clone()),
472                    None => {
473                        tracing::warn!(conn_id, proto_id = proto_id_val, "ws req on unknown conn");
474                        continue;
475                    }
476                };
477                let key_id_str = key_id_snap.as_deref().unwrap_or("<none>");
478                // 从 KeyStore 按 id 查最新 limits(SIGHUP reload 后能立刻生效,
479                // 对齐 MCP 的 SIGHUP-aware 行为)
480                let limits_snap = match (&key_store, &key_id_snap) {
481                    (Some(ks), Some(id)) => {
482                        ks.get_by_id(id).map(|r| r.limits()).unwrap_or_default()
483                    }
484                    _ => futu_auth::Limits::default(),
485                };
486                if !scopes.contains(&needed) {
487                    futu_auth::audit::reject(
488                        "ws",
489                        &format!("proto_id={proto_id_val}"),
490                        key_id_str,
491                        &format!("missing scope {needed}"),
492                    );
493                    continue;
494                }
495                // trade:real 多跑一次 rate + hours 全局闸门
496                if needed == Scope::TradeReal {
497                    if let Some(c) = &counters {
498                        let ctx = CheckCtx {
499                            market: String::new(),
500                            symbol: String::new(),
501                            order_value: None,
502                            trd_side: None,
503                        };
504                        if let LimitOutcome::Reject(reason) =
505                            c.check_and_commit(key_id_str, &limits_snap, &ctx, Utc::now())
506                        {
507                            futu_auth::audit::reject(
508                                "ws",
509                                &format!("proto_id={proto_id_val}"),
510                                key_id_str,
511                                &format!("limit: {reason}"),
512                            );
513                            continue;
514                        }
515                    }
516                }
517                futu_auth::audit::allow(
518                    "ws",
519                    &format!("proto_id={proto_id_val}"),
520                    key_id_str,
521                    Some(needed.as_str()),
522                );
523            }
524            // scope_for_proto_id 返回 None(1xxx 系统)→ 放行不 audit
525        }
526
527        // 非 InitConnect 请求需要 AES 解密
528        if proto_id_val != proto_id::INIT_CONNECT {
529            if let Some(conn) = connections.get(&conn_id) {
530                if conn.aes_encrypt_enabled {
531                    match conn.decrypt_body(&req.body) {
532                        Ok(decrypted) => {
533                            req.body = bytes::Bytes::from(decrypted);
534                        }
535                        Err(e) => {
536                            tracing::warn!(
537                                conn_id = conn_id,
538                                proto_id = proto_id_val,
539                                error = %e,
540                                "ws AES decrypt request failed, dropping"
541                            );
542                            continue;
543                        }
544                    }
545                }
546            }
547        }
548
549        let response_body = match proto_id_val {
550            proto_id::INIT_CONNECT => {
551                if let Some(mut conn) = connections.get_mut(&conn_id) {
552                    conn.handle_init_connect(
553                        &req.body,
554                        config.server_ver,
555                        config.login_user_id,
556                        config.keepalive_interval,
557                        config.rsa_private_key.as_deref(),
558                    )
559                    .ok()
560                } else {
561                    None
562                }
563            }
564            proto_id::KEEP_ALIVE => {
565                if let Some(conn) = connections.get(&conn_id) {
566                    conn.handle_keepalive(&req.body).ok()
567                } else {
568                    None
569                }
570            }
571            _ => router.dispatch(conn_id, &req).await,
572        };
573
574        if let Some(body) = response_body {
575            ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
576        }
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use tokio_tungstenite::tungstenite::http::Request as HttpRequest;
584
585    fn mk_req(uri: &str, auth_header: Option<&str>) -> HttpRequest<()> {
586        let mut b = HttpRequest::builder().uri(uri);
587        if let Some(v) = auth_header {
588            b = b.header("authorization", v);
589        }
590        b.body(()).unwrap()
591    }
592
593    #[test]
594    fn extract_token_from_query() {
595        let r = mk_req("/ws?token=abc123", None);
596        assert_eq!(extract_ws_token(&r), Some("abc123".to_string()));
597    }
598
599    #[test]
600    fn extract_token_from_bearer_header() {
601        let r = mk_req("/ws", Some("Bearer xyz789"));
602        assert_eq!(extract_ws_token(&r), Some("xyz789".to_string()));
603    }
604
605    #[test]
606    fn extract_token_query_preferred_over_header() {
607        let r = mk_req("/ws?token=from-query", Some("Bearer from-header"));
608        assert_eq!(extract_ws_token(&r), Some("from-query".to_string()));
609    }
610
611    #[test]
612    fn extract_token_empty_query_falls_back_to_header() {
613        let r = mk_req("/ws?token=", Some("Bearer from-header"));
614        assert_eq!(extract_ws_token(&r), Some("from-header".to_string()));
615    }
616
617    #[test]
618    fn extract_token_missing_everywhere() {
619        let r = mk_req("/ws", None);
620        assert_eq!(extract_ws_token(&r), None);
621        let r2 = mk_req("/ws?foo=bar", None);
622        assert_eq!(extract_ws_token(&r2), None);
623    }
624
625    #[test]
626    fn extract_token_non_bearer_auth_ignored() {
627        let r = mk_req("/ws", Some("Basic asdf"));
628        assert_eq!(extract_ws_token(&r), None);
629    }
630
631    #[test]
632    fn extract_token_bearer_empty_after_prefix() {
633        // "Bearer " 后无 token → None(.filter 把空串过滤)
634        let r = mk_req("/ws", Some("Bearer "));
635        assert_eq!(extract_ws_token(&r), None);
636    }
637}