1use crate::{Error, Request, Response};
2
3/// Chained processing of request (and response).
4///
5/// # Middleware as `fn`
6///
7/// The middleware trait is implemented for all functions that have the signature
8///
9/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
10///
11/// That means the easiest way to implement middleware is by providing a `fn`, like so
12///
13/// ```no_run
14/// # use ureq::{Request, Response, MiddlewareNext, Error};
15/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
16/// // do middleware things
17///
18/// // continue the middleware chain
19/// next.handle(req)
20/// }
21/// ```
22///
23/// # Adding headers
24///
25/// A common use case is to add headers to the outgoing request. Here an example of how.
26///
27/// ```no_run
28/// # #[cfg(feature = "json")]
29/// # fn main() -> Result<(), ureq::Error> {
30/// # use ureq::{Request, Response, MiddlewareNext, Error};
31/// # ureq::is_test(true);
32/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
33/// // set my bespoke header and continue the chain
34/// next.handle(req.set("X-My-Header", "value_42"))
35/// }
36///
37/// let agent = ureq::builder()
38/// .middleware(my_middleware)
39/// .build();
40///
41/// let result: serde_json::Value =
42/// agent.get("http://httpbin.org/headers").call()?.into_json()?;
43///
44/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
45///
46/// # Ok(()) }
47/// # #[cfg(not(feature = "json"))]
48/// # fn main() {}
49/// ```
50///
51/// # State
52///
53/// To maintain state between middleware invocations, we need to do something more elaborate than
54/// the simple `fn` and implement the `Middleware` trait directly.
55///
56/// ## Example with mutex lock
57///
58/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
59/// a mutex lock like shown below.
60///
61/// ```no_run
62/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
63/// # use std::sync::{Arc, Mutex};
64/// struct MyState {
65/// // whatever is needed
66/// }
67///
68/// struct MyMiddleware(Arc<Mutex<MyState>>);
69///
70/// impl Middleware for MyMiddleware {
71/// fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
72/// // These extra brackets ensures we release the Mutex lock before continuing the
73/// // chain. There could also be scenarios where we want to maintain the lock through
74/// // the invocation, which would block other requests from proceeding concurrently
75/// // through the middleware.
76/// {
77/// let mut state = self.0.lock().unwrap();
78/// // do stuff with state
79/// }
80///
81/// // continue middleware chain
82/// next.handle(request)
83/// }
84/// }
85/// ```
86///
87/// ## Example with atomic
88///
89/// This example shows how we can increase a counter for each request going
90/// through the agent.
91///
92/// ```no_run
93/// # fn main() -> Result<(), ureq::Error> {
94/// # ureq::is_test(true);
95/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
96/// use std::sync::atomic::{AtomicU64, Ordering};
97/// use std::sync::Arc;
98///
99/// // Middleware that stores a counter state. This example uses an AtomicU64
100/// // since the middleware is potentially shared by multiple threads running
101/// // requests at the same time.
102/// struct MyCounter(Arc<AtomicU64>);
103///
104/// impl Middleware for MyCounter {
105/// fn handle(&self, req: Request, next: MiddlewareNext) -> Result<Response, Error> {
106/// // increase the counter for each invocation
107/// self.0.fetch_add(1, Ordering::SeqCst);
108///
109/// // continue the middleware chain
110/// next.handle(req)
111/// }
112/// }
113///
114/// let shared_counter = Arc::new(AtomicU64::new(0));
115///
116/// let agent = ureq::builder()
117/// // Add our middleware
118/// .middleware(MyCounter(shared_counter.clone()))
119/// .build();
120///
121/// agent.get("http://httpbin.org/get").call()?;
122/// agent.get("http://httpbin.org/get").call()?;
123///
124/// // Check we did indeed increase the counter twice.
125/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
126///
127/// # Ok(()) }
128/// ```
129pub trait Middleware: Send + Sync + 'static {
130 /// Handle of the middleware logic.
131 fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error>;
132}
133
134/// Continuation of a [`Middleware`] chain.
135pub struct MiddlewareNext<'a> {
136 pub(crate) chain: &'a mut (dyn Iterator<Item = &'a dyn Middleware>),
137 // Since request_fn consumes the Payload<'a>, we must have an FnOnce.
138 //
139 // It's possible to get rid of this Box if we make MiddlewareNext generic
140 // over some type variable, i.e. MiddlewareNext<'a, R> where R: FnOnce...
141 // however that would "leak" to Middleware::handle introducing a complicated
142 // type signature that is totally irrelevant for someone implementing a middleware.
143 //
144 // So in the name of having a sane external API, we accept this Box.
145 pub(crate) request_fn: Box<dyn FnOnce(Request) -> Result<Response, Error> + 'a>,
146}
147
148impl<'a> MiddlewareNext<'a> {
149 /// Continue the middleware chain by providing (a possibly amended) [`Request`].
150 pub fn handle(self, request: Request) -> Result<Response, Error> {
151 if let Some(step: &dyn Middleware) = self.chain.next() {
152 step.handle(request, self)
153 } else {
154 (self.request_fn)(request)
155 }
156 }
157}
158
159impl<F> Middleware for F
160where
161 F: Fn(Request, MiddlewareNext) -> Result<Response, Error> + Send + Sync + 'static,
162{
163 fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
164 (self)(request, next)
165 }
166}
167