Skip to main content

futu_server/
conn.rs

1// 单连接管理:状态机、帧收发、加密、心跳超时
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
6
7use bytes::Bytes;
8use futures::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, watch};
11use tokio_util::codec::Framed;
12
13use futu_auth::Scope;
14use futu_codec::FutuCodec;
15use futu_codec::frame::FutuFrame;
16use futu_codec::header::ProtoFmtType;
17use futu_core::error::FutuError;
18use futu_net::encrypt;
19
20/// 连接状态
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ConnState {
24    /// TCP / WebSocket 刚建立,尚未完成 InitConnect 握手
25    Connected,
26    /// 已完成 InitConnect,等待首次业务请求
27    Initialized,
28    /// 正常交互中(业务请求 / KeepAlive / push 均已流转)
29    Active,
30    /// 连接已断开(客户端主动关闭 / 被动超时 / IO 错误)
31    Disconnected,
32}
33
34/// 单个客户端连接
35pub struct ClientConn {
36    /// 随机连接 ID(对应 C++ `GetRand_MilliTimeAndU22`)
37    pub conn_id: u64,
38    /// 连接状态:InitConnect 前 / 后 / 已断开
39    pub state: ConnState,
40    /// 随机 AES-128 key(InitConnect 响应里下发给客户端)
41    pub aes_key: [u8; 16],
42    /// AES 加解密已启用(InitConnect 完成且配置了 RSA 时为 true)
43    pub aes_encrypt_enabled: bool,
44    /// 该连接协商的 proto 格式(Protobuf / JSON)
45    pub proto_fmt_type: ProtoFmtType,
46    /// 上次收到 KeepAlive 的时间,用于超时检查
47    pub last_keepalive: Instant,
48    /// InitConnect.C2S.recvNotify:此连接是否接收市场状态 / 交易解锁等通知。
49    ///
50    /// C++ 在 `APIServer_InitConnect.cpp` 里把该字段写入 ConnInfo;
51    /// `RegQotPush` / `Qot_Sub(isRegOrUnRegPush)` 不会修改这个开关。
52    pub recv_notify: bool,
53    /// InitConnect.C2S.aiType:AI 调用类型。C++ 10.7
54    /// `APIServer_InitConnect.cpp` 缺省为 0,并写入连接状态。
55    pub ai_type: i32,
56    /// 已收到的 KeepAlive 计数(监控用)
57    pub keepalive_count: AtomicU32,
58    /// 发送帧到此连接
59    pub tx: mpsc::Sender<FutuFrame>,
60
61    // ---- v1.0 WS per-message scope 鉴权(raw TCP legacy 兼容:scopes 空集全放行)----
62    /// 该连接绑定的 API key id;WS 握手时填,未配 keys.json 时为 None
63    pub key_id: Option<String>,
64    /// 该连接持有的 scope 集合;空集 = legacy 模式 / TCP 直连,scope 检查放行
65    pub scopes: HashSet<Scope>,
66    /// v1.4.105 D3 (Phase 4) T-B2: 该连接 caller key 的 `allowed_markets` 硬
67    /// 限额 (大写字符串 set, e.g. {"HK","US"}). `None` = 无限制 (legacy 模式
68    /// / TCP 直连默认全开 / 未配 allowed_markets); `Some(set)` 非空 → push 端
69    /// 应过滤 trd_market 不在 set 中的 trade event.
70    ///
71    /// **触发**: WS handshake 时从 `KeyRecord.allowed_markets` 拷贝过来.
72    /// `PushDispatcher::push_trd_acc` 端 Layer 3 filter 检查. 与
73    /// `caller_allowed_acc_ids` (Layer 1, per-call snapshot in IncomingRequest)
74    /// 区别: 本字段是 per-conn snapshot (handshake 时一次性), 不随 per-call
75    /// 重读 — KeyRecord SIGHUP reload 后**仅新建连接生效**, 老连接保持 snapshot
76    /// (与 `scopes` / `caller_allowed_acc_ids` 的 snapshot 语义一致).
77    pub allowed_markets: Option<std::sync::Arc<HashSet<String>>>,
78    /// codex round 1 F4 (P2) v1.4.105: 该连接 caller key 的 `allowed_acc_ids`
79    /// 硬限额 (per-conn snapshot, handshake 时一次性). `None` / `Some(empty)` =
80    /// 无限制 (legacy 模式 / TCP 直连默认全开 / 未配 allowed_acc_ids);
81    /// `Some(non-empty set)` →
82    /// `PushDispatcher::push_trd_acc` 端 push-time 硬过滤 acc_id 不在 set 中
83    /// 的 trade event (Layer 1, 与 `allowed_markets` 的 Layer 3 互补).
84    /// Deny-all 使用 sentinel `{0}`,不使用空集合。
85    ///
86    /// **触发**: codex F4 指出 raw TCP push 端只查 `acc:read` scope +
87    /// `allowed_markets`, 不查 `allowed_acc_ids`. 即使 request-time
88    /// `SubAccPushHandler` 已阻止越权订阅, stale subscription / KeyRecord
89    /// reload 后窄化的 acc 范围 / 历史 bug 留下的 conn→acc 关系 仍可能让 push
90    /// 漏 leak. 本字段提供第二层 push-time 兜底.
91    ///
92    /// 与 `caller_allowed_acc_ids` (IncomingRequest, per-call) 区别: 本字段
93    /// 在 push-time 用 (无 IncomingRequest), per-conn snapshot 与 `scopes` /
94    /// `allowed_markets` 的 snapshot 语义一致.
95    pub allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
96}
97
98/// 从连接接收到的请求
99#[derive(Debug)]
100pub struct IncomingRequest {
101    /// 发送请求的连接 ID(用于响应路由 + SubscriptionManager / cache 记账)
102    ///
103    /// **跨 surface 命名空间分配**(v1.4.106 codex 0517 ζ25-redo F2 沉淀):
104    /// - raw TCP listener: `ClientConn::generate_conn_id()` 派生(u32 范围)
105    /// - REST: `crates/futu-rest/src/routes/qot.rs::REST_SHARED_CONN`
106    ///   = `0xFFFF_FFFE`(u32 上限附近, 单值共享)
107    /// - gRPC: `crates/futu-grpc/src/auth.rs::GRPC_STABLE_CONN_NAMESPACE`
108    ///   = `0x4000_0000_0000_0000`(bit 62 namespace, 按 caller 派生)
109    /// - WS / MCP: 通常派生自所属物理 TCP 连接的 conn_id
110    ///
111    /// 各 surface 不重叠. 加新 surface 时分配一个 namespace base, 不要与
112    /// 上述 4 个段重合.
113    pub conn_id: u64,
114    /// 协议 ID(对齐 C++ `NN_ProtoCmd_*`)
115    pub proto_id: u32,
116    /// 序列号(和 Response 配对,供 client 端请求-响应匹配)
117    pub serial_no: u32,
118    /// 请求体 proto 格式(Protobuf / JSON)
119    pub proto_fmt_type: ProtoFmtType,
120    /// 请求 body(已解密后的明文)
121    pub body: Bytes,
122    /// v1.4.38 Phase 4: 订单幂等 key(由 REST `Idempotency-Key` header / gRPC
123    /// metadata / WS envelope / MCP tool args 填入)。None 表示客户端未传,
124    /// handler 走无幂等直通 path(backward-compat)。
125    pub idempotency_key: Option<String>,
126    /// v1.4.106 codex 0920 F1 (P1): caller key id 副本 (per-call snapshot,
127    /// 由 surface adapter 层从 KeyRecord 读取后填入).
128    ///
129    /// **目标**: idempotency cache key namespace 必须含 caller key id, 否则
130    /// 不同 caller 用同 Idempotency-Key 会跨 caller 命中老 response —— 严重
131    /// 跨账户数据泄漏 + 重复下单 silent fail.
132    ///
133    /// `None` = 无 caller 标识 (legacy TCP / 未 auth) → namespace 用 `<no_key>`
134    /// 占位符. `Some("alice")` = WS / MCP / REST 已 auth 的 caller —— namespace
135    /// 用 `<caller_key_id="alice">`, 不与其他 caller 串.
136    pub caller_key_id: Option<String>,
137    /// v1.4.105 D2 contract-hardening 补丁: caller key 的 `allowed_acc_ids` 硬限额
138    /// 副本 (per-call snapshot, 在 surface adapter 层从 KeyRecord 读取后填入).
139    ///
140    /// **目标**: 让 dispatch-time handlers (e.g. `SubAccPushHandler` 注册 acc_id
141    /// 到 SubscriptionManager) 也能 enforce per-acc whitelist — 即使上游 pipeline
142    /// body-aware step 已 enforce, 让 handler 自己 defense-in-depth 防 future
143    /// regression (新 surface 加进来漏调 pipeline body-aware).
144    ///
145    /// `None` / `Some(empty)` = caller 无 acc_id 限制 (legacy mode 或 unrestricted
146    /// key) → handler 不 filter; `Some(non-empty set)` → handler 应 reject 不在
147    /// set 中的 acc_id. Deny-all 使用 sentinel `{0}`,不使用空集合。
148    pub caller_allowed_acc_ids: Option<std::sync::Arc<std::collections::HashSet<u64>>>,
149}
150
151impl IncomingRequest {
152    /// codex 0522 F4 v1.4.106: cross-surface 单测 hook. 构 IncomingRequest
153    /// 并填 caller scope (`caller_key_id` + `caller_allowed_acc_ids`) — 让
154    /// REST / gRPC / WS / MCP 等 surface 的 adapter 都用同一构造路径, 防
155    /// "某个 surface 漏填字段" silent regression.
156    ///
157    /// 之前 4 surface 各写一份 struct literal, 加新字段需逐个改, 漏一个就
158    /// 出现 silent None — 与坑 #54 schema-only fix 同模式 (实装符号 vs 真
159    /// 行为差距). 本 helper 是 single point, 加新字段 schema 自动 propagate.
160    ///
161    /// **注意**: 本 helper 不 take ownership of body — caller 已 own bytes.
162    /// idempotency_key / caller_key_id 接 String 而非 &str 让 caller 决定
163    /// 是 clone 还是 move.
164    pub fn builder(
165        conn_id: u64,
166        proto_id: u32,
167        serial_no: u32,
168        proto_fmt_type: ProtoFmtType,
169        body: Bytes,
170    ) -> IncomingRequestBuilder {
171        IncomingRequestBuilder {
172            request: Self {
173                conn_id,
174                proto_id,
175                serial_no,
176                proto_fmt_type,
177                body,
178                idempotency_key: None,
179                caller_allowed_acc_ids: None,
180                caller_key_id: None,
181            },
182        }
183    }
184}
185
186/// Thin builder for `IncomingRequest`.
187///
188/// The base request shape is the wire envelope; idempotency and caller scope are
189/// optional per-surface decorations. Keeping those defaults here avoids every
190/// REST / gRPC / raw WS / MCP adapter spelling out `None` independently.
191#[derive(Debug)]
192pub struct IncomingRequestBuilder {
193    request: IncomingRequest,
194}
195
196impl IncomingRequestBuilder {
197    pub fn with_idempotency_key(mut self, idempotency_key: Option<String>) -> Self {
198        self.request.idempotency_key = idempotency_key;
199        self
200    }
201
202    pub fn with_caller_scope(
203        mut self,
204        caller_allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
205        caller_key_id: Option<String>,
206    ) -> Self {
207        self.request.caller_allowed_acc_ids = caller_allowed_acc_ids;
208        self.request.caller_key_id = caller_key_id;
209        self
210    }
211
212    pub fn build(self) -> IncomingRequest {
213        self.request
214    }
215}
216
217impl From<IncomingRequestBuilder> for IncomingRequest {
218    fn from(builder: IncomingRequestBuilder) -> Self {
219        builder.build()
220    }
221}
222
223fn conn_id_epoch_elapsed_or_zero() -> Duration {
224    match SystemTime::now().duration_since(UNIX_EPOCH) {
225        Ok(elapsed) => elapsed,
226        Err(err) => {
227            tracing::warn!(
228                error = %err,
229                "system wall clock is before UNIX_EPOCH; using zero duration fallback for conn_id"
230            );
231            Duration::ZERO
232        }
233    }
234}
235
236impl ClientConn {
237    /// 生成随机连接 ID(与 C++ 的 GetRand_MilliTimeAndU22 对应)
238    pub fn generate_conn_id() -> u64 {
239        let millis = conn_id_epoch_elapsed_or_zero().as_millis() as u64;
240        let rand_part: u32 = rand::random();
241        (millis << 22) | (rand_part as u64 & 0x3FFFFF)
242    }
243
244    /// 生成随机 AES key(16 字节 hex 字符串的 ASCII 字节)
245    pub fn generate_aes_key() -> [u8; 16] {
246        let rand_val: u64 = rand::random();
247        let hex = format!("{rand_val:016X}");
248        let mut key = [0u8; 16];
249        key.copy_from_slice(hex.as_bytes());
250        key
251    }
252
253    /// 创建发送帧,自动处理 AES 加密
254    ///
255    /// 当 aes_encrypt_enabled 为 true 时:
256    /// - SHA1 基于明文计算
257    /// - body 使用 AES-128 ECB 加密
258    /// - header.body_len 更新为密文长度
259    ///
260    /// 对应 C++ APIServerCS_Conn::OnSendPacketData 的加密逻辑
261    pub fn make_frame(&self, proto_id: u32, serial_no: u32, body: Bytes) -> FutuFrame {
262        let body_sha1 = FutuFrame::body_sha1(&body);
263        self.make_frame_with_sha1(proto_id, serial_no, body, body_sha1)
264    }
265
266    /// 创建发送帧,复用调用方已计算的明文 body SHA1。
267    pub fn make_frame_with_sha1(
268        &self,
269        proto_id: u32,
270        serial_no: u32,
271        body: Bytes,
272        body_sha1: [u8; 20],
273    ) -> FutuFrame {
274        if self.aes_encrypt_enabled {
275            let encrypted = encrypt::aes_ecb_encrypt(&self.aes_key, &body);
276            FutuFrame::with_sha1(proto_id, serial_no, Bytes::from(encrypted), body_sha1)
277        } else {
278            FutuFrame::with_sha1(proto_id, serial_no, body, body_sha1)
279        }
280    }
281
282    /// 解密请求 body(如果启用了 AES 加密)
283    ///
284    /// 对应 C++ APIServerCS_Conn::OnRecvPacket 的解密逻辑
285    pub fn decrypt_body(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
286        if self.aes_encrypt_enabled {
287            encrypt::aes_ecb_decrypt(&self.aes_key, body).map_err(|e| {
288                tracing::warn!(conn_id = self.conn_id, error = %e, "AES decrypt body failed");
289                e
290            })
291        } else {
292            Ok(body.to_vec())
293        }
294    }
295
296    /// 处理 InitConnect 请求,返回 InitConnect 响应 body
297    ///
298    /// 当配置了 RSA 私钥时:
299    /// - C2S 请求 body 使用 RSA 公钥加密(需要用私钥解密)
300    /// - S2C 响应 body 使用 RSA 公钥加密(客户端用私钥解密)
301    ///
302    /// 对应 C++ APIServer::OnRecvInitConnect
303    pub fn handle_init_connect(
304        &mut self,
305        body: &[u8],
306        server_ver: i32,
307        login_user_id: u64,
308        keepalive_interval: i32,
309        rsa_private_key: Option<&str>,
310    ) -> Result<Vec<u8>, FutuError> {
311        // 1. 解密 C2S(如果配置了 RSA)
312        let decrypted_body;
313        let req_body = if let Some(rsa_key) = rsa_private_key {
314            decrypted_body =
315                futu_net::encrypt::rsa_private_decrypt_blocks(rsa_key, body).map_err(|e| {
316                    tracing::warn!(error = %e, "RSA decrypt InitConnect C2S failed");
317                    e
318                })?;
319            tracing::debug!(
320                encrypted_len = body.len(),
321                decrypted_len = decrypted_body.len(),
322                "RSA decrypted InitConnect C2S"
323            );
324            &decrypted_body[..]
325        } else {
326            body
327        };
328
329        let req: futu_proto::init_connect::Request =
330            prost::Message::decode(req_body).map_err(FutuError::Proto)?;
331
332        self.state = ConnState::Initialized;
333        self.recv_notify = req.c2s.recv_notify.unwrap_or(false);
334        self.ai_type = req.c2s.ai_type.unwrap_or(0);
335
336        // 当配置了 RSA 时,后续所有帧使用 AES 加解密
337        if rsa_private_key.is_some() {
338            self.aes_encrypt_enabled = true;
339            tracing::debug!(conn_id = self.conn_id, "AES body encryption enabled");
340        }
341
342        let aes_key_str = std::str::from_utf8(&self.aes_key)
343            .map_err(|e| FutuError::Codec(format!("invalid InitConnect conn_aes_key: {e}")))?
344            .to_string();
345
346        let resp = futu_proto::init_connect::Response {
347            ret_type: 0,
348            ret_msg: None,
349            err_code: None,
350            s2c: Some(futu_proto::init_connect::S2c {
351                server_ver,
352                login_user_id,
353                conn_id: self.conn_id,
354                conn_aes_key: aes_key_str,
355                keep_alive_interval: keepalive_interval,
356                aes_cb_civ: None,
357                user_attribution: None,
358            }),
359        };
360
361        let resp_body = prost::Message::encode_to_vec(&resp);
362
363        // 2. 加密 S2C(如果配置了 RSA)
364        if let Some(rsa_key) = rsa_private_key {
365            let encrypted = futu_net::encrypt::rsa_public_encrypt_blocks(rsa_key, &resp_body)
366                .map_err(|e| {
367                    tracing::warn!(error = %e, "RSA encrypt InitConnect S2C failed");
368                    e
369                })?;
370            tracing::debug!(
371                plaintext_len = resp_body.len(),
372                encrypted_len = encrypted.len(),
373                "RSA encrypted InitConnect S2C"
374            );
375            Ok(encrypted)
376        } else {
377            Ok(resp_body)
378        }
379    }
380
381    /// 处理 KeepAlive 请求。
382    ///
383    /// This compatibility helper falls back to the local daemon clock. The
384    /// TCP / WebSocket dispatch paths must inject the backend-adjusted server
385    /// time through [`Self::handle_keepalive_at`] to match C++ OpenD.
386    pub fn handle_keepalive(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
387        self.handle_keepalive_at(body, chrono::Utc::now().timestamp())
388    }
389
390    /// 处理 KeepAlive 请求,使用调用方注入的 server time。
391    ///
392    /// Ref: C++ `APIServerCS_Conn.cpp:370-373` replies with
393    /// `INNBiz_SvrTime::GetSvrTimeStamp()`.
394    pub fn handle_keepalive_at(
395        &self,
396        body: &[u8],
397        server_now_ts: i64,
398    ) -> Result<Vec<u8>, FutuError> {
399        let _req: futu_proto::keep_alive::Request =
400            prost::Message::decode(body).map_err(FutuError::Proto)?;
401
402        self.keepalive_count.fetch_add(1, Ordering::Relaxed);
403
404        let resp = futu_proto::keep_alive::Response {
405            ret_type: 0,
406            ret_msg: None,
407            err_code: None,
408            s2c: Some(futu_proto::keep_alive::S2c {
409                time: server_now_ts,
410            }),
411        };
412
413        Ok(prost::Message::encode_to_vec(&resp))
414    }
415}
416
417/// 连接断开通知
418pub struct DisconnectNotify {
419    /// 被断开的连接 ID(订阅 / push / auth 状态清理用)
420    pub conn_id: u64,
421}
422
423/// 通知 listener 清理连接;cleanup task 已退出时至少留下可观测日志。
424pub(crate) fn notify_disconnect(
425    disconnect_tx: &mpsc::UnboundedSender<DisconnectNotify>,
426    conn_id: u64,
427    reason: &'static str,
428) {
429    if let Err(e) = disconnect_tx.send(DisconnectNotify { conn_id }) {
430        tracing::warn!(
431            conn_id,
432            reason,
433            error = %e,
434            "disconnect cleanup notification failed"
435        );
436    }
437}
438
439/// 运行单个连接的收发循环
440///
441/// 返回 (接收请求的 channel, 连接信息)
442pub async fn run_connection(
443    stream: TcpStream,
444    conn_id: u64,
445    _aes_key: [u8; 16],
446    req_tx: mpsc::Sender<IncomingRequest>,
447    disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
448    mut shutdown_rx: watch::Receiver<bool>,
449) -> mpsc::Sender<FutuFrame> {
450    let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
451
452    let framed = Framed::new(stream, FutuCodec);
453    let (mut sink, mut stream) = framed.split();
454
455    // 发送任务
456    let send_disconnect_tx = disconnect_tx.clone();
457    tokio::spawn(async move {
458        while let Some(frame) = frame_rx.recv().await {
459            if let Err(e) = sink.send(frame).await {
460                tracing::warn!(conn_id = conn_id, error = %e, "send failed");
461                notify_disconnect(&send_disconnect_tx, conn_id, "tcp send failed");
462                break;
463            }
464        }
465    });
466
467    // 接收任务
468    tokio::spawn(async move {
469        loop {
470            let result = tokio::select! {
471                changed = shutdown_rx.changed() => {
472                    if changed.is_err() || *shutdown_rx.borrow() {
473                        tracing::info!(
474                            conn_id = conn_id,
475                            "connection receive loop stopped by shutdown signal"
476                        );
477                        break;
478                    }
479                    continue;
480                }
481                result = stream.next() => result,
482            };
483            let Some(result) = result else {
484                break;
485            };
486            match result {
487                Ok(frame) => {
488                    // TCP wire currently has no idempotency or key scope fields.
489                    // The builder keeps legacy defaults explicit in one place.
490                    let req = IncomingRequest::builder(
491                        conn_id,
492                        frame.header.proto_id,
493                        frame.header.serial_no,
494                        frame.header.proto_fmt_type,
495                        frame.body,
496                    )
497                    .build();
498                    if req_tx.send(req).await.is_err() {
499                        break;
500                    }
501                }
502                Err(e) => {
503                    tracing::warn!(conn_id = conn_id, error = %e, "recv error");
504                    break;
505                }
506            }
507        }
508        tracing::info!(conn_id = conn_id, "connection closed");
509        // 通知 listener 清理连接
510        notify_disconnect(&disconnect_tx, conn_id, "tcp receive loop closed");
511    });
512
513    frame_tx
514}
515
516#[cfg(test)]
517mod tests;