futu_server/
listener.rs

1// TCP 监听器:接受客户端连接,管理连接池
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use dashmap::DashMap;
7use tokio::net::TcpListener;
8use tokio::sync::mpsc;
9
10use futu_codec::header::ProtoFmtType;
11use futu_core::proto_id;
12
13use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
14use crate::metrics::GatewayMetrics;
15use crate::router::RequestRouter;
16
17/// 服务端最大连接数
18pub const MAX_CONNECTIONS: usize = 128;
19
20/// 服务端配置
21#[derive(Debug, Clone)]
22pub struct ServerConfig {
23    pub listen_addr: String,
24    pub server_ver: i32,
25    pub login_user_id: u64,
26    pub keepalive_interval: i32,
27    /// RSA 私钥 PEM 内容(可选,启用后 InitConnect 使用 RSA 加解密)
28    pub rsa_private_key: Option<String>,
29}
30
31/// API 服务端
32pub struct ApiServer {
33    config: ServerConfig,
34    connections: Arc<DashMap<u64, ClientConn>>,
35    router: Arc<RequestRouter>,
36    subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
37    metrics: Arc<GatewayMetrics>,
38}
39
40impl ApiServer {
41    pub fn new(config: ServerConfig) -> Self {
42        Self {
43            config,
44            connections: Arc::new(DashMap::new()),
45            router: Arc::new(RequestRouter::new()),
46            subscriptions: None,
47            metrics: Arc::new(GatewayMetrics::new()),
48        }
49    }
50
51    /// 设置订阅管理器,用于连接断开时自动清理订阅关系
52    pub fn set_subscriptions(&mut self, subs: Arc<crate::subscription::SubscriptionManager>) {
53        self.subscriptions = Some(subs);
54    }
55
56    /// 获取路由器引用(用于注册业务处理器)
57    pub fn router(&self) -> &Arc<RequestRouter> {
58        &self.router
59    }
60
61    /// 获取连接池引用(用于推送分发)
62    pub fn connections(&self) -> &Arc<DashMap<u64, ClientConn>> {
63        &self.connections
64    }
65
66    /// 设置外部监控指标(共享同一个 Arc,让 bridge 和 server 使用同一份计数器)
67    pub fn set_metrics(&mut self, metrics: Arc<GatewayMetrics>) {
68        self.metrics = metrics;
69    }
70
71    /// 获取监控指标引用
72    pub fn metrics(&self) -> &Arc<GatewayMetrics> {
73        &self.metrics
74    }
75
76    /// 启动服务端监听
77    pub async fn run(&self) -> anyhow::Result<()> {
78        let listener = TcpListener::bind(&self.config.listen_addr).await?;
79        tracing::info!(addr = %self.config.listen_addr, "API server listening");
80
81        let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
82        let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
83
84        // 启动请求处理任务
85        let connections = Arc::clone(&self.connections);
86        let router = Arc::clone(&self.router);
87        let config = self.config.clone();
88        let metrics = Arc::clone(&self.metrics);
89        tokio::spawn(async move {
90            process_requests(req_rx, connections, router, config, metrics).await;
91        });
92
93        // 启动连接清理任务(TCP 断开通知)
94        let cleanup_connections = Arc::clone(&self.connections);
95        let cleanup_subs = self.subscriptions.clone();
96        let cleanup_metrics = Arc::clone(&self.metrics);
97        tokio::spawn(async move {
98            while let Some(notify) = disconnect_rx.recv().await {
99                let removed = cleanup_connections.remove(&notify.conn_id);
100                if removed.is_some() {
101                    cleanup_metrics
102                        .total_disconnections
103                        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
104                    // 清理该连接的所有订阅关系
105                    if let Some(ref subs) = cleanup_subs {
106                        subs.on_disconnect(notify.conn_id);
107                    }
108                    tracing::info!(
109                        conn_id = notify.conn_id,
110                        remaining = cleanup_connections.len(),
111                        "connection removed from pool"
112                    );
113                }
114            }
115        });
116
117        // 启动 KeepAlive 超时检测任务(对应 C++ OnTimeTicker,每 66 秒无活动断连)
118        let ka_connections = Arc::clone(&self.connections);
119        let ka_subs = self.subscriptions.clone();
120        let ka_metrics = Arc::clone(&self.metrics);
121        tokio::spawn(async move {
122            const CHECK_INTERVAL_SECS: u64 = 15;
123            const TIMEOUT_SECS: u64 = 66;
124            let mut interval =
125                tokio::time::interval(std::time::Duration::from_secs(CHECK_INTERVAL_SECS));
126            interval.tick().await; // 跳过首次立即触发
127            loop {
128                interval.tick().await;
129                let now = Instant::now();
130                let mut timed_out = Vec::new();
131                for entry in ka_connections.iter() {
132                    let conn = entry.value();
133                    if now.duration_since(conn.last_keepalive).as_secs() >= TIMEOUT_SECS {
134                        timed_out.push(conn.conn_id);
135                    }
136                }
137                for conn_id in timed_out {
138                    if ka_connections.remove(&conn_id).is_some() {
139                        ka_metrics
140                            .keepalive_timeouts
141                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
142                        ka_metrics
143                            .total_disconnections
144                            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145                        if let Some(ref subs) = ka_subs {
146                            subs.on_disconnect(conn_id);
147                        }
148                        tracing::info!(
149                            conn_id = conn_id,
150                            remaining = ka_connections.len(),
151                            "keepalive timeout, connection removed"
152                        );
153                    }
154                }
155            }
156        });
157
158        // 接受连接循环
159        let connections = Arc::clone(&self.connections);
160        let accept_metrics = Arc::clone(&self.metrics);
161        loop {
162            let (stream, peer_addr) = listener.accept().await?;
163
164            if connections.len() >= MAX_CONNECTIONS {
165                accept_metrics
166                    .rejected_connections
167                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
168                tracing::warn!(
169                    peer = %peer_addr,
170                    "max connections reached ({}), rejecting",
171                    MAX_CONNECTIONS
172                );
173                drop(stream);
174                continue;
175            }
176
177            accept_metrics
178                .total_connections
179                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
180
181            let conn_id = crate::conn::ClientConn::generate_conn_id();
182            let aes_key = crate::conn::ClientConn::generate_aes_key();
183            stream.set_nodelay(true).ok();
184
185            tracing::info!(
186                conn_id = conn_id,
187                peer = %peer_addr,
188                total = connections.len() + 1,
189                "client connected"
190            );
191
192            let tx = crate::conn::run_connection(
193                stream,
194                conn_id,
195                aes_key,
196                req_tx.clone(),
197                disconnect_tx.clone(),
198            )
199            .await;
200
201            let conn = ClientConn {
202                conn_id,
203                state: ConnState::Connected,
204                aes_key,
205                aes_encrypt_enabled: false,
206                proto_fmt_type: ProtoFmtType::Protobuf,
207                last_keepalive: Instant::now(),
208                keepalive_count: std::sync::atomic::AtomicU32::new(0),
209                tx,
210                // 原 TCP listener 不做 per-message scope 校验(保持兼容):
211                // key_id=None / scopes=空集 被 ws_listener 的 gate 解释为"legacy 全放行"
212                key_id: None,
213                scopes: std::collections::HashSet::new(),
214            };
215
216            connections.insert(conn_id, conn);
217        }
218    }
219
220    /// 向指定连接发送响应(自动处理 AES 加密)
221    pub async fn send_response(
222        connections: &DashMap<u64, ClientConn>,
223        conn_id: u64,
224        proto_id: u32,
225        serial_no: u32,
226        body: Vec<u8>,
227    ) {
228        if let Some(conn) = connections.get(&conn_id) {
229            let frame = conn.make_frame(proto_id, serial_no, bytes::Bytes::from(body));
230            if conn.tx.send(frame).await.is_err() {
231                tracing::warn!(
232                    conn_id = conn_id,
233                    "failed to send response, connection closed"
234                );
235            }
236        }
237    }
238}
239
240/// 处理所有连接的请求
241async fn process_requests(
242    mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
243    connections: Arc<DashMap<u64, ClientConn>>,
244    router: Arc<RequestRouter>,
245    config: ServerConfig,
246    metrics: Arc<GatewayMetrics>,
247) {
248    while let Some(mut req) = req_rx.recv().await {
249        let conn_id = req.conn_id;
250        let proto_id_val = req.proto_id;
251        let serial_no = req.serial_no;
252        let req_start = Instant::now();
253
254        metrics
255            .total_requests
256            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257
258        // 更新 last_keepalive(任何包都算活跃,对应 C++ m_nKeepAlive_Count_Curt++)
259        if let Some(mut conn) = connections.get_mut(&conn_id) {
260            conn.last_keepalive = Instant::now();
261        }
262
263        // 非 InitConnect 请求需要 AES 解密(InitConnect 自身处理 RSA 解密)
264        if proto_id_val != proto_id::INIT_CONNECT {
265            if let Some(conn) = connections.get(&conn_id) {
266                if conn.aes_encrypt_enabled {
267                    match conn.decrypt_body(&req.body) {
268                        Ok(decrypted) => {
269                            req.body = bytes::Bytes::from(decrypted);
270                        }
271                        Err(e) => {
272                            metrics
273                                .total_request_errors
274                                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
275                            tracing::warn!(
276                                conn_id = conn_id,
277                                proto_id = proto_id_val,
278                                error = %e,
279                                "AES decrypt request failed, dropping"
280                            );
281                            continue;
282                        }
283                    }
284                }
285            }
286        }
287
288        // InitConnect 和 KeepAlive 内部处理
289        let response_body = match proto_id_val {
290            proto_id::INIT_CONNECT => {
291                if let Some(mut conn) = connections.get_mut(&conn_id) {
292                    conn.handle_init_connect(
293                        &req.body,
294                        config.server_ver,
295                        config.login_user_id,
296                        config.keepalive_interval,
297                        config.rsa_private_key.as_deref(),
298                    )
299                    .ok()
300                } else {
301                    None
302                }
303            }
304            proto_id::KEEP_ALIVE => {
305                if let Some(conn) = connections.get(&conn_id) {
306                    conn.handle_keepalive(&req.body).ok()
307                } else {
308                    None
309                }
310            }
311            _ => {
312                // 委托给路由器
313                router.dispatch(conn_id, &req).await
314            }
315        };
316
317        // 记录延迟
318        metrics.record_latency_ns(req_start.elapsed().as_nanos() as u64);
319
320        // 发送响应
321        if let Some(body) = response_body {
322            metrics
323                .total_response_bytes
324                .fetch_add(body.len() as u64, std::sync::atomic::Ordering::Relaxed);
325            ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
326        } else if proto_id_val != proto_id::INIT_CONNECT && proto_id_val != proto_id::KEEP_ALIVE {
327            metrics
328                .total_request_errors
329                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
330        }
331    }
332}