Making a secure Axum route

How to build a secure Axum route using JWT with tests

Fredrik Park published on
20 min, 3925 words

Categories: Axum

In this blog series, I will demonstrate how to build an Axum route that is secured by a JWT token. I will also show how to write a black box test for the route to verify that it works.

Prerequisites

You need to have Rust installed, if you don't please install it using rustup.

Versions

In case you are a time traveler living in the future, this post was written using Rust 1.74.0 and Axum 0.7.1 if things have moved a lot since the time of writing it you might need to do some manual tweaks to things.

Create base project

Create a new cargo project

cargo init secure-axum && cd secure-axum

Next, let's add axum and tokio to the project

cargo add axum
cargo add -F full tokio

Now let's add a route to src/main.rs

use axum::{routing::get, Router};

#[tokio::main]
async fn main() {
    let router = Router::new().route("/", get(read));

    // run it
    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();
    println!("listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, router).await.unwrap();
}

async fn read() -> &'static str {
    "Hello world"
}

Our web server listens to 127.0.0.1:3000 and even tells us that when we run it. It will simply return the string "Hello world" if we try it out. So let's do that!

cargo run

Open a new terminal and send a request to the service.

>curl 127.0.0.1:3000
Hello world

Test harness

Great it works! But let's improve our development experience. We don't want to fiddle around with manually testing that what we do actually works, as it will become really tedious once we introduce the authorization layer to the route.

So let's set up a test harness that will start our service and do the request for us.

💡 See my Rust tip on testing using error handling if you are wondering why we add anyhow here.

cargo add --dev anyhow

Now let's add the tests as a test module at the bottom of src/main.rs.

#[cfg(test)]
mod test {
    #[tokio::test]
    async fn hello_world_test() -> anyhow::Result<()> {
        unimplemented!("We need to add the logic to the test here");

        Ok(())
    }
}

Let's see if we can try to implement our test now. We want to verify that the body returned from the request is "Hello world". Start by adding reqwest as we will use that when making HTTP calls.

cargo add --dev reqwest
    #[tokio::test]
    async fn hello_world_test() -> anyhow::Result<()> {
        let client = reqwest::Client::new();

        let url = ""; // TODO What is our url?

        let response = client.get(url).send().await?.text().await?;
        assert_eq!(response, "Hello world");

        Ok(())
    }
}

But we don't have our service running, so let's start it. We do this by creating a router with the same setup as we have in our service and then starting it using Axum.

mod test {
    use super::*;
    use std::net::SocketAddr;
    use tokio::net::TcpListener;

    #[tokio::test]
    async fn hello_world_test() -> anyhow::Result<()> {
        let router = Router::new().route("/", get(read));

        let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>()?)
            .await?;
        let addr = listener.local_addr()?;

        tokio::spawn(async move {
            axum::serve(listener, router).await?;
        });

        let client = reqwest::Client::new();

        let url = format!("http://{addr}");

        let response = client.get(url).send().await?.text().await?;
        assert_eq!(response, "Hello world");

        Ok(())
    }

If we run our test now, it actually passes. It is important to note that we are actually testing our public API using a proper HTTP request with no knowledge of what happens inside the router. We are testing that a certain input gives us a certain output, a proper black box test.

But there are some things that are a bit messy with that test. We are doing a lot of setup that is not related to the test itself. Let's take care of that. One thing about Rust's test system is that it does not have setup/teardown functionality like you might be used to in other languages. We are still able to get around this issue, levering Rusts own structures.

TestApp

So let's create a TestApp inside the test module (just above the hello_world_test). We will have it remembering the address of the test server and also keep a copy of the reqwest client for re-use.

#[cfg(test)]
mod test {
    use super::*;
    use std::net::SocketAddr;
    use tokio::net::TcpListener;

    #[derive(Clone)]
    pub struct TestApp {
        pub client: reqwest::Client,
        addr: String,
    }

