Making a secure Axum route
How to build a secure Axum route using JWT with tests
In this blog series, I will demonstrate how to build an Axum route that is secured by a JWT token. I will also show how to write a black box test for the route to verify that it works.
Prerequisites
You need to have Rust installed, if you don't please install it using rustup.
Versions
In case you are a time traveler living in the future, this post was written using Rust 1.74.0 and Axum 0.7.1 if things have moved a lot since the time of writing it you might need to do some manual tweaks to things.
Create base project
Create a new cargo project
cargo init secure-axum && cd secure-axum
Next, let's add axum and tokio to the project
cargo add axum
cargo add -F full tokio
Now let's add a route to src/main.rs
use axum::{routing::get, Router};
#[tokio::main]
async fn main() {
let router = Router::new().route("/", get(read));
// run it
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
println!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, router).await.unwrap();
}
async fn read() -> &'static str {
"Hello world"
}
Our web server listens to 127.0.0.1:3000
and even tells us that when we run
it. It will simply return the string "Hello world" if we try it out.
So let's do that!
cargo run
Open a new terminal and send a request to the service.
>curl 127.0.0.1:3000
Hello world
Test harness
Great it works! But let's improve our development experience. We don't want to fiddle around with manually testing that what we do actually works, as it will become really tedious once we introduce the authorization layer to the route.
So let's set up a test harness that will start our service and do the request for us.
💡 See my Rust tip on testing using error handling if you are wondering why we add anyhow here.
cargo add --dev anyhow
Now let's add the tests as a test module at the bottom of src/main.rs
.
#[cfg(test)]
mod test {
#[tokio::test]
async fn hello_world_test() -> anyhow::Result<()> {
unimplemented!("We need to add the logic to the test here");
Ok(())
}
}
Let's see if we can try to implement our test now. We want to verify that the body returned from the request is "Hello world". Start by adding reqwest as we will use that when making HTTP calls.
cargo add --dev reqwest
#[tokio::test]
async fn hello_world_test() -> anyhow::Result<()> {
let client = reqwest::Client::new();
let url = ""; // TODO What is our url?
let response = client.get(url).send().await?.text().await?;
assert_eq!(response, "Hello world");
Ok(())
}
}
But we don't have our service running, so let's start it. We do this by creating a router with the same setup as we have in our service and then starting it using Axum.
mod test {
use super::*;
use std::net::SocketAddr;
use tokio::net::TcpListener;
#[tokio::test]
async fn hello_world_test() -> anyhow::Result<()> {
let router = Router::new().route("/", get(read));
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>()?)
.await?;
let addr = listener.local_addr()?;
tokio::spawn(async move {
axum::serve(listener, router).await?;
});
let client = reqwest::Client::new();
let url = format!("http://{addr}");
let response = client.get(url).send().await?.text().await?;
assert_eq!(response, "Hello world");
Ok(())
}
If we run our test now, it actually passes. It is important to note that we are actually testing our public API using a proper HTTP request with no knowledge of what happens inside the router. We are testing that a certain input gives us a certain output, a proper black box test.
But there are some things that are a bit messy with that test. We are doing a lot of setup that is not related to the test itself. Let's take care of that. One thing about Rust's test system is that it does not have setup/teardown functionality like you might be used to in other languages. We are still able to get around this issue, levering Rusts own structures.
TestApp
So let's create a TestApp inside the test module (just above the hello_world_test). We will have it remembering the address of the test server and also keep a copy of the reqwest client for re-use.
#[cfg(test)]
mod test {
use super::*;
use std::net::SocketAddr;
use tokio::net::TcpListener;
#[derive(Clone)]
pub struct TestApp {
pub client: reqwest::Client,
addr: String,
}
#[tokio::test]
async fn hello_world_test() -> anyhow::Result<()> {
[...]
Next let's implement the new method on it and move the creation of the test server and the reqwest client inside it. We will also add an url helper method to construct the url for a certain path to make life easier.
#[cfg(test)]
mod test {
use super::*;
use std::net::SocketAddr;
use tokio::net::TcpListener;
use reqwest::Url;
#[derive(Clone)]
pub struct TestApp {
pub client: reqwest::Client,
addr: String,
}
impl TestApp {
pub async fn new() -> anyhow::result<Self> {
let router = Router::new().route("/", get(read));
let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>()?)
.await?;
let addr = listener.local_addr()?;
tokio::spawn(async move {
axum::serve(listener, router).await?;
});
let client = reqwest::Client::new();
Ok(Self {
client,
addr: addr.to_string(),
})
}
pub fn url<S: AsRef<str>>(&self, path: S) -> anyhow::Result<String>>>>>>>> {
let base = Url::parse(format!("http://{}", self.addr).as_ref())?;
let url = base.join(path.as_ref())?;
Ok(url.as_str().to_string())
}
}
#[tokio::test]
async fn hello_world_test() -> anyhow::Result<()> {
[...]
Lastly, we update the test to use the TestApp.
#[tokio::test]
async fn hello_world_test() -> anyhow::Result<()> {
let test_app = TestApp::new().await;
let response = test_app
.client
.get(test_app.url("/"))
.send()
.await?
.text()
.await?;
assert_eq!(response, "Hello world");
Ok(())
}
Now I would say that is a pretty neat test. If you don't fancy the send().await? chain we have to do, you can easily hide that in a function called get_root or similar.
Securing the route
Now with all of that foundational work done, we can get on with the real meat of the post. Making our end-point secure using Axums handy middleware system.
So our goal is to have a way to take a JWT token as input, extract and validate it and only then be allowed to see our super secret greeting!
First, let's update the handler to this.
// Update the use statement to include the Extension
use axum::{routing::get, Router, Extension};
// Update the handler to look like this.
async fn read(Extension(claims): Extension<Authorized>) -> String {
let email = claims.0.email;
format!("Hello world {email}")
}
Well, that was not too bad, was it? We can see that we have an Extension that holds an Authorized type (which we will be implementing soon) and out of that we can get to the claims (data contained inside a JWT is referred to as claims) containing an email. Basically, the Extension is a way to share data between different routes in Axum.
Let's create the Authorized structure above the main function.
#[derive(Debug, Clone)]
pub struct Authorized(pub Claims);
Basically it only holds the Claims structure, and we derive the Debug and Clone traits on it. Claims next!
cargo add -F serde chrono
cargo add -F derive serde
// At the top of the file add these lines
use serde::{Serialize, Deserialize};
use chrono::{serde::ts_seconds, DateTime, Utc};
// Add next to Authorized
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub email: String,
#[serde(with = "ts_seconds")]
pub exp: DateTime<Utc>,
}
A real claims structure holds more information, but for this exercise we only care about the email and the exp (expiry) field is required.
So now we have resolved all errors from that we can see when we do cargo check
, so let's see if our test passes?
❯ cargo test
Compiling secure-axum v0.1.0 (/home/sedrik/devarea/secure-axum)
Finished test [unoptimized + debuginfo] target(s) in 1.89s
Running unittests src/main.rs (target/debug/deps/secure_axum-7f6e6c701d568cfa)
running 1 test
test test::hello_world_test ... FAILED
failures:
---- test::hello_world_test stdout ----
thread 'test::hello_world_test' panicked at src/main.rs:81:9:
assertion `left == right` failed
left: "Missing request extension: Extension of type `secure_axum::Authorized`
was not found. Perhaps you forgot to add it? See `axum::Extension`."
right: "Hello world"
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
failures:
test::hello_world_test
test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.02s
So what this is saying is that it can not read an Extension called Authorized. The bad part is that it is not caught at compile time (it would not be possible/feasible I believe), the good part is that we caught the issue in testing and that it was not one of our users that found it in production!
The way Extensions work in Axum is that you can add data to the system based on its type information, and that is then retrieved when we use an Extractor (Extension is an Extractor).
But in order for us to create and add the Authorized type into the system, we need a AuthorizationMiddleware that can extract the JWT from the authorization header.
// Update our import statement for axum at the top of the file to this.
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
routing::get,
Extension, Router,
};
pub struct AuthorizationMiddleware;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
unimplemented!("Add logic")
}
}
It might look scary, but this is basically the scaffolding that you can find in the documentation for extractors.
Next, let's start to add our logic. First, we need to extract the JWT token from
the Authorization HTTP header. JWT tokens should be sent as bearer tokens, so
the format of the header is authorization: bearer <token>
which is a bit
messy to parse out if you like me subscribe to the no-regex allowed magazine.
Luckily for us, there is something for us in the axum-extra library, namely TypedHeader. So let's add that library and enable the feature.
cargo add -F typed-header axum-extra
Now let's add the functionality for extracting the bearer token.
// Add to the top of the file
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};
pub struct AuthorizationMiddleware;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
else {
eprintln!("Could not get Authorization header from the request");
return Err(StatusCode::UNAUTHORIZED);
};
unimplemented!("Add logic")
}
}
We use a pretty new feature here called let else which allows us to return an error if we are not able to extract the header properly.
Then we need to check the integrity of the JWT and add it to the axum system.
// Add to the top of the file
use axum::http::Method;
#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if parts.method == Method::OPTIONS {
// For options requests browsers will not send the authorization header.
return Ok(Self);
}
let Ok(TypedHeader(Authorization(bearer))) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
else {
eprintln!("Could not get Authorization header from the request");
return Err(StatusCode::UNAUTHORIZED);
};
match check_auth(bearer) {
Ok(auth) => {
// Set `auth` as a request extension so it can be accessed by other
// services down the stack.
parts.extensions.insert(auth);
Ok(Self)
}
Err(error) => {
eprintln!("{error:?}");
Err(StatusCode::UNAUTHORIZED)
}
}
}
}
fn check_auth(bearer: Bearer) -> Result<Authorized, String> {
unimplemented!("TODO");
}
It looks daunting but now that we have broken it down most of it should be understandable. We also have support for the browsers Pre-flight requests by allowing options to pass unauthenticated.
Validating the JWT
Next up is the check_auth
method.
We get a Bearer as input and expect it to authorize us or return an
error so let's start with that. Remember that we defined Claims
above and that
Authorized holds the Claims
structure if we are successful.
fn check_auth(bearer: Bearer) -> Result<Authorized, String> {
Ok(Authorized(claims))
}
So how do we get the claims out of the token? Basically, we need to decode the token inside the bearer using the corresponding JWK (Json Webtoken Key). Each JWT holds a reference to which key needs to be used to decrypt it, so if we only had the keys stored we could get the correct one and decode the token.
fn check_auth(bearer: Bearer, jwks: &Jwks) -> Result<Authorized, String> {
if let Some(jwk) = find_jwk(bearer.token(), &jwks.0) {
let claims = jwt_decode(bearer.token(), jwk)?;
Ok(Authorized(claims))
} else {
Err("JWK not found".to_string())
}
}
So how do we get the JWK storage? Why, let's store it in the Axum server as an
extension. So add the following snippet to the from_request_parts
implementation to extract it (don't worry, we will set this all up soon). And
don't forget to pass it to the check_auth
function.
let Some(jwks) = parts.extensions.get::<Jwks>()
else {
eprintln!("Could not find the JWK layer, did you forget to add it?");
return Err(StatusCode::UNAUTHORIZED);
};
match check_auth(bearer, &jwks) {
...
Now we have a few more pieces to implement. We need to
- Define the Jwks storage
- Implement
find_jwk
- Implement
jwt_decode
So let's start with the Jwks storage. For simplicity, we only store them in a list for now. We also define the Jwk structure itself, which is defined in RFC7517.
cargo add jsonwebtoken
// Add to the top of the file
use jsonwebtoken::Algorithm;
const ALGORITHM: Algorithm = jsonwebtoken::Algorithm::RS256;
#[derive(Clone, Debug)]
pub struct Jwks(pub Vec<Jwk>);
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct Jwk {
pub alg: Algorithm, // The encryption algorithm used to encrypt the token.
pub e: String, // The exponent value for the RSA public key.
pub kid: String, // Key ID
pub kty: String, // Key type
pub n: String, // The modulus value for the RSA public key.
pub r#use: String, // The intended use for the public key.
}
So let's tackle find_jwk
now. We are given the token and the list of keys
that we know of, and we need to get the key id from the token, look through our
list and return the key that matches. If we don't find a key we will return a
None value. We are able to utilize the jsonwebtoken
library to decode the
header of the token without the key as it is not encrypted (as compared to the
body of the key).
pub fn find_jwk<'a>(token: &'_ str, jwks: &'a [Jwk]) -> Option<&'a Jwk> {
let headers = jsonwebtoken::decode_header(token).unwrap();
jwks.iter().find(|jwk| {
if let Some(kid) = &headers.kid {
&jwk.kid == kid
} else {
false
}
})
}
For the jwt_decode
function, we have the token and the correct key now.
// Update our jsonwebtoken use statement to this
use jsonwebtoken::{Algorithm, Validation, DecodingKey};
pub fn jwt_decode(token: &str, jwk: &Jwk) -> Result<Claims, String> {
let validation = Validation::new(ALGORITHM);
let decode_key = &DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
.expect("To be able to build the DecodingKey");
let decoded = jsonwebtoken::decode::<Claims>(
token,
decode_key,
&validation,
).expect("To be able to decode the claims");
Ok(decoded.claims)
}
So let's walk this one through. We create a Validation structure with the defined algorithm. We then construct a DecodeKey from values found inside the Jwk (I won't lie, the exact math behind all of this is beyond my understanding). Using this DecodeKey we are able to decode the token and finally get to the claims.
Phew! This has been quite a journey but now cargo check
is green, let's
see what our tests says.
❯ cargo test
...
failures:
---- test::hello_world_test stdout ----
thread 'test::hello_world_test' panicked at src/main.rs:188:9:
assertion `left == right` failed
left: "Missing request extension: Extension of type `secure_axum::Authorized`
was not found. Perhaps you forgot to add it? See `axum::Extension`."
right: "Hello world"
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Nope, Axum does not know where to find the Authorized type. Remember that we
added an error string for just this case. Axum is not able to detect that we did
not register the AuthorizationMiddleware
at compile time, so we got an error at
runtime. Luckily, we are testing our code with automated tests to be able to
catch things like this.
Register our middleware
Now, our tests currently fails because we have not added our middleware to the Axum system.
So add a new function setup_routes_and_middlewares
that does the route setup
and registers our middleware.
fn setup_routes_and_middlewares(router: Router) -> Router {
router
.route("/", get(read))
.route_layer(from_extractor::<AuthorizationMiddleware>())
}
Then use it where we create our router in both the main
function and the
TestApp
like shown below.
let router = Router::new();
let router = setup_routes_and_middlewares(router);
Now it is time to modify our test to check for an UNAUTHORIZED status code when we try to call it.
#[tokio::test]
async fn hello_world_unauthorized() -> anyhow::Result<()> {
let test_app = TestApp::new().await;
let response = test_app
.client
.get(test_app.url("/"))
.send()
.await?;
let status = response.status();
assert_eq!(status, reqwest::StatusCode::UNAUTHORIZED);
Ok(())
}
Our tests are now passing, so let's add one where we expect the user to be authorized and get a response back.
So for a user to be authorized we need to send a valid token in the authorization header, let's start there.
#[tokio::test]
async fn hello_world_authorized() -> anyhow::Result<()> {
let test_app = TestApp::new().await;
let token = "...";
let response = test_app
.client
.get(test_app.url("/"))
.header("authorization", format!("Bearer {token}").as_str())
.send()
.await?;
let status = response.status();
assert_eq!(status, reqwest::StatusCode::OK);
let response_text = response
.text()
.await?;
assert_eq!(response_text, "Hello world");
Ok(())
}
So that is the skeleton of our new test, but how do we generate the token?
Again, the jsonwebtoken
library will do this for us.
So add a function called jwt_encode
to our test module (I put it just below
the TestApp
) with the following content, and we will go through what it does
afterward.
// At the top of the test module
use chrono::Duration;
const KID: &str = "donaldduck";
pub fn jwt_encode(
email: &str,
private_key: &[u8],
) -> anyhow::Result<String> {
let exp = Utc::now() + Duration::weeks(52);
let claims = Claims {
email: email.to_string(),
exp,
};
let mut header = jsonwebtoken::Header::new(ALGORITHM);
header.kid = Some(KID.to_string());
Ok(jsonwebtoken::encode(
&header,
&claims,
&EncodingKey::from_rsa_pem(private_key)?,
)?)
}
So given an email and a private key, we create a new Claims
structure to be
encoded. We have also defined a key id (KID) that we will use when registering
the public part of the key with the webserver. Then we set up a EncodingKey
and encode the token
So before continuing, let's generate a key pair to use in our testing setup. Enter the following commands in a terminal and simply reply yes to all questions. We especially don't want to set a passphrase, as we will only be using it in our testing.
ssh-keygen -t rsa -b 4096 -m PEM -f test.key
openssl rsa -in test.key -pubout -outform PEM -out test.key.pub
You should now have two new files test.key
(private part) and test.key.pub
(public part). Copy them into our test module inside src/main.rs
and place
them in the bottom of the file so that it follows the pattern shown below.
const KEY_PUB: &str = r#"-----BEGIN PUBLIC KEY-----
-----END PUBLIC KEY-----"#;
const KEY_PRIV: &str = r#"
-----BEGIN RSA PRIVATE KEY-----
...
-----END RSA PRIVATE KEY-----"#;
I want to stress that this is not how you should do it in the production part of the code, this setup is purely for the tests. Proper key management is needed in production.
Now that we have our RSA keys, we can finally generate the token for our
hello_world_authorized
test. Update the token generation to the following
line.
let token = jwt_encode("test@example.com", KEY_PRIV.as_bytes())?;
Now we are almost ready to test it out, but first we must set up our JWK storage and register it with the tests.
So we have a public RSA key and need to generate a valid JWK that matches it. So let's add a new helper function that does that and place it inside our test module, as we are not using it outside our tests.
cargo add --dev base64
cargo add --dev openssl
// Add to the top of the test module
use base64::{Engine, engine::general_purpose};
use jsonwebtoken::EncodingKey;
use openssl::rsa::Rsa;
pub fn get_jwk(pub_key: &[u8]) -> anyhow::Result<Jwk> {
let rsa = Rsa::public_key_from_pem(pub_key)?;
Ok(Jwk {
alg: ALGORITHM,
kid: KID.to_string(),
kty: "RSA".to_string(),
r#use: "sig".to_string(),
n: general_purpose::URL_SAFE_NO_PAD.encode(rsa.n().to_vec()),
e: general_purpose::URL_SAFE_NO_PAD.encode(rsa.e().to_vec()),
})
}
So it takes the public key as input as a byte array. Creates an RSA structure from it and uses that to define a Jwk struct with all needed values.
Now let's use it to generate our JWK storage.
In TestApp::new
change the setup_routes_and_middlewares
to accept the new
key store
// In Prod the JWKS should be fetched from service_auth
let jwks = Jwks(vec![get_jwk(KEY_PUB.as_bytes())?]);
let router = Router::new();
let router = setup_routes_and_middlewares(router, jwks);
We now also need to generate it and supply it in main
.
// TODO: In production you need to populate this with the Jwk's you are using
let jwks = Jwks(vec![]);
let router = Router::new();
let router = setup_routes_and_middlewares(router,jwks);
And update the setup_routes_and_middlewares
function to actually register the
storage.
fn setup_routes_and_middlewares(router: Router, jwks: Jwks) -> Router {
router
.route("/", get(read))
.route_layer(from_extractor::<AuthorizationMiddleware>())
.layer(Extension(jwks))
}
If you test it now, we have one final thing to do. Update our
hello_world_authorized
test to validate that we actually get our email from
inside the token back in the response!
assert_eq!(response_text, "Hello world test@example.com");
Now our tests passes!
Future work
While we got something working (and that is the most important part). There are some things that could be improved.
-
We could rewrite the
from_request_parts
using the from_fn approach, which should simplify things. This code was written using Axum 0.5 and has then been migrated to Axum 0.7 so that is why I have not changed it. -
For this blog posts sake everything is in one file, but of course it should be broken down and moved to separate modules if anyone were to pick it up and use it.
-
There is no Jwk handling in the production part of the code. That needs to be added if anyone was going to use it.
-
Logging should be added, we use the tracing and it is great. Really recommend it.
Final thoughts
We have actually successfully implemented a working authorization check using JWT.
So now we have a handler that only cares about the data it needs while still being authenticated. I think it is a pretty nice solution for how to manage this.
We also have a really nice test setup to build from when we add more routes, and there are still a lot of cool features to explore in Axum.
Acknowledgements
I would like to thank Magnus Markling for proofreading this article.