1use crossbeam_channel::{Receiver, Sender, bounded};
10use lsp_types::SetTraceParams;
11use serde_json::from_value;
12use std::{
13 io,
14 sync::{Arc, RwLock},
15 thread::{Builder, JoinHandle},
16};
17use tracing::{level_filters::LevelFilter, trace, warn};
18
19use crate::{Notification, TracingLayer};
20
21use super::Message;
22
23mod handler;
24
25pub use handler::Handler;
26
27pub struct ThreadConnection {
28 pub sender: JoinHandle<io::Result<()>>,
29 pub receiver: JoinHandle<io::Result<()>>,
30}
31
32impl ThreadConnection {
33 pub fn is_finished(&self) -> bool {
34 self.sender.is_finished() || self.receiver.is_finished()
35 }
36}
37
38pub struct Server {
40 write_sender: Sender<Message>,
41 write_receiver: Receiver<Message>,
42 read_sender: Sender<Message>,
43 #[allow(dead_code)]
44 request_handler: JoinHandle<Result<(), io::Error>>,
45 #[allow(dead_code)]
46 read_receiver: Receiver<Message>,
47 trace_level: Arc<RwLock<LevelFilter>>,
48}
49
50impl Server {
51 pub fn new<T: Handler>(handler: T) -> Self {
52 let (write_sender, write_receiver) = bounded::<Message>(0);
53 let (read_sender, read_receiver) = bounded::<Message>(0);
54
55 let handler_receiver = read_receiver.clone();
56 let handler_sender = write_sender.clone();
57 let trace_level = Arc::new(RwLock::new(LevelFilter::OFF));
58 let level_set = trace_level.clone();
59 let request_handler = Builder::new()
60 .name("LspMessageHandler".into())
61 .spawn(move || {
62 while let Ok(message) = handler_receiver.recv() {
63 trace!("LspMessageHandler -> {:#?}", &message);
64 if let Message::Notification(Notification { method, params }) = &message {
65 if method == "exit" {
66 break;
67 }
68 if method == "$/setTrace" {
69 let level = from_value::<SetTraceParams>(params.clone())
70 .map(|p| match p.value {
71 lsp_types::TraceValue::Off => LevelFilter::OFF,
72 lsp_types::TraceValue::Messages => LevelFilter::WARN,
73 lsp_types::TraceValue::Verbose => LevelFilter::TRACE,
74 })
75 .unwrap_or(LevelFilter::OFF);
76 trace!("Changing level to {:?}", level);
77 let mut level_set = level_set.write().unwrap();
78 *level_set = level;
79 }
80 }
81 let response = handler.handle(message);
82 if let Some(response) = response {
83 if let Err(e) = handler_sender.send(response) {
84 warn!("Handler failed to send response {:?}", &e);
85 return Err(io::Error::other(e));
86 }
87 }
88 }
89 warn!("LspMessageHandler closing, channel closed");
90 Ok(())
91 })
92 .expect("Failed to create Reader");
93 Server { write_sender, write_receiver, read_sender, read_receiver, request_handler, trace_level }
94 }
95
96 pub fn tracer(&self) -> TracingLayer {
97 TracingLayer::new(self.trace_level.clone(), self.write_sender.clone())
98 }
99
100 pub fn listen_stdio(&self) -> Result<ThreadConnection, io::Error> {
101 let write_receiver = self.write_receiver.clone();
102 let writer = Builder::new().name("LspWriter".into()).spawn(move || {
103 let mut stdout = io::stdout().lock();
104 while let Ok(message) = write_receiver.recv() {
105 trace!("{:#?}", message);
106 message.write(&mut stdout)?;
107 }
108 Ok(())
109 })?;
110 let read_sender = self.read_sender.clone();
111 let reader = Builder::new().name("LspReader".into()).spawn(move || {
112 let mut stdin = io::stdin().lock();
113 while let Some(message) = Message::read(&mut stdin)? {
114 if let Err(e) = read_sender.send(message) {
115 return Err(io::Error::other(e));
116 }
117 }
118 Ok(())
119 })?;
120 Ok(ThreadConnection { sender: reader, receiver: writer })
121 }
122
123 #[cfg(test)]
124 pub fn raw_channels(&self) -> (Sender<Message>, Receiver<Message>) {
125 (self.read_sender.clone(), self.write_receiver.clone())
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use std::sync::atomic::{AtomicBool, Ordering};
132
133 use crate::{ErrorCode, Notification, Request, Response};
134
135 use super::*;
136 use lsp_types::{
137 InitializeParams, InitializeResult,
138 request::{GotoDeclaration, Initialize, Request as RequestTrait},
139 };
140 use serde_json::{Value, json, to_value};
141 use tracing::level_filters::LevelFilter;
142 use tracing_subscriber::{Layer, fmt, layer::SubscriberExt, registry, util::SubscriberInitExt};
143
144 #[test]
145 fn smoke_test() {
146 let stderr_log = fmt::layer().with_writer(io::stderr).with_filter(LevelFilter::TRACE);
147 struct TestHandler {
148 initialized: AtomicBool,
149 }
150 impl Handler for TestHandler {
151 fn initialized(&self) -> bool {
152 self.initialized.load(Ordering::SeqCst)
153 }
154 fn initialize(&self, _req: InitializeParams) -> Result<InitializeResult, ErrorCode> {
155 self.initialized.swap(true, Ordering::SeqCst);
156 Ok(InitializeResult { ..Default::default() })
157 }
158 }
159
160 let server = Server::new(TestHandler { initialized: AtomicBool::new(false) });
161 registry().with(stderr_log).with(server.tracer()).init();
162 let (sender, receiver) = server.raw_channels();
163 sender
164 .send(Message::Request(Request {
165 id: 1.into(),
166 method: Initialize::METHOD.into(),
167 params: to_value(InitializeParams { ..Default::default() }).unwrap(),
168 }))
169 .unwrap();
170 assert_eq!(receiver.recv(), Ok(Message::Response(Response::Ok(1.into(), json!({"capabilities": {}})))));
171 sender
172 .send(Message::Request(Request {
173 id: 1.into(),
174 method: GotoDeclaration::METHOD.into(),
175 params: json!({
176 "textDocument": {
177 "uri": "foo/bar",
178 },
179 "position": {
180 "line": 1,
181 "character": 1
182 }
183 }),
184 }))
185 .unwrap();
186 assert_eq!(
187 receiver.recv(),
188 Ok(Message::Response(Response::Err(1.into(), ErrorCode::MethodNotFound, "".into(), Value::Null)))
189 );
190 sender.send(Message::Notification(Notification { method: "exit".into(), params: Value::Null })).unwrap();
191 }
192}