futu_grpc/
server.rs

1//! gRPC 服务实现
2//!
3//! FutuOpenD 服务通过通用的 proto_id + body 方式,
4//! 将所有请求转发到现有的 RequestRouter。
5//! 支持流式推送:行情、交易、广播事件通过 SubscribePush 接口推送给客户端。
6
7use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
8use std::sync::Arc;
9
10use bytes::Bytes;
11use tokio::sync::{broadcast, mpsc};
12use tokio_stream::wrappers::ReceiverStream;
13use tonic::{Request, Response, Status};
14
15use chrono::Utc;
16use futu_auth::{CheckCtx, KeyRecord, KeyStore, LimitOutcome, RuntimeCounters, Scope};
17use futu_codec::header::ProtoFmtType;
18use futu_proto::{trd_modify_order, trd_place_order};
19use futu_server::conn::IncomingRequest;
20use futu_server::push::ExternalPushSink;
21use futu_server::router::RequestRouter;
22use prost::Message;
23
24use crate::auth::{authenticate, check_scope, scope_for_proto};
25use crate::proto::futu_open_d_server::{FutuOpenD, FutuOpenDServer};
26use crate::proto::{FutuRequest, FutuResponse, PushEvent, SubscribePushRequest};
27
28/// gRPC 推送广播器
29///
30/// 实现 `ExternalPushSink` trait,接收 PushDispatcher 的推送事件,
31/// 通过 broadcast channel 分发给所有 SubscribePush 流式连接。
32#[derive(Clone)]
33pub struct GrpcPushBroadcaster {
34    tx: broadcast::Sender<PushEvent>,
35}
36
37impl GrpcPushBroadcaster {
38    pub fn new(capacity: usize) -> Self {
39        let (tx, _) = broadcast::channel(capacity);
40        Self { tx }
41    }
42
43    /// 创建接收端
44    pub fn subscribe(&self) -> broadcast::Receiver<PushEvent> {
45        self.tx.subscribe()
46    }
47
48    fn send(&self, event: PushEvent) {
49        let _ = self.tx.send(event);
50    }
51}
52
53impl ExternalPushSink for GrpcPushBroadcaster {
54    fn on_quote_push(&self, sec_key: &str, sub_type: i32, proto_id: u32, body: &[u8]) {
55        self.send(PushEvent {
56            proto_id,
57            sec_key: sec_key.to_string(),
58            sub_type,
59            body: body.to_vec(),
60            event_type: "quote".to_string(),
61            acc_id: 0,
62        });
63    }
64
65    fn on_broadcast_push(&self, proto_id: u32, body: &[u8]) {
66        self.send(PushEvent {
67            proto_id,
68            sec_key: String::new(),
69            sub_type: 0,
70            body: body.to_vec(),
71            event_type: "notify".to_string(),
72            acc_id: 0,
73        });
74    }
75
76    fn on_trade_push(&self, acc_id: u64, proto_id: u32, body: &[u8]) {
77        self.send(PushEvent {
78            proto_id,
79            sec_key: String::new(),
80            sub_type: 0,
81            body: body.to_vec(),
82            event_type: "trade".to_string(),
83            acc_id,
84        });
85    }
86}
87
88/// gRPC 服务实现
89pub struct FutuGrpcService {
90    router: Arc<RequestRouter>,
91    push_broadcaster: Arc<GrpcPushBroadcaster>,
92    key_store: Arc<KeyStore>,
93    counters: Arc<RuntimeCounters>,
94    conn_id_counter: AtomicU64,
95    serial_counter: AtomicU32,
96}
97
98impl FutuGrpcService {
99    pub fn new(router: Arc<RequestRouter>, push_broadcaster: Arc<GrpcPushBroadcaster>) -> Self {
100        Self::with_auth(
101            router,
102            push_broadcaster,
103            Arc::new(KeyStore::empty()),
104            Arc::new(RuntimeCounters::new()),
105        )
106    }
107
108    /// 仅 key_store,counters 新建(向后兼容,调用方暂未接 counters 时用)
109    pub fn with_key_store(
110        router: Arc<RequestRouter>,
111        push_broadcaster: Arc<GrpcPushBroadcaster>,
112        key_store: Arc<KeyStore>,
113    ) -> Self {
114        Self::with_auth(
115            router,
116            push_broadcaster,
117            key_store,
118            Arc::new(RuntimeCounters::new()),
119        )
120    }
121
122    /// 完整构造:同时接 key_store + counters(v1.0 推荐入口)
123    ///
124    /// `counters` 应由 main 全进程共享:REST / gRPC / MCP 共用一个实例才能保证
125    /// rate limit / 日累计跨接口一致
126    pub fn with_auth(
127        router: Arc<RequestRouter>,
128        push_broadcaster: Arc<GrpcPushBroadcaster>,
129        key_store: Arc<KeyStore>,
130        counters: Arc<RuntimeCounters>,
131    ) -> Self {
132        Self {
133            router,
134            push_broadcaster,
135            key_store,
136            counters,
137            conn_id_counter: AtomicU64::new(20_000_000), // gRPC 虚拟连接从 20M 开始
138            serial_counter: AtomicU32::new(1),
139        }
140    }
141
142    fn next_conn_id(&self) -> u64 {
143        self.conn_id_counter.fetch_add(1, Ordering::Relaxed)
144    }
145
146    fn next_serial(&self) -> u32 {
147        self.serial_counter.fetch_add(1, Ordering::Relaxed)
148    }
149}
150
151#[tonic::async_trait]
152impl FutuOpenD for FutuGrpcService {
153    /// 通用请求-响应
154    async fn request(
155        &self,
156        request: Request<FutuRequest>,
157    ) -> Result<Response<FutuResponse>, Status> {
158        // 鉴权:提取 Bearer token → 验证 key → 按 proto_id 校验 scope
159        let authed = authenticate(&self.key_store, &request)?;
160
161        let req = request.into_inner();
162
163        if req.proto_id == 0 {
164            return Err(Status::invalid_argument("proto_id is required"));
165        }
166
167        if let Some(needed) = scope_for_proto(req.proto_id) {
168            check_scope(&authed, req.proto_id, needed)?;
169
170            // trade:real 的请求过一次通用 rate + hours 闸门(CheckCtx 全空,
171            // 不解析 body,只挂全局闸门;细粒度检查 handler 层做)
172            if needed == Scope::TradeReal {
173                if let Some(rec) = authed.as_ref() {
174                    let ctx = CheckCtx {
175                        market: String::new(),
176                        symbol: String::new(),
177                        order_value: None,
178                        trd_side: None,
179                    };
180                    if let LimitOutcome::Reject(reason) =
181                        self.counters
182                            .check_and_commit(&rec.id, &rec.limits(), &ctx, Utc::now())
183                    {
184                        let endpoint = format!("proto_id={}", req.proto_id);
185                        futu_auth::audit::reject(
186                            "grpc",
187                            &endpoint,
188                            &rec.id,
189                            &format!("limit: {reason}"),
190                        );
191                        return Err(Status::resource_exhausted(format!(
192                            "limit check failed: {reason}"
193                        )));
194                    }
195
196                    // v1.2 handler 层 full CheckCtx:按 proto_id prost decode body
197                    // 提取 market/symbol/value/side/daily(rate 已 commit,用
198                    // check_full_skip_rate 不重复计 rate)
199                    grpc_handler_full_check(&self.counters, rec, req.proto_id, &req.body)?;
200                }
201            }
202        }
203
204        let incoming = IncomingRequest {
205            conn_id: self.next_conn_id(),
206            proto_id: req.proto_id,
207            serial_no: self.next_serial(),
208            proto_fmt_type: ProtoFmtType::Protobuf,
209            body: Bytes::from(req.body),
210        };
211
212        match self.router.dispatch(incoming.conn_id, &incoming).await {
213            Some(resp_bytes) => Ok(Response::new(FutuResponse {
214                ret_type: 0,
215                ret_msg: String::new(),
216                proto_id: req.proto_id,
217                body: resp_bytes,
218            })),
219            None => Ok(Response::new(FutuResponse {
220                ret_type: -1,
221                ret_msg: "handler returned no response".to_string(),
222                proto_id: req.proto_id,
223                body: Vec::new(),
224            })),
225        }
226    }
227
228    type SubscribePushStream = ReceiverStream<Result<PushEvent, Status>>;
229
230    /// 流式推送:客户端建立连接后持续接收行情、交易、广播推送
231    ///
232    /// v1.1:按订阅 key 的 scope 过滤推送 —— `qot:read`-only 的 key 不会收到
233    /// `trade` 类(账户交易回报),对齐 REST `/ws` v0.9.0 加的 push filter。
234    async fn subscribe_push(
235        &self,
236        request: Request<SubscribePushRequest>,
237    ) -> Result<Response<Self::SubscribePushStream>, Status> {
238        // 鉴权:握手最低门槛 qot:read;后续推送按 scope 过滤
239        let authed = authenticate(&self.key_store, &request)?;
240        check_scope(&authed, 0, Scope::QotRead)?;
241
242        // 拍下连接的 scope 集合 + key_id 用于 filter 决策;legacy 模式(authed=None)
243        // 给"全 scope" 让 filter 全放行
244        let (scopes, key_id) = match authed.as_ref() {
245            Some(rec) => (rec.scopes.clone(), rec.id.clone()),
246            None => (
247                [
248                    Scope::QotRead,
249                    Scope::AccRead,
250                    Scope::TradeSimulate,
251                    Scope::TradeReal,
252                ]
253                .into_iter()
254                .collect::<std::collections::HashSet<Scope>>(),
255                "<none>".to_string(),
256            ),
257        };
258
259        let (tx, rx) = mpsc::channel(256);
260        let mut push_rx = self.push_broadcaster.subscribe();
261
262        tracing::info!(key_id = %key_id, scopes = ?scopes, "gRPC client subscribed to push events");
263
264        // 后台任务:从 broadcast channel 读取推送 → 按 scope 过滤 → 发到 gRPC 流
265        tokio::spawn(async move {
266            loop {
267                match push_rx.recv().await {
268                    Ok(event) => {
269                        let needed = scope_for_event(&event.event_type);
270                        if !scopes.contains(&needed) {
271                            // 计 metrics counter:方便看"哪个 key 因 scope 不够丢了多少推送"
272                            futu_auth::metrics::bump_ws_filtered(&event.event_type, &key_id);
273                            continue;
274                        }
275                        if tx.send(Ok(event)).await.is_err() {
276                            break; // 客户端断开
277                        }
278                    }
279                    Err(broadcast::error::RecvError::Lagged(n)) => {
280                        tracing::warn!(skipped = n, "gRPC push client lagged, skipped events");
281                        // 继续接收,不断开
282                    }
283                    Err(broadcast::error::RecvError::Closed) => {
284                        break; // 广播器关闭
285                    }
286                }
287            }
288            tracing::info!("gRPC push stream ended");
289        });
290
291        Ok(Response::new(ReceiverStream::new(rx)))
292    }
293}
294
295/// gRPC PushEvent 的 event_type → 客户端必须持有的 scope(与 REST `/ws` 对齐)
296/// `Trd_Common.TrdMarket` enum int → 市场字符串(同 REST 的 helper,
297/// 复制一份避免给 futu-grpc 拉 futu-rest 依赖)
298fn trd_market_str(i: i32) -> &'static str {
299    match i {
300        1 => "HK",
301        2 => "US",
302        3 => "CN",
303        4 => "HKCC",
304        5 => "FUTURES",
305        6 => "SG",
306        7 => "JP",
307        _ => "",
308    }
309}
310
311/// `Trd_Common.TrdSide` enum int → 方向字符串
312fn trd_side_str(i: i32) -> &'static str {
313    match i {
314        1 => "BUY",
315        2 => "SELL",
316        3 => "SELL_SHORT",
317        4 => "BUY_BACK",
318        _ => "",
319    }
320}
321
322/// gRPC handler 层 full CheckCtx —— 在 trade:real 通用闸门通过后调用。
323/// 按 proto_id 解码 body 提取 market/symbol/value/side,调
324/// `check_full_skip_rate`(rate 已在 auth 层 commit,这里不重复计)。
325///
326/// proto_id 不是 PlaceOrder/ModifyOrder(比如 UnlockTrade、ReconfirmOrder)
327/// 时直接放行 —— 这些请求没有 order params,只挂 rate/hours 全局闸门即可。
328/// body decode 失败时也放行,让下游 handler 报真正的 protobuf 错误。
329#[allow(clippy::result_large_err)] // tonic Status 与 RPC 签名一致
330fn grpc_handler_full_check(
331    counters: &RuntimeCounters,
332    rec: &KeyRecord,
333    proto_id: u32,
334    body: &[u8],
335) -> Result<(), Status> {
336    let ctx = match proto_id {
337        2202 => {
338            // TRD_PLACE_ORDER
339            let parsed = match trd_place_order::Request::decode(body) {
340                Ok(p) => p,
341                Err(_) => return Ok(()), // 让下游报 decode 错误
342            };
343            let c2s = &parsed.c2s;
344            let market = trd_market_str(c2s.header.trd_market);
345            let symbol = if market.is_empty() {
346                String::new()
347            } else {
348                format!("{market}.{}", c2s.code)
349            };
350            let trd_side = match trd_side_str(c2s.trd_side) {
351                "" => None,
352                s => Some(s.to_string()),
353            };
354            CheckCtx {
355                market: market.to_string(),
356                symbol,
357                order_value: c2s.price.map(|p| p * c2s.qty),
358                trd_side,
359            }
360        }
361        2205 => {
362            // TRD_MODIFY_ORDER —— 只有 order_id,能拿到 trd_market
363            let parsed = match trd_modify_order::Request::decode(body) {
364                Ok(p) => p,
365                Err(_) => return Ok(()),
366            };
367            CheckCtx {
368                market: trd_market_str(parsed.c2s.header.trd_market).to_string(),
369                symbol: String::new(),
370                order_value: None,
371                trd_side: None,
372            }
373        }
374        _ => return Ok(()),
375    };
376
377    let now = Utc::now();
378    if let LimitOutcome::Reject(reason) =
379        counters.check_full_skip_rate(&rec.id, &rec.limits(), &ctx, now)
380    {
381        let endpoint = format!("proto_id={proto_id}");
382        futu_auth::audit::reject("grpc", &endpoint, &rec.id, &format!("limit: {reason}"));
383        return Err(Status::resource_exhausted(format!(
384            "limit check failed: {reason}"
385        )));
386    }
387    Ok(())
388}
389
390fn scope_for_event(event_type: &str) -> Scope {
391    match event_type {
392        "trade" => Scope::AccRead, // 账户交易回报
393        _ => Scope::QotRead,       // quote / notify / 未知都按行情门槛
394    }
395}
396
397/// 构建 gRPC 服务(供外部调用 tonic Server 使用)
398pub fn build_service(
399    router: Arc<RequestRouter>,
400    push_broadcaster: Arc<GrpcPushBroadcaster>,
401) -> FutuOpenDServer<FutuGrpcService> {
402    FutuOpenDServer::new(FutuGrpcService::new(router, push_broadcaster))
403}
404
405/// 构建 gRPC 服务(带 KeyStore 鉴权 + 共享限额 counters)
406pub fn build_service_with_auth(
407    router: Arc<RequestRouter>,
408    push_broadcaster: Arc<GrpcPushBroadcaster>,
409    key_store: Arc<KeyStore>,
410    counters: Arc<RuntimeCounters>,
411) -> FutuOpenDServer<FutuGrpcService> {
412    FutuOpenDServer::new(FutuGrpcService::with_auth(
413        router,
414        push_broadcaster,
415        key_store,
416        counters,
417    ))
418}
419
420/// 启动 gRPC 服务
421pub async fn start(
422    listen_addr: &str,
423    router: Arc<RequestRouter>,
424    push_broadcaster: Arc<GrpcPushBroadcaster>,
425) -> Result<(), Box<dyn std::error::Error>> {
426    start_with_auth(
427        listen_addr,
428        router,
429        push_broadcaster,
430        Arc::new(KeyStore::empty()),
431        Arc::new(RuntimeCounters::new()),
432    )
433    .await
434}
435
436/// 启动 gRPC 服务(带 KeyStore 鉴权 + 共享限额 counters)
437pub async fn start_with_auth(
438    listen_addr: &str,
439    router: Arc<RequestRouter>,
440    push_broadcaster: Arc<GrpcPushBroadcaster>,
441    key_store: Arc<KeyStore>,
442    counters: Arc<RuntimeCounters>,
443) -> Result<(), Box<dyn std::error::Error>> {
444    let addr = listen_addr
445        .parse()
446        .map_err(|e| format!("invalid addr: {e}"))?;
447    if !key_store.is_configured() {
448        tracing::warn!(
449            "gRPC server running WITHOUT API key auth (legacy mode); \
450             all RPCs are open. Pass --grpc-keys-file to enable scope-based auth."
451        );
452    }
453    let service = build_service_with_auth(router, push_broadcaster, key_store, counters);
454    tracing::info!(addr = %listen_addr, "gRPC 服务已启动");
455    tonic::transport::Server::builder()
456        .add_service(service)
457        .serve(addr)
458        .await?;
459    Ok(())
460}