mirror of
				https://github.com/dani-garcia/vaultwarden.git
				synced 2025-10-31 10:18:19 +02:00 
			
		
		
		
	Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking
This commit is contained in:
		
							
								
								
									
										70
									
								
								src/util.rs
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								src/util.rs
									
									
									
									
									
								
							| @@ -5,10 +5,10 @@ use std::io::Cursor; | ||||
|  | ||||
| use rocket::{ | ||||
|     fairing::{Fairing, Info, Kind}, | ||||
|     http::{ContentType, Header, HeaderMap, Method, RawStr, Status}, | ||||
|     http::{ContentType, Header, HeaderMap, Method, Status}, | ||||
|     request::FromParam, | ||||
|     response::{self, Responder}, | ||||
|     Data, Request, Response, Rocket, | ||||
|     Data, Orbit, Request, Response, Rocket, | ||||
| }; | ||||
|  | ||||
| use std::thread::sleep; | ||||
| @@ -18,6 +18,7 @@ use crate::CONFIG; | ||||
|  | ||||
| pub struct AppHeaders(); | ||||
|  | ||||
| #[rocket::async_trait] | ||||
| impl Fairing for AppHeaders { | ||||
|     fn info(&self) -> Info { | ||||
|         Info { | ||||
| @@ -26,7 +27,7 @@ impl Fairing for AppHeaders { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn on_response(&self, _req: &Request, res: &mut Response) { | ||||
|     async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) { | ||||
|         res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()"); | ||||
|         res.set_raw_header("Referrer-Policy", "same-origin"); | ||||
|         res.set_raw_header("X-Frame-Options", "SAMEORIGIN"); | ||||
| @@ -72,6 +73,7 @@ impl Cors { | ||||
|     } | ||||
| } | ||||
|  | ||||
| #[rocket::async_trait] | ||||
| impl Fairing for Cors { | ||||
|     fn info(&self) -> Info { | ||||
|         Info { | ||||
| @@ -80,7 +82,7 @@ impl Fairing for Cors { | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn on_response(&self, request: &Request, response: &mut Response) { | ||||
|     async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { | ||||
|         let req_headers = request.headers(); | ||||
|  | ||||
|         if let Some(origin) = Cors::get_allowed_origin(req_headers) { | ||||
| @@ -97,7 +99,7 @@ impl Fairing for Cors { | ||||
|             response.set_header(Header::new("Access-Control-Allow-Credentials", "true")); | ||||
|             response.set_status(Status::Ok); | ||||
|             response.set_header(ContentType::Plain); | ||||
|             response.set_sized_body(Cursor::new("")); | ||||
|             response.set_sized_body(Some(0), Cursor::new("")); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @@ -134,25 +136,21 @@ impl<R> Cached<R> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> { | ||||
|     fn respond_to(self, req: &Request) -> response::Result<'r> { | ||||
| impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> { | ||||
|     fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> { | ||||
|         let mut res = self.response.respond_to(request)?; | ||||
|  | ||||
|         let cache_control_header = if self.is_immutable { | ||||
|             format!("public, immutable, max-age={}", self.ttl) | ||||
|         } else { | ||||
|             format!("public, max-age={}", self.ttl) | ||||
|         }; | ||||
|         res.set_raw_header("Cache-Control", cache_control_header); | ||||
|  | ||||
|         let time_now = chrono::Local::now(); | ||||
|  | ||||
|         match self.response.respond_to(req) { | ||||
|             Ok(mut res) => { | ||||
|                 res.set_raw_header("Cache-Control", cache_control_header); | ||||
|                 let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap()); | ||||
|                 res.set_raw_header("Expires", format_datetime_http(&expiry_time)); | ||||
|                 Ok(res) | ||||
|             } | ||||
|             e @ Err(_) => e, | ||||
|         } | ||||
|         let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap()); | ||||
|         res.set_raw_header("Expires", format_datetime_http(&expiry_time)); | ||||
|         Ok(res) | ||||
|     } | ||||
| } | ||||
|  | ||||
| @@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString { | ||||
|     type Error = (); | ||||
|  | ||||
|     #[inline(always)] | ||||
|     fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> { | ||||
|         let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?; | ||||
|  | ||||
|         if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { | ||||
|             Ok(SafeString(s)) | ||||
|     fn from_param(param: &'r str) -> Result<Self, Self::Error> { | ||||
|         if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) { | ||||
|             Ok(SafeString(param.to_string())) | ||||
|         } else { | ||||
|             Err(()) | ||||
|         } | ||||
| @@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] = | ||||
|  | ||||
| // Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts | ||||
| pub struct BetterLogging(pub bool); | ||||
| #[rocket::async_trait] | ||||
| impl Fairing for BetterLogging { | ||||
|     fn info(&self) -> Info { | ||||
|         Info { | ||||
|             name: "Better Logging", | ||||
|             kind: Kind::Launch | Kind::Request | Kind::Response, | ||||
|             kind: Kind::Liftoff | Kind::Request | Kind::Response, | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn on_launch(&self, rocket: &Rocket) { | ||||
|     async fn on_liftoff(&self, rocket: &Rocket<Orbit>) { | ||||
|         if self.0 { | ||||
|             info!(target: "routes", "Routes loaded:"); | ||||
|             let mut routes: Vec<_> = rocket.routes().collect(); | ||||
| @@ -225,34 +222,36 @@ impl Fairing for BetterLogging { | ||||
|         info!(target: "start", "Rocket has launched from {}", addr); | ||||
|     } | ||||
|  | ||||
|     fn on_request(&self, request: &mut Request<'_>, _data: &Data) { | ||||
|     async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) { | ||||
|         let method = request.method(); | ||||
|         if !self.0 && method == Method::Options { | ||||
|             return; | ||||
|         } | ||||
|         let uri = request.uri(); | ||||
|         let uri_path = uri.path(); | ||||
|         let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path); | ||||
|         let uri_path_str = uri_path.url_decode_lossy(); | ||||
|         let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); | ||||
|         if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { | ||||
|             match uri.query() { | ||||
|                 Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]), | ||||
|                 None => info!(target: "request", "{} {}", method, uri_path), | ||||
|                 Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]), | ||||
|                 None => info!(target: "request", "{} {}", method, uri_path_str), | ||||
|             }; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     fn on_response(&self, request: &Request, response: &mut Response) { | ||||
|     async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) { | ||||
|         if !self.0 && request.method() == Method::Options { | ||||
|             return; | ||||
|         } | ||||
|         let uri_path = request.uri().path(); | ||||
|         let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path); | ||||
|         let uri_path_str = uri_path.url_decode_lossy(); | ||||
|         let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str); | ||||
|         if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) { | ||||
|             let status = response.status(); | ||||
|             if let Some(route) = request.route() { | ||||
|                 info!(target: "response", "{} => {} {}", route, status.code, status.reason) | ||||
|             if let Some(ref route) = request.route() { | ||||
|                 info!(target: "response", "{} => {}", route, status) | ||||
|             } else { | ||||
|                 info!(target: "response", "{} {}", status.code, status.reason) | ||||
|                 info!(target: "response", "{}", status) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @@ -614,10 +613,7 @@ where | ||||
|     } | ||||
| } | ||||
|  | ||||
| use reqwest::{ | ||||
|     blocking::{Client, ClientBuilder}, | ||||
|     header, | ||||
| }; | ||||
| use reqwest::{header, Client, ClientBuilder}; | ||||
|  | ||||
| pub fn get_reqwest_client() -> Client { | ||||
|     get_reqwest_client_builder().build().expect("Failed to build client") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user