futu_server/conn.rs
1// 单连接管理:状态机、帧收发、加密、心跳超时
2
3use std::collections::HashSet;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
6
7use bytes::Bytes;
8use futures::{SinkExt, StreamExt};
9use tokio::net::TcpStream;
10use tokio::sync::{mpsc, watch};
11use tokio_util::codec::Framed;
12
13use futu_auth::Scope;
14use futu_codec::FutuCodec;
15use futu_codec::frame::FutuFrame;
16use futu_codec::header::ProtoFmtType;
17use futu_core::error::FutuError;
18use futu_net::encrypt;
19
20/// 连接状态
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22#[non_exhaustive]
23pub enum ConnState {
24 /// TCP / WebSocket 刚建立,尚未完成 InitConnect 握手
25 Connected,
26 /// 已完成 InitConnect,等待首次业务请求
27 Initialized,
28 /// 正常交互中(业务请求 / KeepAlive / push 均已流转)
29 Active,
30 /// 连接已断开(客户端主动关闭 / 被动超时 / IO 错误)
31 Disconnected,
32}
33
34/// 单个客户端连接
35pub struct ClientConn {
36 /// 随机连接 ID(对应 C++ `GetRand_MilliTimeAndU22`)
37 pub conn_id: u64,
38 /// 连接状态:InitConnect 前 / 后 / 已断开
39 pub state: ConnState,
40 /// 随机 AES-128 key(InitConnect 响应里下发给客户端)
41 pub aes_key: [u8; 16],
42 /// AES 加解密已启用(InitConnect 完成且配置了 RSA 时为 true)
43 pub aes_encrypt_enabled: bool,
44 /// 该连接协商的 proto 格式(Protobuf / JSON)
45 pub proto_fmt_type: ProtoFmtType,
46 /// 上次收到 KeepAlive 的时间,用于超时检查
47 pub last_keepalive: Instant,
48 /// InitConnect.C2S.recvNotify:此连接是否接收市场状态 / 交易解锁等通知。
49 ///
50 /// C++ 在 `APIServer_InitConnect.cpp` 里把该字段写入 ConnInfo;
51 /// `RegQotPush` / `Qot_Sub(isRegOrUnRegPush)` 不会修改这个开关。
52 pub recv_notify: bool,
53 /// InitConnect.C2S.aiType:AI 调用类型。C++ 10.7
54 /// `APIServer_InitConnect.cpp` 缺省为 0,并写入连接状态。
55 pub ai_type: i32,
56 /// 已收到的 KeepAlive 计数(监控用)
57 pub keepalive_count: AtomicU32,
58 /// 发送帧到此连接
59 pub tx: mpsc::Sender<FutuFrame>,
60
61 // ---- v1.0 WS per-message scope 鉴权(raw TCP legacy 兼容:scopes 空集全放行)----
62 /// 该连接绑定的 API key id;WS 握手时填,未配 keys.json 时为 None
63 pub key_id: Option<String>,
64 /// 该连接持有的 scope 集合;空集 = legacy 模式 / TCP 直连,scope 检查放行
65 pub scopes: HashSet<Scope>,
66 /// v1.4.105 D3 (Phase 4) T-B2: 该连接 caller key 的 `allowed_markets` 硬
67 /// 限额 (大写字符串 set, e.g. {"HK","US"}). `None` = 无限制 (legacy 模式
68 /// / TCP 直连默认全开 / 未配 allowed_markets); `Some(set)` 非空 → push 端
69 /// 应过滤 trd_market 不在 set 中的 trade event.
70 ///
71 /// **触发**: WS handshake 时从 `KeyRecord.allowed_markets` 拷贝过来.
72 /// `PushDispatcher::push_trd_acc` 端 Layer 3 filter 检查. 与
73 /// `caller_allowed_acc_ids` (Layer 1, per-call snapshot in IncomingRequest)
74 /// 区别: 本字段是 per-conn snapshot (handshake 时一次性), 不随 per-call
75 /// 重读 — KeyRecord SIGHUP reload 后**仅新建连接生效**, 老连接保持 snapshot
76 /// (与 `scopes` / `caller_allowed_acc_ids` 的 snapshot 语义一致).
77 pub allowed_markets: Option<std::sync::Arc<HashSet<String>>>,
78 /// codex round 1 F4 (P2) v1.4.105: 该连接 caller key 的 `allowed_acc_ids`
79 /// 硬限额 (per-conn snapshot, handshake 时一次性). `None` / `Some(empty)` =
80 /// 无限制 (legacy 模式 / TCP 直连默认全开 / 未配 allowed_acc_ids);
81 /// `Some(non-empty set)` →
82 /// `PushDispatcher::push_trd_acc` 端 push-time 硬过滤 acc_id 不在 set 中
83 /// 的 trade event (Layer 1, 与 `allowed_markets` 的 Layer 3 互补).
84 /// Deny-all 使用 sentinel `{0}`,不使用空集合。
85 ///
86 /// **触发**: codex F4 指出 raw TCP push 端只查 `acc:read` scope +
87 /// `allowed_markets`, 不查 `allowed_acc_ids`. 即使 request-time
88 /// `SubAccPushHandler` 已阻止越权订阅, stale subscription / KeyRecord
89 /// reload 后窄化的 acc 范围 / 历史 bug 留下的 conn→acc 关系 仍可能让 push
90 /// 漏 leak. 本字段提供第二层 push-time 兜底.
91 ///
92 /// 与 `caller_allowed_acc_ids` (IncomingRequest, per-call) 区别: 本字段
93 /// 在 push-time 用 (无 IncomingRequest), per-conn snapshot 与 `scopes` /
94 /// `allowed_markets` 的 snapshot 语义一致.
95 pub allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
96}
97
98/// 从连接接收到的请求
99#[derive(Debug)]
100pub struct IncomingRequest {
101 /// 发送请求的连接 ID(用于响应路由 + SubscriptionManager / cache 记账)
102 ///
103 /// **跨 surface 命名空间分配**(v1.4.106 codex 0517 ζ25-redo F2 沉淀):
104 /// - raw TCP listener: `ClientConn::generate_conn_id()` 派生(u32 范围)
105 /// - REST: `crates/futu-rest/src/routes/qot.rs::REST_SHARED_CONN`
106 /// = `0xFFFF_FFFE`(u32 上限附近, 单值共享)
107 /// - gRPC: `crates/futu-grpc/src/auth.rs::GRPC_STABLE_CONN_NAMESPACE`
108 /// = `0x4000_0000_0000_0000`(bit 62 namespace, 按 caller 派生)
109 /// - WS / MCP: 通常派生自所属物理 TCP 连接的 conn_id
110 ///
111 /// 各 surface 不重叠. 加新 surface 时分配一个 namespace base, 不要与
112 /// 上述 4 个段重合.
113 pub conn_id: u64,
114 /// 协议 ID(对齐 C++ `NN_ProtoCmd_*`)
115 pub proto_id: u32,
116 /// 序列号(和 Response 配对,供 client 端请求-响应匹配)
117 pub serial_no: u32,
118 /// 请求体 proto 格式(Protobuf / JSON)
119 pub proto_fmt_type: ProtoFmtType,
120 /// 请求 body(已解密后的明文)
121 pub body: Bytes,
122 /// v1.4.38 Phase 4: 订单幂等 key(由 REST `Idempotency-Key` header / gRPC
123 /// metadata / WS envelope / MCP tool args 填入)。None 表示客户端未传,
124 /// handler 走无幂等直通 path(backward-compat)。
125 pub idempotency_key: Option<String>,
126 /// v1.4.106 codex 0920 F1 (P1): caller key id 副本 (per-call snapshot,
127 /// 由 surface adapter 层从 KeyRecord 读取后填入).
128 ///
129 /// **目标**: idempotency cache key namespace 必须含 caller key id, 否则
130 /// 不同 caller 用同 Idempotency-Key 会跨 caller 命中老 response —— 严重
131 /// 跨账户数据泄漏 + 重复下单 silent fail.
132 ///
133 /// `None` = 无 caller 标识 (legacy TCP / 未 auth) → namespace 用 `<no_key>`
134 /// 占位符. `Some("alice")` = WS / MCP / REST 已 auth 的 caller —— namespace
135 /// 用 `<caller_key_id="alice">`, 不与其他 caller 串.
136 pub caller_key_id: Option<String>,
137 /// v1.4.105 D2 contract-hardening 补丁: caller key 的 `allowed_acc_ids` 硬限额
138 /// 副本 (per-call snapshot, 在 surface adapter 层从 KeyRecord 读取后填入).
139 ///
140 /// **目标**: 让 dispatch-time handlers (e.g. `SubAccPushHandler` 注册 acc_id
141 /// 到 SubscriptionManager) 也能 enforce per-acc whitelist — 即使上游 pipeline
142 /// body-aware step 已 enforce, 让 handler 自己 defense-in-depth 防 future
143 /// regression (新 surface 加进来漏调 pipeline body-aware).
144 ///
145 /// `None` / `Some(empty)` = caller 无 acc_id 限制 (legacy mode 或 unrestricted
146 /// key) → handler 不 filter; `Some(non-empty set)` → handler 应 reject 不在
147 /// set 中的 acc_id. Deny-all 使用 sentinel `{0}`,不使用空集合。
148 pub caller_allowed_acc_ids: Option<std::sync::Arc<std::collections::HashSet<u64>>>,
149}
150
151impl IncomingRequest {
152 /// codex 0522 F4 v1.4.106: cross-surface 单测 hook. 构 IncomingRequest
153 /// 并填 caller scope (`caller_key_id` + `caller_allowed_acc_ids`) — 让
154 /// REST / gRPC / WS / MCP 等 surface 的 adapter 都用同一构造路径, 防
155 /// "某个 surface 漏填字段" silent regression.
156 ///
157 /// 之前 4 surface 各写一份 struct literal, 加新字段需逐个改, 漏一个就
158 /// 出现 silent None — 与坑 #54 schema-only fix 同模式 (实装符号 vs 真
159 /// 行为差距). 本 helper 是 single point, 加新字段 schema 自动 propagate.
160 ///
161 /// **注意**: 本 helper 不 take ownership of body — caller 已 own bytes.
162 /// idempotency_key / caller_key_id 接 String 而非 &str 让 caller 决定
163 /// 是 clone 还是 move.
164 pub fn builder(
165 conn_id: u64,
166 proto_id: u32,
167 serial_no: u32,
168 proto_fmt_type: ProtoFmtType,
169 body: Bytes,
170 ) -> IncomingRequestBuilder {
171 IncomingRequestBuilder {
172 request: Self {
173 conn_id,
174 proto_id,
175 serial_no,
176 proto_fmt_type,
177 body,
178 idempotency_key: None,
179 caller_allowed_acc_ids: None,
180 caller_key_id: None,
181 },
182 }
183 }
184}
185
186/// Thin builder for `IncomingRequest`.
187///
188/// The base request shape is the wire envelope; idempotency and caller scope are
189/// optional per-surface decorations. Keeping those defaults here avoids every
190/// REST / gRPC / raw WS / MCP adapter spelling out `None` independently.
191#[derive(Debug)]
192pub struct IncomingRequestBuilder {
193 request: IncomingRequest,
194}
195
196impl IncomingRequestBuilder {
197 pub fn with_idempotency_key(mut self, idempotency_key: Option<String>) -> Self {
198 self.request.idempotency_key = idempotency_key;
199 self
200 }
201
202 pub fn with_caller_scope(
203 mut self,
204 caller_allowed_acc_ids: Option<std::sync::Arc<HashSet<u64>>>,
205 caller_key_id: Option<String>,
206 ) -> Self {
207 self.request.caller_allowed_acc_ids = caller_allowed_acc_ids;
208 self.request.caller_key_id = caller_key_id;
209 self
210 }
211
212 pub fn build(self) -> IncomingRequest {
213 self.request
214 }
215}
216
217impl From<IncomingRequestBuilder> for IncomingRequest {
218 fn from(builder: IncomingRequestBuilder) -> Self {
219 builder.build()
220 }
221}
222
223fn conn_id_epoch_elapsed_or_zero() -> Duration {
224 match SystemTime::now().duration_since(UNIX_EPOCH) {
225 Ok(elapsed) => elapsed,
226 Err(err) => {
227 tracing::warn!(
228 error = %err,
229 "system wall clock is before UNIX_EPOCH; using zero duration fallback for conn_id"
230 );
231 Duration::ZERO
232 }
233 }
234}
235
236impl ClientConn {
237 /// 生成随机连接 ID(与 C++ 的 GetRand_MilliTimeAndU22 对应)
238 pub fn generate_conn_id() -> u64 {
239 let millis = conn_id_epoch_elapsed_or_zero().as_millis() as u64;
240 let rand_part: u32 = rand::random();
241 (millis << 22) | (rand_part as u64 & 0x3FFFFF)
242 }
243
244 /// 生成随机 AES key(16 字节 hex 字符串的 ASCII 字节)
245 pub fn generate_aes_key() -> [u8; 16] {
246 let rand_val: u64 = rand::random();
247 let hex = format!("{rand_val:016X}");
248 let mut key = [0u8; 16];
249 key.copy_from_slice(hex.as_bytes());
250 key
251 }
252
253 /// 创建发送帧,自动处理 AES 加密
254 ///
255 /// 当 aes_encrypt_enabled 为 true 时:
256 /// - SHA1 基于明文计算
257 /// - body 使用 AES-128 ECB 加密
258 /// - header.body_len 更新为密文长度
259 ///
260 /// 对应 C++ APIServerCS_Conn::OnSendPacketData 的加密逻辑
261 pub fn make_frame(&self, proto_id: u32, serial_no: u32, body: Bytes) -> FutuFrame {
262 let body_sha1 = FutuFrame::body_sha1(&body);
263 self.make_frame_with_sha1(proto_id, serial_no, body, body_sha1)
264 }
265
266 /// 创建发送帧,复用调用方已计算的明文 body SHA1。
267 pub fn make_frame_with_sha1(
268 &self,
269 proto_id: u32,
270 serial_no: u32,
271 body: Bytes,
272 body_sha1: [u8; 20],
273 ) -> FutuFrame {
274 if self.aes_encrypt_enabled {
275 let encrypted = encrypt::aes_ecb_encrypt(&self.aes_key, &body);
276 FutuFrame::with_sha1(proto_id, serial_no, Bytes::from(encrypted), body_sha1)
277 } else {
278 FutuFrame::with_sha1(proto_id, serial_no, body, body_sha1)
279 }
280 }
281
282 /// 解密请求 body(如果启用了 AES 加密)
283 ///
284 /// 对应 C++ APIServerCS_Conn::OnRecvPacket 的解密逻辑
285 pub fn decrypt_body(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
286 if self.aes_encrypt_enabled {
287 encrypt::aes_ecb_decrypt(&self.aes_key, body).map_err(|e| {
288 tracing::warn!(conn_id = self.conn_id, error = %e, "AES decrypt body failed");
289 e
290 })
291 } else {
292 Ok(body.to_vec())
293 }
294 }
295
296 /// 处理 InitConnect 请求,返回 InitConnect 响应 body
297 ///
298 /// 当配置了 RSA 私钥时:
299 /// - C2S 请求 body 使用 RSA 公钥加密(需要用私钥解密)
300 /// - S2C 响应 body 使用 RSA 公钥加密(客户端用私钥解密)
301 ///
302 /// 对应 C++ APIServer::OnRecvInitConnect
303 pub fn handle_init_connect(
304 &mut self,
305 body: &[u8],
306 server_ver: i32,
307 login_user_id: u64,
308 keepalive_interval: i32,
309 rsa_private_key: Option<&str>,
310 ) -> Result<Vec<u8>, FutuError> {
311 // 1. 解密 C2S(如果配置了 RSA)
312 let decrypted_body;
313 let req_body = if let Some(rsa_key) = rsa_private_key {
314 decrypted_body =
315 futu_net::encrypt::rsa_private_decrypt_blocks(rsa_key, body).map_err(|e| {
316 tracing::warn!(error = %e, "RSA decrypt InitConnect C2S failed");
317 e
318 })?;
319 tracing::debug!(
320 encrypted_len = body.len(),
321 decrypted_len = decrypted_body.len(),
322 "RSA decrypted InitConnect C2S"
323 );
324 &decrypted_body[..]
325 } else {
326 body
327 };
328
329 let req: futu_proto::init_connect::Request =
330 prost::Message::decode(req_body).map_err(FutuError::Proto)?;
331
332 self.state = ConnState::Initialized;
333 self.recv_notify = req.c2s.recv_notify.unwrap_or(false);
334 self.ai_type = req.c2s.ai_type.unwrap_or(0);
335
336 // 当配置了 RSA 时,后续所有帧使用 AES 加解密
337 if rsa_private_key.is_some() {
338 self.aes_encrypt_enabled = true;
339 tracing::debug!(conn_id = self.conn_id, "AES body encryption enabled");
340 }
341
342 let aes_key_str = std::str::from_utf8(&self.aes_key)
343 .map_err(|e| FutuError::Codec(format!("invalid InitConnect conn_aes_key: {e}")))?
344 .to_string();
345
346 let resp = futu_proto::init_connect::Response {
347 ret_type: 0,
348 ret_msg: None,
349 err_code: None,
350 s2c: Some(futu_proto::init_connect::S2c {
351 server_ver,
352 login_user_id,
353 conn_id: self.conn_id,
354 conn_aes_key: aes_key_str,
355 keep_alive_interval: keepalive_interval,
356 aes_cb_civ: None,
357 user_attribution: None,
358 }),
359 };
360
361 let resp_body = prost::Message::encode_to_vec(&resp);
362
363 // 2. 加密 S2C(如果配置了 RSA)
364 if let Some(rsa_key) = rsa_private_key {
365 let encrypted = futu_net::encrypt::rsa_public_encrypt_blocks(rsa_key, &resp_body)
366 .map_err(|e| {
367 tracing::warn!(error = %e, "RSA encrypt InitConnect S2C failed");
368 e
369 })?;
370 tracing::debug!(
371 plaintext_len = resp_body.len(),
372 encrypted_len = encrypted.len(),
373 "RSA encrypted InitConnect S2C"
374 );
375 Ok(encrypted)
376 } else {
377 Ok(resp_body)
378 }
379 }
380
381 /// 处理 KeepAlive 请求。
382 ///
383 /// This compatibility helper falls back to the local daemon clock. The
384 /// TCP / WebSocket dispatch paths must inject the backend-adjusted server
385 /// time through [`Self::handle_keepalive_at`] to match C++ OpenD.
386 pub fn handle_keepalive(&self, body: &[u8]) -> Result<Vec<u8>, FutuError> {
387 self.handle_keepalive_at(body, chrono::Utc::now().timestamp())
388 }
389
390 /// 处理 KeepAlive 请求,使用调用方注入的 server time。
391 ///
392 /// Ref: C++ `APIServerCS_Conn.cpp:370-373` replies with
393 /// `INNBiz_SvrTime::GetSvrTimeStamp()`.
394 pub fn handle_keepalive_at(
395 &self,
396 body: &[u8],
397 server_now_ts: i64,
398 ) -> Result<Vec<u8>, FutuError> {
399 let _req: futu_proto::keep_alive::Request =
400 prost::Message::decode(body).map_err(FutuError::Proto)?;
401
402 self.keepalive_count.fetch_add(1, Ordering::Relaxed);
403
404 let resp = futu_proto::keep_alive::Response {
405 ret_type: 0,
406 ret_msg: None,
407 err_code: None,
408 s2c: Some(futu_proto::keep_alive::S2c {
409 time: server_now_ts,
410 }),
411 };
412
413 Ok(prost::Message::encode_to_vec(&resp))
414 }
415}
416
417/// 连接断开通知
418pub struct DisconnectNotify {
419 /// 被断开的连接 ID(订阅 / push / auth 状态清理用)
420 pub conn_id: u64,
421}
422
423/// 通知 listener 清理连接;cleanup task 已退出时至少留下可观测日志。
424pub(crate) fn notify_disconnect(
425 disconnect_tx: &mpsc::UnboundedSender<DisconnectNotify>,
426 conn_id: u64,
427 reason: &'static str,
428) {
429 if let Err(e) = disconnect_tx.send(DisconnectNotify { conn_id }) {
430 tracing::warn!(
431 conn_id,
432 reason,
433 error = %e,
434 "disconnect cleanup notification failed"
435 );
436 }
437}
438
439/// 运行单个连接的收发循环
440///
441/// 返回 (接收请求的 channel, 连接信息)
442pub async fn run_connection(
443 stream: TcpStream,
444 conn_id: u64,
445 _aes_key: [u8; 16],
446 req_tx: mpsc::Sender<IncomingRequest>,
447 disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
448 mut shutdown_rx: watch::Receiver<bool>,
449) -> mpsc::Sender<FutuFrame> {
450 let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
451
452 let framed = Framed::new(stream, FutuCodec);
453 let (mut sink, mut stream) = framed.split();
454
455 // 发送任务
456 let send_disconnect_tx = disconnect_tx.clone();
457 tokio::spawn(async move {
458 while let Some(frame) = frame_rx.recv().await {
459 if let Err(e) = sink.send(frame).await {
460 tracing::warn!(conn_id = conn_id, error = %e, "send failed");
461 notify_disconnect(&send_disconnect_tx, conn_id, "tcp send failed");
462 break;
463 }
464 }
465 });
466
467 // 接收任务
468 tokio::spawn(async move {
469 loop {
470 let result = tokio::select! {
471 changed = shutdown_rx.changed() => {
472 if changed.is_err() || *shutdown_rx.borrow() {
473 tracing::info!(
474 conn_id = conn_id,
475 "connection receive loop stopped by shutdown signal"
476 );
477 break;
478 }
479 continue;
480 }
481 result = stream.next() => result,
482 };
483 let Some(result) = result else {
484 break;
485 };
486 match result {
487 Ok(frame) => {
488 // TCP wire currently has no idempotency or key scope fields.
489 // The builder keeps legacy defaults explicit in one place.
490 let req = IncomingRequest::builder(
491 conn_id,
492 frame.header.proto_id,
493 frame.header.serial_no,
494 frame.header.proto_fmt_type,
495 frame.body,
496 )
497 .build();
498 if req_tx.send(req).await.is_err() {
499 break;
500 }
501 }
502 Err(e) => {
503 tracing::warn!(conn_id = conn_id, error = %e, "recv error");
504 break;
505 }
506 }
507 }
508 tracing::info!(conn_id = conn_id, "connection closed");
509 // 通知 listener 清理连接
510 notify_disconnect(&disconnect_tx, conn_id, "tcp receive loop closed");
511 });
512
513 frame_tx
514}
515
516#[cfg(test)]
517mod tests;