1use std::sync::Arc;
4use std::time::Instant;
5
6use dashmap::DashMap;
7use tokio::net::TcpListener;
8use tokio::sync::mpsc;
9
10use futu_codec::header::ProtoFmtType;
11use futu_core::proto_id;
12
13use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
14use crate::metrics::GatewayMetrics;
15use crate::router::RequestRouter;
16
17pub const MAX_CONNECTIONS: usize = 128;
19
20#[derive(Debug, Clone)]
22pub struct ServerConfig {
23 pub listen_addr: String,
24 pub server_ver: i32,
25 pub login_user_id: u64,
26 pub keepalive_interval: i32,
27 pub rsa_private_key: Option<String>,
29}
30
31pub struct ApiServer {
33 config: ServerConfig,
34 connections: Arc<DashMap<u64, ClientConn>>,
35 router: Arc<RequestRouter>,
36 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
37 metrics: Arc<GatewayMetrics>,
38}
39
40impl ApiServer {
41 pub fn new(config: ServerConfig) -> Self {
42 Self {
43 config,
44 connections: Arc::new(DashMap::new()),
45 router: Arc::new(RequestRouter::new()),
46 subscriptions: None,
47 metrics: Arc::new(GatewayMetrics::new()),
48 }
49 }
50
51 pub fn set_subscriptions(&mut self, subs: Arc<crate::subscription::SubscriptionManager>) {
53 self.subscriptions = Some(subs);
54 }
55
56 pub fn router(&self) -> &Arc<RequestRouter> {
58 &self.router
59 }
60
61 pub fn connections(&self) -> &Arc<DashMap<u64, ClientConn>> {
63 &self.connections
64 }
65
66 pub fn set_metrics(&mut self, metrics: Arc<GatewayMetrics>) {
68 self.metrics = metrics;
69 }
70
71 pub fn metrics(&self) -> &Arc<GatewayMetrics> {
73 &self.metrics
74 }
75
76 pub async fn run(&self) -> anyhow::Result<()> {
78 let listener = TcpListener::bind(&self.config.listen_addr).await?;
79 tracing::info!(addr = %self.config.listen_addr, "API server listening");
80
81 let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
82 let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
83
84 let connections = Arc::clone(&self.connections);
86 let router = Arc::clone(&self.router);
87 let config = self.config.clone();
88 let metrics = Arc::clone(&self.metrics);
89 tokio::spawn(async move {
90 process_requests(req_rx, connections, router, config, metrics).await;
91 });
92
93 let cleanup_connections = Arc::clone(&self.connections);
95 let cleanup_subs = self.subscriptions.clone();
96 let cleanup_metrics = Arc::clone(&self.metrics);
97 tokio::spawn(async move {
98 while let Some(notify) = disconnect_rx.recv().await {
99 let removed = cleanup_connections.remove(¬ify.conn_id);
100 if removed.is_some() {
101 cleanup_metrics
102 .total_disconnections
103 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
104 if let Some(ref subs) = cleanup_subs {
106 subs.on_disconnect(notify.conn_id);
107 }
108 tracing::info!(
109 conn_id = notify.conn_id,
110 remaining = cleanup_connections.len(),
111 "connection removed from pool"
112 );
113 }
114 }
115 });
116
117 let ka_connections = Arc::clone(&self.connections);
119 let ka_subs = self.subscriptions.clone();
120 let ka_metrics = Arc::clone(&self.metrics);
121 tokio::spawn(async move {
122 const CHECK_INTERVAL_SECS: u64 = 15;
123 const TIMEOUT_SECS: u64 = 66;
124 let mut interval =
125 tokio::time::interval(std::time::Duration::from_secs(CHECK_INTERVAL_SECS));
126 interval.tick().await; loop {
128 interval.tick().await;
129 let now = Instant::now();
130 let mut timed_out = Vec::new();
131 for entry in ka_connections.iter() {
132 let conn = entry.value();
133 if now.duration_since(conn.last_keepalive).as_secs() >= TIMEOUT_SECS {
134 timed_out.push(conn.conn_id);
135 }
136 }
137 for conn_id in timed_out {
138 if ka_connections.remove(&conn_id).is_some() {
139 ka_metrics
140 .keepalive_timeouts
141 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
142 ka_metrics
143 .total_disconnections
144 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
145 if let Some(ref subs) = ka_subs {
146 subs.on_disconnect(conn_id);
147 }
148 tracing::info!(
149 conn_id = conn_id,
150 remaining = ka_connections.len(),
151 "keepalive timeout, connection removed"
152 );
153 }
154 }
155 }
156 });
157
158 let connections = Arc::clone(&self.connections);
160 let accept_metrics = Arc::clone(&self.metrics);
161 loop {
162 let (stream, peer_addr) = listener.accept().await?;
163
164 if connections.len() >= MAX_CONNECTIONS {
165 accept_metrics
166 .rejected_connections
167 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
168 tracing::warn!(
169 peer = %peer_addr,
170 "max connections reached ({}), rejecting",
171 MAX_CONNECTIONS
172 );
173 drop(stream);
174 continue;
175 }
176
177 accept_metrics
178 .total_connections
179 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
180
181 let conn_id = crate::conn::ClientConn::generate_conn_id();
182 let aes_key = crate::conn::ClientConn::generate_aes_key();
183 stream.set_nodelay(true).ok();
184
185 tracing::info!(
186 conn_id = conn_id,
187 peer = %peer_addr,
188 total = connections.len() + 1,
189 "client connected"
190 );
191
192 let tx = crate::conn::run_connection(
193 stream,
194 conn_id,
195 aes_key,
196 req_tx.clone(),
197 disconnect_tx.clone(),
198 )
199 .await;
200
201 let conn = ClientConn {
202 conn_id,
203 state: ConnState::Connected,
204 aes_key,
205 aes_encrypt_enabled: false,
206 proto_fmt_type: ProtoFmtType::Protobuf,
207 last_keepalive: Instant::now(),
208 keepalive_count: std::sync::atomic::AtomicU32::new(0),
209 tx,
210 key_id: None,
213 scopes: std::collections::HashSet::new(),
214 };
215
216 connections.insert(conn_id, conn);
217 }
218 }
219
220 pub async fn send_response(
222 connections: &DashMap<u64, ClientConn>,
223 conn_id: u64,
224 proto_id: u32,
225 serial_no: u32,
226 body: Vec<u8>,
227 ) {
228 if let Some(conn) = connections.get(&conn_id) {
229 let frame = conn.make_frame(proto_id, serial_no, bytes::Bytes::from(body));
230 if conn.tx.send(frame).await.is_err() {
231 tracing::warn!(
232 conn_id = conn_id,
233 "failed to send response, connection closed"
234 );
235 }
236 }
237 }
238}
239
240async fn process_requests(
242 mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
243 connections: Arc<DashMap<u64, ClientConn>>,
244 router: Arc<RequestRouter>,
245 config: ServerConfig,
246 metrics: Arc<GatewayMetrics>,
247) {
248 while let Some(mut req) = req_rx.recv().await {
249 let conn_id = req.conn_id;
250 let proto_id_val = req.proto_id;
251 let serial_no = req.serial_no;
252 let req_start = Instant::now();
253
254 metrics
255 .total_requests
256 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257
258 if let Some(mut conn) = connections.get_mut(&conn_id) {
260 conn.last_keepalive = Instant::now();
261 }
262
263 if proto_id_val != proto_id::INIT_CONNECT {
265 if let Some(conn) = connections.get(&conn_id) {
266 if conn.aes_encrypt_enabled {
267 match conn.decrypt_body(&req.body) {
268 Ok(decrypted) => {
269 req.body = bytes::Bytes::from(decrypted);
270 }
271 Err(e) => {
272 metrics
273 .total_request_errors
274 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
275 tracing::warn!(
276 conn_id = conn_id,
277 proto_id = proto_id_val,
278 error = %e,
279 "AES decrypt request failed, dropping"
280 );
281 continue;
282 }
283 }
284 }
285 }
286 }
287
288 let response_body = match proto_id_val {
290 proto_id::INIT_CONNECT => {
291 if let Some(mut conn) = connections.get_mut(&conn_id) {
292 conn.handle_init_connect(
293 &req.body,
294 config.server_ver,
295 config.login_user_id,
296 config.keepalive_interval,
297 config.rsa_private_key.as_deref(),
298 )
299 .ok()
300 } else {
301 None
302 }
303 }
304 proto_id::KEEP_ALIVE => {
305 if let Some(conn) = connections.get(&conn_id) {
306 conn.handle_keepalive(&req.body).ok()
307 } else {
308 None
309 }
310 }
311 _ => {
312 router.dispatch(conn_id, &req).await
314 }
315 };
316
317 metrics.record_latency_ns(req_start.elapsed().as_nanos() as u64);
319
320 if let Some(body) = response_body {
322 metrics
323 .total_response_bytes
324 .fetch_add(body.len() as u64, std::sync::atomic::Ordering::Relaxed);
325 ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
326 } else if proto_id_val != proto_id::INIT_CONNECT && proto_id_val != proto_id::KEEP_ALIVE {
327 metrics
328 .total_request_errors
329 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
330 }
331 }
332}