csskit_lsp/
server.rs

1//! A library for handling LSP servers
2//!
3//! This implements the core parts of the [LSP Specification](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/)
4//! which is a [JSON-RPC 2.0](https://www.jsonrpc.org/specification) based protocol.
5//!
6//! A [`Server`] can be instantiated via [`Server::listen_stdio`] (or [`Server::raw_channels`] can be used for testing)
7//!
8
9use crossbeam_channel::{Receiver, Sender, bounded};
10use lsp_types::SetTraceParams;
11use serde_json::from_value;
12use std::{
13	io,
14	sync::{Arc, RwLock},
15	thread::{Builder, JoinHandle},
16};
17use tracing::{level_filters::LevelFilter, trace, warn};
18
19use crate::{Notification, TracingLayer};
20
21use super::Message;
22
23mod handler;
24
25pub use handler::Handler;
26
27pub struct ThreadConnection {
28	pub sender: JoinHandle<io::Result<()>>,
29	pub receiver: JoinHandle<io::Result<()>>,
30}
31
32impl ThreadConnection {
33	pub fn is_finished(&self) -> bool {
34		self.sender.is_finished() || self.receiver.is_finished()
35	}
36}
37
38/// A set of Sender/Receiver objects for passing [`Message`s](Message) around.
39pub struct Server {
40	write_sender: Sender<Message>,
41	write_receiver: Receiver<Message>,
42	read_sender: Sender<Message>,
43	#[allow(dead_code)]
44	request_handler: JoinHandle<Result<(), io::Error>>,
45	#[allow(dead_code)]
46	read_receiver: Receiver<Message>,
47	trace_level: Arc<RwLock<LevelFilter>>,
48}
49
50impl Server {
51	pub fn new<T: Handler>(handler: T) -> Self {
52		let (write_sender, write_receiver) = bounded::<Message>(0);
53		let (read_sender, read_receiver) = bounded::<Message>(0);
54
55		let handler_receiver = read_receiver.clone();
56		let handler_sender = write_sender.clone();
57		let trace_level = Arc::new(RwLock::new(LevelFilter::OFF));
58		let level_set = trace_level.clone();
59		let request_handler = Builder::new()
60			.name("LspMessageHandler".into())
61			.spawn(move || {
62				while let Ok(message) = handler_receiver.recv() {
63					trace!("LspMessageHandler -> {:#?}", &message);
64					if let Message::Notification(Notification { method, params }) = &message {
65						if method == "exit" {
66							break;
67						}
68						if method == "$/setTrace" {
69							let level = from_value::<SetTraceParams>(params.clone())
70								.map(|p| match p.value {
71									lsp_types::TraceValue::Off => LevelFilter::OFF,
72									lsp_types::TraceValue::Messages => LevelFilter::WARN,
73									lsp_types::TraceValue::Verbose => LevelFilter::TRACE,
74								})
75								.unwrap_or(LevelFilter::OFF);
76							trace!("Changing level to {:?}", level);
77							let mut level_set = level_set.write().unwrap();
78							*level_set = level;
79						}
80					}
81					let response = handler.handle(message);
82					if let Some(response) = response {
83						if let Err(e) = handler_sender.send(response) {
84							warn!("Handler failed to send response {:?}", &e);
85							return Err(io::Error::other(e));
86						}
87					}
88				}
89				warn!("LspMessageHandler closing, channel closed");
90				Ok(())
91			})
92			.expect("Failed to create Reader");
93		Server { write_sender, write_receiver, read_sender, read_receiver, request_handler, trace_level }
94	}
95
96	pub fn tracer(&self) -> TracingLayer {
97		TracingLayer::new(self.trace_level.clone(), self.write_sender.clone())
98	}
99
100	pub fn listen_stdio(&self) -> Result<ThreadConnection, io::Error> {
101		let write_receiver = self.write_receiver.clone();
102		let writer = Builder::new().name("LspWriter".into()).spawn(move || {
103			let mut stdout = io::stdout().lock();
104			while let Ok(message) = write_receiver.recv() {
105				trace!("{:#?}", message);
106				message.write(&mut stdout)?;
107			}
108			Ok(())
109		})?;
110		let read_sender = self.read_sender.clone();
111		let reader = Builder::new().name("LspReader".into()).spawn(move || {
112			let mut stdin = io::stdin().lock();
113			while let Some(message) = Message::read(&mut stdin)? {
114				if let Err(e) = read_sender.send(message) {
115					return Err(io::Error::other(e));
116				}
117			}
118			Ok(())
119		})?;
120		Ok(ThreadConnection { sender: reader, receiver: writer })
121	}
122
123	#[cfg(test)]
124	pub fn raw_channels(&self) -> (Sender<Message>, Receiver<Message>) {
125		(self.read_sender.clone(), self.write_receiver.clone())
126	}
127}
128
129#[cfg(test)]
130mod tests {
131	use std::sync::atomic::{AtomicBool, Ordering};
132
133	use crate::{ErrorCode, Notification, Request, Response};
134
135	use super::*;
136	use lsp_types::{
137		InitializeParams, InitializeResult,
138		request::{GotoDeclaration, Initialize, Request as RequestTrait},
139	};
140	use serde_json::{Value, json, to_value};
141	use tracing::level_filters::LevelFilter;
142	use tracing_subscriber::{Layer, fmt, layer::SubscriberExt, registry, util::SubscriberInitExt};
143
144	#[test]
145	fn smoke_test() {
146		let stderr_log = fmt::layer().with_writer(io::stderr).with_filter(LevelFilter::TRACE);
147		struct TestHandler {
148			initialized: AtomicBool,
149		}
150		impl Handler for TestHandler {
151			fn initialized(&self) -> bool {
152				self.initialized.load(Ordering::SeqCst)
153			}
154			fn initialize(&self, _req: InitializeParams) -> Result<InitializeResult, ErrorCode> {
155				self.initialized.swap(true, Ordering::SeqCst);
156				Ok(InitializeResult { ..Default::default() })
157			}
158		}
159
160		let server = Server::new(TestHandler { initialized: AtomicBool::new(false) });
161		registry().with(stderr_log).with(server.tracer()).init();
162		let (sender, receiver) = server.raw_channels();
163		sender
164			.send(Message::Request(Request {
165				id: 1.into(),
166				method: Initialize::METHOD.into(),
167				params: to_value(InitializeParams { ..Default::default() }).unwrap(),
168			}))
169			.unwrap();
170		assert_eq!(receiver.recv(), Ok(Message::Response(Response::Ok(1.into(), json!({"capabilities": {}})))));
171		sender
172			.send(Message::Request(Request {
173				id: 1.into(),
174				method: GotoDeclaration::METHOD.into(),
175				params: json!({
176					"textDocument": {
177						"uri": "foo/bar",
178					},
179					"position": {
180						"line": 1,
181						"character": 1
182					}
183				}),
184			}))
185			.unwrap();
186		assert_eq!(
187			receiver.recv(),
188			Ok(Message::Response(Response::Err(1.into(), ErrorCode::MethodNotFound, "".into(), Value::Null)))
189		);
190		sender.send(Message::Notification(Notification { method: "exit".into(), params: Value::Null })).unwrap();
191	}
192}