Skip to main content

futu_server/
listener.rs

1// TCP 监听器:接受客户端连接,管理连接池
2
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, Ordering};
5use std::time::Instant;
6
7use dashmap::DashMap;
8use tokio::net::TcpListener;
9use tokio::sync::{mpsc, watch};
10
11use futu_codec::header::ProtoFmtType;
12use futu_core::proto_id;
13
14use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
15use crate::metrics::GatewayMetrics;
16use crate::router::RequestRouter;
17
18/// 服务端最大连接数
19pub const MAX_CONNECTIONS: usize = 128;
20
21/// Inbound request queue capacity shared by raw TCP and WebSocket listeners.
22///
23/// A bounded queue turns slow backend dispatch into socket-level backpressure
24/// instead of letting client frames accumulate without a memory ceiling.
25pub(crate) const REQUEST_QUEUE_CAPACITY: usize = 4096;
26
27/// 服务端配置
28#[derive(Debug, Clone)]
29pub struct ServerConfig {
30    /// TCP 监听地址(如 `127.0.0.1:11111`)
31    pub listen_addr: String,
32    /// 服务端版本号,InitConnect 响应下发给客户端
33    pub server_ver: i32,
34    /// 服务端登录 user_id,InitConnect 响应下发给客户端
35    pub login_user_id: u64,
36    /// KeepAlive 心跳间隔(秒),InitConnect 响应下发给客户端
37    pub keepalive_interval: i32,
38    /// RSA 私钥 PEM 内容(可选,启用后 InitConnect 使用 RSA 加解密)
39    pub rsa_private_key: Option<String>,
40}
41
42#[must_use]
43pub(crate) fn default_server_time_offset_secs() -> Arc<AtomicI64> {
44    Arc::new(AtomicI64::new(0))
45}
46
47#[must_use]
48pub(crate) fn server_now_ts_at(server_time_offset_secs: &AtomicI64, local_now_ts: i64) -> i64 {
49    local_now_ts.saturating_add(server_time_offset_secs.load(Ordering::Relaxed))
50}
51
52#[must_use]
53pub(crate) fn server_now_ts(server_time_offset_secs: &AtomicI64) -> i64 {
54    server_now_ts_at(server_time_offset_secs, chrono::Utc::now().timestamp())
55}
56
57pub(crate) fn set_nodelay_with_log(
58    stream: &tokio::net::TcpStream,
59    peer_addr: std::net::SocketAddr,
60    surface: &'static str,
61) {
62    if let Err(error) = stream.set_nodelay(true) {
63        tracing::debug!(
64            peer = %peer_addr,
65            surface,
66            error = %error,
67            "tcp nodelay setup failed"
68        );
69    }
70}
71
72pub(crate) async fn shutdown_requested(shutdown_rx: &mut watch::Receiver<bool>) {
73    loop {
74        if *shutdown_rx.borrow() {
75            return;
76        }
77        if shutdown_rx.changed().await.is_err() {
78            return;
79        }
80    }
81}
82
83/// API 服务端
84pub struct ApiServer {
85    config: ServerConfig,
86    connections: Arc<DashMap<u64, ClientConn>>,
87    router: Arc<RequestRouter>,
88    subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
89    metrics: Arc<GatewayMetrics>,
90    server_time_offset_secs: Arc<AtomicI64>,
91}
92
93impl ApiServer {
94    /// 创建新的服务端实例。不自动启动,需调用 [`ApiServer::run`] 进入接收循环。
95    pub fn new(config: ServerConfig) -> Self {
96        Self {
97            config,
98            connections: Arc::new(DashMap::new()),
99            router: Arc::new(RequestRouter::new()),
100            subscriptions: None,
101            metrics: Arc::new(GatewayMetrics::new()),
102            server_time_offset_secs: default_server_time_offset_secs(),
103        }
104    }
105
106    /// 设置订阅管理器,用于连接断开时自动清理订阅关系
107    pub fn set_subscriptions(&mut self, subs: Arc<crate::subscription::SubscriptionManager>) {
108        self.subscriptions = Some(subs);
109    }
110
111    /// 获取路由器引用(用于注册业务处理器)
112    pub fn router(&self) -> &Arc<RequestRouter> {
113        &self.router
114    }
115
116    /// 获取连接池引用(用于推送分发)
117    pub fn connections(&self) -> &Arc<DashMap<u64, ClientConn>> {
118        &self.connections
119    }
120
121    /// 设置外部监控指标(共享同一个 Arc,让 bridge 和 server 使用同一份计数器)
122    pub fn set_metrics(&mut self, metrics: Arc<GatewayMetrics>) {
123        self.metrics = metrics;
124    }
125
126    /// Inject backend server-time offset for SDK-facing API protocol fields.
127    pub fn set_server_time_offset_secs(&mut self, offset: Arc<AtomicI64>) {
128        self.server_time_offset_secs = offset;
129    }
130
131    /// 获取监控指标引用
132    pub fn metrics(&self) -> &Arc<GatewayMetrics> {
133        &self.metrics
134    }
135
136    /// 启动服务端监听
137    pub async fn run(&self) -> anyhow::Result<()> {
138        let (_shutdown_tx, shutdown_rx) = watch::channel(false);
139        self.run_until_shutdown(shutdown_rx).await
140    }
141
142    /// 启动服务端监听,并在 shutdown 信号到来时停止接受新连接。
143    pub async fn run_until_shutdown(
144        &self,
145        mut shutdown_rx: watch::Receiver<bool>,
146    ) -> anyhow::Result<()> {
147        let listener = TcpListener::bind(&self.config.listen_addr).await?;
148        tracing::info!(addr = %self.config.listen_addr, "API server listening");
149
150        let (req_tx, req_rx) = mpsc::channel::<IncomingRequest>(REQUEST_QUEUE_CAPACITY);
151        // The disconnect cleanup signal is intentionally unbounded: each item
152        // is a tiny `conn_id`, and blocking this path behind request-queue
153        // backpressure would risk leaking connection/subscription state.
154        let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
155
156        // 启动请求处理任务
157        let connections = Arc::clone(&self.connections);
158        let router = Arc::clone(&self.router);
159        let config = self.config.clone();
160        let metrics = Arc::clone(&self.metrics);
161        let server_time_offset_secs = Arc::clone(&self.server_time_offset_secs);
162        tokio::spawn(async move {
163            process_requests(
164                req_rx,
165                connections,
166                router,
167                config,
168                metrics,
169                server_time_offset_secs,
170            )
171            .await;
172        });
173
174        // 启动连接清理任务(TCP 断开通知)
175        let cleanup_connections = Arc::clone(&self.connections);
176        let cleanup_subs = self.subscriptions.clone();
177        let cleanup_metrics = Arc::clone(&self.metrics);
178        tokio::spawn(async move {
179            while let Some(notify) = disconnect_rx.recv().await {
180                let removed = cleanup_connections.remove(&notify.conn_id);
181                if removed.is_some() {
182                    cleanup_metrics
183                        .total_disconnections
184                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
185                    // 清理该连接的所有订阅关系
186                    if let Some(ref subs) = cleanup_subs {
187                        subs.on_disconnect(notify.conn_id);
188                    }
189                    tracing::info!(
190                        conn_id = notify.conn_id,
191                        remaining = cleanup_connections.len(),
192                        "connection removed from pool"
193                    );
194                }
195            }
196        });
197
198        // 启动 KeepAlive 超时检测任务(对应 C++ OnTimeTicker,每 66 秒无活动断连)
199        let ka_connections = Arc::clone(&self.connections);
200        let ka_subs = self.subscriptions.clone();
201        let ka_metrics = Arc::clone(&self.metrics);
202        let mut ka_shutdown_rx = shutdown_rx.clone();
203        tokio::spawn(async move {
204            const CHECK_INTERVAL_SECS: u64 = 15;
205            const TIMEOUT_SECS: u64 = 66;
206            let mut interval =
207                tokio::time::interval(std::time::Duration::from_secs(CHECK_INTERVAL_SECS));
208            interval.tick().await; // 跳过首次立即触发
209            loop {
210                tokio::select! {
211                    _ = shutdown_requested(&mut ka_shutdown_rx) => {
212                        tracing::info!("API server keepalive task stopped by shutdown signal");
213                        break;
214                    }
215                    _ = interval.tick() => {}
216                }
217                let now = Instant::now();
218                let mut timed_out = Vec::new();
219                for entry in ka_connections.iter() {
220                    let conn = entry.value();
221                    if now.duration_since(conn.last_keepalive).as_secs() >= TIMEOUT_SECS {
222                        timed_out.push(conn.conn_id);
223                    }
224                }
225                for conn_id in timed_out {
226                    if ka_connections.remove(&conn_id).is_some() {
227                        ka_metrics
228                            .keepalive_timeouts
229                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
230                        ka_metrics
231                            .total_disconnections
232                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
233                        if let Some(ref subs) = ka_subs {
234                            subs.on_disconnect(conn_id);
235                        }
236                        tracing::info!(
237                            conn_id = conn_id,
238                            remaining = ka_connections.len(),
239                            "keepalive timeout, connection removed"
240                        );
241                    }
242                }
243            }
244        });
245
246        // 接受连接循环
247        let connections = Arc::clone(&self.connections);
248        let accept_metrics = Arc::clone(&self.metrics);
249        loop {
250            let (stream, peer_addr) = tokio::select! {
251                _ = shutdown_requested(&mut shutdown_rx) => {
252                    tracing::info!("API server accept loop stopped by shutdown signal");
253                    break;
254                }
255                accepted = listener.accept() => accepted?,
256            };
257
258            if connections.len() >= MAX_CONNECTIONS {
259                accept_metrics
260                    .rejected_connections
261                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
262                tracing::warn!(
263                    peer = %peer_addr,
264                    "max connections reached ({}), rejecting",
265                    MAX_CONNECTIONS
266                );
267                drop(stream);
268                continue;
269            }
270
271            accept_metrics
272                .total_connections
273                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
274
275            let conn_id = crate::conn::ClientConn::generate_conn_id();
276            let aes_key = crate::conn::ClientConn::generate_aes_key();
277            set_nodelay_with_log(&stream, peer_addr, "tcp");
278
279            tracing::info!(
280                conn_id = conn_id,
281                peer = %peer_addr,
282                total = connections.len() + 1,
283                "client connected"
284            );
285
286            let tx = crate::conn::run_connection(
287                stream,
288                conn_id,
289                aes_key,
290                req_tx.clone(),
291                disconnect_tx.clone(),
292                shutdown_rx.clone(),
293            )
294            .await;
295
296            let conn = ClientConn {
297                conn_id,
298                state: ConnState::Connected,
299                aes_key,
300                aes_encrypt_enabled: false,
301                proto_fmt_type: ProtoFmtType::Protobuf,
302                last_keepalive: Instant::now(),
303                recv_notify: false,
304                ai_type: 0,
305                keepalive_count: std::sync::atomic::AtomicU32::new(0),
306                tx,
307                // 原 TCP listener 不做 per-message scope 校验(保持兼容):
308                // key_id=None / scopes=空集 被 ws_listener 的 gate 解释为"legacy 全放行"
309                key_id: None,
310                scopes: std::collections::HashSet::new(),
311                // v1.4.105 D3 (Phase 4) T-B2: TCP listener 同样 legacy 模式 →
312                // allowed_markets None = 无限制 (push_trd_acc Layer 3 不 trigger).
313                allowed_markets: None,
314                // codex round 1 F4 (P2) v1.4.105: 同 legacy 模式, 无 acc_id 限制.
315                allowed_acc_ids: None,
316            };
317
318            connections.insert(conn_id, conn);
319        }
320
321        Ok(())
322    }
323
324    /// 向指定连接发送响应(自动处理 AES 加密)
325    pub async fn send_response(
326        connections: &DashMap<u64, ClientConn>,
327        conn_id: u64,
328        proto_id: u32,
329        serial_no: u32,
330        body: Vec<u8>,
331    ) {
332        if let Some(conn) = connections.get(&conn_id) {
333            let frame = conn.make_frame(proto_id, serial_no, bytes::Bytes::from(body));
334            if conn.tx.send(frame).await.is_err() {
335                tracing::warn!(
336                    conn_id = conn_id,
337                    "failed to send response, connection closed"
338                );
339            }
340        }
341    }
342}
343
344/// 处理所有连接的请求
345async fn process_requests(
346    mut req_rx: mpsc::Receiver<IncomingRequest>,
347    connections: Arc<DashMap<u64, ClientConn>>,
348    router: Arc<RequestRouter>,
349    config: ServerConfig,
350    metrics: Arc<GatewayMetrics>,
351    server_time_offset_secs: Arc<AtomicI64>,
352) {
353    while let Some(mut req) = req_rx.recv().await {
354        let conn_id = req.conn_id;
355        let proto_id_val = req.proto_id;
356        let serial_no = req.serial_no;
357        let req_start = Instant::now();
358
359        metrics
360            .total_requests
361            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
362
363        // 更新 last_keepalive(任何包都算活跃,对应 C++ m_nKeepAlive_Count_Curt++)
364        if let Some(mut conn) = connections.get_mut(&conn_id) {
365            conn.last_keepalive = Instant::now();
366        }
367
368        // v1.4.106 codex 0532 F3 (P2): daemon-internal proto_id (高位
369        // 0x8000_0000 bit) 绝不应从 raw TCP 公开 surface 进入 — 仅 REST
370        // handler 内部合成给 router. 显式 reject + log, 防探测 daemon
371        // 内部 routing.
372        if futu_auth::is_internal_proto_id(proto_id_val) {
373            metrics
374                .total_request_errors
375                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
376            tracing::warn!(
377                conn_id,
378                proto_id = proto_id_val,
379                "rejecting daemon-internal proto_id at raw TCP public surface (audit 0532 F3)"
380            );
381            continue;
382        }
383
384        // 非 InitConnect 请求需要 AES 解密(InitConnect 自身处理 RSA 解密)
385        if proto_id_val != proto_id::INIT_CONNECT
386            && let Some(conn) = connections.get(&conn_id)
387            && conn.aes_encrypt_enabled
388        {
389            match conn.decrypt_body(&req.body) {
390                Ok(decrypted) => {
391                    req.body = bytes::Bytes::from(decrypted);
392                }
393                Err(e) => {
394                    metrics
395                        .total_request_errors
396                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
397                    tracing::warn!(
398                        conn_id = conn_id,
399                        proto_id = proto_id_val,
400                        error = %e,
401                        "AES decrypt request failed, dropping"
402                    );
403                    continue;
404                }
405            }
406        }
407
408        // InitConnect 和 KeepAlive 内部处理
409        let response_body = futu_backend::delay_stats::with_api_request(
410            conn_id,
411            serial_no,
412            proto_id_val,
413            || async {
414                match proto_id_val {
415                    proto_id::INIT_CONNECT => match connections.get_mut(&conn_id) {
416                        Some(mut conn) => match conn.handle_init_connect(
417                            &req.body,
418                            config.server_ver,
419                            config.login_user_id,
420                            config.keepalive_interval,
421                            config.rsa_private_key.as_deref(),
422                        ) {
423                            Ok(body) => Some(body),
424                            Err(error) => {
425                                metrics
426                                    .total_request_errors
427                                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
428                                tracing::warn!(
429                                    conn_id,
430                                    proto_id = proto_id_val,
431                                    error = %error,
432                                    "InitConnect handling failed"
433                                );
434                                None
435                            }
436                        },
437                        None => {
438                            metrics
439                                .total_request_errors
440                                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
441                            tracing::warn!(
442                                conn_id,
443                                proto_id = proto_id_val,
444                                "InitConnect request received for missing connection"
445                            );
446                            None
447                        }
448                    },
449                    proto_id::KEEP_ALIVE => match connections.get(&conn_id) {
450                        Some(conn) => match conn
451                            .handle_keepalive_at(&req.body, server_now_ts(&server_time_offset_secs))
452                        {
453                            Ok(body) => Some(body),
454                            Err(error) => {
455                                metrics
456                                    .total_request_errors
457                                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
458                                tracing::warn!(
459                                    conn_id,
460                                    proto_id = proto_id_val,
461                                    error = %error,
462                                    "KeepAlive handling failed"
463                                );
464                                None
465                            }
466                        },
467                        None => {
468                            metrics
469                                .total_request_errors
470                                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
471                            tracing::warn!(
472                                conn_id,
473                                proto_id = proto_id_val,
474                                "KeepAlive request received for missing connection"
475                            );
476                            None
477                        }
478                    },
479                    _ => {
480                        // 委托给路由器
481                        router.dispatch(conn_id, &req).await
482                    }
483                }
484            },
485        )
486        .await;
487
488        // 记录延迟
489        metrics.record_latency_ns(req_start.elapsed().as_nanos() as u64);
490
491        // 发送响应
492        if let Some(body) = response_body {
493            metrics
494                .total_response_bytes
495                .fetch_add(body.len() as u64, std::sync::atomic::Ordering::Relaxed);
496            ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
497        } else if proto_id_val != proto_id::INIT_CONNECT && proto_id_val != proto_id::KEEP_ALIVE {
498            metrics
499                .total_request_errors
500                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
501        }
502    }
503}
504
505#[cfg(test)]
506mod tests;