Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking

This commit is contained in:
Daniel García
2021-11-07 18:53:39 +01:00
parent 08f0de7b46
commit 2d5f172e77
30 changed files with 1314 additions and 1028 deletions

View File

@@ -1,19 +1,19 @@
use std::{
collections::HashMap,
fs::{create_dir_all, remove_file, symlink_metadata, File},
io::prelude::*,
net::{IpAddr, ToSocketAddrs},
sync::{Arc, RwLock},
time::{Duration, SystemTime},
};
use bytes::{Buf, Bytes, BytesMut};
use futures::{stream::StreamExt, TryFutureExt};
use once_cell::sync::Lazy;
use regex::Regex;
use reqwest::{blocking::Client, blocking::Response, header};
use rocket::{
http::ContentType,
response::{Content, Redirect},
Route,
use reqwest::{header, Client, Response};
use rocket::{http::ContentType, response::Redirect, Route};
use tokio::{
fs::{create_dir_all, remove_file, symlink_metadata, File},
io::{AsyncReadExt, AsyncWriteExt},
};
use crate::{
@@ -104,27 +104,23 @@ fn icon_google(domain: String) -> Option<Redirect> {
}
#[get("/<domain>/icon.png")]
fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
if !is_valid_domain(&domain) {
warn!("Invalid domain: {}", domain);
return Cached::ttl(
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
);
}
match get_icon(&domain) {
match get_icon(&domain).await {
Some((icon, icon_type)) => {
Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
}
_ => Cached::ttl(
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
),
_ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
}
}
@@ -317,15 +313,15 @@ fn is_domain_blacklisted(domain: &str) -> bool {
is_blacklisted
}
fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
// Check for expiration of negatively cached copy
if icon_is_negcached(&path) {
if icon_is_negcached(&path).await {
return None;
}
if let Some(icon) = get_cached_icon(&path) {
if let Some(icon) = get_cached_icon(&path).await {
let icon_type = match get_icon_type(&icon) {
Some(x) => x,
_ => "x-icon",
@@ -338,31 +334,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
}
// Get the icon, or None in case of error
match download_icon(domain) {
match download_icon(domain).await {
Ok((icon, icon_type)) => {
save_icon(&path, &icon);
Some((icon, icon_type.unwrap_or("x-icon").to_string()))
save_icon(&path, &icon).await;
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
}
Err(e) => {
warn!("Unable to download icon: {:?}", e);
let miss_indicator = path + ".miss";
save_icon(&miss_indicator, &[]);
save_icon(&miss_indicator, &[]).await;
None
}
}
}
fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
// Check for expiration of successfully cached copy
if icon_is_expired(path) {
if icon_is_expired(path).await {
return None;
}
// Try to read the cached icon, and return it if it exists
if let Ok(mut f) = File::open(path) {
if let Ok(mut f) = File::open(path).await {
let mut buffer = Vec::new();
if f.read_to_end(&mut buffer).is_ok() {
if f.read_to_end(&mut buffer).await.is_ok() {
return Some(buffer);
}
}
@@ -370,22 +366,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
None
}
fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
let meta = symlink_metadata(path)?;
async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
let meta = symlink_metadata(path).await?;
let modified = meta.modified()?;
let age = SystemTime::now().duration_since(modified)?;
Ok(ttl > 0 && ttl <= age.as_secs())
}
fn icon_is_negcached(path: &str) -> bool {
async fn icon_is_negcached(path: &str) -> bool {
let miss_indicator = path.to_owned() + ".miss";
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl());
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
match expired {
// No longer negatively cached, drop the marker
Ok(true) => {
if let Err(e) = remove_file(&miss_indicator) {
if let Err(e) = remove_file(&miss_indicator).await {
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
}
false
@@ -397,8 +393,8 @@ fn icon_is_negcached(path: &str) -> bool {
}
}
fn icon_is_expired(path: &str) -> bool {
let expired = file_is_expired(path, CONFIG.icon_cache_ttl());
async fn icon_is_expired(path: &str) -> bool {
let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
expired.unwrap_or(true)
}
@@ -521,13 +517,13 @@ struct IconUrlResult {
/// let icon_result = get_icon_url("github.com")?;
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
/// ```
fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// Default URL with secure and insecure schemes
let ssldomain = format!("https://{}", domain);
let httpdomain = format!("http://{}", domain);
// First check the domain as given during the request for both HTTPS and HTTP.
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) {
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
Ok(c) => Ok(c),
Err(e) => {
let mut sub_resp = Err(e);
@@ -546,7 +542,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
let httpbase = format!("http://{}", base_domain);
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase));
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
}
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
@@ -557,7 +553,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
let httpwww = format!("http://{}", www_domain);
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww));
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
}
}
@@ -581,7 +577,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
let mut limited_reader = content.take(384 * 1024);
let mut limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.reader();
use html5ever::tendril::TendrilSink;
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
@@ -607,11 +603,11 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
})
}
fn get_page(url: &str) -> Result<Response, Error> {
get_page_with_referer(url, "")
async fn get_page(url: &str) -> Result<Response, Error> {
get_page_with_referer(url, "").await
}
fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
}
@@ -621,7 +617,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
client = client.header("Referer", referer)
}
match client.send() {
match client.send().await {
Ok(c) => c.error_for_status().map_err(Into::into),
Err(e) => err_silent!(format!("{}", e)),
}
@@ -706,14 +702,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
(width, height)
}
fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
if is_domain_blacklisted(domain) {
err_silent!("Domain is blacklisted", domain)
}
let icon_result = get_icon_url(domain)?;
let icon_result = get_icon_url(domain).await?;
let mut buffer = Vec::new();
let mut buffer = Bytes::new();
let mut icon_type: Option<&str> = None;
use data_url::DataUrl;
@@ -722,8 +718,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
if icon.href.starts_with("data:image") {
let datauri = DataUrl::process(&icon.href).unwrap();
// Check if we are able to decode the data uri
match datauri.decode_to_vec() {
Ok((body, _fragment)) => {
let mut body = BytesMut::new();
match datauri.decode::<_, ()>(|bytes| {
body.extend_from_slice(bytes);
Ok(())
}) {
Ok(_) => {
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
if body.len() >= 67 {
// Check if the icon type is allowed, else try an icon from the list.
@@ -733,17 +733,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
continue;
}
info!("Extracted icon from data:image uri for {}", domain);
buffer = body;
buffer = body.freeze();
break;
}
}
_ => debug!("Extracted icon from data:image uri is invalid"),
};
} else {
match get_page_with_referer(&icon.href, &icon_result.referer) {
Ok(mut res) => {
res.copy_to(&mut buffer)?;
// Check if the icon type is allowed, else try an icon from the list.
match get_page_with_referer(&icon.href, &icon_result.referer).await {
Ok(res) => {
buffer = stream_to_bytes_limit(res, 512 * 1024).await?; // 512 KB for each icon max
// Check if the icon type is allowed, else try an icon from the list.
icon_type = get_icon_type(&buffer);
if icon_type.is_none() {
buffer.clear();
@@ -765,13 +765,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
Ok((buffer, icon_type))
}
fn save_icon(path: &str, icon: &[u8]) {
match File::create(path) {
async fn save_icon(path: &str, icon: &[u8]) {
match File::create(path).await {
Ok(mut f) => {
f.write_all(icon).expect("Error writing icon file");
f.write_all(icon).await.expect("Error writing icon file");
}
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder");
create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
}
Err(e) => {
warn!("Unable to save icon: {:?}", e);
@@ -820,8 +820,6 @@ impl reqwest::cookie::CookieStore for Jar {
}
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
use bytes::Bytes;
let cookie_store = self.0.read().unwrap();
let s = cookie_store
.get_request_values(url)
@@ -836,3 +834,12 @@ impl reqwest::cookie::CookieStore for Jar {
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
}
}
async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
let mut stream = res.bytes_stream().take(max_size);
let mut buf = BytesMut::new();
while let Some(chunk) = stream.next().await {
buf.extend(chunk?);
}
Ok(buf.freeze())
}