1 | //! Server-side traits. |
2 | |
3 | use std::cell::Cell; |
4 | use std::marker::PhantomData; |
5 | |
6 | use super::*; |
7 | |
8 | macro_rules! define_server_handles { |
9 | ( |
10 | 'owned: $($oty:ident,)* |
11 | 'interned: $($ity:ident,)* |
12 | ) => { |
13 | #[allow(non_snake_case)] |
14 | pub(super) struct HandleStore<S: Types> { |
15 | $($oty: handle::OwnedStore<S::$oty>,)* |
16 | $($ity: handle::InternedStore<S::$ity>,)* |
17 | } |
18 | |
19 | impl<S: Types> HandleStore<S> { |
20 | fn new(handle_counters: &'static client::HandleCounters) -> Self { |
21 | HandleStore { |
22 | $($oty: handle::OwnedStore::new(&handle_counters.$oty),)* |
23 | $($ity: handle::InternedStore::new(&handle_counters.$ity),)* |
24 | } |
25 | } |
26 | } |
27 | |
28 | $( |
29 | impl<S: Types> Encode<HandleStore<MarkedTypes<S>>> for Marked<S::$oty, client::$oty> { |
30 | fn encode(self, w: &mut Writer, s: &mut HandleStore<MarkedTypes<S>>) { |
31 | s.$oty.alloc(self).encode(w, s); |
32 | } |
33 | } |
34 | |
35 | impl<S: Types> DecodeMut<'_, '_, HandleStore<MarkedTypes<S>>> |
36 | for Marked<S::$oty, client::$oty> |
37 | { |
38 | fn decode(r: &mut Reader<'_>, s: &mut HandleStore<MarkedTypes<S>>) -> Self { |
39 | s.$oty.take(handle::Handle::decode(r, &mut ())) |
40 | } |
41 | } |
42 | |
43 | impl<'s, S: Types> Decode<'_, 's, HandleStore<MarkedTypes<S>>> |
44 | for &'s Marked<S::$oty, client::$oty> |
45 | { |
46 | fn decode(r: &mut Reader<'_>, s: &'s HandleStore<MarkedTypes<S>>) -> Self { |
47 | &s.$oty[handle::Handle::decode(r, &mut ())] |
48 | } |
49 | } |
50 | |
51 | impl<'s, S: Types> DecodeMut<'_, 's, HandleStore<MarkedTypes<S>>> |
52 | for &'s mut Marked<S::$oty, client::$oty> |
53 | { |
54 | fn decode( |
55 | r: &mut Reader<'_>, |
56 | s: &'s mut HandleStore<MarkedTypes<S>> |
57 | ) -> Self { |
58 | &mut s.$oty[handle::Handle::decode(r, &mut ())] |
59 | } |
60 | } |
61 | )* |
62 | |
63 | $( |
64 | impl<S: Types> Encode<HandleStore<MarkedTypes<S>>> for Marked<S::$ity, client::$ity> { |
65 | fn encode(self, w: &mut Writer, s: &mut HandleStore<MarkedTypes<S>>) { |
66 | s.$ity.alloc(self).encode(w, s); |
67 | } |
68 | } |
69 | |
70 | impl<S: Types> DecodeMut<'_, '_, HandleStore<MarkedTypes<S>>> |
71 | for Marked<S::$ity, client::$ity> |
72 | { |
73 | fn decode(r: &mut Reader<'_>, s: &mut HandleStore<MarkedTypes<S>>) -> Self { |
74 | s.$ity.copy(handle::Handle::decode(r, &mut ())) |
75 | } |
76 | } |
77 | )* |
78 | } |
79 | } |
80 | with_api_handle_types!(define_server_handles); |
81 | |
82 | pub trait Types { |
83 | type FreeFunctions: 'static; |
84 | type TokenStream: 'static + Clone; |
85 | type Span: 'static + Copy + Eq + Hash; |
86 | type Symbol: 'static; |
87 | } |
88 | |
89 | /// Declare an associated fn of one of the traits below, adding necessary |
90 | /// default bodies. |
91 | macro_rules! associated_fn { |
92 | (fn drop(&mut self, $arg:ident: $arg_ty:ty)) => |
93 | (fn drop(&mut self, $arg: $arg_ty) { mem::drop($arg) }); |
94 | |
95 | (fn clone(&mut self, $arg:ident: $arg_ty:ty) -> $ret_ty:ty) => |
96 | (fn clone(&mut self, $arg: $arg_ty) -> $ret_ty { $arg.clone() }); |
97 | |
98 | ($($item:tt)*) => ($($item)*;) |
99 | } |
100 | |
101 | macro_rules! declare_server_traits { |
102 | ($($name:ident { |
103 | $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)* |
104 | }),* $(,)?) => { |
105 | $(pub trait $name: Types { |
106 | $(associated_fn!(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)?);)* |
107 | })* |
108 | |
109 | pub trait Server: Types $(+ $name)* { |
110 | fn globals(&mut self) -> ExpnGlobals<Self::Span>; |
111 | |
112 | /// Intern a symbol received from RPC |
113 | fn intern_symbol(ident: &str) -> Self::Symbol; |
114 | |
115 | /// Recover the string value of a symbol, and invoke a callback with it. |
116 | fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str)); |
117 | } |
118 | } |
119 | } |
120 | with_api!(Self, self_, declare_server_traits); |
121 | |
122 | pub(super) struct MarkedTypes<S: Types>(S); |
123 | |
124 | impl<S: Server> Server for MarkedTypes<S> { |
125 | fn globals(&mut self) -> ExpnGlobals<Self::Span> { |
126 | <_>::mark(unmarked:Server::globals(&mut self.0)) |
127 | } |
128 | fn intern_symbol(ident: &str) -> Self::Symbol { |
129 | <_>::mark(S::intern_symbol(ident)) |
130 | } |
131 | fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str)) { |
132 | S::with_symbol_string(symbol.unmark(), f) |
133 | } |
134 | } |
135 | |
136 | macro_rules! define_mark_types_impls { |
137 | ($($name:ident { |
138 | $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)* |
139 | }),* $(,)?) => { |
140 | impl<S: Types> Types for MarkedTypes<S> { |
141 | $(type $name = Marked<S::$name, client::$name>;)* |
142 | } |
143 | |
144 | $(impl<S: $name> $name for MarkedTypes<S> { |
145 | $(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)? { |
146 | <_>::mark($name::$method(&mut self.0, $($arg.unmark()),*)) |
147 | })* |
148 | })* |
149 | } |
150 | } |
151 | with_api!(Self, self_, define_mark_types_impls); |
152 | |
153 | struct Dispatcher<S: Types> { |
154 | handle_store: HandleStore<S>, |
155 | server: S, |
156 | } |
157 | |
158 | macro_rules! define_dispatcher_impl { |
159 | ($($name:ident { |
160 | $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)* |
161 | }),* $(,)?) => { |
162 | // FIXME(eddyb) `pub` only for `ExecutionStrategy` below. |
163 | pub trait DispatcherTrait { |
164 | // HACK(eddyb) these are here to allow `Self::$name` to work below. |
165 | $(type $name;)* |
166 | |
167 | fn dispatch(&mut self, buf: Buffer) -> Buffer; |
168 | } |
169 | |
170 | impl<S: Server> DispatcherTrait for Dispatcher<MarkedTypes<S>> { |
171 | $(type $name = <MarkedTypes<S> as Types>::$name;)* |
172 | |
173 | fn dispatch(&mut self, mut buf: Buffer) -> Buffer { |
174 | let Dispatcher { handle_store, server } = self; |
175 | |
176 | let mut reader = &buf[..]; |
177 | match api_tags::Method::decode(&mut reader, &mut ()) { |
178 | $(api_tags::Method::$name(m) => match m { |
179 | $(api_tags::$name::$method => { |
180 | let mut call_method = || { |
181 | reverse_decode!(reader, handle_store; $($arg: $arg_ty),*); |
182 | $name::$method(server, $($arg),*) |
183 | }; |
184 | // HACK(eddyb) don't use `panic::catch_unwind` in a panic. |
185 | // If client and server happen to use the same `std`, |
186 | // `catch_unwind` asserts that the panic counter was 0, |
187 | // even when the closure passed to it didn't panic. |
188 | let r = if thread::panicking() { |
189 | Ok(call_method()) |
190 | } else { |
191 | panic::catch_unwind(panic::AssertUnwindSafe(call_method)) |
192 | .map_err(PanicMessage::from) |
193 | }; |
194 | |
195 | buf.clear(); |
196 | r.encode(&mut buf, handle_store); |
197 | })* |
198 | }),* |
199 | } |
200 | buf |
201 | } |
202 | } |
203 | } |
204 | } |
205 | with_api!(Self, self_, define_dispatcher_impl); |
206 | |
207 | pub trait ExecutionStrategy { |
208 | fn run_bridge_and_client( |
209 | &self, |
210 | dispatcher: &mut impl DispatcherTrait, |
211 | input: Buffer, |
212 | run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer, |
213 | force_show_panics: bool, |
214 | ) -> Buffer; |
215 | } |
216 | |
217 | thread_local! { |
218 | /// While running a proc-macro with the same-thread executor, this flag will |
219 | /// be set, forcing nested proc-macro invocations (e.g. due to |
220 | /// `TokenStream::expand_expr`) to be run using a cross-thread executor. |
221 | /// |
222 | /// This is required as the thread-local state in the proc_macro client does |
223 | /// not handle being re-entered, and will invalidate all `Symbol`s when |
224 | /// entering a nested macro. |
225 | static ALREADY_RUNNING_SAME_THREAD: Cell<bool> = const { Cell::new(false) }; |
226 | } |
227 | |
228 | /// Keep `ALREADY_RUNNING_SAME_THREAD` (see also its documentation) |
229 | /// set to `true`, preventing same-thread reentrance. |
230 | struct RunningSameThreadGuard(()); |
231 | |
232 | impl RunningSameThreadGuard { |
233 | fn new() -> Self { |
234 | let already_running: bool = ALREADY_RUNNING_SAME_THREAD.replace(true); |
235 | assert!( |
236 | !already_running, |
237 | "same-thread nesting ( \"reentrance \") of proc macro executions is not supported" |
238 | ); |
239 | RunningSameThreadGuard(()) |
240 | } |
241 | } |
242 | |
243 | impl Drop for RunningSameThreadGuard { |
244 | fn drop(&mut self) { |
245 | ALREADY_RUNNING_SAME_THREAD.set(false); |
246 | } |
247 | } |
248 | |
249 | pub struct MaybeCrossThread<P> { |
250 | cross_thread: bool, |
251 | marker: PhantomData<P>, |
252 | } |
253 | |
254 | impl<P> MaybeCrossThread<P> { |
255 | pub const fn new(cross_thread: bool) -> Self { |
256 | MaybeCrossThread { cross_thread, marker: PhantomData } |
257 | } |
258 | } |
259 | |
260 | impl<P> ExecutionStrategy for MaybeCrossThread<P> |
261 | where |
262 | P: MessagePipe<Buffer> + Send + 'static, |
263 | { |
264 | fn run_bridge_and_client( |
265 | &self, |
266 | dispatcher: &mut impl DispatcherTrait, |
267 | input: Buffer, |
268 | run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer, |
269 | force_show_panics: bool, |
270 | ) -> Buffer { |
271 | if self.cross_thread || ALREADY_RUNNING_SAME_THREAD.get() { |
272 | <CrossThread<P>>::new().run_bridge_and_client( |
273 | dispatcher, |
274 | input, |
275 | run_client, |
276 | force_show_panics, |
277 | ) |
278 | } else { |
279 | SameThread.run_bridge_and_client(dispatcher, input, run_client, force_show_panics) |
280 | } |
281 | } |
282 | } |
283 | |
284 | pub struct SameThread; |
285 | |
286 | impl ExecutionStrategy for SameThread { |
287 | fn run_bridge_and_client( |
288 | &self, |
289 | dispatcher: &mut impl DispatcherTrait, |
290 | input: Buffer, |
291 | run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer, |
292 | force_show_panics: bool, |
293 | ) -> Buffer { |
294 | let _guard: RunningSameThreadGuard = RunningSameThreadGuard::new(); |
295 | |
296 | let mut dispatch: impl FnMut(Buffer) -> Buffer = |buf: Buffer| dispatcher.dispatch(buf); |
297 | |
298 | run_client(BridgeConfig { |
299 | input, |
300 | dispatch: (&mut dispatch).into(), |
301 | force_show_panics, |
302 | _marker: marker::PhantomData, |
303 | }) |
304 | } |
305 | } |
306 | |
307 | pub struct CrossThread<P>(PhantomData<P>); |
308 | |
309 | impl<P> CrossThread<P> { |
310 | pub const fn new() -> Self { |
311 | CrossThread(PhantomData) |
312 | } |
313 | } |
314 | |
315 | impl<P> ExecutionStrategy for CrossThread<P> |
316 | where |
317 | P: MessagePipe<Buffer> + Send + 'static, |
318 | { |
319 | fn run_bridge_and_client( |
320 | &self, |
321 | dispatcher: &mut impl DispatcherTrait, |
322 | input: Buffer, |
323 | run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer, |
324 | force_show_panics: bool, |
325 | ) -> Buffer { |
326 | let (mut server, mut client) = P::new(); |
327 | |
328 | let join_handle = thread::spawn(move || { |
329 | let mut dispatch = |b: Buffer| -> Buffer { |
330 | client.send(b); |
331 | client.recv().expect("server died while client waiting for reply" ) |
332 | }; |
333 | |
334 | run_client(BridgeConfig { |
335 | input, |
336 | dispatch: (&mut dispatch).into(), |
337 | force_show_panics, |
338 | _marker: marker::PhantomData, |
339 | }) |
340 | }); |
341 | |
342 | while let Some(b) = server.recv() { |
343 | server.send(dispatcher.dispatch(b)); |
344 | } |
345 | |
346 | join_handle.join().unwrap() |
347 | } |
348 | } |
349 | |
350 | /// A message pipe used for communicating between server and client threads. |
351 | pub trait MessagePipe<T>: Sized { |
352 | /// Creates a new pair of endpoints for the message pipe. |
353 | fn new() -> (Self, Self); |
354 | |
355 | /// Send a message to the other endpoint of this pipe. |
356 | fn send(&mut self, value: T); |
357 | |
358 | /// Receive a message from the other endpoint of this pipe. |
359 | /// |
360 | /// Returns `None` if the other end of the pipe has been destroyed, and no |
361 | /// message was received. |
362 | fn recv(&mut self) -> Option<T>; |
363 | } |
364 | |
365 | fn run_server< |
366 | S: Server, |
367 | I: Encode<HandleStore<MarkedTypes<S>>>, |
368 | O: for<'a, 's> dynDecodeMut<'a, 's, HandleStore<MarkedTypes<S>>>, |
369 | >( |
370 | strategy: &impl ExecutionStrategy, |
371 | handle_counters: &'static client::HandleCounters, |
372 | server: S, |
373 | input: I, |
374 | run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer, |
375 | force_show_panics: bool, |
376 | ) -> Result<O, PanicMessage> { |
377 | let mut dispatcher: Dispatcher> = |
378 | Dispatcher { handle_store: HandleStore::new(handle_counters), server: MarkedTypes(server) }; |
379 | |
380 | let globals: ExpnGlobals::Span, …>> = dispatcher.server.globals(); |
381 | |
382 | let mut buf: Buffer = Buffer::new(); |
383 | (globals, input).encode(&mut buf, &mut dispatcher.handle_store); |
384 | |
385 | buf = strategy.run_bridge_and_client(&mut dispatcher, input:buf, run_client, force_show_panics); |
386 | |
387 | Result::decode(&mut &buf[..], &mut dispatcher.handle_store) |
388 | } |
389 | |
390 | impl client::Client<crate::TokenStream, crate::TokenStream> { |
391 | pub fn run<S>( |
392 | &self, |
393 | strategy: &impl ExecutionStrategy, |
394 | server: S, |
395 | input: S::TokenStream, |
396 | force_show_panics: bool, |
397 | ) -> Result<S::TokenStream, PanicMessage> |
398 | where |
399 | S: Server, |
400 | S::TokenStream: Default, |
401 | { |
402 | let client::Client { handle_counters: &'static HandleCounters, run: fn(BridgeConfig<'_>) -> Buffer, _marker: PhantomData …> } = *self; |
403 | run_server( |
404 | strategy, |
405 | handle_counters, |
406 | server, |
407 | <MarkedTypes<S> as Types>::TokenStream::mark(input), |
408 | run, |
409 | force_show_panics, |
410 | ) |
411 | .map(|s: Option>| <Option<<MarkedTypes<S> as Types>::TokenStream>>::unmark(self:s).unwrap_or_default()) |
412 | } |
413 | } |
414 | |
415 | impl client::Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream> { |
416 | pub fn run<S>( |
417 | &self, |
418 | strategy: &impl ExecutionStrategy, |
419 | server: S, |
420 | input: S::TokenStream, |
421 | input2: S::TokenStream, |
422 | force_show_panics: bool, |
423 | ) -> Result<S::TokenStream, PanicMessage> |
424 | where |
425 | S: Server, |
426 | S::TokenStream: Default, |
427 | { |
428 | let client::Client { handle_counters, run, _marker } = *self; |
429 | run_server( |
430 | strategy, |
431 | handle_counters, |
432 | server, |
433 | ( |
434 | <MarkedTypes<S> as Types>::TokenStream::mark(input), |
435 | <MarkedTypes<S> as Types>::TokenStream::mark(input2), |
436 | ), |
437 | run, |
438 | force_show_panics, |
439 | ) |
440 | .map(|s| <Option<<MarkedTypes<S> as Types>::TokenStream>>::unmark(s).unwrap_or_default()) |
441 | } |
442 | } |
443 | |