Redis cache

This commit is contained in:
Evann Regnault 2024-07-14 14:34:01 +02:00
parent 4754cead4a
commit fcd36ce099
4 changed files with 115 additions and 9 deletions

56
Cargo.lock generated
View file

@ -101,6 +101,16 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"memchr",
]
[[package]] [[package]]
name = "core-foundation" name = "core-foundation"
version = "0.9.4" version = "0.9.4"
@ -809,6 +819,8 @@ dependencies = [
"base64", "base64",
"hmac", "hmac",
"hyper 0.14.30", "hyper 0.14.30",
"redis",
"redis-macros",
"reqwest", "reqwest",
"serde", "serde",
"serde_json", "serde_json",
@ -856,6 +868,44 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "redis"
version = "0.25.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0d7a6955c7511f60f3ba9e86c6d02b3c3f144f8c24b288d1f4e18074ab8bbec"
dependencies = [
"combine",
"itoa",
"percent-encoding",
"ryu",
"sha1_smol",
"socket2",
"url",
]
[[package]]
name = "redis-macros"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8b5407866b6626d251b18c878f043d37f43124680f26a806595a61714ab049a"
dependencies = [
"redis",
"redis-macros-derive",
"serde",
"serde_json",
]
[[package]]
name = "redis-macros-derive"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8dfe1dc77e38e260bbd53e98d3aec64add3cdf5d773e38d344c63660196117f5"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.2" version = "0.5.2"
@ -1080,6 +1130,12 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.2" version = "1.4.2"

View file

@ -7,6 +7,8 @@ edition = "2021"
base64 = "0.22.1" base64 = "0.22.1"
hmac = "0.12.1" hmac = "0.12.1"
hyper = {version="0.14", features=["full"]} hyper = {version="0.14", features=["full"]}
redis = "0.25.4"
redis-macros = "0.3.0"
reqwest = "0.12.5" reqwest = "0.12.5"
serde = {version="1.0.204", features=["derive"]} serde = {version="1.0.204", features=["derive"]}
serde_json = "1.0.120" serde_json = "1.0.120"

View file

@ -8,3 +8,4 @@ This is a proxy to authenticate on CouchDB using Keycloak to provide the roles
- COUCHDB_HOST - COUCHDB_HOST
- COUCHDB_PORT - COUCHDB_PORT
- COUCHDB_SECRET - COUCHDB_SECRET
- REDIS_HOST

View file

@ -1,17 +1,19 @@
use std::{convert::Infallible, env, error::Error, net::SocketAddr, str::FromStr}; use std::{convert::Infallible, env, error::Error, net::SocketAddr, str::FromStr};
use redis::Commands;
use base64::prelude::*; use base64::prelude::*;
use hmac::{digest::generic_array::functional::FunctionalSequence, Hmac, Mac}; use hmac::{digest::generic_array::functional::FunctionalSequence, Hmac, Mac};
use hyper::{ use hyper::{
service::{make_service_fn, service_fn}, Body, Client, HeaderMap, Method, Request, Response, Server, Uri service::{make_service_fn, service_fn}, Body, Client, HeaderMap, Method, Request, Response, Server, Uri
}; };
use serde::Deserialize; use redis_macros::{FromRedisValue, ToRedisArgs};
use serde::{Deserialize, Serialize};
use sha1::Sha1; use sha1::Sha1;
// STRUCTS // STRUCTS
/// USER /// USER
#[derive(ToRedisArgs, FromRedisValue, Serialize, Deserialize)]
struct User { struct User {
name: String, name: String,
roles: Vec<String>, roles: Vec<String>,
@ -46,6 +48,7 @@ pub struct Couchdb {
// IMPLEMENTATIONS // IMPLEMENTATIONS
impl User { impl User {
/// Generates the x-auth-token for the proxy auth
fn couchdb_token(&self) -> String { fn couchdb_token(&self) -> String {
let hmac_secret = env::var("COUCHDB_SECRET").unwrap(); let hmac_secret = env::var("COUCHDB_SECRET").unwrap();
let mut hmac: Hmac<Sha1> = let mut hmac: Hmac<Sha1> =
@ -56,6 +59,7 @@ impl User {
.fold(String::new(), |acc, b| format!("{}{:02x}", acc, b)) .fold(String::new(), |acc, b| format!("{}{:02x}", acc, b))
} }
/// Sets the required headers for the proxy auth
pub fn set_headers(&self, headers: &mut HeaderMap) { pub fn set_headers(&self, headers: &mut HeaderMap) {
headers.insert("X-Auth-CouchDB-UserName", self.name.parse().unwrap()); headers.insert("X-Auth-CouchDB-UserName", self.name.parse().unwrap());
headers.insert( headers.insert(
@ -86,6 +90,7 @@ impl KeycloakToken {
} }
impl DecodedJWT { impl DecodedJWT {
/// Creates the user struct from the JWT Token
fn get_user(self, name: String) -> User { fn get_user(self, name: String) -> User {
let roles = match self.resource_access { let roles = match self.resource_access {
None => vec![], None => vec![],
@ -106,13 +111,47 @@ impl DecodedJWT {
// FUNCTIONS // FUNCTIONS
fn get_redis_connection() -> Result<redis::Connection, redis::RedisError> {
let client = redis::Client::open(format!("redis://{}/", env::var("REDIS_HOST").unwrap())).unwrap();
client.get_connection()
}
fn create_user_key(username: String, password: String) -> String {
let mut hasher : Hmac<Sha1> = Mac::new_from_slice(username.as_bytes()).unwrap();
hasher.update(password.as_bytes());
hasher.finalize().into_bytes().fold(String::new(), |acc, b| format!("{}{:02x}", acc, b))
}
fn get_user_from_cache(username: String, password: String) -> Option<User> {
let connection = get_redis_connection();
if connection.is_err() {
return None
}
let mut connection = connection.unwrap();
let key = create_user_key(username, password);
match connection.get::<String, User>(key) {
Ok(user) => Some(user),
Err(_) => None
}
}
fn set_user_to_cache(password: String, user: User) {
let connection = get_redis_connection();
if connection.is_err() {
return;
}
let mut connection = connection.unwrap();
let key = create_user_key(user.name.clone(), password);
let _ : () = connection.set_ex(key, user, 60*5).expect("Cannot set user");
}
async fn authenticate_keycloak(username: String, password: String) -> Result<User, (u16, String)> { async fn authenticate_keycloak(username: String, password: String) -> Result<User, (u16, String)> {
let client_id = env::var("KEYCLOAK_CLIENT_ID").unwrap(); let client_id = env::var("KEYCLOAK_CLIENT_ID").unwrap();
let client_secret = env::var("KEYCLOAK_CLIENT_SECRET").unwrap(); let client_secret = env::var("KEYCLOAK_CLIENT_SECRET").unwrap();
// Authenticate on keycloak using the password grant type
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/x-www-form-urlencoded".parse().unwrap()); headers.insert("Content-Type", "application/x-www-form-urlencoded".parse().unwrap());
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none()) .redirect(reqwest::redirect::Policy::none())
.build() .build()
@ -121,7 +160,6 @@ async fn authenticate_keycloak(username: String, password: String) -> Result<Use
.headers(headers) .headers(headers)
.body(format!("grant_type=password&client_id={}&client_secret={}&username={}&password={}",client_id, client_secret, username, password)) .body(format!("grant_type=password&client_id={}&client_secret={}&username={}&password={}",client_id, client_secret, username, password))
.send().await; .send().await;
if res.is_err() { if res.is_err() {
return Err((400, "Cannot Connect to Keycloak".to_string())) return Err((400, "Cannot Connect to Keycloak".to_string()))
} }
@ -136,7 +174,7 @@ async fn authenticate_keycloak(username: String, password: String) -> Result<Use
return Err((400, e.to_string())) return Err((400, e.to_string()))
} }
// Decode the JWT Token and get the User struct
let keycloak_token = KeycloakToken::from_json_text(&res_text.unwrap()); let keycloak_token = KeycloakToken::from_json_text(&res_text.unwrap());
if let Err(e) = keycloak_token { if let Err(e) = keycloak_token {
return Err((400, e.to_string())) return Err((400, e.to_string()))
@ -150,6 +188,7 @@ async fn authenticate_keycloak(username: String, password: String) -> Result<Use
Ok(decoded_jwt.unwrap().get_user(username)) Ok(decoded_jwt.unwrap().get_user(username))
} }
/// Retrieves the username and password from a Basic Auth
fn extract_creds_from_request(req: &Request<Body>) -> Option<(String, String)> { fn extract_creds_from_request(req: &Request<Body>) -> Option<(String, String)> {
let auth_value = req.headers().get("authorization")?; let auth_value = req.headers().get("authorization")?;
let b64_auth = auth_value.to_str().unwrap_or("").split(" ").skip(1).next()?; let b64_auth = auth_value.to_str().unwrap_or("").split(" ").skip(1).next()?;
@ -194,14 +233,20 @@ async fn handle(old_req: Request<Body>) -> Result<Response<Body>, hyper::Error>
return Ok(forbidden_response()); return Ok(forbidden_response());
} }
let (username, password) = creds.unwrap(); let (username, password) = creds.unwrap();
let user : Result<User, (u16, String)> = match get_user_from_cache(username.clone(), password.clone()) {
Some(x) => Ok(x),
None => {
authenticate_keycloak(username.clone(), password.clone()).await
}
};
// Keycloak authentication // Keycloak authentication
let user = authenticate_keycloak(username, password).await;
if let Err(e) = user { if let Err(e) = user {
return Ok(error_response(e)) return Ok(error_response(e))
} }
user.unwrap().set_headers(req.headers_mut()); let user = user.unwrap();
user.set_headers(req.headers_mut());
set_user_to_cache(password, user);
// Execute request // Execute request
match client.request(req).await { match client.request(req).await {
Ok(mut res) => { Ok(mut res) => {
@ -219,6 +264,8 @@ async fn main() {
let addr = SocketAddr::from(([0, 0, 0, 0], 8080)); let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
let server = Server::bind(&addr).serve(make_service); let server = Server::bind(&addr).serve(make_service);
println!("Server started !");
if let Err(e) = server.await { if let Err(e) = server.await {
println!("Error: {}", e.message()); println!("Error: {}", e.message());
} }