    #[tokio::test]
    async fn hello_world_test() -> anyhow::Result<()> {

    [...]

Next let's implement the new method on it and move the creation of the test server and the reqwest client inside it. We will also add an url helper method to construct the url for a certain path to make life easier.

#[cfg(test)]
mod test {
    use super::*;
    use std::net::SocketAddr;
    use tokio::net::TcpListener;
    use reqwest::Url;

    #[derive(Clone)]
    pub struct TestApp {
        pub client: reqwest::Client,
        addr: String,
    }

    impl TestApp {
        pub async fn new() -> anyhow::result<Self> {
            let router = Router::new().route("/", get(read));

            let listener = TcpListener::bind("0.0.0.0:0".parse::<SocketAddr>()?)
                .await?;
            let addr = listener.local_addr()?;

            tokio::spawn(async move {
                axum::serve(listener, router).await?;
            });

            let client = reqwest::Client::new();

            Ok(Self {
                client,
                addr: addr.to_string(),
            })
        }

        pub fn url<S: AsRef<str>>(&self, path: S) -> anyhow::Result<String>>>>>>>> {
            let base = Url::parse(format!("http://{}", self.addr).as_ref())?;
            let url = base.join(path.as_ref())?;
            Ok(url.as_str().to_string())
        }
    }

    #[tokio::test]
    async fn hello_world_test() -> anyhow::Result<()> {

    [...]

Lastly, we update the test to use the TestApp.

    #[tokio::test]
    async fn hello_world_test() -> anyhow::Result<()> {
        let test_app = TestApp::new().await;

        let response = test_app
            .client
            .get(test_app.url("/"))
            .send()
            .await?
            .text()
            .await?;
        assert_eq!(response, "Hello world");

        Ok(())
    }

Now I would say that is a pretty neat test. If you don't fancy the send().await? chain we have to do, you can easily hide that in a function called get_root or similar.

Securing the route

Now with all of that foundational work done, we can get on with the real meat of the post. Making our end-point secure using Axums handy middleware system.

So our goal is to have a way to take a JWT token as input, extract and validate it and only then be allowed to see our super secret greeting!

First, let's update the handler to this.

// Update the use statement to include the Extension
use axum::{routing::get, Router, Extension};

// Update the handler to look like this.
async fn read(Extension(claims): Extension<Authorized>) -> String {
    let email = claims.0.email;
    format!("Hello world {email}")
}

Well, that was not too bad, was it? We can see that we have an Extension that holds an Authorized type (which we will be implementing soon) and out of that we can get to the claims (data contained inside a JWT is referred to as claims) containing an email. Basically, the Extension is a way to share data between different routes in Axum.

Let's create the Authorized structure above the main function.

#[derive(Debug, Clone)]
pub struct Authorized(pub Claims);

Basically it only holds the Claims structure, and we derive the Debug and Clone traits on it. Claims next!

cargo add -F serde chrono
cargo add -F derive serde
// At the top of the file add these lines
use serde::{Serialize, Deserialize};
use chrono::{serde::ts_seconds, DateTime, Utc};

// Add next to Authorized
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
    pub email: String,
    #[serde(with = "ts_seconds")]
    pub exp: DateTime<Utc>,
}

A real claims structure holds more information, but for this exercise we only care about the email and the exp (expiry) field is required.

So now we have resolved all errors from that we can see when we do cargo check, so let's see if our test passes?

 cargo test
   Compiling secure-axum v0.1.0 (/home/sedrik/devarea/secure-axum)
    Finished test [unoptimized + debuginfo] target(s) in 1.89s
     Running unittests src/main.rs (target/debug/deps/secure_axum-7f6e6c701d568cfa)

running 1 test
test test::hello_world_test ... FAILED

failures:

---- test::hello_world_test stdout ----
thread 'test::hello_world_test' panicked at src/main.rs:81:9:
assertion `left == right` failed
  left: "Missing request extension: Extension of type `secure_axum::Authorized`
            was not found. Perhaps you forgot to add it? See `axum::Extension`."
 right: "Hello world"
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace


failures:
    test::hello_world_test

test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.02s

So what this is saying is that it can not read an Extension called Authorized. The bad part is that it is not caught at compile time (it would not be possible/feasible I believe), the good part is that we caught the issue in testing and that it was not one of our users that found it in production!

The way Extensions work in Axum is that you can add data to the system based on its type information, and that is then retrieved when we use an Extractor (Extension is an Extractor).

But in order for us to create and add the Authorized type into the system, we need a AuthorizationMiddleware that can extract the JWT from the authorization header.

// Update our import statement for axum at the top of the file to this.
use axum::{
    async_trait,
    extract::FromRequestParts,
    http::{request::Parts, StatusCode},
    routing::get,
    Extension, Router,
};

pub struct AuthorizationMiddleware;

#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
      unimplemented!("Add logic")

    }
}

It might look scary, but this is basically the scaffolding that you can find in the documentation for extractors.

Next, let's start to add our logic. First, we need to extract the JWT token from the Authorization HTTP header. JWT tokens should be sent as bearer tokens, so the format of the header is authorization: bearer <token> which is a bit messy to parse out if you like me subscribe to the no-regex allowed magazine.

Luckily for us, there is something for us in the axum-extra library, namely TypedHeader. So let's add that library and enable the feature.

cargo add -F typed-header axum-extra

Now let's add the functionality for extracting the bearer token.

// Add to the top of the file
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};

