mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2025-09-09 18:25:58 +03:00
Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking
This commit is contained in:
125
src/api/icons.rs
125
src/api/icons.rs
@@ -1,19 +1,19 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
||||
io::prelude::*,
|
||||
net::{IpAddr, ToSocketAddrs},
|
||||
sync::{Arc, RwLock},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use futures::{stream::StreamExt, TryFutureExt};
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use reqwest::{blocking::Client, blocking::Response, header};
|
||||
use rocket::{
|
||||
http::ContentType,
|
||||
response::{Content, Redirect},
|
||||
Route,
|
||||
use reqwest::{header, Client, Response};
|
||||
use rocket::{http::ContentType, response::Redirect, Route};
|
||||
use tokio::{
|
||||
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
@@ -104,27 +104,23 @@ fn icon_google(domain: String) -> Option<Redirect> {
|
||||
}
|
||||
|
||||
#[get("/<domain>/icon.png")]
|
||||
fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
|
||||
async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
|
||||
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
||||
|
||||
if !is_valid_domain(&domain) {
|
||||
warn!("Invalid domain: {}", domain);
|
||||
return Cached::ttl(
|
||||
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
||||
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
||||
CONFIG.icon_cache_negttl(),
|
||||
true,
|
||||
);
|
||||
}
|
||||
|
||||
match get_icon(&domain) {
|
||||
match get_icon(&domain).await {
|
||||
Some((icon, icon_type)) => {
|
||||
Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
|
||||
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
|
||||
}
|
||||
_ => Cached::ttl(
|
||||
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
||||
CONFIG.icon_cache_negttl(),
|
||||
true,
|
||||
),
|
||||
_ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,15 +313,15 @@ fn is_domain_blacklisted(domain: &str) -> bool {
|
||||
is_blacklisted
|
||||
}
|
||||
|
||||
fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
|
||||
|
||||
// Check for expiration of negatively cached copy
|
||||
if icon_is_negcached(&path) {
|
||||
if icon_is_negcached(&path).await {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(icon) = get_cached_icon(&path) {
|
||||
if let Some(icon) = get_cached_icon(&path).await {
|
||||
let icon_type = match get_icon_type(&icon) {
|
||||
Some(x) => x,
|
||||
_ => "x-icon",
|
||||
@@ -338,31 +334,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||
}
|
||||
|
||||
// Get the icon, or None in case of error
|
||||
match download_icon(domain) {
|
||||
match download_icon(domain).await {
|
||||
Ok((icon, icon_type)) => {
|
||||
save_icon(&path, &icon);
|
||||
Some((icon, icon_type.unwrap_or("x-icon").to_string()))
|
||||
save_icon(&path, &icon).await;
|
||||
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Unable to download icon: {:?}", e);
|
||||
let miss_indicator = path + ".miss";
|
||||
save_icon(&miss_indicator, &[]);
|
||||
save_icon(&miss_indicator, &[]).await;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
||||
async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
||||
// Check for expiration of successfully cached copy
|
||||
if icon_is_expired(path) {
|
||||
if icon_is_expired(path).await {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Try to read the cached icon, and return it if it exists
|
||||
if let Ok(mut f) = File::open(path) {
|
||||
if let Ok(mut f) = File::open(path).await {
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
if f.read_to_end(&mut buffer).is_ok() {
|
||||
if f.read_to_end(&mut buffer).await.is_ok() {
|
||||
return Some(buffer);
|
||||
}
|
||||
}
|
||||
@@ -370,22 +366,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
|
||||
let meta = symlink_metadata(path)?;
|
||||
async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
|
||||
let meta = symlink_metadata(path).await?;
|
||||
let modified = meta.modified()?;
|
||||
let age = SystemTime::now().duration_since(modified)?;
|
||||
|
||||
Ok(ttl > 0 && ttl <= age.as_secs())
|
||||
}
|
||||
|
||||
fn icon_is_negcached(path: &str) -> bool {
|
||||
async fn icon_is_negcached(path: &str) -> bool {
|
||||
let miss_indicator = path.to_owned() + ".miss";
|
||||
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl());
|
||||
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
|
||||
|
||||
match expired {
|
||||
// No longer negatively cached, drop the marker
|
||||
Ok(true) => {
|
||||
if let Err(e) = remove_file(&miss_indicator) {
|
||||
if let Err(e) = remove_file(&miss_indicator).await {
|
||||
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
|
||||
}
|
||||
false
|
||||
@@ -397,8 +393,8 @@ fn icon_is_negcached(path: &str) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
fn icon_is_expired(path: &str) -> bool {
|
||||
let expired = file_is_expired(path, CONFIG.icon_cache_ttl());
|
||||
async fn icon_is_expired(path: &str) -> bool {
|
||||
let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
|
||||
expired.unwrap_or(true)
|
||||
}
|
||||
|
||||
@@ -521,13 +517,13 @@ struct IconUrlResult {
|
||||
/// let icon_result = get_icon_url("github.com")?;
|
||||
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
|
||||
/// ```
|
||||
fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||
async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||
// Default URL with secure and insecure schemes
|
||||
let ssldomain = format!("https://{}", domain);
|
||||
let httpdomain = format!("http://{}", domain);
|
||||
|
||||
// First check the domain as given during the request for both HTTPS and HTTP.
|
||||
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) {
|
||||
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
|
||||
Ok(c) => Ok(c),
|
||||
Err(e) => {
|
||||
let mut sub_resp = Err(e);
|
||||
@@ -546,7 +542,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||
let httpbase = format!("http://{}", base_domain);
|
||||
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
|
||||
|
||||
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase));
|
||||
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
|
||||
}
|
||||
|
||||
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
|
||||
@@ -557,7 +553,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||
let httpwww = format!("http://{}", www_domain);
|
||||
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
|
||||
|
||||
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww));
|
||||
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,7 +577,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
|
||||
|
||||
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
|
||||
let mut limited_reader = content.take(384 * 1024);
|
||||
let mut limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.reader();
|
||||
|
||||
use html5ever::tendril::TendrilSink;
|
||||
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
|
||||
@@ -607,11 +603,11 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||
})
|
||||
}
|
||||
|
||||
fn get_page(url: &str) -> Result<Response, Error> {
|
||||
get_page_with_referer(url, "")
|
||||
async fn get_page(url: &str) -> Result<Response, Error> {
|
||||
get_page_with_referer(url, "").await
|
||||
}
|
||||
|
||||
fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
|
||||
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
|
||||
}
|
||||
@@ -621,7 +617,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||
client = client.header("Referer", referer)
|
||||
}
|
||||
|
||||
match client.send() {
|
||||
match client.send().await {
|
||||
Ok(c) => c.error_for_status().map_err(Into::into),
|
||||
Err(e) => err_silent!(format!("{}", e)),
|
||||
}
|
||||
@@ -706,14 +702,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
|
||||
(width, height)
|
||||
}
|
||||
|
||||
fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
||||
if is_domain_blacklisted(domain) {
|
||||
err_silent!("Domain is blacklisted", domain)
|
||||
}
|
||||
|
||||
let icon_result = get_icon_url(domain)?;
|
||||
let icon_result = get_icon_url(domain).await?;
|
||||
|
||||
let mut buffer = Vec::new();
|
||||
let mut buffer = Bytes::new();
|
||||
let mut icon_type: Option<&str> = None;
|
||||
|
||||
use data_url::DataUrl;
|
||||
@@ -722,8 +718,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||
if icon.href.starts_with("data:image") {
|
||||
let datauri = DataUrl::process(&icon.href).unwrap();
|
||||
// Check if we are able to decode the data uri
|
||||
match datauri.decode_to_vec() {
|
||||
Ok((body, _fragment)) => {
|
||||
let mut body = BytesMut::new();
|
||||
match datauri.decode::<_, ()>(|bytes| {
|
||||
body.extend_from_slice(bytes);
|
||||
Ok(())
|
||||
}) {
|
||||
Ok(_) => {
|
||||
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
|
||||
if body.len() >= 67 {
|
||||
// Check if the icon type is allowed, else try an icon from the list.
|
||||
@@ -733,17 +733,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||
continue;
|
||||
}
|
||||
info!("Extracted icon from data:image uri for {}", domain);
|
||||
buffer = body;
|
||||
buffer = body.freeze();
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => debug!("Extracted icon from data:image uri is invalid"),
|
||||
};
|
||||
} else {
|
||||
match get_page_with_referer(&icon.href, &icon_result.referer) {
|
||||
Ok(mut res) => {
|
||||
res.copy_to(&mut buffer)?;
|
||||
// Check if the icon type is allowed, else try an icon from the list.
|
||||
match get_page_with_referer(&icon.href, &icon_result.referer).await {
|
||||
Ok(res) => {
|
||||
buffer = stream_to_bytes_limit(res, 512 * 1024).await?; // 512 KB for each icon max
|
||||
// Check if the icon type is allowed, else try an icon from the list.
|
||||
icon_type = get_icon_type(&buffer);
|
||||
if icon_type.is_none() {
|
||||
buffer.clear();
|
||||
@@ -765,13 +765,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||
Ok((buffer, icon_type))
|
||||
}
|
||||
|
||||
fn save_icon(path: &str, icon: &[u8]) {
|
||||
match File::create(path) {
|
||||
async fn save_icon(path: &str, icon: &[u8]) {
|
||||
match File::create(path).await {
|
||||
Ok(mut f) => {
|
||||
f.write_all(icon).expect("Error writing icon file");
|
||||
f.write_all(icon).await.expect("Error writing icon file");
|
||||
}
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||
create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder");
|
||||
create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Unable to save icon: {:?}", e);
|
||||
@@ -820,8 +820,6 @@ impl reqwest::cookie::CookieStore for Jar {
|
||||
}
|
||||
|
||||
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
|
||||
use bytes::Bytes;
|
||||
|
||||
let cookie_store = self.0.read().unwrap();
|
||||
let s = cookie_store
|
||||
.get_request_values(url)
|
||||
@@ -836,3 +834,12 @@ impl reqwest::cookie::CookieStore for Jar {
|
||||
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
|
||||
let mut stream = res.bytes_stream().take(max_size);
|
||||
let mut buf = BytesMut::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
buf.extend(chunk?);
|
||||
}
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
|
Reference in New Issue
Block a user