Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking

This commit is contained in:
Daniel García
2021-11-07 18:53:39 +01:00
parent 08f0de7b46
commit 2d5f172e77
30 changed files with 1314 additions and 1028 deletions

View File

@@ -3,13 +3,14 @@ use serde::de::DeserializeOwned;
use serde_json::Value;
use std::env;
use rocket::serde::json::Json;
use rocket::{
http::{Cookie, Cookies, SameSite, Status},
request::{self, FlashMessage, Form, FromRequest, Outcome, Request},
response::{content::Html, Flash, Redirect},
form::Form,
http::{Cookie, CookieJar, SameSite, Status},
request::{self, FlashMessage, FromRequest, Outcome, Request},
response::{content::RawHtml as Html, Flash, Redirect},
Route,
};
use rocket_contrib::json::Json;
use crate::{
api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
@@ -85,10 +86,11 @@ fn admin_path() -> String {
struct Referer(Option<String>);
impl<'a, 'r> FromRequest<'a, 'r> for Referer {
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Referer {
type Error = ();
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
}
}
@@ -96,10 +98,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for Referer {
#[derive(Debug)]
struct IpHeader(Option<String>);
impl<'a, 'r> FromRequest<'a, 'r> for IpHeader {
#[rocket::async_trait]
impl<'r> FromRequest<'r> for IpHeader {
type Error = ();
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
if req.headers().get_one(&CONFIG.ip_header()).is_some() {
Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
} else if req.headers().get_one("X-Client-IP").is_some() {
@@ -138,7 +141,7 @@ fn admin_url(referer: Referer) -> String {
#[get("/", rank = 2)]
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
// If there is an error, show it
let msg = flash.map(|msg| format!("{}: {}", msg.name(), msg.msg()));
let msg = flash.map(|msg| format!("{}: {}", msg.kind(), msg.message()));
let json = json!({
"page_content": "admin/login",
"version": VERSION,
@@ -159,7 +162,7 @@ struct LoginForm {
#[post("/", data = "<data>")]
fn post_admin_login(
data: Form<LoginForm>,
mut cookies: Cookies,
cookies: &CookieJar,
ip: ClientIp,
referer: Referer,
) -> Result<Redirect, Flash<Redirect>> {
@@ -180,7 +183,7 @@ fn post_admin_login(
let cookie = Cookie::build(COOKIE_NAME, jwt)
.path(admin_path())
.max_age(time::Duration::minutes(20))
.max_age(rocket::time::Duration::minutes(20))
.same_site(SameSite::Strict)
.http_only(true)
.finish();
@@ -297,7 +300,7 @@ fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
}
#[get("/logout")]
fn logout(mut cookies: Cookies, referer: Referer) -> Redirect {
fn logout(cookies: &CookieJar, referer: Referer) -> Redirect {
cookies.remove(Cookie::named(COOKIE_NAME));
Redirect::to(admin_url(referer))
}
@@ -462,23 +465,23 @@ struct GitCommit {
sha: String,
}
fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
async fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
let github_api = get_reqwest_client();
Ok(github_api.get(url).send()?.error_for_status()?.json::<T>()?)
Ok(github_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
}
fn has_http_access() -> bool {
async fn has_http_access() -> bool {
let http_access = get_reqwest_client();
match http_access.head("https://github.com/dani-garcia/vaultwarden").send() {
match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
Ok(r) => r.status().is_success(),
_ => false,
}
}
#[get("/diagnostics")]
fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
use crate::util::read_file_string;
use chrono::prelude::*;
use std::net::ToSocketAddrs;
@@ -497,7 +500,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
// Execute some environment checks
let running_within_docker = is_running_in_docker();
let has_http_access = has_http_access();
let has_http_access = has_http_access().await;
let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|| env::var_os("http_proxy").is_some()
|| env::var_os("HTTPS_PROXY").is_some()
@@ -513,11 +516,14 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
(
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest") {
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest")
.await
{
Ok(r) => r.tag_name,
_ => "-".to_string(),
},
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main") {
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await
{
Ok(mut c) => {
c.sha.truncate(8);
c.sha
@@ -531,7 +537,9 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
} else {
match get_github_api::<GitRelease>(
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
) {
)
.await
{
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
_ => "-".to_string(),
}
@@ -562,7 +570,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
"ip_header_config": &CONFIG.ip_header(),
"uses_proxy": uses_proxy,
"db_type": *DB_TYPE,
"db_version": get_sql_server_version(&conn),
"db_version": get_sql_server_version(&conn).await,
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
"overrides": &CONFIG.get_overrides().join(", "),
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
@@ -591,9 +599,9 @@ fn delete_config(_token: AdminToken) -> EmptyResult {
}
#[post("/config/backup_db")]
fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
async fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
if *CAN_BACKUP {
backup_database(&conn)
backup_database(&conn).await
} else {
err!("Can't back up current DB (Only SQLite supports this feature)");
}
@@ -601,21 +609,22 @@ fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
pub struct AdminToken {}
impl<'a, 'r> FromRequest<'a, 'r> for AdminToken {
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminToken {
type Error = &'static str;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
if CONFIG.disable_admin_token() {
Outcome::Success(AdminToken {})
} else {
let mut cookies = request.cookies();
let cookies = request.cookies();
let access_token = match cookies.get(COOKIE_NAME) {
Some(cookie) => cookie.value(),
None => return Outcome::Forward(()), // If there is no cookie, redirect to login
};
let ip = match request.guard::<ClientIp>() {
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip.ip,
_ => err_handler!("Error getting Client IP"),
};

View File

@@ -1,5 +1,5 @@
use chrono::Utc;
use rocket_contrib::json::Json;
use rocket::serde::json::Json;
use serde_json::Value;
use crate::{

View File

@@ -1,13 +1,14 @@
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use chrono::{NaiveDateTime, Utc};
use rocket::{http::ContentType, request::Form, Data, Route};
use rocket_contrib::json::Json;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use rocket::{
form::{Form, FromForm},
Route,
};
use serde_json::Value;
use multipart::server::{save::SavedData, Multipart, SaveResult};
use crate::{
api::{self, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordData, UpdateType},
auth::Headers,
@@ -79,9 +80,9 @@ pub fn routes() -> Vec<Route> {
]
}
pub fn purge_trashed_ciphers(pool: DbPool) {
pub async fn purge_trashed_ciphers(pool: DbPool) {
debug!("Purging trashed ciphers");
if let Ok(conn) = pool.get() {
if let Ok(conn) = pool.get().await {
Cipher::purge_trash(&conn);
} else {
error!("Failed to get DB connection while purging trashed ciphers")
@@ -90,12 +91,12 @@ pub fn purge_trashed_ciphers(pool: DbPool) {
#[derive(FromForm, Default)]
struct SyncData {
#[form(field = "excludeDomains")]
#[field(name = "excludeDomains")]
exclude_domains: bool, // Default: 'false'
}
#[get("/sync?<data..>")]
fn sync(data: Form<SyncData>, headers: Headers, conn: DbConn) -> Json<Value> {
fn sync(data: SyncData, headers: Headers, conn: DbConn) -> Json<Value> {
let user_json = headers.user.to_json(&conn);
let folders = Folder::find_by_user(&headers.user.uuid, &conn);
@@ -828,6 +829,12 @@ fn post_attachment_v2(
})))
}
#[derive(FromForm)]
struct UploadData<'f> {
key: Option<String>,
data: TempFile<'f>,
}
/// Saves the data content of an attachment to a file. This is common code
/// shared between the v2 and legacy attachment APIs.
///
@@ -836,22 +843,21 @@ fn post_attachment_v2(
///
/// When used with the v2 API, post_attachment_v2() has already created the
/// database record, which is passed in as `attachment`.
fn save_attachment(
async fn save_attachment(
mut attachment: Option<Attachment>,
cipher_uuid: String,
data: Data,
content_type: &ContentType,
data: Form<UploadData<'_>>,
headers: &Headers,
conn: &DbConn,
nt: Notify,
) -> Result<Cipher, crate::error::Error> {
let cipher = match Cipher::find_by_uuid(&cipher_uuid, conn) {
conn: DbConn,
nt: Notify<'_>,
) -> Result<(Cipher, DbConn), crate::error::Error> {
let cipher = match Cipher::find_by_uuid(&cipher_uuid, &conn) {
Some(cipher) => cipher,
None => err_discard!("Cipher doesn't exist", data),
None => err!("Cipher doesn't exist"),
};
if !cipher.is_write_accessible_to_user(&headers.user.uuid, conn) {
err_discard!("Cipher is not write accessible", data)
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn) {
err!("Cipher is not write accessible")
}
// In the v2 API, the attachment record has already been created,
@@ -863,11 +869,11 @@ fn save_attachment(
let size_limit = if let Some(ref user_uuid) = cipher.user_uuid {
match CONFIG.user_attachment_limit() {
Some(0) => err_discard!("Attachments are disabled", data),
Some(0) => err!("Attachments are disabled"),
Some(limit_kb) => {
let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, conn) + size_adjust;
let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, &conn) + size_adjust;
if left <= 0 {
err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
err!("Attachment storage limit reached! Delete some attachments to free up space")
}
Some(left as u64)
}
@@ -875,130 +881,78 @@ fn save_attachment(
}
} else if let Some(ref org_uuid) = cipher.organization_uuid {
match CONFIG.org_attachment_limit() {
Some(0) => err_discard!("Attachments are disabled", data),
Some(0) => err!("Attachments are disabled"),
Some(limit_kb) => {
let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, conn) + size_adjust;
let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, &conn) + size_adjust;
if left <= 0 {
err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
err!("Attachment storage limit reached! Delete some attachments to free up space")
}
Some(left as u64)
}
None => None,
}
} else {
err_discard!("Cipher is neither owned by a user nor an organization", data);
err!("Cipher is neither owned by a user nor an organization");
};
let mut params = content_type.params();
let boundary_pair = params.next().expect("No boundary provided");
let boundary = boundary_pair.1;
let mut data = data.into_inner();
let base_path = Path::new(&CONFIG.attachments_folder()).join(&cipher_uuid);
let mut path = PathBuf::new();
let mut attachment_key = None;
let mut error = None;
Multipart::with_body(data.open(), boundary)
.foreach_entry(|mut field| {
match &*field.headers.name {
"key" => {
use std::io::Read;
let mut key_buffer = String::new();
if field.data.read_to_string(&mut key_buffer).is_ok() {
attachment_key = Some(key_buffer);
}
}
"data" => {
// In the legacy API, this is the encrypted filename
// provided by the client, stored to the database as-is.
// In the v2 API, this value doesn't matter, as it was
// already provided and stored via an earlier API call.
let encrypted_filename = field.headers.filename;
// This random ID is used as the name of the file on disk.
// In the legacy API, we need to generate this value here.
// In the v2 API, we use the value from post_attachment_v2().
let file_id = match &attachment {
Some(attachment) => attachment.id.clone(), // v2 API
None => crypto::generate_attachment_id(), // Legacy API
};
path = base_path.join(&file_id);
let size =
match field.data.save().memory_threshold(0).size_limit(size_limit).with_path(path.clone()) {
SaveResult::Full(SavedData::File(_, size)) => size as i32,
SaveResult::Full(other) => {
error = Some(format!("Attachment is not a file: {:?}", other));
return;
}
SaveResult::Partial(_, reason) => {
error = Some(format!("Attachment storage limit exceeded with this file: {:?}", reason));
return;
}
SaveResult::Error(e) => {
error = Some(format!("Error: {:?}", e));
return;
}
};
if let Some(attachment) = &mut attachment {
// v2 API
// Check the actual size against the size initially provided by
// the client. Upstream allows +/- 1 MiB deviation from this
// size, but it's not clear when or why this is needed.
const LEEWAY: i32 = 1024 * 1024; // 1 MiB
let min_size = attachment.file_size - LEEWAY;
let max_size = attachment.file_size + LEEWAY;
if min_size <= size && size <= max_size {
if size != attachment.file_size {
// Update the attachment with the actual file size.
attachment.file_size = size;
attachment.save(conn).expect("Error updating attachment");
}
} else {
attachment.delete(conn).ok();
let err_msg = "Attachment size mismatch".to_string();
error!("{} (expected within [{}, {}], got {})", err_msg, min_size, max_size, size);
error = Some(err_msg);
}
} else {
// Legacy API
if encrypted_filename.is_none() {
error = Some("No filename provided".to_string());
return;
}
if attachment_key.is_none() {
error = Some("No attachment key provided".to_string());
return;
}
let attachment = Attachment::new(
file_id,
cipher_uuid.clone(),
encrypted_filename.unwrap(),
size,
attachment_key.clone(),
);
attachment.save(conn).expect("Error saving attachment");
}
}
_ => error!("Invalid multipart name"),
}
})
.expect("Error processing multipart data");
if let Some(ref e) = error {
std::fs::remove_file(path).ok();
err!(e);
if let Some(size_limit) = size_limit {
if data.data.len() > size_limit {
err!("Attachment storage limit exceeded with this file");
}
}
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn));
let file_id = match &attachment {
Some(attachment) => attachment.id.clone(), // v2 API
None => crypto::generate_attachment_id(), // Legacy API
};
Ok(cipher)
let folder_path = tokio::fs::canonicalize(&CONFIG.attachments_folder()).await?.join(&cipher_uuid);
let file_path = folder_path.join(&file_id);
tokio::fs::create_dir_all(&folder_path).await?;
let size = data.data.len() as i32;
if let Some(attachment) = &mut attachment {
// v2 API
// Check the actual size against the size initially provided by
// the client. Upstream allows +/- 1 MiB deviation from this
// size, but it's not clear when or why this is needed.
const LEEWAY: i32 = 1024 * 1024; // 1 MiB
let min_size = attachment.file_size - LEEWAY;
let max_size = attachment.file_size + LEEWAY;
if min_size <= size && size <= max_size {
if size != attachment.file_size {
// Update the attachment with the actual file size.
attachment.file_size = size;
attachment.save(&conn).expect("Error updating attachment");
}
} else {
attachment.delete(&conn).ok();
err!(format!("Attachment size mismatch (expected within [{}, {}], got {})", min_size, max_size, size));
}
} else {
// Legacy API
let encrypted_filename = data.data.raw_name().map(|s| s.dangerous_unsafe_unsanitized_raw().to_string());
if encrypted_filename.is_none() {
err!("No filename provided")
}
if data.key.is_none() {
err!("No attachment key provided")
}
let attachment = Attachment::new(file_id, cipher_uuid.clone(), encrypted_filename.unwrap(), size, data.key);
attachment.save(&conn).expect("Error saving attachment");
}
data.data.persist_to(file_path).await?;
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(&conn));
Ok((cipher, conn))
}
/// v2 API for uploading the actual data content of an attachment.
@@ -1006,14 +960,13 @@ fn save_attachment(
/// /ciphers/<uuid>/attachment/v2 route, which would otherwise conflict
/// with this one.
#[post("/ciphers/<uuid>/attachment/<attachment_id>", format = "multipart/form-data", data = "<data>", rank = 1)]
fn post_attachment_v2_data(
async fn post_attachment_v2_data(
uuid: String,
attachment_id: String,
data: Data,
content_type: &ContentType,
data: Form<UploadData<'_>>,
headers: Headers,
conn: DbConn,
nt: Notify,
nt: Notify<'_>,
) -> EmptyResult {
let attachment = match Attachment::find_by_id(&attachment_id, &conn) {
Some(attachment) if uuid == attachment.cipher_uuid => Some(attachment),
@@ -1021,54 +974,51 @@ fn post_attachment_v2_data(
None => err!("Attachment doesn't exist"),
};
save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
Ok(())
}
/// Legacy API for creating an attachment associated with a cipher.
#[post("/ciphers/<uuid>/attachment", format = "multipart/form-data", data = "<data>")]
fn post_attachment(
async fn post_attachment(
uuid: String,
data: Data,
content_type: &ContentType,
data: Form<UploadData<'_>>,
headers: Headers,
conn: DbConn,
nt: Notify,
nt: Notify<'_>,
) -> JsonResult {
// Setting this as None signifies to save_attachment() that it should create
// the attachment database record as well as saving the data to disk.
let attachment = None;
let cipher = save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
let (cipher, conn) = save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, &conn)))
}
#[post("/ciphers/<uuid>/attachment-admin", format = "multipart/form-data", data = "<data>")]
fn post_attachment_admin(
async fn post_attachment_admin(
uuid: String,
data: Data,
content_type: &ContentType,
data: Form<UploadData<'_>>,
headers: Headers,
conn: DbConn,
nt: Notify,
nt: Notify<'_>,
) -> JsonResult {
post_attachment(uuid, data, content_type, headers, conn, nt)
post_attachment(uuid, data, headers, conn, nt).await
}
#[post("/ciphers/<uuid>/attachment/<attachment_id>/share", format = "multipart/form-data", data = "<data>")]
fn post_attachment_share(
async fn post_attachment_share(
uuid: String,
attachment_id: String,
data: Data,
content_type: &ContentType,
data: Form<UploadData<'_>>,
headers: Headers,
conn: DbConn,
nt: Notify,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&uuid, &attachment_id, &headers, &conn, &nt)?;
post_attachment(uuid, data, content_type, headers, conn, nt)
post_attachment(uuid, data, headers, conn, nt).await
}
#[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")]
@@ -1248,13 +1198,13 @@ fn move_cipher_selected_put(
#[derive(FromForm)]
struct OrganizationId {
#[form(field = "organizationId")]
#[field(name = "organizationId")]
org_id: String,
}
#[post("/ciphers/purge?<organization..>", data = "<data>")]
fn delete_all(
organization: Option<Form<OrganizationId>>,
organization: Option<OrganizationId>,
data: JsonUpcase<PasswordData>,
headers: Headers,
conn: DbConn,

View File

@@ -1,6 +1,6 @@
use chrono::{Duration, Utc};
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use std::borrow::Borrow;
@@ -709,13 +709,13 @@ fn check_emergency_access_allowed() -> EmptyResult {
Ok(())
}
pub fn emergency_request_timeout_job(pool: DbPool) {
pub async fn emergency_request_timeout_job(pool: DbPool) {
debug!("Start emergency_request_timeout_job");
if !CONFIG.emergency_access_allowed() {
return;
}
if let Ok(conn) = pool.get() {
if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
if emergency_access_list.is_empty() {
@@ -756,13 +756,13 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
}
}
pub fn emergency_notification_reminder_job(pool: DbPool) {
pub async fn emergency_notification_reminder_job(pool: DbPool) {
debug!("Start emergency_notification_reminder_job");
if !CONFIG.emergency_access_allowed() {
return;
}
if let Ok(conn) = pool.get() {
if let Ok(conn) = pool.get().await {
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
if emergency_access_list.is_empty() {

View File

@@ -1,4 +1,4 @@
use rocket_contrib::json::Json;
use rocket::serde::json::Json;
use serde_json::Value;
use crate::{

View File

@@ -31,8 +31,8 @@ pub fn routes() -> Vec<Route> {
//
// Move this somewhere else
//
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use crate::{
@@ -144,7 +144,7 @@ fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbC
}
#[get("/hibp/breach?<username>")]
fn hibp_breach(username: String) -> JsonResult {
async fn hibp_breach(username: String) -> JsonResult {
let url = format!(
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
username
@@ -153,14 +153,14 @@ fn hibp_breach(username: String) -> JsonResult {
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
let hibp_client = get_reqwest_client();
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send()?;
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
// If we get a 404, return a 404, it means no breached accounts
if res.status() == 404 {
return Err(Error::empty().with_code(404));
}
let value: Value = res.error_for_status()?.json()?;
let value: Value = res.error_for_status()?.json().await?;
Ok(Json(value))
} else {
Ok(Json(json!([{

View File

@@ -1,6 +1,6 @@
use num_traits::FromPrimitive;
use rocket::{request::Form, Route};
use rocket_contrib::json::Json;
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use crate::{
@@ -469,12 +469,12 @@ fn put_collection_users(
#[derive(FromForm)]
struct OrgIdData {
#[form(field = "organizationId")]
#[field(name = "organizationId")]
organization_id: String,
}
#[get("/ciphers/organization-details?<data..>")]
fn get_org_details(data: Form<OrgIdData>, headers: Headers, conn: DbConn) -> Json<Value> {
fn get_org_details(data: OrgIdData, headers: Headers, conn: DbConn) -> Json<Value> {
let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
let ciphers_json: Vec<Value> =
ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
@@ -1097,14 +1097,14 @@ struct RelationsData {
#[post("/ciphers/import-organization?<query..>", data = "<data>")]
fn post_org_import(
query: Form<OrgIdData>,
query: OrgIdData,
data: JsonUpcase<ImportData>,
headers: AdminHeaders,
conn: DbConn,
nt: Notify,
) -> EmptyResult {
let data: ImportData = data.into_inner().data;
let org_id = query.into_inner().organization_id;
let org_id = query.organization_id;
// Read and create the collections
let collections: Vec<_> = data

View File

@@ -1,9 +1,10 @@
use std::{io::Read, path::Path};
use std::path::Path;
use chrono::{DateTime, Duration, Utc};
use multipart::server::{save::SavedData, Multipart, SaveResult};
use rocket::{http::ContentType, response::NamedFile, Data};
use rocket_contrib::json::Json;
use rocket::form::Form;
use rocket::fs::NamedFile;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use serde_json::Value;
use crate::{
@@ -31,9 +32,9 @@ pub fn routes() -> Vec<rocket::Route> {
]
}
pub fn purge_sends(pool: DbPool) {
pub async fn purge_sends(pool: DbPool) {
debug!("Purging sends");
if let Ok(conn) = pool.get() {
if let Ok(conn) = pool.get().await {
Send::purge(&conn);
} else {
error!("Failed to get DB connection while purging sends")
@@ -177,25 +178,23 @@ fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Not
Ok(Json(send.to_json()))
}
#[derive(FromForm)]
struct UploadData<'f> {
model: Json<crate::util::UpCase<SendData>>,
data: TempFile<'f>,
}
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult {
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &conn)?;
let boundary = content_type.params().next().expect("No boundary provided").1;
let UploadData {
model,
mut data,
} = data.into_inner();
let model = model.into_inner().data;
let mut mpart = Multipart::with_body(data.open(), boundary);
// First entry is the SendData JSON
let mut model_entry = match mpart.read_entry()? {
Some(e) if &*e.headers.name == "model" => e,
Some(_) => err!("Invalid entry name"),
None => err!("No model entry present"),
};
let mut buf = String::new();
model_entry.data.read_to_string(&mut buf)?;
let data = serde_json::from_str::<crate::util::UpCase<SendData>>(&buf)?;
enforce_disable_hide_email_policy(&data.data, &headers, &conn)?;
enforce_disable_hide_email_policy(&model, &headers, &conn)?;
// Get the file length and add an extra 5% to avoid issues
const SIZE_525_MB: u64 = 550_502_400;
@@ -212,45 +211,27 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
None => SIZE_525_MB,
};
// Create the Send
let mut send = create_send(data.data, headers.user.uuid)?;
let file_id = crate::crypto::generate_send_id();
let mut send = create_send(model, headers.user.uuid)?;
if send.atype != SendType::File as i32 {
err!("Send content is not a file");
}
let file_path = Path::new(&CONFIG.sends_folder()).join(&send.uuid).join(&file_id);
let size = data.len();
if size > size_limit {
err!("Attachment storage limit exceeded with this file");
}
// Read the data entry and save the file
let mut data_entry = match mpart.read_entry()? {
Some(e) if &*e.headers.name == "data" => e,
Some(_) => err!("Invalid entry name"),
None => err!("No model entry present"),
};
let file_id = crate::crypto::generate_send_id();
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
let file_path = folder_path.join(&file_id);
tokio::fs::create_dir_all(&folder_path).await?;
data.persist_to(&file_path).await?;
let size = match data_entry.data.save().memory_threshold(0).size_limit(size_limit).with_path(&file_path) {
SaveResult::Full(SavedData::File(_, size)) => size as i32,
SaveResult::Full(other) => {
std::fs::remove_file(&file_path).ok();
err!(format!("Attachment is not a file: {:?}", other));
}
SaveResult::Partial(_, reason) => {
std::fs::remove_file(&file_path).ok();
err!(format!("Attachment storage limit exceeded with this file: {:?}", reason));
}
SaveResult::Error(e) => {
std::fs::remove_file(&file_path).ok();
err!(format!("Error: {:?}", e));
}
};
// Set ID and sizes
let mut data_value: Value = serde_json::from_str(&send.data)?;
if let Some(o) = data_value.as_object_mut() {
o.insert(String::from("Id"), Value::String(file_id));
o.insert(String::from("Size"), Value::Number(size.into()));
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size)));
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size as i32)));
}
send.data = serde_json::to_string(&data_value)?;
@@ -367,10 +348,10 @@ fn post_access_file(
}
#[get("/sends/<send_id>/<file_id>?<t>")]
fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
async fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(&t) {
if claims.sub == format!("{}/{}", send_id, file_id) {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).ok();
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
}
}
None

View File

@@ -1,6 +1,6 @@
use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use crate::{
api::{

View File

@@ -1,7 +1,7 @@
use chrono::Utc;
use data_encoding::BASE64;
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use crate::{
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
@@ -152,7 +152,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
}
#[post("/two-factor/duo", data = "<data>")]
fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
async fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
let data: EnableDuoData = data.into_inner().data;
let mut user = headers.user;
@@ -163,7 +163,7 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
let (data, data_str) = if check_duo_fields_custom(&data) {
let data_req: DuoData = data.into();
let data_str = serde_json::to_string(&data_req)?;
duo_api_request("GET", "/auth/v2/check", "", &data_req).map_res("Failed to validate Duo credentials")?;
duo_api_request("GET", "/auth/v2/check", "", &data_req).await.map_res("Failed to validate Duo credentials")?;
(data_req.obscure(), data_str)
} else {
(DuoData::secret(), String::new())
@@ -185,11 +185,11 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
}
#[put("/two-factor/duo", data = "<data>")]
fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
activate_duo(data, headers, conn)
async fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
activate_duo(data, headers, conn).await
}
fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
use reqwest::{header, Method};
use std::str::FromStr;
@@ -209,7 +209,8 @@ fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> Em
.basic_auth(username, Some(password))
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
.header(header::DATE, date)
.send()?
.send()
.await?
.error_for_status()?;
Ok(())

View File

@@ -1,6 +1,6 @@
use chrono::{Duration, NaiveDateTime, Utc};
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use crate::{
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},

View File

@@ -1,7 +1,7 @@
use chrono::{Duration, Utc};
use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use crate::{
@@ -158,14 +158,14 @@ fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Header
disable_twofactor(data, headers, conn)
}
pub fn send_incomplete_2fa_notifications(pool: DbPool) {
pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
debug!("Sending notifications for incomplete 2FA logins");
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
return;
}
let conn = match pool.get() {
let conn = match pool.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");

View File

@@ -1,6 +1,6 @@
use once_cell::sync::Lazy;
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use u2f::{
messages::{RegisterResponse, SignResponse, U2fSignRequest},

View File

@@ -1,5 +1,5 @@
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use url::Url;
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};

View File

@@ -1,5 +1,5 @@
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value;
use yubico::{config::Config, verify};

View File

@@ -1,19 +1,19 @@
use std::{
collections::HashMap,
fs::{create_dir_all, remove_file, symlink_metadata, File},
io::prelude::*,
net::{IpAddr, ToSocketAddrs},
sync::{Arc, RwLock},
time::{Duration, SystemTime},
};
use bytes::{Buf, Bytes, BytesMut};
use futures::{stream::StreamExt, TryFutureExt};
use once_cell::sync::Lazy;
use regex::Regex;
use reqwest::{blocking::Client, blocking::Response, header};
use rocket::{
http::ContentType,
response::{Content, Redirect},
Route,
use reqwest::{header, Client, Response};
use rocket::{http::ContentType, response::Redirect, Route};
use tokio::{
fs::{create_dir_all, remove_file, symlink_metadata, File},
io::{AsyncReadExt, AsyncWriteExt},
};
use crate::{
@@ -104,27 +104,23 @@ fn icon_google(domain: String) -> Option<Redirect> {
}
#[get("/<domain>/icon.png")]
fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
if !is_valid_domain(&domain) {
warn!("Invalid domain: {}", domain);
return Cached::ttl(
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
);
}
match get_icon(&domain) {
match get_icon(&domain).await {
Some((icon, icon_type)) => {
Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
}
_ => Cached::ttl(
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
),
_ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
}
}
@@ -317,15 +313,15 @@ fn is_domain_blacklisted(domain: &str) -> bool {
is_blacklisted
}
fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
// Check for expiration of negatively cached copy
if icon_is_negcached(&path) {
if icon_is_negcached(&path).await {
return None;
}
if let Some(icon) = get_cached_icon(&path) {
if let Some(icon) = get_cached_icon(&path).await {
let icon_type = match get_icon_type(&icon) {
Some(x) => x,
_ => "x-icon",
@@ -338,31 +334,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
}
// Get the icon, or None in case of error
match download_icon(domain) {
match download_icon(domain).await {
Ok((icon, icon_type)) => {
save_icon(&path, &icon);
Some((icon, icon_type.unwrap_or("x-icon").to_string()))
save_icon(&path, &icon).await;
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
}
Err(e) => {
warn!("Unable to download icon: {:?}", e);
let miss_indicator = path + ".miss";
save_icon(&miss_indicator, &[]);
save_icon(&miss_indicator, &[]).await;
None
}
}
}
fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
// Check for expiration of successfully cached copy
if icon_is_expired(path) {
if icon_is_expired(path).await {
return None;
}
// Try to read the cached icon, and return it if it exists
if let Ok(mut f) = File::open(path) {
if let Ok(mut f) = File::open(path).await {
let mut buffer = Vec::new();
if f.read_to_end(&mut buffer).is_ok() {
if f.read_to_end(&mut buffer).await.is_ok() {
return Some(buffer);
}
}
@@ -370,22 +366,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
None
}
fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
let meta = symlink_metadata(path)?;
async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
let meta = symlink_metadata(path).await?;
let modified = meta.modified()?;
let age = SystemTime::now().duration_since(modified)?;
Ok(ttl > 0 && ttl <= age.as_secs())
}
fn icon_is_negcached(path: &str) -> bool {
async fn icon_is_negcached(path: &str) -> bool {
let miss_indicator = path.to_owned() + ".miss";
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl());
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
match expired {
// No longer negatively cached, drop the marker
Ok(true) => {
if let Err(e) = remove_file(&miss_indicator) {
if let Err(e) = remove_file(&miss_indicator).await {
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
}
false
@@ -397,8 +393,8 @@ fn icon_is_negcached(path: &str) -> bool {
}
}
fn icon_is_expired(path: &str) -> bool {
let expired = file_is_expired(path, CONFIG.icon_cache_ttl());
async fn icon_is_expired(path: &str) -> bool {
let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
expired.unwrap_or(true)
}
@@ -521,13 +517,13 @@ struct IconUrlResult {
/// let icon_result = get_icon_url("github.com")?;
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
/// ```
fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// Default URL with secure and insecure schemes
let ssldomain = format!("https://{}", domain);
let httpdomain = format!("http://{}", domain);
// First check the domain as given during the request for both HTTPS and HTTP.
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) {
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
Ok(c) => Ok(c),
Err(e) => {
let mut sub_resp = Err(e);
@@ -546,7 +542,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
let httpbase = format!("http://{}", base_domain);
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase));
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
}
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
@@ -557,7 +553,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
let httpwww = format!("http://{}", www_domain);
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww));
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
}
}
@@ -581,7 +577,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
let mut limited_reader = content.take(384 * 1024);
let mut limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.reader();
use html5ever::tendril::TendrilSink;
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
@@ -607,11 +603,11 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
})
}
fn get_page(url: &str) -> Result<Response, Error> {
get_page_with_referer(url, "")
async fn get_page(url: &str) -> Result<Response, Error> {
get_page_with_referer(url, "").await
}
fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
}
@@ -621,7 +617,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
client = client.header("Referer", referer)
}
match client.send() {
match client.send().await {
Ok(c) => c.error_for_status().map_err(Into::into),
Err(e) => err_silent!(format!("{}", e)),
}
@@ -706,14 +702,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
(width, height)
}
fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
if is_domain_blacklisted(domain) {
err_silent!("Domain is blacklisted", domain)
}
let icon_result = get_icon_url(domain)?;
let icon_result = get_icon_url(domain).await?;
let mut buffer = Vec::new();
let mut buffer = Bytes::new();
let mut icon_type: Option<&str> = None;
use data_url::DataUrl;
@@ -722,8 +718,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
if icon.href.starts_with("data:image") {
let datauri = DataUrl::process(&icon.href).unwrap();
// Check if we are able to decode the data uri
match datauri.decode_to_vec() {
Ok((body, _fragment)) => {
let mut body = BytesMut::new();
match datauri.decode::<_, ()>(|bytes| {
body.extend_from_slice(bytes);
Ok(())
}) {
Ok(_) => {
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
if body.len() >= 67 {
// Check if the icon type is allowed, else try an icon from the list.
@@ -733,17 +733,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
continue;
}
info!("Extracted icon from data:image uri for {}", domain);
buffer = body;
buffer = body.freeze();
break;
}
}
_ => debug!("Extracted icon from data:image uri is invalid"),
};
} else {
match get_page_with_referer(&icon.href, &icon_result.referer) {
Ok(mut res) => {
res.copy_to(&mut buffer)?;
// Check if the icon type is allowed, else try an icon from the list.
match get_page_with_referer(&icon.href, &icon_result.referer).await {
Ok(res) => {
buffer = stream_to_bytes_limit(res, 512 * 1024).await?; // 512 KB for each icon max
// Check if the icon type is allowed, else try an icon from the list.
icon_type = get_icon_type(&buffer);
if icon_type.is_none() {
buffer.clear();
@@ -765,13 +765,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
Ok((buffer, icon_type))
}
fn save_icon(path: &str, icon: &[u8]) {
match File::create(path) {
async fn save_icon(path: &str, icon: &[u8]) {
match File::create(path).await {
Ok(mut f) => {
f.write_all(icon).expect("Error writing icon file");
f.write_all(icon).await.expect("Error writing icon file");
}
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder");
create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
}
Err(e) => {
warn!("Unable to save icon: {:?}", e);
@@ -820,8 +820,6 @@ impl reqwest::cookie::CookieStore for Jar {
}
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
use bytes::Bytes;
let cookie_store = self.0.read().unwrap();
let s = cookie_store
.get_request_values(url)
@@ -836,3 +834,12 @@ impl reqwest::cookie::CookieStore for Jar {
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
}
}
async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
let mut stream = res.bytes_stream().take(max_size);
let mut buf = BytesMut::new();
while let Some(chunk) = stream.next().await {
buf.extend(chunk?);
}
Ok(buf.freeze())
}

View File

@@ -1,10 +1,10 @@
use chrono::Utc;
use num_traits::FromPrimitive;
use rocket::serde::json::Json;
use rocket::{
request::{Form, FormItems, FromForm},
form::{Form, FromForm},
Route,
};
use rocket_contrib::json::Json;
use serde_json::Value;
use crate::{
@@ -455,66 +455,57 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
#[derive(Debug, Clone, Default)]
#[derive(Debug, Clone, Default, FromForm)]
#[allow(non_snake_case)]
struct ConnectData {
// refresh_token, password, client_credentials (API key)
grant_type: String,
#[field(name = uncased("grant_type"))]
#[field(name = uncased("granttype"))]
grant_type: String, // refresh_token, password, client_credentials (API key)
// Needed for grant_type="refresh_token"
#[field(name = uncased("refresh_token"))]
#[field(name = uncased("refreshtoken"))]
refresh_token: Option<String>,
// Needed for grant_type = "password" | "client_credentials"
client_id: Option<String>, // web, cli, desktop, browser, mobile
client_secret: Option<String>, // API key login (cli only)
#[field(name = uncased("client_id"))]
#[field(name = uncased("clientid"))]
client_id: Option<String>, // web, cli, desktop, browser, mobile
#[field(name = uncased("client_secret"))]
#[field(name = uncased("clientsecret"))]
client_secret: Option<String>,
#[field(name = uncased("password"))]
password: Option<String>,
#[field(name = uncased("scope"))]
scope: Option<String>,
#[field(name = uncased("username"))]
username: Option<String>,
#[field(name = uncased("device_identifier"))]
#[field(name = uncased("deviceidentifier"))]
device_identifier: Option<String>,
#[field(name = uncased("device_name"))]
#[field(name = uncased("devicename"))]
device_name: Option<String>,
#[field(name = uncased("device_type"))]
#[field(name = uncased("devicetype"))]
device_type: Option<String>,
#[field(name = uncased("device_push_token"))]
#[field(name = uncased("devicepushtoken"))]
device_push_token: Option<String>, // Unused; mobile device push not yet supported.
// Needed for two-factor auth
#[field(name = uncased("two_factor_provider"))]
#[field(name = uncased("twofactorprovider"))]
two_factor_provider: Option<i32>,
#[field(name = uncased("two_factor_token"))]
#[field(name = uncased("twofactortoken"))]
two_factor_token: Option<String>,
#[field(name = uncased("two_factor_remember"))]
#[field(name = uncased("twofactorremember"))]
two_factor_remember: Option<i32>,
}
impl<'f> FromForm<'f> for ConnectData {
type Error = String;
fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
let mut form = Self::default();
for item in items {
let (key, value) = item.key_value_decoded();
let mut normalized_key = key.to_lowercase();
normalized_key.retain(|c| c != '_'); // Remove '_'
match normalized_key.as_ref() {
"granttype" => form.grant_type = value,
"refreshtoken" => form.refresh_token = Some(value),
"clientid" => form.client_id = Some(value),
"clientsecret" => form.client_secret = Some(value),
"password" => form.password = Some(value),
"scope" => form.scope = Some(value),
"username" => form.username = Some(value),
"deviceidentifier" => form.device_identifier = Some(value),
"devicename" => form.device_name = Some(value),
"devicetype" => form.device_type = Some(value),
"devicepushtoken" => form.device_push_token = Some(value),
"twofactorprovider" => form.two_factor_provider = value.parse().ok(),
"twofactortoken" => form.two_factor_token = Some(value),
"twofactorremember" => form.two_factor_remember = value.parse().ok(),
key => warn!("Detected unexpected parameter during login: {}", key),
}
}
Ok(form)
}
}
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
if value.is_none() {
err!(msg)

View File

@@ -5,7 +5,7 @@ mod identity;
mod notifications;
mod web;
use rocket_contrib::json::Json;
use rocket::serde::json::Json;
use serde_json::Value;
pub use crate::api::{

View File

@@ -1,7 +1,7 @@
use std::sync::atomic::{AtomicBool, Ordering};
use rocket::serde::json::Json;
use rocket::Route;
use rocket_contrib::json::Json;
use serde_json::Value as JsonValue;
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
@@ -417,7 +417,7 @@ pub enum UpdateType {
}
use rocket::State;
pub type Notify<'a> = State<'a, WebSocketUsers>;
pub type Notify<'a> = &'a State<WebSocketUsers>;
pub fn start_notification_server() -> WebSocketUsers {
let factory = WsFactory::init();
@@ -430,12 +430,11 @@ pub fn start_notification_server() -> WebSocketUsers {
settings.queue_size = 2;
settings.panic_on_internal = false;
ws::Builder::new()
.with_settings(settings)
.build(factory)
.unwrap()
.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
.unwrap();
let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap();
CONFIG.set_ws_shutdown_handle(ws.broadcaster());
ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap();
warn!("WS Server stopped!");
});
}

View File

@@ -1,7 +1,7 @@
use std::path::{Path, PathBuf};
use rocket::{http::ContentType, response::content::Content, response::NamedFile, Route};
use rocket_contrib::json::Json;
use rocket::serde::json::Json;
use rocket::{fs::NamedFile, http::ContentType, Route};
use serde_json::Value;
use crate::{
@@ -21,16 +21,16 @@ pub fn routes() -> Vec<Route> {
}
#[get("/")]
fn web_index() -> Cached<Option<NamedFile>> {
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).ok(), false)
async fn web_index() -> Cached<Option<NamedFile>> {
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).await.ok(), false)
}
#[get("/app-id.json")]
fn app_id() -> Cached<Content<Json<Value>>> {
fn app_id() -> Cached<(ContentType, Json<Value>)> {
let content_type = ContentType::new("application", "fido.trusted-apps+json");
Cached::long(
Content(
(
content_type,
Json(json!({
"trustedFacets": [
@@ -58,13 +58,13 @@ fn app_id() -> Cached<Content<Json<Value>>> {
}
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).ok(), true)
async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
}
#[get("/attachments/<uuid>/<file_id>")]
fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).ok()
async fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).await.ok()
}
// We use DbConn here to let the alive healthcheck also verify the database connection.
@@ -78,25 +78,20 @@ fn alive(_conn: DbConn) -> Json<String> {
}
#[get("/vw_static/<filename>")]
fn static_files(filename: String) -> Result<Content<&'static [u8]>, Error> {
fn static_files(filename: String) -> Result<(ContentType, &'static [u8]), Error> {
match filename.as_ref() {
"mail-github.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
"logo-gray.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
"error-x.svg" => Ok(Content(ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
"hibp.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
"vaultwarden-icon.png" => {
Ok(Content(ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png")))
}
"bootstrap.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
"bootstrap-native.js" => {
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js")))
}
"identicon.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
"datatables.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
"datatables.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
"mail-github.png" => Ok((ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
"logo-gray.png" => Ok((ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
"error-x.svg" => Ok((ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
"hibp.png" => Ok((ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
"vaultwarden-icon.png" => Ok((ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))),
"bootstrap.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
"bootstrap-native.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js"))),
"identicon.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
"datatables.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
"datatables.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
"jquery-3.6.0.slim.js" => {
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
Ok((ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
}
_ => err!(format!("Static file not found: {}", filename)),
}