pub struct AuthorizationMiddleware;

#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {

        let Ok(TypedHeader(Authorization(bearer))) =
            TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
        else {
            eprintln!("Could not get Authorization header from the request");
            return Err(StatusCode::UNAUTHORIZED);
        };

      unimplemented!("Add logic")
    }
}

We use a pretty new feature here called let else which allows us to return an error if we are not able to extract the header properly.

Then we need to check the integrity of the JWT and add it to the axum system.

// Add to the top of the file
use axum::http::Method;

#[async_trait]
impl<S> FromRequestParts<S> for AuthorizationMiddleware
where
    S: Send + Sync,
{
    type Rejection = StatusCode;

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        if parts.method == Method::OPTIONS {
            // For options requests browsers will not send the authorization header.
            return Ok(Self);
        }

        let Ok(TypedHeader(Authorization(bearer))) =
            TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await
        else {
            eprintln!("Could not get Authorization header from the request");
            return Err(StatusCode::UNAUTHORIZED);
        };

        match check_auth(bearer) {
            Ok(auth) => {
                // Set `auth` as a request extension so it can be accessed by other
                // services down the stack.
                parts.extensions.insert(auth);

                Ok(Self)
            }
            Err(error) => {
                eprintln!("{error:?}");
                Err(StatusCode::UNAUTHORIZED)
            }
        }
    }
}

fn check_auth(bearer: Bearer) -> Result<Authorized, String> {
    unimplemented!("TODO");
}

It looks daunting but now that we have broken it down most of it should be understandable. We also have support for the browsers Pre-flight requests by allowing options to pass unauthenticated.

Validating the JWT

Next up is the check_auth method.

We get a Bearer as input and expect it to authorize us or return an error so let's start with that. Remember that we defined Claims above and that Authorized holds the Claims structure if we are successful.

fn check_auth(bearer: Bearer) -> Result<Authorized, String> {
    Ok(Authorized(claims))
}

So how do we get the claims out of the token? Basically, we need to decode the token inside the bearer using the corresponding JWK (Json Webtoken Key). Each JWT holds a reference to which key needs to be used to decrypt it, so if we only had the keys stored we could get the correct one and decode the token.

fn check_auth(bearer: Bearer, jwks: &Jwks) -> Result<Authorized, String> {
    if let Some(jwk) = find_jwk(bearer.token(), &jwks.0) {
        let claims = jwt_decode(bearer.token(), jwk)?;
        Ok(Authorized(claims))
    } else {
        Err("JWK not found".to_string())
    }
}

So how do we get the JWK storage? Why, let's store it in the Axum server as an extension. So add the following snippet to the from_request_parts implementation to extract it (don't worry, we will set this all up soon). And don't forget to pass it to the check_auth function.

        let Some(jwks) = parts.extensions.get::<Jwks>()
        else {
            eprintln!("Could not find the JWK layer, did you forget to add it?");
            return Err(StatusCode::UNAUTHORIZED);
        };

        match check_auth(bearer, &jwks) {
            ...

Now we have a few more pieces to implement. We need to

  • Define the Jwks storage
  • Implement find_jwk
  • Implement jwt_decode

So let's start with the Jwks storage. For simplicity, we only store them in a list for now. We also define the Jwk structure itself, which is defined in RFC7517.

cargo add jsonwebtoken
// Add to the top of the file
use jsonwebtoken::Algorithm;

const ALGORITHM: Algorithm = jsonwebtoken::Algorithm::RS256;

#[derive(Clone, Debug)]
pub struct Jwks(pub Vec<Jwk>);

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct Jwk {
    pub alg: Algorithm, // The encryption algorithm used to encrypt the token.
    pub e: String,      // The exponent value for the RSA public key.
    pub kid: String,    // Key ID
    pub kty: String,    // Key type
    pub n: String,      // The modulus value for the RSA public key.
    pub r#use: String,  // The intended use for the public key.
}

So let's tackle find_jwk now. We are given the token and the list of keys that we know of, and we need to get the key id from the token, look through our list and return the key that matches. If we don't find a key we will return a None value. We are able to utilize the jsonwebtoken library to decode the header of the token without the key as it is not encrypted (as compared to the body of the key).

pub fn find_jwk<'a>(token: &'_ str, jwks: &'a [Jwk]) -> Option<&'a Jwk> {
    let headers = jsonwebtoken::decode_header(token).unwrap();
    jwks.iter().find(|jwk| {
        if let Some(kid) = &headers.kid {
            &jwk.kid == kid
        } else {
            false
        }
    })
}

For the jwt_decode function, we have the token and the correct key now.

// Update our jsonwebtoken use statement to this
use jsonwebtoken::{Algorithm, Validation, DecodingKey};

pub fn jwt_decode(token: &str, jwk: &Jwk) -> Result<Claims, String> {
    let validation = Validation::new(ALGORITHM);

    let decode_key = &DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
        .expect("To be able to build the DecodingKey");
    let decoded = jsonwebtoken::decode::<Claims>(
        token,
        decode_key,
        &validation,
    ).expect("To be able to decode the claims");

    Ok(decoded.claims)
}

So let's walk this one through. We create a Validation structure with the defined algorithm. We then construct a DecodeKey from values found inside the Jwk (I won't lie, the exact math behind all of this is beyond my understanding). Using this DecodeKey we are able to decode the token and finally get to the claims.

