mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2025-09-10 18:55:57 +03:00
Improve JWT key initialization and avoid saving public key (#4085)
This commit is contained in:
59
src/auth.rs
59
src/auth.rs
@@ -2,9 +2,10 @@
|
||||
//
|
||||
use chrono::{Duration, Utc};
|
||||
use num_traits::FromPrimitive;
|
||||
use once_cell::sync::Lazy;
|
||||
use once_cell::sync::{Lazy, OnceCell};
|
||||
|
||||
use jsonwebtoken::{self, errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
|
||||
use openssl::rsa::Rsa;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::ser::Serialize;
|
||||
|
||||
@@ -26,23 +27,45 @@ static JWT_SEND_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|send", CONFIG.do
|
||||
static JWT_ORG_API_KEY_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|api.organization", CONFIG.domain_origin()));
|
||||
static JWT_FILE_DOWNLOAD_ISSUER: Lazy<String> = Lazy::new(|| format!("{}|file_download", CONFIG.domain_origin()));
|
||||
|
||||
static PRIVATE_RSA_KEY: Lazy<EncodingKey> = Lazy::new(|| {
|
||||
let key =
|
||||
std::fs::read(CONFIG.private_rsa_key()).unwrap_or_else(|e| panic!("Error loading private RSA Key. \n{e}"));
|
||||
EncodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding private RSA Key.\n{e}"))
|
||||
});
|
||||
static PUBLIC_RSA_KEY: Lazy<DecodingKey> = Lazy::new(|| {
|
||||
let key = std::fs::read(CONFIG.public_rsa_key()).unwrap_or_else(|e| panic!("Error loading public RSA Key. \n{e}"));
|
||||
DecodingKey::from_rsa_pem(&key).unwrap_or_else(|e| panic!("Error decoding public RSA Key.\n{e}"))
|
||||
});
|
||||
static PRIVATE_RSA_KEY: OnceCell<EncodingKey> = OnceCell::new();
|
||||
static PUBLIC_RSA_KEY: OnceCell<DecodingKey> = OnceCell::new();
|
||||
|
||||
pub fn load_keys() {
|
||||
Lazy::force(&PRIVATE_RSA_KEY);
|
||||
Lazy::force(&PUBLIC_RSA_KEY);
|
||||
pub fn initialize_keys() -> Result<(), crate::error::Error> {
|
||||
let mut priv_key_buffer = Vec::with_capacity(2048);
|
||||
|
||||
let priv_key = {
|
||||
let mut priv_key_file = File::options().create(true).read(true).write(true).open(CONFIG.private_rsa_key())?;
|
||||
|
||||
#[allow(clippy::verbose_file_reads)]
|
||||
let bytes_read = priv_key_file.read_to_end(&mut priv_key_buffer)?;
|
||||
|
||||
if bytes_read > 0 {
|
||||
Rsa::private_key_from_pem(&priv_key_buffer[..bytes_read])?
|
||||
} else {
|
||||
// Only create the key if the file doesn't exist or is empty
|
||||
let rsa_key = openssl::rsa::Rsa::generate(2048)?;
|
||||
priv_key_buffer = rsa_key.private_key_to_pem()?;
|
||||
priv_key_file.write_all(&priv_key_buffer)?;
|
||||
info!("Private key created correctly.");
|
||||
rsa_key
|
||||
}
|
||||
};
|
||||
|
||||
let pub_key_buffer = priv_key.public_key_to_pem()?;
|
||||
|
||||
let enc = EncodingKey::from_rsa_pem(&priv_key_buffer)?;
|
||||
let dec: DecodingKey = DecodingKey::from_rsa_pem(&pub_key_buffer)?;
|
||||
if PRIVATE_RSA_KEY.set(enc).is_err() {
|
||||
err!("PRIVATE_RSA_KEY must only be initialized once")
|
||||
}
|
||||
if PUBLIC_RSA_KEY.set(dec).is_err() {
|
||||
err!("PUBLIC_RSA_KEY must only be initialized once")
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn encode_jwt<T: Serialize>(claims: &T) -> String {
|
||||
match jsonwebtoken::encode(&JWT_HEADER, claims, &PRIVATE_RSA_KEY) {
|
||||
match jsonwebtoken::encode(&JWT_HEADER, claims, PRIVATE_RSA_KEY.wait()) {
|
||||
Ok(token) => token,
|
||||
Err(e) => panic!("Error encoding jwt {e}"),
|
||||
}
|
||||
@@ -56,7 +79,7 @@ fn decode_jwt<T: DeserializeOwned>(token: &str, issuer: String) -> Result<T, Err
|
||||
validation.set_issuer(&[issuer]);
|
||||
|
||||
let token = token.replace(char::is_whitespace, "");
|
||||
match jsonwebtoken::decode(&token, &PUBLIC_RSA_KEY, &validation) {
|
||||
match jsonwebtoken::decode(&token, PUBLIC_RSA_KEY.wait(), &validation) {
|
||||
Ok(d) => Ok(d.claims),
|
||||
Err(err) => match *err.kind() {
|
||||
ErrorKind::InvalidToken => err!("Token is invalid"),
|
||||
@@ -799,7 +822,11 @@ impl<'r> FromRequest<'r> for OwnerHeaders {
|
||||
//
|
||||
// Client IP address detection
|
||||
//
|
||||
use std::net::IpAddr;
|
||||
use std::{
|
||||
fs::File,
|
||||
io::{Read, Write},
|
||||
net::IpAddr,
|
||||
};
|
||||
|
||||
pub struct ClientIp {
|
||||
pub ip: IpAddr,
|
||||
|
Reference in New Issue
Block a user