1//! Server-side traits.
2
3use super::*;
4
5use std::cell::Cell;
6use std::marker::PhantomData;
7
8macro_rules! define_server_handles {
9 (
10 'owned: $($oty:ident,)*
11 'interned: $($ity:ident,)*
12 ) => {
13 #[allow(non_snake_case)]
14 pub(super) struct HandleStore<S: Types> {
15 $($oty: handle::OwnedStore<S::$oty>,)*
16 $($ity: handle::InternedStore<S::$ity>,)*
17 }
18
19 impl<S: Types> HandleStore<S> {
20 fn new(handle_counters: &'static client::HandleCounters) -> Self {
21 HandleStore {
22 $($oty: handle::OwnedStore::new(&handle_counters.$oty),)*
23 $($ity: handle::InternedStore::new(&handle_counters.$ity),)*
24 }
25 }
26 }
27
28 $(
29 impl<S: Types> Encode<HandleStore<MarkedTypes<S>>> for Marked<S::$oty, client::$oty> {
30 fn encode(self, w: &mut Writer, s: &mut HandleStore<MarkedTypes<S>>) {
31 s.$oty.alloc(self).encode(w, s);
32 }
33 }
34
35 impl<S: Types> DecodeMut<'_, '_, HandleStore<MarkedTypes<S>>>
36 for Marked<S::$oty, client::$oty>
37 {
38 fn decode(r: &mut Reader<'_>, s: &mut HandleStore<MarkedTypes<S>>) -> Self {
39 s.$oty.take(handle::Handle::decode(r, &mut ()))
40 }
41 }
42
43 impl<'s, S: Types> Decode<'_, 's, HandleStore<MarkedTypes<S>>>
44 for &'s Marked<S::$oty, client::$oty>
45 {
46 fn decode(r: &mut Reader<'_>, s: &'s HandleStore<MarkedTypes<S>>) -> Self {
47 &s.$oty[handle::Handle::decode(r, &mut ())]
48 }
49 }
50
51 impl<'s, S: Types> DecodeMut<'_, 's, HandleStore<MarkedTypes<S>>>
52 for &'s mut Marked<S::$oty, client::$oty>
53 {
54 fn decode(
55 r: &mut Reader<'_>,
56 s: &'s mut HandleStore<MarkedTypes<S>>
57 ) -> Self {
58 &mut s.$oty[handle::Handle::decode(r, &mut ())]
59 }
60 }
61 )*
62
63 $(
64 impl<S: Types> Encode<HandleStore<MarkedTypes<S>>> for Marked<S::$ity, client::$ity> {
65 fn encode(self, w: &mut Writer, s: &mut HandleStore<MarkedTypes<S>>) {
66 s.$ity.alloc(self).encode(w, s);
67 }
68 }
69
70 impl<S: Types> DecodeMut<'_, '_, HandleStore<MarkedTypes<S>>>
71 for Marked<S::$ity, client::$ity>
72 {
73 fn decode(r: &mut Reader<'_>, s: &mut HandleStore<MarkedTypes<S>>) -> Self {
74 s.$ity.copy(handle::Handle::decode(r, &mut ()))
75 }
76 }
77 )*
78 }
79}
80with_api_handle_types!(define_server_handles);
81
82pub trait Types {
83 type FreeFunctions: 'static;
84 type TokenStream: 'static + Clone;
85 type SourceFile: 'static + Clone;
86 type Span: 'static + Copy + Eq + Hash;
87 type Symbol: 'static;
88}
89
90/// Declare an associated fn of one of the traits below, adding necessary
91/// default bodies.
92macro_rules! associated_fn {
93 (fn drop(&mut self, $arg:ident: $arg_ty:ty)) =>
94 (fn drop(&mut self, $arg: $arg_ty) { mem::drop($arg) });
95
96 (fn clone(&mut self, $arg:ident: $arg_ty:ty) -> $ret_ty:ty) =>
97 (fn clone(&mut self, $arg: $arg_ty) -> $ret_ty { $arg.clone() });
98
99 ($($item:tt)*) => ($($item)*;)
100}
101
102macro_rules! declare_server_traits {
103 ($($name:ident {
104 $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
105 }),* $(,)?) => {
106 $(pub trait $name: Types {
107 $(associated_fn!(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)?);)*
108 })*
109
110 pub trait Server: Types $(+ $name)* {
111 fn globals(&mut self) -> ExpnGlobals<Self::Span>;
112
113 /// Intern a symbol received from RPC
114 fn intern_symbol(ident: &str) -> Self::Symbol;
115
116 /// Recover the string value of a symbol, and invoke a callback with it.
117 fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str));
118 }
119 }
120}
121with_api!(Self, self_, declare_server_traits);
122
123pub(super) struct MarkedTypes<S: Types>(S);
124
125impl<S: Server> Server for MarkedTypes<S> {
126 fn globals(&mut self) -> ExpnGlobals<Self::Span> {
127 <_>::mark(unmarked:Server::globals(&mut self.0))
128 }
129 fn intern_symbol(ident: &str) -> Self::Symbol {
130 <_>::mark(S::intern_symbol(ident))
131 }
132 fn with_symbol_string(symbol: &Self::Symbol, f: impl FnOnce(&str)) {
133 S::with_symbol_string(symbol:symbol.unmark(), f)
134 }
135}
136
137macro_rules! define_mark_types_impls {
138 ($($name:ident {
139 $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
140 }),* $(,)?) => {
141 impl<S: Types> Types for MarkedTypes<S> {
142 $(type $name = Marked<S::$name, client::$name>;)*
143 }
144
145 $(impl<S: $name> $name for MarkedTypes<S> {
146 $(fn $method(&mut self, $($arg: $arg_ty),*) $(-> $ret_ty)? {
147 <_>::mark($name::$method(&mut self.0, $($arg.unmark()),*))
148 })*
149 })*
150 }
151}
152with_api!(Self, self_, define_mark_types_impls);
153
154struct Dispatcher<S: Types> {
155 handle_store: HandleStore<S>,
156 server: S,
157}
158
159macro_rules! define_dispatcher_impl {
160 ($($name:ident {
161 $(fn $method:ident($($arg:ident: $arg_ty:ty),* $(,)?) $(-> $ret_ty:ty)?;)*
162 }),* $(,)?) => {
163 // FIXME(eddyb) `pub` only for `ExecutionStrategy` below.
164 pub trait DispatcherTrait {
165 // HACK(eddyb) these are here to allow `Self::$name` to work below.
166 $(type $name;)*
167
168 fn dispatch(&mut self, buf: Buffer) -> Buffer;
169 }
170
171 impl<S: Server> DispatcherTrait for Dispatcher<MarkedTypes<S>> {
172 $(type $name = <MarkedTypes<S> as Types>::$name;)*
173
174 fn dispatch(&mut self, mut buf: Buffer) -> Buffer {
175 let Dispatcher { handle_store, server } = self;
176
177 let mut reader = &buf[..];
178 match api_tags::Method::decode(&mut reader, &mut ()) {
179 $(api_tags::Method::$name(m) => match m {
180 $(api_tags::$name::$method => {
181 let mut call_method = || {
182 reverse_decode!(reader, handle_store; $($arg: $arg_ty),*);
183 $name::$method(server, $($arg),*)
184 };
185 // HACK(eddyb) don't use `panic::catch_unwind` in a panic.
186 // If client and server happen to use the same `std`,
187 // `catch_unwind` asserts that the panic counter was 0,
188 // even when the closure passed to it didn't panic.
189 let r = if thread::panicking() {
190 Ok(call_method())
191 } else {
192 panic::catch_unwind(panic::AssertUnwindSafe(call_method))
193 .map_err(PanicMessage::from)
194 };
195
196 buf.clear();
197 r.encode(&mut buf, handle_store);
198 })*
199 }),*
200 }
201 buf
202 }
203 }
204 }
205}
206with_api!(Self, self_, define_dispatcher_impl);
207
208pub trait ExecutionStrategy {
209 fn run_bridge_and_client(
210 &self,
211 dispatcher: &mut impl DispatcherTrait,
212 input: Buffer,
213 run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
214 force_show_panics: bool,
215 ) -> Buffer;
216}
217
218thread_local! {
219 /// While running a proc-macro with the same-thread executor, this flag will
220 /// be set, forcing nested proc-macro invocations (e.g. due to
221 /// `TokenStream::expand_expr`) to be run using a cross-thread executor.
222 ///
223 /// This is required as the thread-local state in the proc_macro client does
224 /// not handle being re-entered, and will invalidate all `Symbol`s when
225 /// entering a nested macro.
226 static ALREADY_RUNNING_SAME_THREAD: Cell<bool> = const { Cell::new(false) };
227}
228
229/// Keep `ALREADY_RUNNING_SAME_THREAD` (see also its documentation)
230/// set to `true`, preventing same-thread reentrance.
231struct RunningSameThreadGuard(());
232
233impl RunningSameThreadGuard {
234 fn new() -> Self {
235 let already_running: bool = ALREADY_RUNNING_SAME_THREAD.replace(true);
236 assert!(
237 !already_running,
238 "same-thread nesting (\"reentrance\") of proc macro executions is not supported"
239 );
240 RunningSameThreadGuard(())
241 }
242}
243
244impl Drop for RunningSameThreadGuard {
245 fn drop(&mut self) {
246 ALREADY_RUNNING_SAME_THREAD.set(false);
247 }
248}
249
250pub struct MaybeCrossThread<P> {
251 cross_thread: bool,
252 marker: PhantomData<P>,
253}
254
255impl<P> MaybeCrossThread<P> {
256 pub const fn new(cross_thread: bool) -> Self {
257 MaybeCrossThread { cross_thread, marker: PhantomData }
258 }
259}
260
261impl<P> ExecutionStrategy for MaybeCrossThread<P>
262where
263 P: MessagePipe<Buffer> + Send + 'static,
264{
265 fn run_bridge_and_client(
266 &self,
267 dispatcher: &mut impl DispatcherTrait,
268 input: Buffer,
269 run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
270 force_show_panics: bool,
271 ) -> Buffer {
272 if self.cross_thread || ALREADY_RUNNING_SAME_THREAD.get() {
273 <CrossThread<P>>::new().run_bridge_and_client(
274 dispatcher,
275 input,
276 run_client,
277 force_show_panics,
278 )
279 } else {
280 SameThread.run_bridge_and_client(dispatcher, input, run_client, force_show_panics)
281 }
282 }
283}
284
285pub struct SameThread;
286
287impl ExecutionStrategy for SameThread {
288 fn run_bridge_and_client(
289 &self,
290 dispatcher: &mut impl DispatcherTrait,
291 input: Buffer,
292 run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
293 force_show_panics: bool,
294 ) -> Buffer {
295 let _guard: RunningSameThreadGuard = RunningSameThreadGuard::new();
296
297 let mut dispatch: impl FnMut(Buffer) -> Buffer = |buf: Buffer| dispatcher.dispatch(buf);
298
299 run_client(BridgeConfig {
300 input,
301 dispatch: (&mut dispatch).into(),
302 force_show_panics,
303 _marker: marker::PhantomData,
304 })
305 }
306}
307
308pub struct CrossThread<P>(PhantomData<P>);
309
310impl<P> CrossThread<P> {
311 pub const fn new() -> Self {
312 CrossThread(PhantomData)
313 }
314}
315
316impl<P> ExecutionStrategy for CrossThread<P>
317where
318 P: MessagePipe<Buffer> + Send + 'static,
319{
320 fn run_bridge_and_client(
321 &self,
322 dispatcher: &mut impl DispatcherTrait,
323 input: Buffer,
324 run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
325 force_show_panics: bool,
326 ) -> Buffer {
327 let (mut server, mut client) = P::new();
328
329 let join_handle = thread::spawn(move || {
330 let mut dispatch = |b: Buffer| -> Buffer {
331 client.send(b);
332 client.recv().expect("server died while client waiting for reply")
333 };
334
335 run_client(BridgeConfig {
336 input,
337 dispatch: (&mut dispatch).into(),
338 force_show_panics,
339 _marker: marker::PhantomData,
340 })
341 });
342
343 while let Some(b) = server.recv() {
344 server.send(dispatcher.dispatch(b));
345 }
346
347 join_handle.join().unwrap()
348 }
349}
350
351/// A message pipe used for communicating between server and client threads.
352pub trait MessagePipe<T>: Sized {
353 /// Create a new pair of endpoints for the message pipe.
354 fn new() -> (Self, Self);
355
356 /// Send a message to the other endpoint of this pipe.
357 fn send(&mut self, value: T);
358
359 /// Receive a message from the other endpoint of this pipe.
360 ///
361 /// Returns `None` if the other end of the pipe has been destroyed, and no
362 /// message was received.
363 fn recv(&mut self) -> Option<T>;
364}
365
366fn run_server<
367 S: Server,
368 I: Encode<HandleStore<MarkedTypes<S>>>,
369 O: for<'a, 's> DecodeMut<'a, 's, HandleStore<MarkedTypes<S>>>,
370>(
371 strategy: &impl ExecutionStrategy,
372 handle_counters: &'static client::HandleCounters,
373 server: S,
374 input: I,
375 run_client: extern "C" fn(BridgeConfig<'_>) -> Buffer,
376 force_show_panics: bool,
377) -> Result<O, PanicMessage> {
378 let mut dispatcher: Dispatcher> =
379 Dispatcher { handle_store: HandleStore::new(handle_counters), server: MarkedTypes(server) };
380
381 let globals: ExpnGlobals::Span, …>> = dispatcher.server.globals();
382
383 let mut buf: Buffer = Buffer::new();
384 (globals, input).encode(&mut buf, &mut dispatcher.handle_store);
385
386 buf = strategy.run_bridge_and_client(&mut dispatcher, input:buf, run_client, force_show_panics);
387
388 Result::decode(&mut &buf[..], &mut dispatcher.handle_store)
389}
390
391impl client::Client<crate::TokenStream, crate::TokenStream> {
392 pub fn run<S>(
393 &self,
394 strategy: &impl ExecutionStrategy,
395 server: S,
396 input: S::TokenStream,
397 force_show_panics: bool,
398 ) -> Result<S::TokenStream, PanicMessage>
399 where
400 S: Server,
401 S::TokenStream: Default,
402 {
403 let client::Client { get_handle_counters: fn() -> &HandleCounters, run: fn(BridgeConfig<'_>) -> Buffer, _marker: PhantomData …> } = *self;
404 run_server(
405 strategy,
406 get_handle_counters(),
407 server,
408 <MarkedTypes<S> as Types>::TokenStream::mark(input),
409 run,
410 force_show_panics,
411 )
412 .map(|s: Option::TokenStream, …>>| <Option<<MarkedTypes<S> as Types>::TokenStream>>::unmark(self:s).unwrap_or_default())
413 }
414}
415
416impl client::Client<(crate::TokenStream, crate::TokenStream), crate::TokenStream> {
417 pub fn run<S>(
418 &self,
419 strategy: &impl ExecutionStrategy,
420 server: S,
421 input: S::TokenStream,
422 input2: S::TokenStream,
423 force_show_panics: bool,
424 ) -> Result<S::TokenStream, PanicMessage>
425 where
426 S: Server,
427 S::TokenStream: Default,
428 {
429 let client::Client { get_handle_counters, run, _marker } = *self;
430 run_server(
431 strategy,
432 get_handle_counters(),
433 server,
434 (
435 <MarkedTypes<S> as Types>::TokenStream::mark(input),
436 <MarkedTypes<S> as Types>::TokenStream::mark(input2),
437 ),
438 run,
439 force_show_panics,
440 )
441 .map(|s| <Option<<MarkedTypes<S> as Types>::TokenStream>>::unmark(s).unwrap_or_default())
442 }
443}
444