Phew! This has been quite a journey but now cargo check is green, let's see what our tests says.

 cargo test
   ...

failures:

---- test::hello_world_test stdout ----
thread 'test::hello_world_test' panicked at src/main.rs:188:9:
assertion `left == right` failed
  left: "Missing request extension: Extension of type `secure_axum::Authorized`
           was not found. Perhaps you forgot to add it? See `axum::Extension`."
 right: "Hello world"
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Nope, Axum does not know where to find the Authorized type. Remember that we added an error string for just this case. Axum is not able to detect that we did not register the AuthorizationMiddleware at compile time, so we got an error at runtime. Luckily, we are testing our code with automated tests to be able to catch things like this.

Register our middleware

Now, our tests currently fails because we have not added our middleware to the Axum system.

So add a new function setup_routes_and_middlewares that does the route setup and registers our middleware.

fn setup_routes_and_middlewares(router: Router) -> Router {
    router
        .route("/", get(read))
        .route_layer(from_extractor::<AuthorizationMiddleware>())
}

Then use it where we create our router in both the main function and the TestApp like shown below.

    let router = Router::new();
    let router = setup_routes_and_middlewares(router);

Now it is time to modify our test to check for an UNAUTHORIZED status code when we try to call it.

    #[tokio::test]
    async fn hello_world_unauthorized() -> anyhow::Result<()> {
        let test_app = TestApp::new().await;

        let response = test_app
            .client
            .get(test_app.url("/"))
            .send()
            .await?;
        let status = response.status();
        assert_eq!(status, reqwest::StatusCode::UNAUTHORIZED);

        Ok(())
    }

Our tests are now passing, so let's add one where we expect the user to be authorized and get a response back.

So for a user to be authorized we need to send a valid token in the authorization header, let's start there.

    #[tokio::test]
    async fn hello_world_authorized() -> anyhow::Result<()> {
        let test_app = TestApp::new().await;

        let token = "...";

        let response = test_app
            .client
            .get(test_app.url("/"))
            .header("authorization", format!("Bearer {token}").as_str())
            .send()
            .await?;
        let status = response.status();
        assert_eq!(status, reqwest::StatusCode::OK);

        let response_text = response
            .text()
            .await?;
        assert_eq!(response_text, "Hello world");

        Ok(())
    }

So that is the skeleton of our new test, but how do we generate the token? Again, the jsonwebtoken library will do this for us.

So add a function called jwt_encode to our test module (I put it just below the TestApp) with the following content, and we will go through what it does afterward.

// At the top of the test module
use chrono::Duration;
const KID: &str = "donaldduck";

pub fn jwt_encode(
    email: &str,
    private_key: &[u8],
) -> anyhow::Result<String> {
    let exp = Utc::now() + Duration::weeks(52);
    let claims = Claims {
        email: email.to_string(),
        exp,
    };

    let mut header = jsonwebtoken::Header::new(ALGORITHM);
    header.kid = Some(KID.to_string());
    Ok(jsonwebtoken::encode(
        &header,
        &claims,
        &EncodingKey::from_rsa_pem(private_key)?,
    )?)
}

