Skip to main content

futu_mcp/
transport.rs

1//! Resilient stdio transport for MCP server (v1.4.90 P0-A).
2//!
3//! ## Why this exists
4//!
5//! `rmcp::transport::stdio()` (i.e. the default `(Stdin, Stdout)` adapter via
6//! `AsyncRwTransport` + `JsonRpcMessageCodec`) treats *any* JSON parse error
7//! as a fatal stream error. Concretely:
8//!
9//! 1. `JsonRpcMessageCodec::decode()` returns `Err(JsonRpcMessageCodecError::Serde(_))`
10//!    when a line is malformed (e.g. `{"price": Infinity}` — JSON spec forbids
11//!    `Infinity` / `NaN` literals, but LLM clients emit them occasionally).
12//! 2. `FramedRead` yields `Some(Err(_))`.
13//! 3. `AsyncRwTransport::receive()` does `next.await.and_then(|e| e.ok())` —
14//!    converting `Err` to `None`.
15//! 4. The rmcp service loop interprets `None` as "input stream closed" and
16//!    breaks with `QuitReason::Closed`, terminating the entire MCP server.
17//!
18//! Result: a *single* malformed JSON line silently kills the whole server,
19//! disconnecting every client (multi-version sweep proven across v1.4.47 →
20//! v1.4.86 — 11 versions all vulnerable).
21//!
22//! Per JSON-RPC 2.0 §5.1, the correct behavior is to return a `-32700 Parse
23//! error` response and keep the connection alive. This module implements that
24//! behavior as a drop-in replacement for `rmcp::transport::stdio()`.
25//!
26//! ## Design
27//!
28//! - `ResilientStdioTransport` implements `rmcp::transport::Transport<RoleServer>`.
29//! - A background **reader task** owns stdin, reads newline-delimited frames,
30//!   and parses each into `RxJsonRpcMessage<RoleServer>`. Successful parses go
31//!   into an inbound mpsc channel for `receive()`. Parse failures cause a
32//!   synthetic `JsonRpcError(-32700)` to be enqueued onto the **outbound**
33//!   channel directly (bypassing `receive()` so the service never sees an
34//!   error event), and the loop continues.
35//! - A background **writer task** owns stdout and drains the bounded outbound
36//!   channel, serialising messages as one-line JSON each.
37//! - `send()` enqueues onto the outbound channel; `receive()` polls the
38//!   inbound channel; `close()` drops the senders so both tasks exit cleanly.
39//!
40//! ## What this is NOT
41//!
42//! This is a stdio-only fix. The HTTP transport (`StreamableHttpService`)
43//! has its own per-request HTTP body parsing — a malformed request there
44//! returns 4xx without killing the server, so it's not affected by this bug.
45//! Future work: upstream PR to rmcp so all transports share resilient parsing.
46
47use std::{collections::HashSet, io, sync::Arc, time::Duration};
48
49use rmcp::RoleServer;
50use rmcp::model::{ErrorCode, ErrorData, JsonRpcError, NumberOrString};
51use rmcp::service::{RxJsonRpcMessage, TxJsonRpcMessage};
52use rmcp::transport::Transport;
53use serde_json::Value;
54use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
55use tokio::sync::{Mutex, Notify, mpsc};
56
57/// Bound for the inbound channel. 64 frames buffered is plenty — the rmcp
58/// service loop drains promptly. If the channel ever fills up, `receive()` /
59/// the reader task will simply backpressure on stdin, which is fine.
60const INBOUND_BUFFER: usize = 64;
61
62/// Max bytes accepted for one inbound stdio JSONL frame.
63///
64/// This caps only client -> server request lines. Server -> client responses
65/// are still written normally through stdout. 10 MiB matches the REST strict
66/// JSON body order of magnitude and leaves plenty of room for large tool args
67/// while preventing an unbounded `Vec` growth on malformed/no-newline input.
68const MAX_STDIO_JSONL_LINE_BYTES: usize = 10 * 1024 * 1024;
69
70/// One-shot stdio clients commonly write initialize + tool call JSONL and then
71/// close stdin. rmcp treats stdin EOF as service shutdown and starts draining
72/// in-flight responses with its own short timeout; slow backend tools can lose
73/// their result before the response is written. Keep EOF hidden from rmcp until
74/// requests already accepted by this transport either respond or this guard
75/// expires. Normal persistent MCP clients never hit this path.
76const EOF_PENDING_DRAIN_GRACE: Duration = Duration::from_secs(15);
77
78/// Bound for server -> client JSON-RPC messages.
79///
80/// Tool results can be large. Keeping stdout writes behind a bounded channel
81/// makes slow readers apply backpressure instead of accumulating unbounded
82/// response bodies in memory.
83const OUTBOUND_BUFFER: usize = 64;
84
85type OutboundTx = mpsc::Sender<TxJsonRpcMessage<RoleServer>>;
86type OutboundRx = mpsc::Receiver<TxJsonRpcMessage<RoleServer>>;
87type PendingRequestIds = Arc<PendingRequests>;
88
89#[derive(Default)]
90struct PendingRequests {
91    ids: Mutex<HashSet<NumberOrString>>,
92    notify: Notify,
93}
94
95impl PendingRequests {
96    async fn insert(&self, id: NumberOrString) {
97        self.ids.lock().await.insert(id);
98    }
99
100    async fn remove(&self, id: &NumberOrString) {
101        let mut ids = self.ids.lock().await;
102        if ids.remove(id) {
103            self.notify.notify_waiters();
104        }
105    }
106
107    async fn wait_empty_or_timeout(&self, timeout: Duration) {
108        let deadline = tokio::time::Instant::now() + timeout;
109        loop {
110            let notified = self.notify.notified();
111            tokio::pin!(notified);
112            notified.as_mut().enable();
113            if self.ids.lock().await.is_empty() {
114                return;
115            }
116            let now = tokio::time::Instant::now();
117            if now >= deadline {
118                return;
119            }
120            tokio::select! {
121                _ = &mut notified => {}
122                _ = tokio::time::sleep_until(deadline) => return,
123            }
124        }
125    }
126}
127
128/// Resilient stdio transport — see module docs.
129pub struct ResilientStdioTransport {
130    inbound_rx: mpsc::Receiver<RxJsonRpcMessage<RoleServer>>,
131    /// Wrapped in `Arc<Mutex<>>` so `send()` can return a `'static` future
132    /// per the `Transport` trait contract.
133    outbound_tx: Arc<Mutex<Option<OutboundTx>>>,
134}
135
136impl ResilientStdioTransport {
137    /// Spawn reader + writer tasks bound to the supplied I/O handles.
138    ///
139    /// Generic over `R` / `W` so tests can inject in-memory pipes; production
140    /// callers use [`resilient_stdio()`].
141    pub fn new<R, W>(read: R, write: W) -> Self
142    where
143        R: AsyncRead + Send + Unpin + 'static,
144        W: AsyncWrite + Send + Unpin + 'static,
145    {
146        let (inbound_tx, inbound_rx) =
147            mpsc::channel::<RxJsonRpcMessage<RoleServer>>(INBOUND_BUFFER);
148        let (outbound_tx, outbound_rx) =
149            mpsc::channel::<TxJsonRpcMessage<RoleServer>>(OUTBOUND_BUFFER);
150        let pending = Arc::new(PendingRequests::default());
151
152        // Reader task — owns stdin, parses lines, recovers from parse errors.
153        let outbound_tx_for_reader = outbound_tx.clone();
154        tokio::spawn(reader_task(
155            read,
156            inbound_tx,
157            outbound_tx_for_reader,
158            Arc::clone(&pending),
159        ));
160
161        // Writer task — owns stdout, drains outbound queue.
162        tokio::spawn(writer_task(write, outbound_rx, pending));
163
164        Self {
165            inbound_rx,
166            outbound_tx: Arc::new(Mutex::new(Some(outbound_tx))),
167        }
168    }
169}
170
171/// Drop-in replacement for `rmcp::transport::stdio()`. Returns a transport
172/// that survives malformed JSON instead of `exit(0)`-ing.
173pub fn resilient_stdio() -> ResilientStdioTransport {
174    ResilientStdioTransport::new(tokio::io::stdin(), tokio::io::stdout())
175}
176
177async fn enqueue_parse_error_response(
178    outbound_tx: &OutboundTx,
179    err_msg: JsonRpcError,
180    context: &'static str,
181) {
182    if outbound_tx
183        .send(rmcp::model::JsonRpcMessage::Error(err_msg))
184        .await
185        .is_err()
186    {
187        tracing::debug!(
188            context,
189            "parse error response dropped because writer is gone"
190        );
191    }
192}
193
194async fn enqueue_message_too_large_response(outbound_tx: &OutboundTx) {
195    let err_msg = JsonRpcError::new(
196        NumberOrString::Number(0),
197        ErrorData::new(
198            ErrorCode::INVALID_REQUEST,
199            format!("stdio JSON-RPC message too large: exceeds {MAX_STDIO_JSONL_LINE_BYTES} bytes"),
200            None,
201        ),
202    );
203    enqueue_parse_error_response(outbound_tx, err_msg, "message_too_large").await;
204}
205
206#[derive(Debug)]
207pub enum TransportError {
208    Closed,
209}
210
211impl std::fmt::Display for TransportError {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        match self {
214            Self::Closed => f.write_str("transport closed"),
215        }
216    }
217}
218
219impl std::error::Error for TransportError {}
220
221impl Transport<RoleServer> for ResilientStdioTransport {
222    type Error = TransportError;
223
224    fn send(
225        &mut self,
226        item: TxJsonRpcMessage<RoleServer>,
227    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
228        let lock = self.outbound_tx.clone();
229        async move {
230            let tx = { lock.lock().await.as_ref().cloned() };
231            match tx {
232                Some(tx) => tx.send(item).await.map_err(|_| TransportError::Closed),
233                None => Err(TransportError::Closed),
234            }
235        }
236    }
237
238    async fn receive(&mut self) -> Option<RxJsonRpcMessage<RoleServer>> {
239        self.inbound_rx.recv().await
240    }
241
242    async fn close(&mut self) -> Result<(), Self::Error> {
243        let mut guard = self.outbound_tx.lock().await;
244        // Dropping the sender signals the writer task to exit. The reader
245        // task exits naturally on stdin EOF (or when its inbound_tx half is
246        // dropped, which happens when this struct drops).
247        guard.take();
248        self.inbound_rx.close();
249        Ok(())
250    }
251}
252
253// `ResilientStdioTransport: Transport<RoleServer>` automatically gives us
254// `IntoTransport<RoleServer, TransportError, TransportAdapterIdentity>` via
255// rmcp's blanket impl, so no explicit `IntoTransport` impl is needed here.
256
257// ---------------------------------------------------------------------------
258// Reader / writer tasks
259// ---------------------------------------------------------------------------
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq)]
262enum LimitedLineRead {
263    Eof,
264    Line,
265    TooLarge,
266}
267
268async fn read_until_newline_limited<R>(
269    reader: &mut R,
270    buf: &mut Vec<u8>,
271    max_bytes: usize,
272) -> io::Result<LimitedLineRead>
273where
274    R: AsyncBufRead + Unpin,
275{
276    let mut saw_any = false;
277    loop {
278        let available = reader.fill_buf().await?;
279        if available.is_empty() {
280            return if saw_any {
281                Ok(LimitedLineRead::Line)
282            } else {
283                Ok(LimitedLineRead::Eof)
284            };
285        }
286
287        let newline_pos = available.iter().position(|b| *b == b'\n');
288        let found_newline = newline_pos.is_some();
289        let take = newline_pos.map_or(available.len(), |idx| idx + 1);
290
291        if buf.len().saturating_add(take) > max_bytes {
292            reader.consume(take);
293            if !found_newline {
294                discard_until_newline(reader).await?;
295            }
296            return Ok(LimitedLineRead::TooLarge);
297        }
298
299        buf.extend_from_slice(&available[..take]);
300        reader.consume(take);
301        saw_any = true;
302
303        if found_newline {
304            return Ok(LimitedLineRead::Line);
305        }
306    }
307}
308
309async fn discard_until_newline<R>(reader: &mut R) -> io::Result<()>
310where
311    R: AsyncBufRead + Unpin,
312{
313    loop {
314        let available = reader.fill_buf().await?;
315        if available.is_empty() {
316            return Ok(());
317        }
318
319        let newline_pos = available.iter().position(|b| *b == b'\n');
320        let found_newline = newline_pos.is_some();
321        let take = newline_pos.map_or(available.len(), |idx| idx + 1);
322        reader.consume(take);
323
324        if found_newline {
325            return Ok(());
326        }
327    }
328}
329
330async fn reader_task<R>(
331    read: R,
332    inbound_tx: mpsc::Sender<RxJsonRpcMessage<RoleServer>>,
333    outbound_tx: OutboundTx,
334    pending: PendingRequestIds,
335) where
336    R: AsyncRead + Send + Unpin + 'static,
337{
338    let mut reader = BufReader::new(read);
339    // v1.4.93 P0-3 (BUG-003): read raw bytes instead of String to gracefully
340    // handle UTF-8 invalid sequences (e.g. `\xfe\xfe`, UTF-16 BOM, mixed
341    // binary). `read_line` into String returns Err(InvalidData) on bad UTF-8
342    // and the previous code matched `Err => break;` -> server terminated
343    // mid-session. Per JSON-RPC 2.0 §5.1 we should return -32700 Parse error
344    // and keep the connection alive (same as v1.4.90 P0-A but for the
345    // pre-string-conversion stage).
346    let mut line_bytes = Vec::<u8>::new();
347
348    loop {
349        line_bytes.clear();
350        match read_until_newline_limited(&mut reader, &mut line_bytes, MAX_STDIO_JSONL_LINE_BYTES)
351            .await
352        {
353            Ok(LimitedLineRead::Eof) => {
354                // True EOF — stdin closed by client. Give already accepted
355                // requests a chance to emit their JSON-RPC response before
356                // rmcp observes EOF and starts shutdown/drain.
357                pending.wait_empty_or_timeout(EOF_PENDING_DRAIN_GRACE).await;
358                // Let the inbound channel close so the service loop sees
359                // `receive() -> None` and shuts down cleanly.
360                tracing::debug!("resilient stdio: stdin EOF, closing");
361                break;
362            }
363            Ok(LimitedLineRead::TooLarge) => {
364                tracing::warn!(
365                    max_bytes = MAX_STDIO_JSONL_LINE_BYTES,
366                    "resilient stdio: inbound message too large, returning -32600 (server stays alive)"
367                );
368                enqueue_message_too_large_response(&outbound_tx).await;
369                continue;
370            }
371            Ok(LimitedLineRead::Line) => {
372                // v1.4.93 P0-3: try UTF-8 conversion; on failure emit -32700
373                // and continue reading instead of terminating the reader task.
374                let line_str = match std::str::from_utf8(&line_bytes) {
375                    Ok(s) => s,
376                    Err(utf8_err) => {
377                        // Build a small ASCII preview of the offending bytes
378                        // for the error message (escape non-ASCII as `\xNN`).
379                        let preview_bytes: String = line_bytes
380                            .iter()
381                            .take(64)
382                            .map(|b| {
383                                if (0x20..=0x7e).contains(b) {
384                                    (*b as char).to_string()
385                                } else {
386                                    format!("\\x{b:02x}")
387                                }
388                            })
389                            .collect();
390                        let err_msg = JsonRpcError::new(
391                            recover_request_id(""), // no id recoverable from non-UTF8
392                            ErrorData::new(
393                                ErrorCode::PARSE_ERROR,
394                                format!(
395                                    "Parse error: invalid UTF-8 at byte {}: {}",
396                                    utf8_err.valid_up_to(),
397                                    utf8_err
398                                ),
399                                None,
400                            ),
401                        );
402                        tracing::warn!(
403                            error = %utf8_err,
404                            line_preview = %preview_bytes,
405                            "resilient stdio: invalid UTF-8 input, returning -32700 (server stays alive)"
406                        );
407                        enqueue_parse_error_response(&outbound_tx, err_msg, "invalid_utf8").await;
408                        continue;
409                    }
410                };
411                let trimmed =
412                    line_str.trim_matches(|c| c == '\n' || c == '\r' || c == ' ' || c == '\t');
413                if trimmed.is_empty() {
414                    continue;
415                }
416                match serde_json::from_str::<RxJsonRpcMessage<RoleServer>>(trimmed) {
417                    Ok(msg) => {
418                        let pending_id = request_id(&msg).cloned();
419                        if let Some(id) = pending_id.clone() {
420                            pending.insert(id).await;
421                        }
422                        if inbound_tx.send(msg).await.is_err() {
423                            if let Some(id) = pending_id {
424                                pending.remove(&id).await;
425                            }
426                            // Receiver dropped — transport is closing.
427                            break;
428                        }
429                    }
430                    Err(parse_err) => {
431                        // Per JSON-RPC 2.0 §5.1, return -32700 Parse error
432                        // and keep the connection alive. Try to extract the
433                        // request id from the malformed payload (best-effort
434                        // — the spec says id should be `null` if it can't
435                        // be determined, but rmcp's `JsonRpcError` requires
436                        // a `RequestId` so we synthesise one as 0 / "" when
437                        // missing).
438                        let id = recover_request_id(trimmed);
439                        let err_msg = JsonRpcError::new(
440                            id,
441                            ErrorData::new(
442                                ErrorCode::PARSE_ERROR,
443                                format!("Parse error: {parse_err}"),
444                                None,
445                            ),
446                        );
447                        tracing::warn!(
448                            error = %parse_err,
449                            line_preview = %preview(trimmed),
450                            "resilient stdio: parse error, returning -32700 (server stays alive)"
451                        );
452                        enqueue_parse_error_response(&outbound_tx, err_msg, "json_parse_error")
453                            .await;
454                    }
455                }
456            }
457            Err(io_err) => {
458                tracing::warn!(error = %io_err, "resilient stdio: read error, terminating");
459                break;
460            }
461        }
462    }
463
464    // Drop the inbound sender so receive() returns None and the service
465    // loop exits cleanly.
466    drop(inbound_tx);
467}
468
469async fn writer_task<W>(write: W, mut outbound_rx: OutboundRx, pending: PendingRequestIds)
470where
471    W: AsyncWrite + Send + Unpin + 'static,
472{
473    let mut write = write;
474    while let Some(msg) = outbound_rx.recv().await {
475        let response_id = response_id(&msg).cloned();
476        match serde_json::to_vec(&msg) {
477            Ok(mut bytes) => {
478                bytes.push(b'\n');
479                if let Err(io_err) = write.write_all(&bytes).await {
480                    if let Some(id) = response_id.as_ref() {
481                        pending.remove(id).await;
482                    }
483                    tracing::warn!(error = %io_err, "resilient stdio: write error, terminating");
484                    break;
485                }
486                if let Err(io_err) = write.flush().await {
487                    if let Some(id) = response_id.as_ref() {
488                        pending.remove(id).await;
489                    }
490                    tracing::warn!(error = %io_err, "resilient stdio: flush error, terminating");
491                    break;
492                }
493                if let Some(id) = response_id.as_ref() {
494                    pending.remove(id).await;
495                }
496            }
497            Err(serde_err) => {
498                // This should never happen — TxJsonRpcMessage<RoleServer>
499                // is always serialisable. If it does, log and skip.
500                tracing::error!(
501                    error = %serde_err,
502                    "resilient stdio: failed to serialise outbound message (BUG)"
503                );
504            }
505        }
506    }
507}
508
509fn request_id(msg: &RxJsonRpcMessage<RoleServer>) -> Option<&NumberOrString> {
510    match msg {
511        rmcp::model::JsonRpcMessage::Request(req) => Some(&req.id),
512        _ => None,
513    }
514}
515
516fn response_id(msg: &TxJsonRpcMessage<RoleServer>) -> Option<&NumberOrString> {
517    match msg {
518        rmcp::model::JsonRpcMessage::Response(resp) => Some(&resp.id),
519        rmcp::model::JsonRpcMessage::Error(err) => Some(&err.id),
520        _ => None,
521    }
522}
523
524/// Best-effort recovery of the `id` field from a malformed JSON-RPC payload.
525/// Falls back to `Number(0)` when extraction fails (the JSON-RPC spec says
526/// "null" is the canonical placeholder, but rmcp's `RequestId` doesn't admit
527/// a null variant — `Number(0)` is the closest match and round-trips cleanly).
528fn recover_request_id(line: &str) -> NumberOrString {
529    if let Ok(value) = serde_json::from_str::<Value>(line)
530        && let Some(id) = value.get("id")
531    {
532        if let Some(n) = id.as_i64() {
533            return NumberOrString::Number(n);
534        }
535        if let Some(s) = id.as_str() {
536            return NumberOrString::String(s.into());
537        }
538    }
539    NumberOrString::Number(0)
540}
541
542/// Truncate a line for log output (avoid dumping arbitrary client input
543/// into the audit log unbounded).
544fn preview(s: &str) -> String {
545    const MAX: usize = 200;
546    if s.len() <= MAX {
547        s.to_string()
548    } else {
549        format!("{}…(+{} bytes)", &s[..MAX], s.len() - MAX)
550    }
551}
552
553// ---------------------------------------------------------------------------
554// Tests
555// ---------------------------------------------------------------------------
556
557#[cfg(test)]
558mod tests;