csskit_lsp/jsonrpc/
message.rs

1use httparse::{EMPTY_HEADER, parse_headers};
2use serde::{Deserialize, Serialize, de::DeserializeOwned};
3use serde_json::{Error, Value, from_value, to_string, to_value};
4use std::io;
5
6use crate::{Notification, Request, Response};
7
8use super::{ErrorCode, Id};
9
10/// JSON RPC Message
11/// This represents a single message coming in or going out, that is
12/// compliant with the [JSON-RPC 2.0 spec](https://www.jsonrpc.org/specification).
13/// It wraps the [`Request`], [`Response`] and [`Notification`] structs.
14#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
15#[serde(untagged)]
16pub enum Message {
17	/// Wraps the [`Request`] object.
18	Request(Request),
19	/// Wraps the [`Response`] object.
20	Response(Response),
21	/// Wraps the [`Notification`] object.
22	Notification(Notification),
23}
24
25impl Message {
26	#[doc(hidden)]
27	#[inline]
28	pub fn is_exit_notification(&self) -> bool {
29		if let Message::Notification(notification) = self {
30			matches!(notification.method.as_str(), "exit")
31		} else {
32			false
33		}
34	}
35
36	#[doc(hidden)]
37	#[inline]
38	pub fn is_initialize_request(&self) -> bool {
39		if let Message::Request(request) = self { matches!(request.method.as_str(), "initialize") } else { false }
40	}
41
42	pub fn method(&self) -> Option<&str> {
43		match self {
44			Message::Request(Request { method, .. }) => Some(method),
45			Message::Notification(Notification { method, .. }) => Some(method),
46			_ => None,
47		}
48	}
49
50	pub fn id(&self) -> Option<Id> {
51		match self {
52			Message::Request(Request { id, .. }) => Some(id.clone()),
53			Message::Notification(_) => None,
54			Message::Response(Response::Ok(id, _)) => Some(id.clone()),
55			Message::Response(Response::Err(id, _, _, _)) => Some(id.clone()),
56		}
57	}
58
59	pub fn from_value<T: DeserializeOwned>(&self) -> Result<T, Error> {
60		match self {
61			Message::Request(Request { params, .. }) => from_value(params.clone()),
62			Message::Notification(Notification { params, .. }) => from_value(params.clone()),
63			Message::Response(Response::Ok(_, params)) => from_value(params.clone()),
64			Message::Response(Response::Err(_, _, _, params)) => from_value(params.clone()),
65		}
66	}
67
68	pub fn from_result(id: Id, result: Result<impl Serialize, ErrorCode>) -> Message {
69		match result {
70			Err(code) => Message::Response(Response::Err(id, code, "".into(), Value::Null)),
71			Ok(value) => Message::Response(Response::Ok(id, to_value(value).unwrap())),
72		}
73	}
74}
75
76impl Message {
77	pub fn read(r: &mut impl io::BufRead) -> Result<Option<Message>, ParseError> {
78		let mut buf = String::new();
79		// Consume all headers - either end of stream or "\r\n\r\n"
80		loop {
81			// No more content, therefore no message
82			if r.read_line(&mut buf)? == 0 {
83				return Ok(None);
84			}
85			if buf.ends_with("\r\n\r\n") {
86				break;
87			}
88		}
89		let mut headers = [EMPTY_HEADER; 2];
90		if let httparse::Status::Complete((size, _)) = parse_headers(buf.as_bytes(), &mut headers)? {
91			if size != buf.len() {
92				Err(ParseError::HeaderDecodeMismatch(size, buf.len()))?
93			}
94		}
95		let mut content_length = 0;
96		for header in &headers {
97			if header.name.eq_ignore_ascii_case("content-length") {
98				content_length = std::str::from_utf8(header.value)?.parse::<usize>()?;
99			} else if header.name.eq_ignore_ascii_case("content-type") {
100				// ¯\_(ツ)_/¯
101			} else if header != &EMPTY_HEADER {
102				Err(ParseError::InvalidHeader(header.name.to_owned()))?
103			}
104		}
105		if content_length == 0 {
106			Err(ParseError::NoLength)?
107		}
108		buf.clear();
109		let mut buf = buf.into_bytes();
110		buf.resize(content_length, 0);
111		r.read_exact(&mut buf)?;
112		let message: Message = serde_json::from_slice(buf.as_slice())?;
113		Ok(Some(message))
114	}
115
116	pub fn write(self, w: &mut impl io::Write) -> io::Result<()> {
117		#[derive(Serialize)]
118		struct JSONRPCMessage {
119			jsonrpc: &'static str,
120			#[serde(flatten)]
121			message: Message,
122		}
123		let msg = to_string(&JSONRPCMessage { jsonrpc: "2.0", message: self })?;
124		write!(w, "Content-Length: {}\r\n\r\n", msg.len())?;
125		w.write_all(msg.as_bytes())?;
126		w.flush()?;
127		Ok(())
128	}
129}
130
131#[derive(Debug)]
132pub enum ParseError {
133	NoLength,
134	CouldNotDecodeHeader,
135	HeaderDecodeMismatch(usize, usize),
136	InvalidHeader(String),
137	Encode(io::Error),
138	Utf8(std::str::Utf8Error),
139	InvalidContentLength(std::num::ParseIntError),
140	Headers(httparse::Error),
141	Body(serde_json::Error),
142}
143
144impl From<ParseError> for io::Error {
145	fn from(error: ParseError) -> Self {
146		match error {
147			ParseError::NoLength => io::Error::new(io::ErrorKind::InvalidData, "could not read content-length header"),
148			ParseError::CouldNotDecodeHeader => io::Error::new(io::ErrorKind::InvalidData, "could not decode headers"),
149			ParseError::HeaderDecodeMismatch(expected, actual) => io::Error::new(
150				io::ErrorKind::InvalidData,
151				format!("failed to fully parse headers, expected {expected} but parsing ended at {actual} bytes"),
152			),
153			ParseError::InvalidHeader(string) => {
154				io::Error::new(io::ErrorKind::InvalidData, format!("saw invalid header {string}"))
155			}
156			ParseError::Encode(e) => e,
157			ParseError::Utf8(e) => io::Error::new(io::ErrorKind::InvalidData, format!("Utf8 decode error: {e}")),
158			ParseError::InvalidContentLength(e) => {
159				io::Error::new(io::ErrorKind::InvalidData, format!("invalid content-length: {e}"))
160			}
161			ParseError::Headers(e) => io::Error::new(io::ErrorKind::InvalidData, format!("invalid headers: {e}")),
162			ParseError::Body(e) => io::Error::new(io::ErrorKind::InvalidData, format!("invalid body: {e}")),
163		}
164	}
165}
166
167impl From<io::Error> for ParseError {
168	fn from(error: io::Error) -> Self {
169		ParseError::Encode(error)
170	}
171}
172
173impl From<httparse::Error> for ParseError {
174	fn from(error: httparse::Error) -> Self {
175		ParseError::Headers(error)
176	}
177}
178
179impl From<serde_json::Error> for ParseError {
180	fn from(error: serde_json::Error) -> Self {
181		ParseError::Body(error)
182	}
183}
184
185impl From<std::num::ParseIntError> for ParseError {
186	fn from(error: std::num::ParseIntError) -> Self {
187		ParseError::InvalidContentLength(error)
188	}
189}
190
191impl From<std::str::Utf8Error> for ParseError {
192	fn from(error: std::str::Utf8Error) -> Self {
193		ParseError::Utf8(error)
194	}
195}
196
197impl From<Request> for Message {
198	fn from(request: Request) -> Message {
199		Message::Request(request)
200	}
201}
202
203impl From<Response> for Message {
204	fn from(response: Response) -> Message {
205		Message::Response(response)
206	}
207}
208
209impl From<Notification> for Message {
210	fn from(notification: Notification) -> Message {
211		Message::Notification(notification)
212	}
213}
214
215pub enum MessageError {
216	MethodNotFound,
217	MethodMistmatch(String, String),
218	JsonError(serde_json::Error),
219}
220
221#[cfg(test)]
222mod tests {
223	use std::str::from_utf8;
224
225	use super::*;
226	use lsp_types::request::{Initialize, Request as RequestTrait};
227	use serde_json::{from_str, json};
228
229	pub fn into_http_bytes(str: &str) -> String {
230		format!("Content-Length: {}\r\n\r\n{}", str.len(), str)
231	}
232
233	#[test]
234	fn test_message_deserialize() {
235		assert_eq!(
236			from_str::<Message>(r#"{"jsonrpc": "2.0","method": "initialize", "params": null, "id": 1}"#).unwrap(),
237			Message::Request(Request { id: 1.into(), method: Initialize::METHOD.into(), params: json!(null) })
238		);
239		assert_eq!(
240			from_str::<Message>(r#"{"jsonrpc": "2.0","method": "initialize", "params": [1,2], "id": "a"}"#).unwrap(),
241			Message::Request(Request { id: "a".into(), method: Initialize::METHOD.into(), params: json!([1, 2]) })
242		);
243		assert_eq!(
244			from_str::<Message>(r#"{"jsonrpc": "2.0","result": "foo","id":8}"#).unwrap(),
245			Message::Response(Response::Ok(8.into(), json!("foo")))
246		);
247		assert_eq!(
248			from_str::<Message>(r#"{"jsonrpc": "2.0","method": "exit"}"#).unwrap(),
249			Message::Notification(Notification { method: "exit".into(), params: json!(null) })
250		);
251	}
252
253	#[test]
254	fn test_message_read_from_bufreader() {
255		let r = into_http_bytes(r#"{"jsonrpc": "2.0","method": "initialize", "params": null, "id": 1}"#);
256		assert_eq!(
257			Message::read(&mut r.as_bytes()).unwrap(),
258			Some(Message::Request(Request { id: 1.into(), method: "initialize".into(), params: json!(null) }))
259		);
260	}
261
262	#[test]
263	fn test_message_write_to_bufreader() {
264		let mut bytes = vec![];
265		Message::Request(Request { id: 1.into(), method: "initialize".into(), params: json!(null) })
266			.write(&mut bytes)
267			.unwrap();
268		assert_eq!(from_utf8(&bytes).unwrap(), into_http_bytes(r#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#));
269	}
270}