1 | /* |
2 | * |
3 | * Copyright 2018 gRPC authors. |
4 | * |
5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
6 | * you may not use this file except in compliance with the License. |
7 | * You may obtain a copy of the License at |
8 | * |
9 | * http://www.apache.org/licenses/LICENSE-2.0 |
10 | * |
11 | * Unless required by applicable law or agreed to in writing, software |
12 | * distributed under the License is distributed on an "AS IS" BASIS, |
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | * See the License for the specific language governing permissions and |
15 | * limitations under the License. |
16 | * |
17 | */ |
18 | |
19 | #ifndef GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H |
20 | #define GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H |
21 | |
22 | #include <array> |
23 | #include <functional> |
24 | |
25 | #include <grpcpp/impl/codegen/call.h> |
26 | #include <grpcpp/impl/codegen/call_op_set_interface.h> |
27 | #include <grpcpp/impl/codegen/client_interceptor.h> |
28 | #include <grpcpp/impl/codegen/intercepted_channel.h> |
29 | #include <grpcpp/impl/codegen/server_interceptor.h> |
30 | |
31 | #include <grpc/impl/codegen/grpc_types.h> |
32 | |
33 | namespace grpc { |
34 | namespace internal { |
35 | |
36 | class InterceptorBatchMethodsImpl |
37 | : public experimental::InterceptorBatchMethods { |
38 | public: |
39 | InterceptorBatchMethodsImpl() { |
40 | for (auto i = static_cast<experimental::InterceptionHookPoints>(0); |
41 | i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; |
42 | i = static_cast<experimental::InterceptionHookPoints>( |
43 | static_cast<size_t>(i) + 1)) { |
44 | hooks_[static_cast<size_t>(i)] = false; |
45 | } |
46 | } |
47 | |
48 | ~InterceptorBatchMethodsImpl() {} |
49 | |
50 | bool QueryInterceptionHookPoint( |
51 | experimental::InterceptionHookPoints type) override { |
52 | return hooks_[static_cast<size_t>(type)]; |
53 | } |
54 | |
55 | void Proceed() override { |
56 | if (call_->client_rpc_info() != nullptr) { |
57 | return ProceedClient(); |
58 | } |
59 | GPR_CODEGEN_ASSERT(call_->server_rpc_info() != nullptr); |
60 | ProceedServer(); |
61 | } |
62 | |
63 | void Hijack() override { |
64 | // Only the client can hijack when sending down initial metadata |
65 | GPR_CODEGEN_ASSERT(!reverse_ && ops_ != nullptr && |
66 | call_->client_rpc_info() != nullptr); |
67 | // It is illegal to call Hijack twice |
68 | GPR_CODEGEN_ASSERT(!ran_hijacking_interceptor_); |
69 | auto* rpc_info = call_->client_rpc_info(); |
70 | rpc_info->hijacked_ = true; |
71 | rpc_info->hijacked_interceptor_ = current_interceptor_index_; |
72 | ClearHookPoints(); |
73 | ops_->SetHijackingState(); |
74 | ran_hijacking_interceptor_ = true; |
75 | rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
76 | } |
77 | |
78 | void AddInterceptionHookPoint(experimental::InterceptionHookPoints type) { |
79 | hooks_[static_cast<size_t>(type)] = true; |
80 | } |
81 | |
82 | ByteBuffer* GetSerializedSendMessage() override { |
83 | GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); |
84 | if (*orig_send_message_ != nullptr) { |
85 | GPR_CODEGEN_ASSERT(serializer_(*orig_send_message_).ok()); |
86 | *orig_send_message_ = nullptr; |
87 | } |
88 | return send_message_; |
89 | } |
90 | |
91 | const void* GetSendMessage() override { |
92 | GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); |
93 | return *orig_send_message_; |
94 | } |
95 | |
96 | void ModifySendMessage(const void* message) override { |
97 | GPR_CODEGEN_ASSERT(orig_send_message_ != nullptr); |
98 | *orig_send_message_ = message; |
99 | } |
100 | |
101 | bool GetSendMessageStatus() override { return !*fail_send_message_; } |
102 | |
103 | std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { |
104 | return send_initial_metadata_; |
105 | } |
106 | |
107 | Status GetSendStatus() override { |
108 | return Status(static_cast<StatusCode>(*code_), *error_message_, |
109 | *error_details_); |
110 | } |
111 | |
112 | void ModifySendStatus(const Status& status) override { |
113 | *code_ = static_cast<grpc_status_code>(status.error_code()); |
114 | *error_details_ = status.error_details(); |
115 | *error_message_ = status.error_message(); |
116 | } |
117 | |
118 | std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata() |
119 | override { |
120 | return send_trailing_metadata_; |
121 | } |
122 | |
123 | void* GetRecvMessage() override { return recv_message_; } |
124 | |
125 | std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() |
126 | override { |
127 | return recv_initial_metadata_->map(); |
128 | } |
129 | |
130 | Status* GetRecvStatus() override { return recv_status_; } |
131 | |
132 | void FailHijackedSendMessage() override { |
133 | GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( |
134 | experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]); |
135 | *fail_send_message_ = true; |
136 | } |
137 | |
138 | std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() |
139 | override { |
140 | return recv_trailing_metadata_->map(); |
141 | } |
142 | |
143 | void SetSendMessage(ByteBuffer* buf, const void** msg, |
144 | bool* fail_send_message, |
145 | std::function<Status(const void*)> serializer) { |
146 | send_message_ = buf; |
147 | orig_send_message_ = msg; |
148 | fail_send_message_ = fail_send_message; |
149 | serializer_ = serializer; |
150 | } |
151 | |
152 | void SetSendInitialMetadata( |
153 | std::multimap<grpc::string, grpc::string>* metadata) { |
154 | send_initial_metadata_ = metadata; |
155 | } |
156 | |
157 | void SetSendStatus(grpc_status_code* code, grpc::string* error_details, |
158 | grpc::string* error_message) { |
159 | code_ = code; |
160 | error_details_ = error_details; |
161 | error_message_ = error_message; |
162 | } |
163 | |
164 | void SetSendTrailingMetadata( |
165 | std::multimap<grpc::string, grpc::string>* metadata) { |
166 | send_trailing_metadata_ = metadata; |
167 | } |
168 | |
169 | void SetRecvMessage(void* message, bool* hijacked_recv_message_failed) { |
170 | recv_message_ = message; |
171 | hijacked_recv_message_failed_ = hijacked_recv_message_failed; |
172 | } |
173 | |
174 | void SetRecvInitialMetadata(MetadataMap* map) { |
175 | recv_initial_metadata_ = map; |
176 | } |
177 | |
178 | void SetRecvStatus(Status* status) { recv_status_ = status; } |
179 | |
180 | void SetRecvTrailingMetadata(MetadataMap* map) { |
181 | recv_trailing_metadata_ = map; |
182 | } |
183 | |
184 | std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { |
185 | auto* info = call_->client_rpc_info(); |
186 | if (info == nullptr) { |
187 | return std::unique_ptr<ChannelInterface>(nullptr); |
188 | } |
189 | // The intercepted channel starts from the interceptor just after the |
190 | // current interceptor |
191 | return std::unique_ptr<ChannelInterface>(new InterceptedChannel( |
192 | info->channel(), current_interceptor_index_ + 1)); |
193 | } |
194 | |
195 | void FailHijackedRecvMessage() override { |
196 | GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>( |
197 | experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)]); |
198 | *hijacked_recv_message_failed_ = true; |
199 | } |
200 | |
201 | // Clears all state |
202 | void ClearState() { |
203 | reverse_ = false; |
204 | ran_hijacking_interceptor_ = false; |
205 | ClearHookPoints(); |
206 | } |
207 | |
208 | // Prepares for Post_recv operations |
209 | void SetReverse() { |
210 | reverse_ = true; |
211 | ran_hijacking_interceptor_ = false; |
212 | ClearHookPoints(); |
213 | } |
214 | |
215 | // This needs to be set before interceptors are run |
216 | void SetCall(Call* call) { call_ = call; } |
217 | |
218 | // This needs to be set before interceptors are run using RunInterceptors(). |
219 | // Alternatively, RunInterceptors(std::function<void(void)> f) can be used. |
220 | void SetCallOpSetInterface(CallOpSetInterface* ops) { ops_ = ops; } |
221 | |
222 | // SetCall should have been called before this. |
223 | // Returns true if the interceptors list is empty |
224 | bool InterceptorsListEmpty() { |
225 | auto* client_rpc_info = call_->client_rpc_info(); |
226 | if (client_rpc_info != nullptr) { |
227 | if (client_rpc_info->interceptors_.size() == 0) { |
228 | return true; |
229 | } else { |
230 | return false; |
231 | } |
232 | } |
233 | |
234 | auto* server_rpc_info = call_->server_rpc_info(); |
235 | if (server_rpc_info == nullptr || |
236 | server_rpc_info->interceptors_.size() == 0) { |
237 | return true; |
238 | } |
239 | return false; |
240 | } |
241 | |
242 | // This should be used only by subclasses of CallOpSetInterface. SetCall and |
243 | // SetCallOpSetInterface should have been called before this. After all the |
244 | // interceptors are done running, either ContinueFillOpsAfterInterception or |
245 | // ContinueFinalizeOpsAfterInterception will be called. Note that neither of |
246 | // them is invoked if there were no interceptors registered. |
247 | bool RunInterceptors() { |
248 | GPR_CODEGEN_ASSERT(ops_); |
249 | auto* client_rpc_info = call_->client_rpc_info(); |
250 | if (client_rpc_info != nullptr) { |
251 | if (client_rpc_info->interceptors_.size() == 0) { |
252 | return true; |
253 | } else { |
254 | RunClientInterceptors(); |
255 | return false; |
256 | } |
257 | } |
258 | |
259 | auto* server_rpc_info = call_->server_rpc_info(); |
260 | if (server_rpc_info == nullptr || |
261 | server_rpc_info->interceptors_.size() == 0) { |
262 | return true; |
263 | } |
264 | RunServerInterceptors(); |
265 | return false; |
266 | } |
267 | |
268 | // Returns true if no interceptors are run. Returns false otherwise if there |
269 | // are interceptors registered. After the interceptors are done running \a f |
270 | // will be invoked. This is to be used only by BaseAsyncRequest and |
271 | // SyncRequest. |
272 | bool RunInterceptors(std::function<void(void)> f) { |
273 | // This is used only by the server for initial call request |
274 | GPR_CODEGEN_ASSERT(reverse_ == true); |
275 | GPR_CODEGEN_ASSERT(call_->client_rpc_info() == nullptr); |
276 | auto* server_rpc_info = call_->server_rpc_info(); |
277 | if (server_rpc_info == nullptr || |
278 | server_rpc_info->interceptors_.size() == 0) { |
279 | return true; |
280 | } |
281 | callback_ = std::move(f); |
282 | RunServerInterceptors(); |
283 | return false; |
284 | } |
285 | |
286 | private: |
287 | void RunClientInterceptors() { |
288 | auto* rpc_info = call_->client_rpc_info(); |
289 | if (!reverse_) { |
290 | current_interceptor_index_ = 0; |
291 | } else { |
292 | if (rpc_info->hijacked_) { |
293 | current_interceptor_index_ = rpc_info->hijacked_interceptor_; |
294 | } else { |
295 | current_interceptor_index_ = rpc_info->interceptors_.size() - 1; |
296 | } |
297 | } |
298 | rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
299 | } |
300 | |
301 | void RunServerInterceptors() { |
302 | auto* rpc_info = call_->server_rpc_info(); |
303 | if (!reverse_) { |
304 | current_interceptor_index_ = 0; |
305 | } else { |
306 | current_interceptor_index_ = rpc_info->interceptors_.size() - 1; |
307 | } |
308 | rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
309 | } |
310 | |
311 | void ProceedClient() { |
312 | auto* rpc_info = call_->client_rpc_info(); |
313 | if (rpc_info->hijacked_ && !reverse_ && |
314 | current_interceptor_index_ == rpc_info->hijacked_interceptor_ && |
315 | !ran_hijacking_interceptor_) { |
316 | // We now need to provide hijacked recv ops to this interceptor |
317 | ClearHookPoints(); |
318 | ops_->SetHijackingState(); |
319 | ran_hijacking_interceptor_ = true; |
320 | rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
321 | return; |
322 | } |
323 | if (!reverse_) { |
324 | current_interceptor_index_++; |
325 | // We are going down the stack of interceptors |
326 | if (current_interceptor_index_ < rpc_info->interceptors_.size()) { |
327 | if (rpc_info->hijacked_ && |
328 | current_interceptor_index_ > rpc_info->hijacked_interceptor_) { |
329 | // This is a hijacked RPC and we are done with hijacking |
330 | ops_->ContinueFillOpsAfterInterception(); |
331 | } else { |
332 | rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
333 | } |
334 | } else { |
335 | // we are done running all the interceptors without any hijacking |
336 | ops_->ContinueFillOpsAfterInterception(); |
337 | } |
338 | } else { |
339 | // We are going up the stack of interceptors |
340 | if (current_interceptor_index_ > 0) { |
341 | // Continue running interceptors |
342 | current_interceptor_index_--; |
343 | rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
344 | } else { |
345 | // we are done running all the interceptors without any hijacking |
346 | ops_->ContinueFinalizeResultAfterInterception(); |
347 | } |
348 | } |
349 | } |
350 | |
351 | void ProceedServer() { |
352 | auto* rpc_info = call_->server_rpc_info(); |
353 | if (!reverse_) { |
354 | current_interceptor_index_++; |
355 | if (current_interceptor_index_ < rpc_info->interceptors_.size()) { |
356 | return rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
357 | } else if (ops_) { |
358 | return ops_->ContinueFillOpsAfterInterception(); |
359 | } |
360 | } else { |
361 | // We are going up the stack of interceptors |
362 | if (current_interceptor_index_ > 0) { |
363 | // Continue running interceptors |
364 | current_interceptor_index_--; |
365 | return rpc_info->RunInterceptor(interceptor_methods: this, pos: current_interceptor_index_); |
366 | } else if (ops_) { |
367 | return ops_->ContinueFinalizeResultAfterInterception(); |
368 | } |
369 | } |
370 | GPR_CODEGEN_ASSERT(callback_); |
371 | callback_(); |
372 | } |
373 | |
374 | void ClearHookPoints() { |
375 | for (auto i = static_cast<experimental::InterceptionHookPoints>(0); |
376 | i < experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS; |
377 | i = static_cast<experimental::InterceptionHookPoints>( |
378 | static_cast<size_t>(i) + 1)) { |
379 | hooks_[static_cast<size_t>(i)] = false; |
380 | } |
381 | } |
382 | |
383 | std::array<bool, |
384 | static_cast<size_t>( |
385 | experimental::InterceptionHookPoints::NUM_INTERCEPTION_HOOKS)> |
386 | hooks_; |
387 | |
388 | size_t current_interceptor_index_ = 0; // Current iterator |
389 | bool reverse_ = false; |
390 | bool ran_hijacking_interceptor_ = false; |
391 | Call* call_ = nullptr; // The Call object is present along with CallOpSet |
392 | // object/callback |
393 | CallOpSetInterface* ops_ = nullptr; |
394 | std::function<void(void)> callback_; |
395 | |
396 | ByteBuffer* send_message_ = nullptr; |
397 | bool* fail_send_message_ = nullptr; |
398 | const void** orig_send_message_ = nullptr; |
399 | std::function<Status(const void*)> serializer_; |
400 | |
401 | std::multimap<grpc::string, grpc::string>* send_initial_metadata_; |
402 | |
403 | grpc_status_code* code_ = nullptr; |
404 | grpc::string* error_details_ = nullptr; |
405 | grpc::string* error_message_ = nullptr; |
406 | |
407 | std::multimap<grpc::string, grpc::string>* send_trailing_metadata_ = nullptr; |
408 | |
409 | void* recv_message_ = nullptr; |
410 | bool* hijacked_recv_message_failed_ = nullptr; |
411 | |
412 | MetadataMap* recv_initial_metadata_ = nullptr; |
413 | |
414 | Status* recv_status_ = nullptr; |
415 | |
416 | MetadataMap* recv_trailing_metadata_ = nullptr; |
417 | }; |
418 | |
419 | // A special implementation of InterceptorBatchMethods to send a Cancel |
420 | // notification down the interceptor stack |
421 | class CancelInterceptorBatchMethods |
422 | : public experimental::InterceptorBatchMethods { |
423 | public: |
424 | bool QueryInterceptionHookPoint( |
425 | experimental::InterceptionHookPoints type) override { |
426 | if (type == experimental::InterceptionHookPoints::PRE_SEND_CANCEL) { |
427 | return true; |
428 | } else { |
429 | return false; |
430 | } |
431 | } |
432 | |
433 | void Proceed() override { |
434 | // This is a no-op. For actual continuation of the RPC simply needs to |
435 | // return from the Intercept method |
436 | } |
437 | |
438 | void Hijack() override { |
439 | // Only the client can hijack when sending down initial metadata |
440 | GPR_CODEGEN_ASSERT(false && |
441 | "It is illegal to call Hijack on a method which has a " |
442 | "Cancel notification" ); |
443 | } |
444 | |
445 | ByteBuffer* GetSerializedSendMessage() override { |
446 | GPR_CODEGEN_ASSERT(false && |
447 | "It is illegal to call GetSendMessage on a method which " |
448 | "has a Cancel notification" ); |
449 | return nullptr; |
450 | } |
451 | |
452 | bool GetSendMessageStatus() override { |
453 | GPR_CODEGEN_ASSERT( |
454 | false && |
455 | "It is illegal to call GetSendMessageStatus on a method which " |
456 | "has a Cancel notification" ); |
457 | return false; |
458 | } |
459 | |
460 | const void* GetSendMessage() override { |
461 | GPR_CODEGEN_ASSERT( |
462 | false && |
463 | "It is illegal to call GetOriginalSendMessage on a method which " |
464 | "has a Cancel notification" ); |
465 | return nullptr; |
466 | } |
467 | |
468 | void ModifySendMessage(const void* /*message*/) override { |
469 | GPR_CODEGEN_ASSERT( |
470 | false && |
471 | "It is illegal to call ModifySendMessage on a method which " |
472 | "has a Cancel notification" ); |
473 | } |
474 | |
475 | std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override { |
476 | GPR_CODEGEN_ASSERT(false && |
477 | "It is illegal to call GetSendInitialMetadata on a " |
478 | "method which has a Cancel notification" ); |
479 | return nullptr; |
480 | } |
481 | |
482 | Status GetSendStatus() override { |
483 | GPR_CODEGEN_ASSERT(false && |
484 | "It is illegal to call GetSendStatus on a method which " |
485 | "has a Cancel notification" ); |
486 | return Status(); |
487 | } |
488 | |
489 | void ModifySendStatus(const Status& /*status*/) override { |
490 | GPR_CODEGEN_ASSERT(false && |
491 | "It is illegal to call ModifySendStatus on a method " |
492 | "which has a Cancel notification" ); |
493 | return; |
494 | } |
495 | |
496 | std::multimap<grpc::string, grpc::string>* GetSendTrailingMetadata() |
497 | override { |
498 | GPR_CODEGEN_ASSERT(false && |
499 | "It is illegal to call GetSendTrailingMetadata on a " |
500 | "method which has a Cancel notification" ); |
501 | return nullptr; |
502 | } |
503 | |
504 | void* GetRecvMessage() override { |
505 | GPR_CODEGEN_ASSERT(false && |
506 | "It is illegal to call GetRecvMessage on a method which " |
507 | "has a Cancel notification" ); |
508 | return nullptr; |
509 | } |
510 | |
511 | std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvInitialMetadata() |
512 | override { |
513 | GPR_CODEGEN_ASSERT(false && |
514 | "It is illegal to call GetRecvInitialMetadata on a " |
515 | "method which has a Cancel notification" ); |
516 | return nullptr; |
517 | } |
518 | |
519 | Status* GetRecvStatus() override { |
520 | GPR_CODEGEN_ASSERT(false && |
521 | "It is illegal to call GetRecvStatus on a method which " |
522 | "has a Cancel notification" ); |
523 | return nullptr; |
524 | } |
525 | |
526 | std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata() |
527 | override { |
528 | GPR_CODEGEN_ASSERT(false && |
529 | "It is illegal to call GetRecvTrailingMetadata on a " |
530 | "method which has a Cancel notification" ); |
531 | return nullptr; |
532 | } |
533 | |
534 | std::unique_ptr<ChannelInterface> GetInterceptedChannel() override { |
535 | GPR_CODEGEN_ASSERT(false && |
536 | "It is illegal to call GetInterceptedChannel on a " |
537 | "method which has a Cancel notification" ); |
538 | return std::unique_ptr<ChannelInterface>(nullptr); |
539 | } |
540 | |
541 | void FailHijackedRecvMessage() override { |
542 | GPR_CODEGEN_ASSERT(false && |
543 | "It is illegal to call FailHijackedRecvMessage on a " |
544 | "method which has a Cancel notification" ); |
545 | } |
546 | |
547 | void FailHijackedSendMessage() override { |
548 | GPR_CODEGEN_ASSERT(false && |
549 | "It is illegal to call FailHijackedSendMessage on a " |
550 | "method which has a Cancel notification" ); |
551 | } |
552 | }; |
553 | } // namespace internal |
554 | } // namespace grpc |
555 | |
556 | #endif // GRPCPP_IMPL_CODEGEN_INTERCEPTOR_COMMON_H |
557 | |