So given an email and a private key, we create a new Claims structure to be encoded. We have also defined a key id (KID) that we will use when registering the public part of the key with the webserver. Then we set up a EncodingKey and encode the token

So before continuing, let's generate a key pair to use in our testing setup. Enter the following commands in a terminal and simply reply yes to all questions. We especially don't want to set a passphrase, as we will only be using it in our testing.

ssh-keygen -t rsa -b 4096 -m PEM -f test.key
openssl rsa -in test.key -pubout -outform PEM -out test.key.pub

You should now have two new files test.key (private part) and test.key.pub (public part). Copy them into our test module inside src/main.rs and place them in the bottom of the file so that it follows the pattern shown below.

const KEY_PUB: &str = r#"-----BEGIN PUBLIC KEY-----

-----END PUBLIC KEY-----"#;

const KEY_PRIV: &str = r#"
-----BEGIN RSA PRIVATE KEY-----
...
-----END RSA PRIVATE KEY-----"#;

I want to stress that this is not how you should do it in the production part of the code, this setup is purely for the tests. Proper key management is needed in production.

Now that we have our RSA keys, we can finally generate the token for our hello_world_authorized test. Update the token generation to the following line.

    let token = jwt_encode("test@example.com", KEY_PRIV.as_bytes())?;

Now we are almost ready to test it out, but first we must set up our JWK storage and register it with the tests.

So we have a public RSA key and need to generate a valid JWK that matches it. So let's add a new helper function that does that and place it inside our test module, as we are not using it outside our tests.

cargo add --dev base64
cargo add --dev openssl
// Add to the top of the test module
use base64::{Engine, engine::general_purpose};
use jsonwebtoken::EncodingKey;
use openssl::rsa::Rsa;

pub fn get_jwk(pub_key: &[u8]) -> anyhow::Result<Jwk> {
    let rsa = Rsa::public_key_from_pem(pub_key)?;
    Ok(Jwk {
        alg: ALGORITHM,
        kid: KID.to_string(),
        kty: "RSA".to_string(),
        r#use: "sig".to_string(),
        n: general_purpose::URL_SAFE_NO_PAD.encode(rsa.n().to_vec()),
        e: general_purpose::URL_SAFE_NO_PAD.encode(rsa.e().to_vec()),
    })
}

So it takes the public key as input as a byte array. Creates an RSA structure from it and uses that to define a Jwk struct with all needed values.

Now let's use it to generate our JWK storage.

In TestApp::new change the setup_routes_and_middlewares to accept the new key store


            // In Prod the JWKS should be fetched from service_auth
            let jwks = Jwks(vec![get_jwk(KEY_PUB.as_bytes())?]);

            let router = Router::new();
            let router = setup_routes_and_middlewares(router, jwks);

We now also need to generate it and supply it in main.

    // TODO: In production you need to populate this with the Jwk's you are using
    let jwks = Jwks(vec![]);

    let router = Router::new();
    let router = setup_routes_and_middlewares(router,jwks);

And update the setup_routes_and_middlewares function to actually register the storage.

fn setup_routes_and_middlewares(router: Router, jwks: Jwks) -> Router {
    router
        .route("/", get(read))
        .route_layer(from_extractor::<AuthorizationMiddleware>())
        .layer(Extension(jwks))
}

If you test it now, we have one final thing to do. Update our hello_world_authorized test to validate that we actually get our email from inside the token back in the response!

    assert_eq!(response_text, "Hello world test@example.com");

Now our tests passes!

Future work

While we got something working (and that is the most important part). There are some things that could be improved.

  • We could rewrite the from_request_parts using the from_fn approach, which should simplify things. This code was written using Axum 0.5 and has then been migrated to Axum 0.7 so that is why I have not changed it.

  • For this blog posts sake everything is in one file, but of course it should be broken down and moved to separate modules if anyone were to pick it up and use it.

  • There is no Jwk handling in the production part of the code. That needs to be added if anyone was going to use it.

  • Logging should be added, we use the tracing and it is great. Really recommend it.

Final thoughts

We have actually successfully implemented a working authorization check using JWT.

So now we have a handler that only cares about the data it needs while still being authenticated. I think it is a pretty nice solution for how to manage this.

We also have a really nice test setup to build from when we add more routes, and there are still a lot of cool features to explore in Axum.

Acknowledgements

I would like to thank Magnus Markling for proofreading this article.