1use httparse::{EMPTY_HEADER, parse_headers};
2use serde::{Deserialize, Serialize, de::DeserializeOwned};
3use serde_json::{Error, Value, from_value, to_string, to_value};
4use std::io;
5
6use crate::{Notification, Request, Response};
7
8use super::{ErrorCode, Id};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
15#[serde(untagged)]
16pub enum Message {
17 Request(Request),
19 Response(Response),
21 Notification(Notification),
23}
24
25impl Message {
26 #[doc(hidden)]
27 #[inline]
28 pub fn is_exit_notification(&self) -> bool {
29 if let Message::Notification(notification) = self {
30 matches!(notification.method.as_str(), "exit")
31 } else {
32 false
33 }
34 }
35
36 #[doc(hidden)]
37 #[inline]
38 pub fn is_initialize_request(&self) -> bool {
39 if let Message::Request(request) = self { matches!(request.method.as_str(), "initialize") } else { false }
40 }
41
42 pub fn method(&self) -> Option<&str> {
43 match self {
44 Message::Request(Request { method, .. }) => Some(method),
45 Message::Notification(Notification { method, .. }) => Some(method),
46 _ => None,
47 }
48 }
49
50 pub fn id(&self) -> Option<Id> {
51 match self {
52 Message::Request(Request { id, .. }) => Some(id.clone()),
53 Message::Notification(_) => None,
54 Message::Response(Response::Ok(id, _)) => Some(id.clone()),
55 Message::Response(Response::Err(id, _, _, _)) => Some(id.clone()),
56 }
57 }
58
59 pub fn from_value<T: DeserializeOwned>(&self) -> Result<T, Error> {
60 match self {
61 Message::Request(Request { params, .. }) => from_value(params.clone()),
62 Message::Notification(Notification { params, .. }) => from_value(params.clone()),
63 Message::Response(Response::Ok(_, params)) => from_value(params.clone()),
64 Message::Response(Response::Err(_, _, _, params)) => from_value(params.clone()),
65 }
66 }
67
68 pub fn from_result(id: Id, result: Result<impl Serialize, ErrorCode>) -> Message {
69 match result {
70 Err(code) => Message::Response(Response::Err(id, code, "".into(), Value::Null)),
71 Ok(value) => Message::Response(Response::Ok(id, to_value(value).unwrap())),
72 }
73 }
74}
75
76impl Message {
77 pub fn read(r: &mut impl io::BufRead) -> Result<Option<Message>, ParseError> {
78 let mut buf = String::new();
79 loop {
81 if r.read_line(&mut buf)? == 0 {
83 return Ok(None);
84 }
85 if buf.ends_with("\r\n\r\n") {
86 break;
87 }
88 }
89 let mut headers = [EMPTY_HEADER; 2];
90 if let httparse::Status::Complete((size, _)) = parse_headers(buf.as_bytes(), &mut headers)? {
91 if size != buf.len() {
92 Err(ParseError::HeaderDecodeMismatch(size, buf.len()))?
93 }
94 }
95 let mut content_length = 0;
96 for header in &headers {
97 if header.name.eq_ignore_ascii_case("content-length") {
98 content_length = std::str::from_utf8(header.value)?.parse::<usize>()?;
99 } else if header.name.eq_ignore_ascii_case("content-type") {
100 } else if header != &EMPTY_HEADER {
102 Err(ParseError::InvalidHeader(header.name.to_owned()))?
103 }
104 }
105 if content_length == 0 {
106 Err(ParseError::NoLength)?
107 }
108 buf.clear();
109 let mut buf = buf.into_bytes();
110 buf.resize(content_length, 0);
111 r.read_exact(&mut buf)?;
112 let message: Message = serde_json::from_slice(buf.as_slice())?;
113 Ok(Some(message))
114 }
115
116 pub fn write(self, w: &mut impl io::Write) -> io::Result<()> {
117 #[derive(Serialize)]
118 struct JSONRPCMessage {
119 jsonrpc: &'static str,
120 #[serde(flatten)]
121 message: Message,
122 }
123 let msg = to_string(&JSONRPCMessage { jsonrpc: "2.0", message: self })?;
124 write!(w, "Content-Length: {}\r\n\r\n", msg.len())?;
125 w.write_all(msg.as_bytes())?;
126 w.flush()?;
127 Ok(())
128 }
129}
130
131#[derive(Debug)]
132pub enum ParseError {
133 NoLength,
134 CouldNotDecodeHeader,
135 HeaderDecodeMismatch(usize, usize),
136 InvalidHeader(String),
137 Encode(io::Error),
138 Utf8(std::str::Utf8Error),
139 InvalidContentLength(std::num::ParseIntError),
140 Headers(httparse::Error),
141 Body(serde_json::Error),
142}
143
144impl From<ParseError> for io::Error {
145 fn from(error: ParseError) -> Self {
146 match error {
147 ParseError::NoLength => io::Error::new(io::ErrorKind::InvalidData, "could not read content-length header"),
148 ParseError::CouldNotDecodeHeader => io::Error::new(io::ErrorKind::InvalidData, "could not decode headers"),
149 ParseError::HeaderDecodeMismatch(expected, actual) => io::Error::new(
150 io::ErrorKind::InvalidData,
151 format!("failed to fully parse headers, expected {expected} but parsing ended at {actual} bytes"),
152 ),
153 ParseError::InvalidHeader(string) => {
154 io::Error::new(io::ErrorKind::InvalidData, format!("saw invalid header {string}"))
155 }
156 ParseError::Encode(e) => e,
157 ParseError::Utf8(e) => io::Error::new(io::ErrorKind::InvalidData, format!("Utf8 decode error: {e}")),
158 ParseError::InvalidContentLength(e) => {
159 io::Error::new(io::ErrorKind::InvalidData, format!("invalid content-length: {e}"))
160 }
161 ParseError::Headers(e) => io::Error::new(io::ErrorKind::InvalidData, format!("invalid headers: {e}")),
162 ParseError::Body(e) => io::Error::new(io::ErrorKind::InvalidData, format!("invalid body: {e}")),
163 }
164 }
165}
166
167impl From<io::Error> for ParseError {
168 fn from(error: io::Error) -> Self {
169 ParseError::Encode(error)
170 }
171}
172
173impl From<httparse::Error> for ParseError {
174 fn from(error: httparse::Error) -> Self {
175 ParseError::Headers(error)
176 }
177}
178
179impl From<serde_json::Error> for ParseError {
180 fn from(error: serde_json::Error) -> Self {
181 ParseError::Body(error)
182 }
183}
184
185impl From<std::num::ParseIntError> for ParseError {
186 fn from(error: std::num::ParseIntError) -> Self {
187 ParseError::InvalidContentLength(error)
188 }
189}
190
191impl From<std::str::Utf8Error> for ParseError {
192 fn from(error: std::str::Utf8Error) -> Self {
193 ParseError::Utf8(error)
194 }
195}
196
197impl From<Request> for Message {
198 fn from(request: Request) -> Message {
199 Message::Request(request)
200 }
201}
202
203impl From<Response> for Message {
204 fn from(response: Response) -> Message {
205 Message::Response(response)
206 }
207}
208
209impl From<Notification> for Message {
210 fn from(notification: Notification) -> Message {
211 Message::Notification(notification)
212 }
213}
214
215pub enum MessageError {
216 MethodNotFound,
217 MethodMistmatch(String, String),
218 JsonError(serde_json::Error),
219}
220
221#[cfg(test)]
222mod tests {
223 use std::str::from_utf8;
224
225 use super::*;
226 use lsp_types::request::{Initialize, Request as RequestTrait};
227 use serde_json::{from_str, json};
228
229 pub fn into_http_bytes(str: &str) -> String {
230 format!("Content-Length: {}\r\n\r\n{}", str.len(), str)
231 }
232
233 #[test]
234 fn test_message_deserialize() {
235 assert_eq!(
236 from_str::<Message>(r#"{"jsonrpc": "2.0","method": "initialize", "params": null, "id": 1}"#).unwrap(),
237 Message::Request(Request { id: 1.into(), method: Initialize::METHOD.into(), params: json!(null) })
238 );
239 assert_eq!(
240 from_str::<Message>(r#"{"jsonrpc": "2.0","method": "initialize", "params": [1,2], "id": "a"}"#).unwrap(),
241 Message::Request(Request { id: "a".into(), method: Initialize::METHOD.into(), params: json!([1, 2]) })
242 );
243 assert_eq!(
244 from_str::<Message>(r#"{"jsonrpc": "2.0","result": "foo","id":8}"#).unwrap(),
245 Message::Response(Response::Ok(8.into(), json!("foo")))
246 );
247 assert_eq!(
248 from_str::<Message>(r#"{"jsonrpc": "2.0","method": "exit"}"#).unwrap(),
249 Message::Notification(Notification { method: "exit".into(), params: json!(null) })
250 );
251 }
252
253 #[test]
254 fn test_message_read_from_bufreader() {
255 let r = into_http_bytes(r#"{"jsonrpc": "2.0","method": "initialize", "params": null, "id": 1}"#);
256 assert_eq!(
257 Message::read(&mut r.as_bytes()).unwrap(),
258 Some(Message::Request(Request { id: 1.into(), method: "initialize".into(), params: json!(null) }))
259 );
260 }
261
262 #[test]
263 fn test_message_write_to_bufreader() {
264 let mut bytes = vec![];
265 Message::Request(Request { id: 1.into(), method: "initialize".into(), params: json!(null) })
266 .write(&mut bytes)
267 .unwrap();
268 assert_eq!(from_utf8(&bytes).unwrap(), into_http_bytes(r#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#));
269 }
270}