futu_rest/
adapter.rs

1//! 通用适配器:JSON ↔ Protobuf 转换 + 请求分发
2//!
3//! 核心思路:
4//! 1. HTTP 请求带 JSON body → 反序列化为 prost Message → encode 为 bytes
5//! 2. 构造 IncomingRequest { proto_id, body } → 调用 RequestRouter::dispatch
6//! 3. 响应 bytes → decode 为 prost Message → 序列化为 JSON 返回
7
8use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
9use std::sync::Arc;
10
11use axum::http::StatusCode;
12use axum::response::Json;
13use bytes::Bytes;
14use prost::Message;
15use serde_json::Value;
16
17use futu_auth::KeyStore;
18use futu_codec::header::ProtoFmtType;
19use futu_server::conn::IncomingRequest;
20use futu_server::router::RequestRouter;
21
22use crate::ws::WsBroadcaster;
23
24/// REST 服务共享状态
25#[derive(Clone)]
26pub struct RestState {
27    /// 复用 OpenD 的请求路由器
28    pub router: Arc<RequestRouter>,
29    /// WebSocket 推送广播器
30    pub ws_broadcaster: Arc<WsBroadcaster>,
31    /// Bearer Token KeyStore(未配置 → is_configured() == false,WS 握手也放行)
32    pub key_store: Arc<KeyStore>,
33    /// v1.2:handler 层 full CheckCtx 用的限额计数器(与 auth middleware
34    /// 共享同一份 Arc,rate window 跨 auth/handler 一致)
35    pub counters: Arc<futu_auth::RuntimeCounters>,
36    /// 虚拟连接 ID 分配器(REST 请求从 10_000_000 开始)
37    conn_id_counter: Arc<AtomicU64>,
38    /// 序列号分配器
39    serial_counter: Arc<AtomicU32>,
40}
41
42impl RestState {
43    pub fn new(router: Arc<RequestRouter>, ws_broadcaster: Arc<WsBroadcaster>) -> Self {
44        Self::with_key_store(router, ws_broadcaster, Arc::new(KeyStore::empty()))
45    }
46
47    pub fn with_key_store(
48        router: Arc<RequestRouter>,
49        ws_broadcaster: Arc<WsBroadcaster>,
50        key_store: Arc<KeyStore>,
51    ) -> Self {
52        Self::with_auth(
53            router,
54            ws_broadcaster,
55            key_store,
56            Arc::new(futu_auth::RuntimeCounters::new()),
57        )
58    }
59
60    /// v1.2 推荐入口:同时接 KeyStore + 共享 RuntimeCounters
61    pub fn with_auth(
62        router: Arc<RequestRouter>,
63        ws_broadcaster: Arc<WsBroadcaster>,
64        key_store: Arc<KeyStore>,
65        counters: Arc<futu_auth::RuntimeCounters>,
66    ) -> Self {
67        Self {
68            router,
69            ws_broadcaster,
70            key_store,
71            counters,
72            conn_id_counter: Arc::new(AtomicU64::new(10_000_000)),
73            serial_counter: Arc::new(AtomicU32::new(1)),
74        }
75    }
76
77    /// 分配虚拟连接 ID
78    pub fn next_conn_id(&self) -> u64 {
79        self.conn_id_counter.fetch_add(1, Ordering::Relaxed)
80    }
81
82    /// 分配序列号
83    fn next_serial(&self) -> u32 {
84        self.serial_counter.fetch_add(1, Ordering::Relaxed)
85    }
86}
87
88/// 通用 protobuf 请求-响应适配器
89///
90/// 泛型参数:
91/// - `Req`: protobuf 请求类型 (prost::Message + serde::Deserialize)
92/// - `Rsp`: protobuf 响应类型 (prost::Message + serde::Serialize)
93///
94/// 流程: JSON → Req → encode → dispatch(proto_id) → decode → Rsp → JSON
95pub async fn proto_request<Req, Rsp>(
96    state: &RestState,
97    proto_id: u32,
98    json_body: Option<Value>,
99) -> Result<Json<Value>, (StatusCode, Json<Value>)>
100where
101    Req: Message + Default + serde::de::DeserializeOwned,
102    Rsp: Message + Default + serde::Serialize,
103{
104    // 1. JSON → protobuf 请求
105    let req_msg: Req = if let Some(body) = json_body {
106        serde_json::from_value(body).map_err(|e| {
107            (
108                StatusCode::BAD_REQUEST,
109                Json(serde_json::json!({
110                    "error": format!("invalid request body: {e}")
111                })),
112            )
113        })?
114    } else {
115        Req::default()
116    };
117
118    // 2. encode 为 protobuf bytes
119    let body = Bytes::from(req_msg.encode_to_vec());
120
121    // 3. 构造 IncomingRequest 调用现有 handler
122    let incoming = IncomingRequest {
123        conn_id: state.next_conn_id(),
124        proto_id,
125        serial_no: state.next_serial(),
126        proto_fmt_type: ProtoFmtType::Protobuf,
127        body,
128    };
129
130    let resp_bytes = state
131        .router
132        .dispatch(incoming.conn_id, &incoming)
133        .await
134        .ok_or_else(|| {
135            (
136                StatusCode::INTERNAL_SERVER_ERROR,
137                Json(serde_json::json!({
138                    "error": "handler returned no response"
139                })),
140            )
141        })?;
142
143    // 4. decode protobuf 响应
144    let rsp_msg = Rsp::decode(Bytes::from(resp_bytes)).map_err(|e| {
145        (
146            StatusCode::INTERNAL_SERVER_ERROR,
147            Json(serde_json::json!({
148                "error": format!("failed to decode response: {e}")
149            })),
150        )
151    })?;
152
153    // 5. 序列化为 JSON
154    let json_rsp = serde_json::to_value(&rsp_msg).map_err(|e| {
155        (
156            StatusCode::INTERNAL_SERVER_ERROR,
157            Json(serde_json::json!({
158                "error": format!("failed to serialize response: {e}")
159            })),
160        )
161    })?;
162
163    Ok(Json(json_rsp))
164}
165
166/// 通用响应格式(包装 proto 响应,加 ret_type/ret_msg)
167#[derive(serde::Serialize)]
168pub struct ApiResponse<T: serde::Serialize> {
169    pub ret_type: i32,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub ret_msg: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub data: Option<T>,
174}