Compare commits

..

1 Commits

Author SHA1 Message Date
Mathijs van Veluw 0a77776a8b Revert "Disable deployments for release env (#7033)"
This reverts commit 8f0e99b875.
2026-04-12 16:57:48 +02:00
107 changed files with 4142 additions and 5481 deletions
+4 -5
View File
@@ -50,11 +50,10 @@
#########################
## Database URL
## When using SQLite, this should use the sqlite:// scheme followed by the path
## to the DB file. It defaults to sqlite://%DATA_FOLDER%/db.sqlite3.
## Bare paths without the sqlite:// scheme are supported for backwards compatibility,
## but only if the database file already exists.
# DATABASE_URL=sqlite://data/db.sqlite3
## When using SQLite, this is the path to the DB file, and it defaults to
## %DATA_FOLDER%/db.sqlite3. If DATA_FOLDER is set to an external location, this
## must be set to a local sqlite3 file path.
# DATABASE_URL=data/db.sqlite3
## When using MySQL, specify an appropriate connection URI.
## Details: https://docs.diesel.rs/2.1.x/diesel/mysql/struct.MysqlConnection.html
# DATABASE_URL=mysql://user:password@host[:port]/database_name
+1
View File
@@ -1,2 +1,3 @@
# Ignore vendored scripts in GitHub stats
src/static/scripts/* linguist-vendored
+2 -6
View File
@@ -38,9 +38,7 @@ jobs:
docker-build:
name: Build Vaultwarden containers
if: ${{ github.repository == 'dani-garcia/vaultwarden' }}
environment:
name: release
deployment: false
environment: release
permissions:
packages: write # Needed to upload packages and artifacts
contents: read
@@ -251,9 +249,7 @@ jobs:
name: Merge manifests
runs-on: ubuntu-latest
needs: docker-build
environment:
name: release
deployment: false
environment: release
permissions:
packages: write # Needed to upload packages and artifacts
attestations: write # Needed to generate an artifact attestation for a build
+2 -2
View File
@@ -38,7 +38,7 @@ jobs:
persist-credentials: false
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@ed142fd0673e97e23eac54620cfb913e5ce36c25 # v0.36.0
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # v0.35.0
env:
TRIVY_DB_REPOSITORY: docker.io/aquasec/trivy-db:2,public.ecr.aws/aquasecurity/trivy-db:2,ghcr.io/aquasecurity/trivy-db:2
TRIVY_JAVA_DB_REPOSITORY: docker.io/aquasec/trivy-java-db:1,public.ecr.aws/aquasecurity/trivy-java-db:1,ghcr.io/aquasecurity/trivy-java-db:1
@@ -50,6 +50,6 @@ jobs:
severity: CRITICAL,HIGH
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4.35.5
uses: github/codeql-action/upload-sarif@c10b8064de6f491fea524254123dbe5e09572f13 # v4.35.1
with:
sarif_file: 'trivy-results.sarif'
+1 -1
View File
@@ -23,4 +23,4 @@ jobs:
# When this version is updated, do not forget to update this in `.pre-commit-config.yaml` too
- name: Spell Check Repo
uses: crate-ci/typos@5374cbf686e897b15713110e233094e2874de7ef # v1.46.1
uses: crate-ci/typos@02ea592e44b3a53c302f697cddca7641cd051c3d # v1.45.0
+1 -1
View File
@@ -24,7 +24,7 @@ jobs:
persist-credentials: false
- name: Run zizmor
uses: zizmorcore/zizmor-action@b572f7b1a1c2d41efaab43d504f68d215c3cd727 # v0.5.4
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
with:
# intentionally not scanning the entire repository,
# since it contains integration tests.
+53 -55
View File
@@ -1,60 +1,58 @@
---
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # v6.0.0
hooks:
- id: check-yaml
- id: check-json
- id: check-toml
- id: mixed-line-ending
args: [ "--fix=no" ]
- id: end-of-file-fixer
exclude: "(.*js$|.*css$)"
- id: check-case-conflict
- id: check-merge-conflict
- id: detect-private-key
- id: check-symlinks
- id: forbid-submodules
# When this version is updated, do not forget to update this in `.github/workflows/typos.yaml` too
- repo: https://github.com/crate-ci/typos
rev: 5374cbf686e897b15713110e233094e2874de7ef # v1.46.1
- id: check-yaml
- id: check-json
- id: check-toml
- id: mixed-line-ending
args: ["--fix=no"]
- id: end-of-file-fixer
exclude: "(.*js$|.*css$)"
- id: check-case-conflict
- id: check-merge-conflict
- id: detect-private-key
- id: check-symlinks
- id: forbid-submodules
- repo: local
hooks:
- id: typos
- repo: local
hooks:
- id: fmt
name: fmt
description: Format files with cargo fmt.
entry: cargo fmt
language: system
always_run: true
pass_filenames: false
args: [ "--", "--check" ]
- id: cargo-test
name: cargo test
description: Test the package for errors.
entry: cargo test
language: system
args: [ "--features", "sqlite,mysql,postgresql", "--" ]
types_or: [ rust, file ]
files: (Cargo.toml|Cargo.lock|rust-toolchain.toml|rustfmt.toml|.*\.rs$)
pass_filenames: false
- id: cargo-clippy
name: cargo clippy
description: Lint Rust sources
entry: cargo clippy
language: system
args: [ "--features", "sqlite,mysql,postgresql", "--", "-D", "warnings" ]
types_or: [ rust, file ]
files: (Cargo.toml|Cargo.lock|rust-toolchain.toml|rustfmt.toml|.*\.rs$)
pass_filenames: false
- id: check-docker-templates
name: check docker templates
description: Check if the Docker templates are updated
language: system
entry: sh
args:
- "-c"
- "cd docker && make"
- id: fmt
name: fmt
description: Format files with cargo fmt.
entry: cargo fmt
language: system
always_run: true
pass_filenames: false
args: ["--", "--check"]
- id: cargo-test
name: cargo test
description: Test the package for errors.
entry: cargo test
language: system
args: ["--features", "sqlite,mysql,postgresql", "--"]
types_or: [rust, file]
files: (Cargo.toml|Cargo.lock|rust-toolchain.toml|rustfmt.toml|.*\.rs$)
pass_filenames: false
- id: cargo-clippy
name: cargo clippy
description: Lint Rust sources
entry: cargo clippy
language: system
args: ["--features", "sqlite,mysql,postgresql", "--", "-D", "warnings"]
types_or: [rust, file]
files: (Cargo.toml|Cargo.lock|rust-toolchain.toml|rustfmt.toml|.*\.rs$)
pass_filenames: false
- id: check-docker-templates
name: check docker templates
description: Check if the Docker templates are updated
language: system
entry: sh
args:
- "-c"
- "cd docker && make"
# When this version is updated, do not forget to update this in `.github/workflows/typos.yaml` too
- repo: https://github.com/crate-ci/typos
rev: 02ea592e44b3a53c302f697cddca7641cd051c3d # v1.45.0
hooks:
- id: typos
-2
View File
@@ -23,6 +23,4 @@ extend-ignore-re = [
# https://github.com/bitwarden/server/blob/dff9f1cf538198819911cf2c20f8cda3307701c5/src/Notifications/HubHelpers.cs#L86
# https://github.com/bitwarden/clients/blob/9612a4ac45063e372a6fbe87eb253c7cb3c588fb/libs/common/src/auth/services/anonymous-hub.service.ts#L45
"AuthRequestResponseRecieved",
# Ignore Punycode/IDN tests
"xn--.+"
]
Generated
+540 -587
View File
File diff suppressed because it is too large Load Diff
+69 -133
View File
@@ -1,6 +1,6 @@
[workspace.package]
edition = "2024"
rust-version = "1.93.0"
edition = "2021"
rust-version = "1.92.0"
license = "AGPL-3.0-only"
repository = "https://github.com/dani-garcia/vaultwarden"
publish = false
@@ -24,31 +24,20 @@ publish.workspace = true
[features]
default = [
# "sqlite",
# "sqlite_system",
# "mysql",
# "postgresql",
]
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
enable_syslog = []
# Please enable at least one of these DB backends.
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
postgresql = ["diesel/postgres", "diesel_migrations/postgres"]
sqlite_system = ["diesel/sqlite", "diesel_migrations/sqlite"] # Dynamically link SQLite
sqlite = ["sqlite_system", "libsqlite3-sys/bundled"] # Statically link SQLite into the binary instead of dynamically.
sqlite = ["diesel/sqlite", "diesel_migrations/sqlite", "dep:libsqlite3-sys"]
# Enable to use a vendored and statically linked openssl
vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds
enable_mimalloc = ["dep:mimalloc"]
s3 = [
"opendal/services-s3",
"dep:aws-config",
"dep:aws-credential-types",
"dep:aws-smithy-runtime-api",
"dep:http",
"dep:reqsign-aws-v4",
"dep:reqsign-core",
]
s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"]
# OIDC specific features
oidc-accept-rfc3339-timestamps = ["openidconnect/accept-rfc3339-timestamps"]
@@ -68,8 +57,7 @@ macros = { path = "./macros" }
# Logging
log = "0.4.29"
fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] }
# We need the `log` feature for `tracing` to enable logging for several crates to work, like lettre or webauthn-rs
tracing = { version = "0.1.44", features = ["log"] }
tracing = { version = "0.1.44", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work
# A `dotenv` implementation for Rust
dotenvy = { version = "0.15.7", default-features = false }
@@ -80,8 +68,8 @@ num-derive = "0.4.2"
bigdecimal = "0.4.10"
# Web framework
rocket = { version = "0.5.1", default-features = false, features = ["json", "tls"] }
rocket_ws = { version = "0.1.1" }
rocket = { version = "0.5.1", features = ["tls", "json"], default-features = false }
rocket_ws = { version ="0.1.1" }
# WebSockets libraries
rmpv = "1.3.1" # MessagePack library
@@ -91,48 +79,34 @@ dashmap = "6.1.0"
# Async futures
futures = "0.3.32"
tokio = { version = "1.52.3", features = [
"fs",
"io-util",
"net",
"parking_lot",
"rt-multi-thread",
"signal",
"time",
] }
tokio-util = { version = "0.7.18", features = ["compat"] }
tokio = { version = "1.51.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] }
tokio-util = { version = "0.7.18", features = ["compat"]}
# A generic serialization/deserialization framework
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
# A safe, extensible ORM and Query builder
diesel = { version = "2.3.9", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.3.2"
# Currently pinned diesel to v2.3.3 as newer version break MySQL/MariaDB compatibility
diesel = { version = "2.3.7", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.3.1"
derive_more = { version = "2.1.1", features = [
"as_ref",
"deref",
"display",
"from",
"into",
] }
derive_more = { version = "2.1.1", features = ["from", "into", "as_ref", "deref", "display"] }
diesel-derive-newtype = "2.1.2"
# SQLite, statically bundled unless the `sqlite_system` feature is enabled
libsqlite3-sys = { version = "0.37.0", optional = true }
# Bundled/Static SQLite
libsqlite3-sys = { version = "0.36.0", features = ["bundled"], optional = true }
# Crypto-related libraries
rand = "0.10.1"
ring = "0.17.14"
rustls = { version = "0.23.40", features = ["ring", "std"], default-features = false }
subtle = "2.6.1"
# UUID generation
uuid = { version = "1.23.1", features = ["v4"] }
uuid = { version = "1.23.0", features = ["v4"] }
# Date and time libraries
chrono = { version = "0.4.44", default-features = false, features = ["clock", "serde"] }
chrono = { version = "0.4.44", features = ["clock", "serde"], default-features = false }
chrono-tz = "0.10.4"
time = "0.3.47"
@@ -140,42 +114,29 @@ time = "0.3.47"
job_scheduler_ng = "2.4.0"
# Data encoding library Hex/Base32/Base64
data-encoding = "2.11.0"
data-encoding = "2.10.0"
# JWT library
jsonwebtoken = { version = "10.4.0", default-features = false, features = ["rust_crypto", "use_pem"] }
jsonwebtoken = { version = "10.3.0", features = ["use_pem", "rust_crypto"], default-features = false }
# TOTP library
totp-lite = "2.0.1"
# Yubico Library
yubico = { package = "yubico_ng", version = "0.15.0", default-features = false, features = ["online-tokio"] }
yubico = { package = "yubico_ng", version = "0.14.1", features = ["online-tokio"], default-features = false }
# WebAuthn libraries
# danger-allow-state-serialisation is needed to save the state in the db
# danger-credential-internals is needed to support U2F to Webauthn migration
webauthn-rs = { version = "0.5.5", features = ["danger-allow-state-serialisation", "danger-credential-internals"] }
webauthn-rs-proto = "0.5.5"
webauthn-rs-core = "0.5.5"
webauthn-rs = { version = "0.5.4", features = ["danger-allow-state-serialisation", "danger-credential-internals"] }
webauthn-rs-proto = "0.5.4"
webauthn-rs-core = "0.5.4"
# Handling of URL's for WebAuthn and favicons
url = "2.5.8"
# Email libraries
lettre = { version = "0.11.22", default-features = false, features = [
# Misc
"tracing",
"serde",
"builder",
"hostname",
# TLS/Security
"ring",
"rustls-native-certs",
"tokio1-rustls",
# Transport
"smtp-transport",
"sendmail-transport",
] }
lettre = { version = "0.11.21", features = ["smtp-transport", "sendmail-transport", "builder", "serde", "hostname", "tracing", "tokio1-rustls", "ring", "rustls-native-certs"], default-features = false }
percent-encoding = "2.3.2" # URL encoding library used for URL's in the emails
email_address = "0.2.9"
@@ -183,33 +144,12 @@ email_address = "0.2.9"
handlebars = { version = "6.4.0", features = ["dir_source"] }
# HTTP client (Used for favicons, version check, DUO and HIBP API)
reqwest = { version = "0.13.3", default-features = false, features = [
# Misc
"charset",
"cookies",
"http2",
"json",
"form",
"rustls-no-provider",
"stream",
# Compression
"brotli",
"deflate",
"gzip",
"zstd",
# Proxy
"socks",
"system-proxy",
] }
hickory-resolver = "0.26.1"
reqwest = { version = "0.12.28", features = ["rustls-tls", "rustls-tls-native-roots", "stream", "json", "deflate", "gzip", "brotli", "zstd", "socks", "cookies", "charset", "http2", "system-proxy"], default-features = false}
hickory-resolver = "0.25.2"
# Favicon extraction libraries
html5gum = "0.8.3"
regex = { version = "1.12.3", default-features = false, features = [
"perf",
"std",
"unicode-perl",
] }
regex = { version = "1.12.3", features = ["std", "perf", "unicode-perl"], default-features = false }
data-url = "0.3.2"
bytes = "1.11.1"
svg-hush = "0.9.6"
@@ -222,17 +162,17 @@ cookie = "0.18.1"
cookie_store = "0.22.1"
# Used by U2F, JWT and PostgreSQL
openssl = "0.10.79"
openssl = "0.10.76"
# CLI argument parsing
pico-args = "0.5.0"
# Macro ident concatenation
pastey = "0.2.2"
pastey = "0.2.1"
governor = "0.10.4"
# OIDC for SSO
openidconnect = { version = "4.0.1", default-features = false }
openidconnect = { version = "4.0.1", features = ["reqwest", "rustls-tls"] }
moka = { version = "0.12.15", features = ["future"] }
# Check client versions for specific features.
@@ -240,7 +180,7 @@ semver = "1.0.28"
# Allow overriding the default memory allocator
# Mainly used for the musl builds, since the default musl malloc is very slow
mimalloc = { version = "0.1.50", optional = true, default-features = false, features = ["secure"] }
mimalloc = { version = "0.1.48", features = ["secure"], default-features = false, optional = true }
which = "8.0.2"
@@ -248,26 +188,21 @@ which = "8.0.2"
argon2 = "0.5.3"
# Reading a password from the cli for generating the Argon2id ADMIN_TOKEN
rpassword = "7.5.2"
rpassword = "7.4.0"
# Loading a dynamic CSS Stylesheet
grass_compiler = { version = "0.13.4", default-features = false }
# File are accessed through Apache OpenDAL
opendal = { version = "0.56.0", default-features = false, features = ["services-fs"] }
opendal = { version = "0.55.0", features = ["services-fs"], default-features = false }
# For retrieving AWS credentials, including temporary SSO credentials
aws-config = { version = "1.8.16", optional = true, default-features = false, features = [
"behavior-version-latest",
"credentials-process",
"rt-tokio",
"sso",
] }
anyhow = { version = "1.0.102", optional = true }
aws-config = { version = "1.8.15", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true }
aws-credential-types = { version = "1.2.14", optional = true }
aws-smithy-runtime-api = { version = "1.12.0", optional = true }
aws-smithy-runtime-api = { version = "1.11.6", optional = true }
http = { version = "1.4.0", optional = true }
reqsign-aws-v4 = { version = "3.0.0", optional = true }
reqsign-core = { version = "3.0.0", optional = true }
reqsign = { version = "0.16.5", optional = true }
# Strip debuginfo from the release builds
# The debug symbols are to provide better panic traces
@@ -327,74 +262,75 @@ unsafe_code = "forbid"
non_ascii_idents = "forbid"
# Deny
warnings = "deny" # Explicitly deny all warnings since we deny all warnings in the end
# Deny lint groups
deprecated_in_future = "deny"
deprecated_safe = { level = "deny", priority = -1 }
future_incompatible = { level = "deny", priority = -1 }
keyword_idents = { level = "deny", priority = -1 }
let_underscore = { level = "deny", priority = -1 }
nonstandard_style = { level = "deny", priority = -1 }
noop_method_call = "deny"
refining_impl_trait = { level = "deny", priority = -1 }
rust_2018_idioms = { level = "deny", priority = -1 }
rust_2021_compatibility = { level = "deny", priority = -1 }
rust_2024_compatibility = { level = "deny", priority = -1 }
unused = { level = "deny", priority = -1 }
# Deny individual lints
closure_returning_async_block = "deny"
deprecated_in_future = "deny"
single_use_lifetimes = "deny"
trivial_casts = "deny"
trivial_numeric_casts = "deny"
unused = { level = "deny", priority = -1 }
unused_import_braces = "deny"
unused_lifetimes = "deny"
unused_qualifications = "deny"
variant_size_differences = "deny"
# Allow the following lints since these cause issues with Rust v1.84.0 or newer
# Building Vaultwarden with Rust v1.85.0 with edition 2024 also works without issues
edition_2024_expr_fragment_specifier = "allow" # Once changed to Rust 2024 this should be removed and macro's should be validated again
if_let_rescope = "allow"
tail_expr_drop_order = "allow"
# https://rust-lang.github.io/rust-clippy/stable/index.html
[workspace.lints.clippy]
# Warn only so you can still use these during development, but not in the final code
# Warn
dbg_macro = "warn"
todo = "warn"
# Ignore/Allow
result_large_err = "allow"
# Warn on these lint group (Some might be warn by default already though)
# Will be denied during CI!
complexity = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
style = { level = "warn", priority = -1 }
suspicious = { level = "warn", priority = -1 }
# Deny individual lints
# Deny
branches_sharing_code = "deny"
case_sensitive_file_extension_comparisons = "deny"
cast_lossless = "deny"
clone_on_ref_ptr = "deny"
equatable_if_let = "deny"
excessive_precision = "deny"
filter_map_next = "deny"
float_cmp_const = "deny"
implicit_clone = "deny"
inefficient_to_string = "deny"
iter_on_empty_collections = "deny"
iter_on_single_items = "deny"
linkedlist = "deny"
macro_use_imports = "deny"
manual_assert = "deny"
manual_instant_elapsed = "deny"
manual_string_new = "deny"
match_wildcard_for_single_variants = "deny"
mem_forget = "deny"
needless_borrow = "deny"
needless_collect = "deny"
needless_continue = "deny"
needless_lifetimes = "deny"
option_option = "deny"
redundant_clone = "deny"
string_add_assign = "deny"
unnecessary_join = "deny"
unnecessary_self_imports = "deny"
unnested_or_patterns = "deny"
unused_async = "deny"
unused_self = "deny"
useless_let_if_seq = "deny"
verbose_file_reads = "deny"
str_to_string = "deny"
# Pedantic Opt-Outs
inline_always = "allow" # We use this sparsely
struct_field_names = "allow" # Noisy and some items are Bitwarden controlled
large_futures = "allow" # Causes a fail in some Rocket macro's, since we experience no issues, allow it
too_many_lines = "allow" # For now, allow this, good to enable in the future and see if we can refactor
unnecessary_wraps = "allow" # Too much false positives because of Rocket integrations
# We do not use these doc items
doc_link_with_quotes = "allow"
doc_markdown = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "allow"
zero_sized_map_values = "deny"
[lints]
workspace = true
+2 -3
View File
@@ -59,9 +59,8 @@ A nearly complete implementation of the Bitwarden Client API is provided, includ
## Usage
> [!IMPORTANT]
> The web-vault requires the use of HTTPS and a secure context for the [Web Crypto API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Crypto_API). <br>
> That means it will only work if you [enable HTTPS](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-HTTPS). <br>
> We also suggest to use a [reverse proxy](https://github.com/dani-garcia/vaultwarden/wiki/Proxy-examples).
> The web-vault requires the use a secure context for the [Web Crypto API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Crypto_API).
> That means it will only work via `http://localhost:8000` (using the port from the example below) or if you [enable HTTPS](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-HTTPS).
The recommended way to install and use Vaultwarden is via our container images which are published to [ghcr.io](https://github.com/dani-garcia/vaultwarden/pkgs/container/vaultwarden), [docker.io](https://hub.docker.com/r/vaultwarden/server) and [quay.io](https://quay.io/repository/vaultwarden/server).
See [which container image to use](https://github.com/dani-garcia/vaultwarden/wiki/Which-container-image-to-use) for an explanation of the provided tags.
+12 -10
View File
@@ -1,21 +1,22 @@
use std::{env, io::Error, process::Command};
use std::env;
use std::process::Command;
fn main() {
// These allow using e.g. #[cfg(mysql)] instead of #[cfg(feature = "mysql")], which helps when trying to add them through macros
#[cfg(feature = "sqlite_system")] // The `sqlite` feature implies this one.
// This allow using #[cfg(sqlite)] instead of #[cfg(feature = "sqlite")], which helps when trying to add them through macros
#[cfg(feature = "sqlite")]
println!("cargo:rustc-cfg=sqlite");
#[cfg(feature = "mysql")]
println!("cargo:rustc-cfg=mysql");
#[cfg(feature = "postgresql")]
println!("cargo:rustc-cfg=postgresql");
#[cfg(not(any(feature = "sqlite_system", feature = "mysql", feature = "postgresql")))]
#[cfg(feature = "s3")]
println!("cargo:rustc-cfg=s3");
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgresql")))]
compile_error!(
"You need to enable one DB backend. To build with previous defaults do: cargo build --features sqlite"
);
#[cfg(feature = "s3")]
println!("cargo:rustc-cfg=s3");
// Use check-cfg to let cargo know which cfg's we define,
// and avoid warnings when they are used in the code.
println!("cargo::rustc-check-cfg=cfg(sqlite)");
@@ -41,12 +42,13 @@ fn main() {
}
}
fn run(args: &[&str]) -> Result<String, Error> {
fn run(args: &[&str]) -> Result<String, std::io::Error> {
let out = Command::new(args[0]).args(&args[1..]).output()?;
if !out.status.success() {
use std::io::Error;
return Err(Error::other("Command not successful"));
}
Ok(String::from_utf8(out.stdout).unwrap().trim().to_owned())
Ok(String::from_utf8(out.stdout).unwrap().trim().to_string())
}
/// This method reads info from Git, namely tags, branch, and revision
@@ -56,7 +58,7 @@ fn run(args: &[&str]) -> Result<String, Error> {
/// - `env!("GIT_BRANCH")`
/// - `env!("GIT_REV")`
/// - `env!("VW_VERSION")`
fn version_from_git_info() -> Result<String, Error> {
fn version_from_git_info() -> Result<String, std::io::Error> {
// The exact tag for the current commit, can be empty when
// the current commit doesn't have an associated tag
let exact_tag = run(&["git", "describe", "--abbrev=0", "--tags", "--exact-match"]).ok();
+1 -1
View File
@@ -2,4 +2,4 @@
# see diesel.rs/guides/configuring-diesel-cli
[print_schema]
file = "src/db/schema.rs"
file = "src/db/schema.rs"
+3 -3
View File
@@ -1,11 +1,11 @@
---
vault_version: "v2026.4.1"
vault_image_digest: "sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe"
vault_version: "v2026.2.0"
vault_image_digest: "sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447"
# Cross Compile Docker Helper Scripts v1.9.0
# We use the linux/amd64 platform shell scripts since there is no difference between the different platform scripts
# https://github.com/tonistiigi/xx | https://hub.docker.com/r/tonistiigi/xx/tags
xx_image_digest: "sha256:c64defb9ed5a91eacb37f96ccc3d4cd72521c4bd18d5442905b95e2226b0e707"
rust_version: 1.95.0 # Rust version to be used
rust_version: 1.94.1 # Rust version to be used
debian_version: trixie # Debian release name to be used
alpine_version: "3.23" # Alpine version to be used
# For which platforms/architectures will we try to build images
+11 -10
View File
@@ -19,23 +19,23 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to.
# - From the command line:
# $ docker pull docker.io/vaultwarden/web-vault:v2026.4.1
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.4.1
# [docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe]
# $ docker pull docker.io/vaultwarden/web-vault:v2026.2.0
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.2.0
# [docker.io/vaultwarden/web-vault@sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447]
#
# - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe
# [docker.io/vaultwarden/web-vault:v2026.4.1]
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447
# [docker.io/vaultwarden/web-vault:v2026.2.0]
#
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe AS vault
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447 AS vault
########################## ALPINE BUILD IMAGES ##########################
## NOTE: The Alpine Base Images do not support other platforms then linux/amd64 and linux/arm64
## And for Alpine we define all build images here, they will only be loaded when actually used
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:x86_64-musl-stable-1.95.0 AS build_amd64
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:aarch64-musl-stable-1.95.0 AS build_arm64
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:armv7-musleabihf-stable-1.95.0 AS build_armv7
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:arm-musleabi-stable-1.95.0 AS build_armv6
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:x86_64-musl-stable-1.94.1 AS build_amd64
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:aarch64-musl-stable-1.94.1 AS build_arm64
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:armv7-musleabihf-stable-1.94.1 AS build_armv7
FROM --platform=$BUILDPLATFORM ghcr.io/blackdex/rust-musl:arm-musleabi-stable-1.94.1 AS build_armv6
########################## BUILD IMAGE ##########################
# hadolint ignore=DL3006
@@ -57,6 +57,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq17/lib"
# Create CARGO_HOME folder and don't download rust docs
RUN mkdir -pv "${CARGO_HOME}" && \
rustup set profile minimal
+40 -18
View File
@@ -19,15 +19,15 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to.
# - From the command line:
# $ docker pull docker.io/vaultwarden/web-vault:v2026.4.1
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.4.1
# [docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe]
# $ docker pull docker.io/vaultwarden/web-vault:v2026.2.0
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.2.0
# [docker.io/vaultwarden/web-vault@sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447]
#
# - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe
# [docker.io/vaultwarden/web-vault:v2026.4.1]
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447
# [docker.io/vaultwarden/web-vault:v2026.2.0]
#
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe AS vault
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:37c8661fa59dcdfbd3baa8366b6e950ef292b15adfeff1f57812b075c1fd3447 AS vault
########################## Cross Compile Docker Helper Scripts ##########################
## We use the linux/amd64 no matter which Build Platform, since these are all bash scripts
@@ -36,7 +36,7 @@ FROM --platform=linux/amd64 docker.io/tonistiigi/xx@sha256:c64defb9ed5a91eacb37f
########################## BUILD IMAGE ##########################
# hadolint ignore=DL3006
FROM --platform=$BUILDPLATFORM docker.io/library/rust:1.95.0-slim-trixie AS build
FROM --platform=$BUILDPLATFORM docker.io/library/rust:1.94.1-slim-trixie AS build
COPY --from=xx / /
ARG TARGETARCH
ARG TARGETVARIANT
@@ -51,7 +51,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
TERM=xterm-256color \
CARGO_HOME="/root/.cargo" \
USER="root"
# Install clang && xx-c-essentials to get `xx-cargo` working
# Install clang to get `xx-cargo` working
# Install pkg-config to allow amd64 builds to find all libraries
# Install git so build.rs can determine the correct version
# Install the libc cross packages based upon the debian-arch
@@ -59,16 +59,19 @@ RUN apt-get update && \
apt-get install -y \
--no-install-recommends \
clang \
git && \
pkg-config \
git \
"libc6-$(xx-info debian-arch)-cross" \
"libc6-dev-$(xx-info debian-arch)-cross" \
"linux-libc-dev-$(xx-info debian-arch)-cross" && \
xx-apt-get install -y \
--no-install-recommends \
gcc \
libpq-dev \
libpq5 \
libssl-dev \
libmariadb-dev \
pkg-config \
zlib1g-dev \
xx-c-essentials && \
zlib1g-dev && \
# Run xx-cargo early, since it sometimes seems to break when run at a later stage
echo "export CARGO_TARGET=$(xx-cargo --print-target-triple)" >> /env-cargo
@@ -80,6 +83,29 @@ RUN mkdir -pv "${CARGO_HOME}" && \
RUN USER=root cargo new --bin /app
WORKDIR /app
# Environment variables for Cargo on Debian based builds
ARG TARGET_PKG_CONFIG_PATH
RUN source /env-cargo && \
if xx-info is-cross ; then \
# We can't use xx-cargo since that uses clang, which doesn't work for our libraries.
# Because of this we generate the needed environment variables here which we can load in the needed steps.
echo "export CC_$(echo "${CARGO_TARGET}" | tr '[:upper:]' '[:lower:]' | tr - _)=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CARGO_TARGET_$(echo "${CARGO_TARGET}" | tr '[:lower:]' '[:upper:]' | tr - _)_LINKER=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CROSS_COMPILE=1" >> /env-cargo && \
echo "export PKG_CONFIG_ALLOW_CROSS=1" >> /env-cargo && \
# For some architectures `xx-info` returns a triple which doesn't matches the path on disk
# In those cases you can override this by setting the `TARGET_PKG_CONFIG_PATH` build-arg
if [[ -n "${TARGET_PKG_CONFIG_PATH}" ]]; then \
echo "export TARGET_PKG_CONFIG_PATH=${TARGET_PKG_CONFIG_PATH}" >> /env-cargo ; \
else \
echo "export PKG_CONFIG_PATH=/usr/lib/$(xx-info)/pkgconfig" >> /env-cargo ; \
fi && \
echo "# End of env-cargo" >> /env-cargo ; \
fi && \
# Output the current contents of the file
cat /env-cargo
RUN source /env-cargo && \
rustup target add "${CARGO_TARGET}"
@@ -96,9 +122,7 @@ ARG DB=sqlite,mysql,postgresql
# dummy project, except the target folder
# This folder contains the compiled dependencies
RUN source /env-cargo && \
# Workaround for xx related build issues
# https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
PKG_CONFIG="$(command -v "$(xx-info)-pkg-config")" xx-cargo build --features ${DB} --profile "${CARGO_PROFILE}" && \
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
find . -not -path "./target*" -delete
# Copies the complete project
@@ -113,9 +137,7 @@ RUN source /env-cargo && \
# Also do this for build.rs to ensure the version is rechecked
touch build.rs src/main.rs && \
# Create a symlink to the binary target folder to easy copy the binary in the final stage
# Workaround for xx related build issues
# https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
PKG_CONFIG="$(command -v "$(xx-info)-pkg-config")" xx-cargo build --features ${DB} --profile "${CARGO_PROFILE}" && \
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
if [[ "${CARGO_PROFILE}" == "dev" ]] ; then \
ln -vfsr "/app/target/${CARGO_TARGET}/debug" /app/target/final ; \
else \
+34 -20
View File
@@ -27,11 +27,6 @@
# $ docker image inspect --format "{{ '{{' }}.RepoTags}}" docker.io/vaultwarden/web-vault@{{ vault_image_digest }}
# [docker.io/vaultwarden/web-vault:{{ vault_version | replace('+', '_') }}]
#
{% macro xx_cargo_config() -%}
# Workaround for xx related build issues
# https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
PKG_CONFIG="$(command -v "$(xx-info)-pkg-config")" xx-cargo build --features ${DB} --profile "${CARGO_PROFILE}"
{%- endmacro %}
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@{{ vault_image_digest }} AS vault
{% if base == "debian" %}
@@ -71,10 +66,10 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Use PostgreSQL v17 during Alpine/MUSL builds instead of the default v16
# Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq17/lib"
{%- endif %}
{% endif %}
{% if base == "debian" %}
# Install clang && xx-c-essentials to get `xx-cargo` working
# Install clang to get `xx-cargo` working
# Install pkg-config to allow amd64 builds to find all libraries
# Install git so build.rs can determine the correct version
# Install the libc cross packages based upon the debian-arch
@@ -82,16 +77,19 @@ RUN apt-get update && \
apt-get install -y \
--no-install-recommends \
clang \
git && \
pkg-config \
git \
"libc6-$(xx-info debian-arch)-cross" \
"libc6-dev-$(xx-info debian-arch)-cross" \
"linux-libc-dev-$(xx-info debian-arch)-cross" && \
xx-apt-get install -y \
--no-install-recommends \
gcc \
libpq-dev \
libpq5 \
libssl-dev \
libmariadb-dev \
pkg-config \
zlib1g-dev \
xx-c-essentials && \
zlib1g-dev && \
# Run xx-cargo early, since it sometimes seems to break when run at a later stage
echo "export CARGO_TARGET=$(xx-cargo --print-target-triple)" >> /env-cargo
{% endif %}
@@ -104,7 +102,31 @@ RUN mkdir -pv "${CARGO_HOME}" && \
RUN USER=root cargo new --bin /app
WORKDIR /app
{% if base == "alpine" %}
{% if base == "debian" %}
# Environment variables for Cargo on Debian based builds
ARG TARGET_PKG_CONFIG_PATH
RUN source /env-cargo && \
if xx-info is-cross ; then \
# We can't use xx-cargo since that uses clang, which doesn't work for our libraries.
# Because of this we generate the needed environment variables here which we can load in the needed steps.
echo "export CC_$(echo "${CARGO_TARGET}" | tr '[:upper:]' '[:lower:]' | tr - _)=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CARGO_TARGET_$(echo "${CARGO_TARGET}" | tr '[:lower:]' '[:upper:]' | tr - _)_LINKER=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CROSS_COMPILE=1" >> /env-cargo && \
echo "export PKG_CONFIG_ALLOW_CROSS=1" >> /env-cargo && \
# For some architectures `xx-info` returns a triple which doesn't matches the path on disk
# In those cases you can override this by setting the `TARGET_PKG_CONFIG_PATH` build-arg
if [[ -n "${TARGET_PKG_CONFIG_PATH}" ]]; then \
echo "export TARGET_PKG_CONFIG_PATH=${TARGET_PKG_CONFIG_PATH}" >> /env-cargo ; \
else \
echo "export PKG_CONFIG_PATH=/usr/lib/$(xx-info)/pkgconfig" >> /env-cargo ; \
fi && \
echo "# End of env-cargo" >> /env-cargo ; \
fi && \
# Output the current contents of the file
cat /env-cargo
{% elif base == "alpine" %}
# Environment variables for Cargo on Alpine based builds
RUN echo "export CARGO_TARGET=${RUST_MUSL_CROSS_TARGET}" >> /env-cargo && \
# Output the current contents of the file
@@ -132,11 +154,7 @@ ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# dummy project, except the target folder
# This folder contains the compiled dependencies
RUN source /env-cargo && \
{% if base == "debian" %}
{{ xx_cargo_config() }} && \
{% elif base == "alpine" %}
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
{% endif %}
find . -not -path "./target*" -delete
# Copies the complete project
@@ -151,11 +169,7 @@ RUN source /env-cargo && \
# Also do this for build.rs to ensure the version is rechecked
touch build.rs src/main.rs && \
# Create a symlink to the binary target folder to easy copy the binary in the final stage
{% if base == "debian" %}
{{ xx_cargo_config() }} && \
{% elif base == "alpine" %}
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
{% endif %}
if [[ "${CARGO_PROFILE}" == "dev" ]] ; then \
ln -vfsr "/app/target/${CARGO_TARGET}/debug" /app/target/final ; \
else \
+4 -5
View File
@@ -1,15 +1,14 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, parse_macro_input};
#[proc_macro_derive(UuidFromParam)]
pub fn derive_uuid_from_param(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let ast = syn::parse(input).unwrap();
impl_derive_uuid_macro(&ast)
}
fn impl_derive_uuid_macro(ast: &DeriveInput) -> TokenStream {
fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen_derive = quote! {
#[automatically_derived]
@@ -31,12 +30,12 @@ fn impl_derive_uuid_macro(ast: &DeriveInput) -> TokenStream {
#[proc_macro_derive(IdFromParam)]
pub fn derive_id_from_param(input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let ast = syn::parse(input).unwrap();
impl_derive_safestring_macro(&ast)
}
fn impl_derive_safestring_macro(ast: &DeriveInput) -> TokenStream {
fn impl_derive_safestring_macro(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen_derive = quote! {
#[automatically_derived]
@@ -1 +0,0 @@
DROP TABLE IF EXISTS archives;
@@ -1,10 +0,0 @@
DROP TABLE IF EXISTS archives;
CREATE TABLE archives (
user_uuid CHAR(36) NOT NULL,
cipher_uuid CHAR(36) NOT NULL,
archived_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_uuid, cipher_uuid),
FOREIGN KEY (user_uuid) REFERENCES users (uuid) ON DELETE CASCADE,
FOREIGN KEY (cipher_uuid) REFERENCES ciphers (uuid) ON DELETE CASCADE
);
@@ -1 +0,0 @@
ALTER TABLE sso_auth DROP COLUMN binding_hash;
@@ -1 +0,0 @@
ALTER TABLE sso_auth ADD COLUMN binding_hash TEXT;
@@ -1 +0,0 @@
ALTER TABLE sso_auth DROP COLUMN code_response_error;
@@ -1 +0,0 @@
ALTER TABLE sso_auth ADD COLUMN code_response_error TEXT;
@@ -1 +0,0 @@
DROP TABLE IF EXISTS archives;
@@ -1,8 +0,0 @@
DROP TABLE IF EXISTS archives;
CREATE TABLE archives (
user_uuid CHAR(36) NOT NULL REFERENCES users (uuid) ON DELETE CASCADE,
cipher_uuid CHAR(36) NOT NULL REFERENCES ciphers (uuid) ON DELETE CASCADE,
archived_at TIMESTAMP NOT NULL DEFAULT now(),
PRIMARY KEY (user_uuid, cipher_uuid)
);
@@ -1 +0,0 @@
ALTER TABLE sso_auth DROP COLUMN binding_hash;
@@ -1 +0,0 @@
ALTER TABLE sso_auth ADD COLUMN binding_hash TEXT;
@@ -1 +0,0 @@
ALTER TABLE sso_auth DROP COLUMN IF EXISTS code_response_error;
@@ -1 +0,0 @@
ALTER TABLE sso_auth ADD COLUMN IF NOT EXISTS code_response_error TEXT;
@@ -1 +0,0 @@
DROP TABLE IF EXISTS archives;
@@ -1,8 +0,0 @@
DROP TABLE IF EXISTS archives;
CREATE TABLE archives (
user_uuid CHAR(36) NOT NULL REFERENCES users (uuid) ON DELETE CASCADE,
cipher_uuid CHAR(36) NOT NULL REFERENCES ciphers (uuid) ON DELETE CASCADE,
archived_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_uuid, cipher_uuid)
);
@@ -1 +0,0 @@
ALTER TABLE sso_auth DROP COLUMN binding_hash;
@@ -1 +0,0 @@
ALTER TABLE sso_auth ADD COLUMN binding_hash TEXT;
@@ -1 +0,0 @@
ALTER TABLE sso_auth DROP COLUMN code_response_error;
@@ -1 +0,0 @@
ALTER TABLE sso_auth ADD COLUMN code_response_error TEXT;
+1 -1
View File
@@ -1,4 +1,4 @@
[toolchain]
channel = "1.95.0"
channel = "1.94.1"
components = [ "rustfmt", "clippy" ]
profile = "minimal"
+1 -1
View File
@@ -1,4 +1,4 @@
edition = "2024"
edition = "2021"
max_width = 120
newline_style = "Unix"
use_small_heuristics = "Off"
+73 -77
View File
@@ -2,40 +2,39 @@ use std::{env, sync::LazyLock};
use reqwest::Method;
use rocket::{
Catcher, Route,
form::Form,
http::{Cookie, CookieJar, MediaType, SameSite, Status},
request::{FromRequest, Outcome, Request},
response::{Redirect, content::RawHtml as Html},
response::{content::RawHtml as Html, Redirect},
serde::json::Json,
Catcher, Route,
};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::{
CONFIG, VERSION,
api::{
ApiResult, EmptyResult, JsonResult, Notify,
core::{log_event, two_factor},
unregister_push_device,
unregister_push_device, ApiResult, EmptyResult, JsonResult, Notify,
},
auth::{ClientIp, Secure, decode_admin, encode_jwt, generate_admin_claims},
auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure},
config::ConfigBuilder,
db::{
ACTIVE_DB_TYPE, DbConn, DbConnType, backup_sqlite, get_sql_server_version,
backup_sqlite, get_sql_server_version,
models::{
Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId,
MembershipType, OrgPolicy, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId,
},
DbConn, DbConnType, ACTIVE_DB_TYPE,
},
error::{Error, MapResult},
http_client::make_http_request,
mail,
sso::FAKE_SSO_IDENTIFIER,
util::{
FeatureFlagFilter, NumberOrString, container_base_image, format_naive_datetime_local, get_active_web_release,
get_display_size, is_running_in_container, parse_experimental_client_feature_flags,
container_base_image, format_naive_datetime_local, get_active_web_release, get_display_size,
is_running_in_container, parse_experimental_client_feature_flags, FeatureFlagFilter, NumberOrString,
},
CONFIG, VERSION,
};
pub fn routes() -> Vec<Route> {
@@ -93,7 +92,8 @@ static DB_TYPE: LazyLock<&str> = LazyLock::new(|| match ACTIVE_DB_TYPE.get() {
});
#[cfg(sqlite)]
static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| ACTIVE_DB_TYPE.get().is_some_and(|t| *t == DbConnType::Sqlite));
static CAN_BACKUP: LazyLock<bool> =
LazyLock::new(|| ACTIVE_DB_TYPE.get().map(|t| *t == DbConnType::Sqlite).unwrap_or(false));
#[cfg(not(sqlite))]
static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| false);
@@ -199,7 +199,13 @@ fn post_admin_login(
}
// If the token is invalid, redirect to login page
if validate_token(&data.token) {
if !_validate_token(&data.token) {
error!("Invalid admin token. IP: {}", ip.ip);
Err(AdminResponse::Unauthorized(render_admin_login(
Some("Invalid admin token, please try again."),
redirect.as_deref(),
)))
} else {
// If the token received is valid, generate JWT and save it as a cookie
let claims = generate_admin_claims();
let jwt = encode_jwt(&claims);
@@ -217,16 +223,10 @@ fn post_admin_login(
} else {
Err(AdminResponse::Ok(render_admin_page()))
}
} else {
error!("Invalid admin token. IP: {}", ip.ip);
Err(AdminResponse::Unauthorized(render_admin_login(
Some("Invalid admin token, please try again."),
redirect.as_deref(),
)))
}
}
fn validate_token(token: &str) -> bool {
fn _validate_token(token: &str) -> bool {
match CONFIG.admin_token().as_ref() {
None => false,
Some(t) if t.starts_with("$argon2") => {
@@ -306,21 +306,6 @@ async fn get_user_or_404(user_id: &UserId, conn: &DbConn) -> ApiResult<User> {
#[post("/invite", format = "application/json", data = "<data>")]
async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult {
async fn generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if CONFIG.mail_enabled() {
let org_id: OrganizationId = if CONFIG.sso_enabled() {
FAKE_SSO_IDENTIFIER.into()
} else {
FAKE_ADMIN_UUID.into()
};
let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into();
mail::send_invite(user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else {
let invitation = Invitation::new(&user.email);
invitation.save(conn).await
}
}
let data: InviteData = data.into_inner();
if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code)
@@ -328,7 +313,18 @@ async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -
let mut user = User::new(&data.email, None);
generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if CONFIG.mail_enabled() {
let org_id: OrganizationId = FAKE_ADMIN_UUID.to_string().into();
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into();
mail::send_invite(user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else {
let invitation = Invitation::new(&user.email);
invitation.save(conn).await
}
}
_generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
Ok(Json(user.to_json(&conn).await))
@@ -385,7 +381,7 @@ async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<Stri
None => json!("Never"),
};
usr["sso_identifier"] = json!(sso_u.map_or(String::new(), |u| u.identifier.to_string()));
usr["sso_identifier"] = json!(sso_u.map(|u| u.identifier.to_string()).unwrap_or(String::new()));
users_json.push(usr);
}
@@ -468,10 +464,10 @@ async fn deauth_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Noti
if CONFIG.push_enabled() {
for device in Device::find_push_devices_by_user(&user.uuid, &conn).await {
match unregister_push_device(device.push_uuid.as_ref()).await {
match unregister_push_device(&device.push_uuid).await {
Ok(r) => r,
Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"),
}
};
}
}
@@ -522,12 +518,8 @@ async fn resend_user_invite(user_id: UserId, _token: AdminToken, conn: DbConn) -
}
if CONFIG.mail_enabled() {
let org_id: OrganizationId = if CONFIG.sso_enabled() {
FAKE_SSO_IDENTIFIER.into()
} else {
FAKE_ADMIN_UUID.into()
};
let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into();
let org_id: OrganizationId = FAKE_ADMIN_UUID.to_string().into();
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into();
mail::send_invite(&user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else {
Ok(())
@@ -553,10 +545,9 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
err!("The specified user isn't member of the organization")
};
let new_type = if let Some(new_type) = MembershipType::from_str(&data.user_type.into_string()) {
new_type as i32
} else {
err!("Invalid type")
let new_type = match MembershipType::from_str(&data.user_type.into_string()) {
Some(new_type) => new_type as i32,
None => err!("Invalid type"),
};
if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner {
@@ -656,40 +647,42 @@ async fn get_release_info(has_http_access: bool) -> (String, String, String) {
.await
{
Ok(r) => r.tag_name,
_ => "-".to_owned(),
_ => "-".to_string(),
},
match get_json_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await {
Ok(mut c) => {
c.sha.truncate(8);
c.sha
}
_ => "-".to_owned(),
_ => "-".to_string(),
},
// Do not fetch the web-vault version when running within a container
// The web-vault version is embedded within the container it self, and should not be updated manually
match get_json_api::<GitRelease>("https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest")
.await
{
Ok(r) => r.tag_name.trim_start_matches('v').to_owned(),
_ => "-".to_owned(),
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
_ => "-".to_string(),
},
)
} else {
("-".to_owned(), "-".to_owned(), "-".to_owned())
("-".to_string(), "-".to_string(), "-".to_string())
}
}
async fn get_ntp_time(has_http_access: bool) -> String {
if has_http_access && let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await {
for line in cf_trace.lines() {
if let Some((key, value)) = line.split_once('=')
&& key == "ts"
{
let ts = value.split_once('.').map_or(value, |(s, _)| s);
if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") {
return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string();
if has_http_access {
if let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await {
for line in cf_trace.lines() {
if let Some((key, value)) = line.split_once('=') {
if key == "ts" {
let ts = value.split_once('.').map_or(value, |(s, _)| s);
if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") {
return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string();
}
break;
}
}
break;
}
}
}
@@ -732,7 +725,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A
// Check if we are able to resolve DNS entries
let dns_resolved = match ("github.com", 0).to_socket_addrs().map(|mut i| i.next()) {
Ok(Some(a)) => a.ip().to_string(),
_ => "Unable to resolve domain name.".to_owned(),
_ => "Unable to resolve domain name.".to_string(),
};
let (latest_vw_release, latest_vw_commit, latest_web_release) = get_release_info(has_http_access).await;
@@ -743,7 +736,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A
let invalid_feature_flags: Vec<String> = parse_experimental_client_feature_flags(
&CONFIG.experimental_client_feature_flags(),
&FeatureFlagFilter::InvalidOnly,
FeatureFlagFilter::InvalidOnly,
)
.into_keys()
.collect();
@@ -832,30 +825,33 @@ impl<'r> FromRequest<'r> for AdminToken {
type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let Outcome::Success(ip) = ClientIp::from_request(request).await else {
err_handler!("Error getting Client IP")
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
};
if !CONFIG.disable_admin_token() {
let cookies = request.cookies();
let access_token = if let Some(cookie) = cookies.get(COOKIE_NAME) {
cookie.value()
} else {
let requested_page =
request.segments::<std::path::PathBuf>(0..).unwrap_or_default().display().to_string();
// When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page
// Else, return a 401 failure, which will be caught
if requested_page.is_empty() {
return Outcome::Forward(Status::Unauthorized);
let access_token = match cookies.get(COOKIE_NAME) {
Some(cookie) => cookie.value(),
None => {
let requested_page =
request.segments::<std::path::PathBuf>(0..).unwrap_or_default().display().to_string();
// When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page
// Else, return a 401 failure, which will be caught
if requested_page.is_empty() {
return Outcome::Forward(Status::Unauthorized);
} else {
return Outcome::Error((Status::Unauthorized, "Unauthorized"));
}
}
return Outcome::Error((Status::Unauthorized, "Unauthorized"));
};
if decode_admin(access_token).is_err() {
// Remove admin cookie
cookies.remove(Cookie::build(COOKIE_NAME).path(admin_path()));
error!("Invalid or expired admin JWT. IP: {}.", ip.ip);
error!("Invalid or expired admin JWT. IP: {}.", &ip.ip);
return Outcome::Error((Status::Unauthorized, "Session expired"));
}
}
+86 -82
View File
@@ -1,37 +1,34 @@
use std::collections::HashSet;
use crate::db::DbPool;
use chrono::Utc;
use rocket::{
http::Status,
request::{FromRequest, Outcome, Request},
serde::json::Json,
};
use rocket::serde::json::Json;
use serde_json::Value;
use crate::{
CONFIG,
api::{
AnonymousNotify, ApiResult, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType,
core::{accept_org_invite, log_user_event, two_factor::email},
master_password_policy, register_push_device, unregister_push_device,
master_password_policy, register_push_device, unregister_push_device, AnonymousNotify, ApiResult, EmptyResult,
JsonResult, Notify, PasswordOrOtpData, UpdateType,
},
auth::{ClientHeaders, Headers, decode_delete, decode_invite, decode_verify_email},
auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers},
crypto,
db::{
DbConn, DbPool,
models::{
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, DeviceWithAuthRequest,
EmergencyAccess, EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId,
OrgPolicy, OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, EmergencyAccess,
EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, OrgPolicy,
OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
},
DbConn,
},
mail,
util::{NumberOrString, deser_opt_nonempty_str, format_date},
util::{deser_opt_nonempty_str, format_date, NumberOrString},
CONFIG,
};
use super::{
ciphers::{CipherData, update_cipher_from_data},
sends::{SendData, update_send_from_data},
use rocket::{
http::Status,
request::{FromRequest, Outcome, Request},
};
pub fn routes() -> Vec<rocket::Route> {
@@ -57,9 +54,9 @@ pub fn routes() -> Vec<rocket::Route> {
delete_account,
revision_date,
password_hint,
post_prelogin,
prelogin,
verify_password,
post_api_key,
api_key,
rotate_api_key,
get_known_device,
get_all_devices,
@@ -140,17 +137,17 @@ struct KeysData {
}
/// Trims whitespace from password hints, and converts blank password hints to `None`.
fn clean_password_hint(password_hint: Option<&String>) -> Option<String> {
fn clean_password_hint(password_hint: &Option<String>) -> Option<String> {
match password_hint {
None => None,
Some(h) => match h.trim() {
"" => None,
ht => Some(ht.to_owned()),
ht => Some(ht.to_string()),
},
}
}
fn enforce_password_hint_setting(password_hint: Option<&String>) -> EmptyResult {
fn enforce_password_hint_setting(password_hint: &Option<String>) -> EmptyResult {
if password_hint.is_some() && !CONFIG.password_hints_allowed() {
err!("Password hints have been disabled by the administrator. Remove the hint and try again.");
}
@@ -169,7 +166,7 @@ async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &DbConn) -
false
}
pub async fn register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
let mut data: RegisterData = data.into_inner();
let email = data.email.to_lowercase();
@@ -240,16 +237,16 @@ pub async fn register(data: Json<RegisterData>, email_verification: bool, conn:
// Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
// This also prevents issues with very long usernames causing to large JWT's. See #2419
if let Some(ref name) = data.name
&& name.len() > 50
{
err!("The field Name must be a string with a maximum length of 50.");
if let Some(ref name) = data.name {
if name.len() > 50 {
err!("The field Name must be a string with a maximum length of 50.");
}
}
// Check against the password hint setting here so if it fails, the user
// can retry without losing their invitation below.
let password_hint = clean_password_hint(data.master_password_hint.as_ref());
enforce_password_hint_setting(password_hint.as_ref())?;
let password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&password_hint)?;
let mut user = match User::find_by_mail(&email, &conn).await {
Some(user) => {
@@ -356,8 +353,8 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn:
// Check against the password hint setting here so if it fails,
// the user can retry without losing their invitation below.
let password_hint = clean_password_hint(data.master_password_hint.as_ref());
enforce_password_hint_setting(password_hint.as_ref())?;
let password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&password_hint)?;
set_kdf_data(&mut user, &data.kdf)?;
@@ -376,19 +373,18 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn:
user.public_key = Some(keys.public_key);
}
if let Some(identifier) = data.org_identifier
&& identifier != crate::sso::FAKE_SSO_IDENTIFIER
&& identifier != crate::api::admin::FAKE_ADMIN_UUID
{
let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else {
err!("Failed to retrieve the associated organization")
};
if let Some(identifier) = data.org_identifier {
if identifier != crate::sso::FAKE_IDENTIFIER && identifier != crate::api::admin::FAKE_ADMIN_UUID {
let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else {
err!("Failed to retrieve the associated organization")
};
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await else {
err!("Failed to retrieve the invitation")
};
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await else {
err!("Failed to retrieve the invitation")
};
accept_org_invite(&user, membership, None, &conn).await?;
accept_org_invite(&user, membership, None, &conn).await?;
}
}
if CONFIG.mail_enabled() {
@@ -455,10 +451,10 @@ async fn put_avatar(data: Json<AvatarData>, headers: Headers, conn: DbConn) -> J
// It looks like it only supports the 6 hex color format.
// If you try to add the short value it will not show that color.
// Check and force 7 chars, including the #.
if let Some(color) = &data.avatar_color
&& color.len() != 7
{
err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters")
if let Some(color) = &data.avatar_color {
if color.len() != 7 {
err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters")
}
}
let mut user = headers.user;
@@ -519,8 +515,8 @@ async fn post_password(data: Json<ChangePassData>, headers: Headers, conn: DbCon
err!("Invalid password")
}
user.password_hint = clean_password_hint(data.master_password_hint.as_ref());
enforce_password_hint_setting(user.password_hint.as_ref())?;
user.password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&user.password_hint)?;
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await;
@@ -672,6 +668,9 @@ struct UpdateResetPasswordData {
reset_password_key: String,
}
use super::ciphers::CipherData;
use super::sends::{update_send_from_data, SendData};
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct KeyData {
@@ -841,7 +840,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
};
saved_folder.name = folder_data.name;
saved_folder.save(&conn).await?;
saved_folder.save(&conn).await?
}
}
@@ -854,7 +853,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
};
saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted);
saved_emergency_access.save(&conn).await?;
saved_emergency_access.save(&conn).await?
}
// Update reset password data
@@ -866,7 +865,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
};
membership.reset_password_key = Some(reset_password_data.reset_password_key);
membership.save(&conn).await?;
membership.save(&conn).await?
}
// Update send data
@@ -879,6 +878,8 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
}
// Update cipher data
use super::ciphers::update_cipher_from_data;
for cipher_data in data.account_data.ciphers {
if cipher_data.organization_id.is_none() {
let Some(saved_cipher) = existing_ciphers.iter_mut().find(|c| &c.uuid == cipher_data.id.as_ref().unwrap())
@@ -889,7 +890,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
// Prevent triggering cipher updates via WebSockets by settings UpdateType::None
// The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues.
// We force the users to logout after the user has been saved to try and prevent these issues.
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?;
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?
}
}
@@ -1019,22 +1020,24 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, conn: DbConn,
err!("Email already in use");
}
if let Some(ref val) = user.email_new {
if val != &data.new_email {
err!("Email change mismatch");
match user.email_new {
Some(ref val) => {
if val != &data.new_email {
err!("Email change mismatch");
}
}
} else {
err!("No email change pending")
None => err!("No email change pending"),
}
if CONFIG.mail_enabled() {
// Only check the token if we sent out an email...
if let Some(ref val) = user.email_new_token {
if *val != data.token.into_string() {
err!("Token mismatch");
match user.email_new_token {
Some(ref val) => {
if *val != data.token.into_string() {
err!("Token mismatch");
}
}
} else {
err!("No email change pending")
None => err!("No email change pending"),
}
user.verified_at = Some(Utc::now().naive_utc());
} else {
@@ -1111,10 +1114,10 @@ async fn post_delete_recover(data: Json<DeleteRecoverData>, conn: DbConn) -> Emp
let data: DeleteRecoverData = data.into_inner();
if CONFIG.mail_enabled() {
if let Some(user) = User::find_by_mail(&data.email, &conn).await
&& let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await
{
error!("Error sending delete account email: {e:#?}");
if let Some(user) = User::find_by_mail(&data.email, &conn).await {
if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await {
error!("Error sending delete account email: {e:#?}");
}
}
Ok(())
} else {
@@ -1166,7 +1169,6 @@ async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, conn: D
user.delete(&conn).await
}
#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")]
#[get("/accounts/revision-date")]
fn revision_date(headers: Headers) -> JsonResult {
let revision_date = headers.user.updated_at.and_utc().timestamp_millis();
@@ -1181,12 +1183,12 @@ struct PasswordHintData {
#[post("/accounts/password-hint", data = "<data>")]
async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResult {
const NO_HINT: &str = "Sorry, you have no password hint...";
if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) {
err!("This server is not configured to provide password hints.");
}
const NO_HINT: &str = "Sorry, you have no password hint...";
let data: PasswordHintData = data.into_inner();
let email = &data.email;
@@ -1197,9 +1199,9 @@ async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResul
// There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that
// don't. Add a randomized sleep to mitigate this somewhat.
use rand::{RngExt, rngs::SmallRng};
use rand::{rngs::SmallRng, RngExt};
let mut rng: SmallRng = rand::make_rng();
let sleep_ms: u64 = rng.random_range(900..=1100);
let sleep_ms = rng.random_range(900..=1100) as u64;
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
Ok(())
} else {
@@ -1227,11 +1229,11 @@ pub struct PreloginData {
}
#[post("/accounts/prelogin", data = "<data>")]
async fn post_prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
prelogin(data, conn).await
async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
}
pub async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
pub async fn _prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
let data: PreloginData = data.into_inner();
let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await {
@@ -1281,7 +1283,9 @@ async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers
Ok(Json(master_password_policy(&user, &conn).await))
}
async fn update_api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
use crate::util::format_date;
let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user;
@@ -1300,13 +1304,13 @@ async fn update_api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: He
}
#[post("/accounts/api-key", data = "<data>")]
async fn post_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
update_api_key(data, false, headers, conn).await
async fn api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, false, headers, conn).await
}
#[post("/accounts/rotate-api-key", data = "<data>")]
async fn rotate_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
update_api_key(data, true, headers, conn).await
_api_key(data, true, headers, conn).await
}
#[get("/devices/knowndevice")]
@@ -1349,7 +1353,7 @@ impl<'r> FromRequest<'r> for KnownDevice {
};
let uuid = if let Some(uuid) = req.headers().get_one("X-Device-Identifier") {
uuid.to_owned().into()
uuid.to_string().into()
} else {
return Outcome::Error((Status::BadRequest, "X-Device-Identifier value is required"));
};
@@ -1364,7 +1368,7 @@ impl<'r> FromRequest<'r> for KnownDevice {
#[get("/devices")]
async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult {
let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await;
let devices = devices.iter().map(DeviceWithAuthRequest::to_json).collect::<Vec<Value>>();
let devices = devices.iter().map(|device| device.to_json()).collect::<Vec<Value>>();
Ok(Json(json!({
"data": devices,
@@ -1434,7 +1438,7 @@ async fn put_clear_device_token(device_id: DeviceId, conn: DbConn) -> EmptyResul
if let Some(device) = Device::find_by_uuid(&device_id, &conn).await {
Device::clear_push_token_by_uuid(&device_id, &conn).await?;
unregister_push_device(device.push_uuid.as_ref()).await?;
unregister_push_device(&device.push_uuid).await?;
}
Ok(())
@@ -1704,6 +1708,6 @@ pub async fn purge_auth_requests(pool: DbPool) {
if let Ok(conn) = pool.get().await {
AuthRequest::purge_expired_auth_requests(&conn).await;
} else {
error!("Failed to get DB connection while purging auth requests");
error!("Failed to get DB connection while purging auth requests")
}
}
+129 -296
View File
@@ -2,30 +2,30 @@ use std::collections::{HashMap, HashSet};
use chrono::{NaiveDateTime, Utc};
use num_traits::ToPrimitive;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use rocket::{
Route,
form::{Form, FromForm},
fs::TempFile,
serde::json::Json,
Route,
};
use serde_json::Value;
use crate::auth::ClientVersion;
use crate::util::{deser_opt_nonempty_str, save_temp_file, NumberOrString};
use crate::{
CONFIG,
api::{self, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, core::log_event},
auth::ClientVersion,
api::{self, core::log_event, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType},
auth::{Headers, OrgIdGuard, OwnerHeaders},
config::PathType,
crypto,
db::{
DbConn, DbPool,
models::{
Archive, Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup,
CollectionId, CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership,
MembershipType, OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId,
CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership, MembershipType,
OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
},
DbConn, DbPool,
},
util::{NumberOrString, deser_opt_nonempty_str, save_temp_file},
CONFIG,
};
use super::folders::FolderData;
@@ -96,10 +96,6 @@ pub fn routes() -> Vec<Route> {
post_collections_update,
post_collections_admin,
put_collections_admin,
archive_cipher_put,
archive_cipher_selected,
unarchive_cipher_put,
unarchive_cipher_selected,
]
}
@@ -108,7 +104,7 @@ pub async fn purge_trashed_ciphers(pool: DbPool) {
if let Ok(conn) = pool.get().await {
Cipher::purge_trash(&conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers");
error!("Failed to get DB connection while purging trashed ciphers")
}
}
@@ -164,7 +160,7 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
let domains_json = if data.exclude_domains {
Value::Null
} else {
api::core::get_eq_domains(&headers, true).into_inner()
api::core::_get_eq_domains(&headers, true).into_inner()
};
// This is very similar to the the userDecryptionOptions sent in connect/token,
@@ -297,7 +293,6 @@ pub struct CipherData {
// when using older client versions, or if the operation doesn't involve
// updating an existing cipher.
last_known_revision_date: Option<String>,
archived_date: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -401,34 +396,20 @@ pub async fn update_cipher_from_data(
nt: &Notify<'_>,
ut: UpdateType,
) -> EmptyResult {
// Cleanup cipher data, like removing the 'Response' key.
// This key is somewhere generated during Javascript so no way for us this fix this.
// Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key.
// We do not mind which data is in it, the keep our model more flexible when there are upstream changes.
// But, we at least know we do not need to store and return this specific key.
fn clean_cipher_data(mut json_data: Value) -> Value {
if json_data.is_array() {
json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| {
f.as_object_mut().unwrap().remove("response");
});
}
json_data
}
enforce_personal_ownership_policy(Some(&data), headers, conn).await?;
// Check that the client isn't updating an existing cipher with stale data.
// And only perform this check when not importing ciphers, else the date/time check will fail.
if ut != UpdateType::None
&& let Some(dt) = data.last_known_revision_date
{
match NaiveDateTime::parse_from_str(&dt, "%+") {
// ISO 8601 format
Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"),
Ok(dt) if cipher.updated_at.signed_duration_since(dt).num_seconds() > 1 => {
err!("The client copy of this cipher is out of date. Resync the client and try again.")
if ut != UpdateType::None {
if let Some(dt) = data.last_known_revision_date {
match NaiveDateTime::parse_from_str(&dt, "%+") {
// ISO 8601 format
Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"),
Ok(dt) if cipher.updated_at.signed_duration_since(dt).num_seconds() > 1 => {
err!("The client copy of this cipher is out of date. Resync the client and try again.")
}
Ok(_) => (),
}
Ok(_) => (),
}
}
@@ -470,22 +451,25 @@ pub async fn update_cipher_from_data(
cipher.user_uuid = Some(headers.user.uuid.clone());
}
if let Some(ref folder_id) = data.folder_id
&& Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user");
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
}
// Modify attachments name and keys when rotating
if let Some(attachments) = data.attachments2 {
for (id, attachment) in attachments {
let Some(mut saved_att) = Attachment::find_by_id(&id, conn).await else {
// Warn and continue here.
// A missing attachment means it was removed via an other client.
// Also the Desktop Client supports removing attachments and save an update afterwards.
// Bitwarden it self ignores these mismatches server side.
warn!("Attachment {id} doesn't exist");
continue;
let mut saved_att = match Attachment::find_by_id(&id, conn).await {
Some(att) => att,
None => {
// Warn and continue here.
// A missing attachment means it was removed via an other client.
// Also the Desktop Client supports removing attachments and save an update afterwards.
// Bitwarden it self ignores these mismatches server side.
warn!("Attachment {id} doesn't exist");
continue;
}
};
if saved_att.cipher_uuid != cipher.uuid {
@@ -502,6 +486,20 @@ pub async fn update_cipher_from_data(
}
}
// Cleanup cipher data, like removing the 'Response' key.
// This key is somewhere generated during Javascript so no way for us this fix this.
// Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key.
// We do not mind which data is in it, the keep our model more flexible when there are upstream changes.
// But, we at least know we do not need to store and return this specific key.
fn _clean_cipher_data(mut json_data: Value) -> Value {
if json_data.is_array() {
json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| {
f.as_object_mut().unwrap().remove("response");
});
};
json_data
}
let type_data_opt = match data.r#type {
1 => data.login,
2 => data.secure_note,
@@ -511,22 +509,23 @@ pub async fn update_cipher_from_data(
_ => err!("Invalid type"),
};
let type_data = if let Some(mut data) = type_data_opt {
// Remove the 'Response' key from the base object.
data.as_object_mut().unwrap().remove("response");
// Remove the 'Response' key from every Uri.
if data["uris"].is_array() {
data["uris"] = clean_cipher_data(data["uris"].clone());
let type_data = match type_data_opt {
Some(mut data) => {
// Remove the 'Response' key from the base object.
data.as_object_mut().unwrap().remove("response");
// Remove the 'Response' key from every Uri.
if data["uris"].is_array() {
data["uris"] = _clean_cipher_data(data["uris"].clone());
}
data
}
data
} else {
err!("Data missing")
None => err!("Data missing"),
};
cipher.key = data.key;
cipher.name = data.name;
cipher.notes = data.notes;
cipher.fields = data.fields.map(|f| clean_cipher_data(f).to_string());
cipher.fields = data.fields.map(|f| _clean_cipher_data(f).to_string());
cipher.data = type_data.to_string();
cipher.password_history = data.password_history.map(|f| f.to_string());
cipher.reprompt = data.reprompt.filter(|r| *r == RepromptType::None as i32 || *r == RepromptType::Password as i32);
@@ -535,13 +534,6 @@ pub async fn update_cipher_from_data(
cipher.move_to_folder(data.folder_id, &headers.user.uuid, conn).await?;
cipher.set_favorite(data.favorite, &headers.user.uuid, conn).await?;
if let Some(dt_str) = data.archived_date {
match NaiveDateTime::parse_from_str(&dt_str, "%+") {
Ok(dt) => cipher.set_archived_at(dt, &headers.user.uuid, conn).await?,
Err(err) => warn!("Error parsing ArchivedDate '{dt_str}': {err}"),
}
}
if ut != UpdateType::None {
// Only log events for organizational ciphers
if let Some(org_id) = &cipher.organization_uuid {
@@ -608,7 +600,7 @@ async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbC
let existing_folders: HashSet<Option<FolderId>> =
Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect();
let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len());
for folder in data.folders {
for folder in data.folders.into_iter() {
let folder_id = if existing_folders.contains(&folder.id) {
folder.id.unwrap()
} else {
@@ -638,7 +630,7 @@ async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbC
let mut user = headers.user;
user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
Ok(())
}
@@ -733,10 +725,10 @@ async fn put_cipher_partial(
err!("Cipher does not exist", "Cipher is not accessible for the current user")
}
if let Some(ref folder_id) = data.folder_id
&& Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user");
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
}
// Move cipher
@@ -810,16 +802,12 @@ async fn post_collections_update(
err!("Collection cannot be changed")
}
let Some(ref org_uuid) = cipher.organization_uuid else {
err!("Cipher is not owned by an organization")
};
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, org_uuid, &conn).await {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
None => err!("Invalid collection ID provided"),
Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
@@ -850,7 +838,7 @@ async fn post_collections_update(
log_event(
EventType::CipherUpdatedCollections as i32,
&cipher.uuid,
org_uuid,
&cipher.organization_uuid.clone().unwrap(),
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
@@ -890,16 +878,12 @@ async fn post_collections_admin(
err!("Collection cannot be changed")
}
let Some(ref org_uuid) = cipher.organization_uuid else {
err!("Cipher is not owned by an organization")
};
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_admin_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, org_uuid, &conn).await {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
None => err!("Invalid collection ID provided"),
Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
@@ -930,7 +914,7 @@ async fn post_collections_admin(
log_event(
EventType::CipherUpdatedCollections as i32,
&cipher.uuid,
org_uuid,
&cipher.organization_uuid.unwrap(),
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
@@ -1000,7 +984,7 @@ async fn put_cipher_share_selected(
err!("You must select at least one collection.")
}
for cipher in &data.ciphers {
for cipher in data.ciphers.iter() {
if cipher.id.is_none() {
err!("Request missing ids field")
}
@@ -1012,15 +996,16 @@ async fn put_cipher_share_selected(
collection_ids: data.collection_ids.clone(),
};
if let Some(id) = shared_cipher_data.cipher.id.take() {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
} else {
err!("Request missing ids field")
match shared_cipher_data.cipher.id.take() {
Some(id) => {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
}
None => err!("Request missing ids field"),
};
}
// Multi share actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
Ok(())
}
@@ -1033,14 +1018,15 @@ async fn share_cipher_by_uuid(
nt: &Notify<'_>,
override_ut: Option<UpdateType>,
) -> JsonResult {
let mut cipher = if let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await {
if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await {
cipher
} else {
err!("Cipher is not write accessible")
let mut cipher = match Cipher::find_by_uuid(cipher_id, conn).await {
Some(cipher) => {
if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await {
cipher
} else {
err!("Cipher is not write accessible")
}
}
} else {
err!("Cipher doesn't exist")
None => err!("Cipher doesn't exist"),
};
let mut shared_to_collections = vec![];
@@ -1059,7 +1045,7 @@ async fn share_cipher_by_uuid(
}
}
}
}
};
// When LastKnownRevisionDate is None, it is a new cipher, so send CipherCreate.
// If there is an override, like when handling multiple items, we want to prevent a push notification for every single item
@@ -1257,10 +1243,10 @@ async fn save_attachment(
err!("Cipher is neither owned by a user nor an organization");
};
if let Some(size_limit) = size_limit
&& size > size_limit
{
err!("Attachment storage limit exceeded with this file");
if let Some(size_limit) = size_limit {
if size > size_limit {
err!("Attachment storage limit exceeded with this file");
}
}
let file_id = match &attachment {
@@ -1402,7 +1388,7 @@ async fn post_attachment_share(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
post_attachment(cipher_id, data, headers, conn, nt).await
}
@@ -1436,7 +1422,7 @@ async fn delete_attachment(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
}
#[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")]
@@ -1447,42 +1433,42 @@ async fn delete_attachment_admin(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
}
#[post("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[post("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[put("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete
}
#[put("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete
}
#[delete("/ciphers/<cipher_id>")]
async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[delete("/ciphers/<cipher_id>/admin")]
async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
@@ -1493,7 +1479,7 @@ async fn delete_cipher_selected(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1504,7 +1490,7 @@ async fn delete_cipher_selected_post(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1515,7 +1501,7 @@ async fn delete_cipher_selected_put(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
// soft delete
}
@@ -1526,7 +1512,7 @@ async fn delete_cipher_selected_admin(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1537,7 +1523,7 @@ async fn delete_cipher_selected_post_admin(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1548,18 +1534,18 @@ async fn delete_cipher_selected_put_admin(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
// soft delete
}
#[put("/ciphers/<cipher_id>/restore")]
async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/<cipher_id>/restore-admin")]
async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/restore-admin", data = "<data>")]
@@ -1569,7 +1555,7 @@ async fn restore_cipher_selected_admin(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
restore_multiple_ciphers(data, &headers, &conn, &nt).await
_restore_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[put("/ciphers/restore", data = "<data>")]
@@ -1579,7 +1565,7 @@ async fn restore_cipher_selected(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
restore_multiple_ciphers(data, &headers, &conn, &nt).await
_restore_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[derive(Deserialize)]
@@ -1600,10 +1586,10 @@ async fn move_cipher_selected(
let data = data.into_inner();
let user_id = &headers.user.uuid;
if let Some(ref folder_id) = data.folder_id
&& Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user");
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
}
let cipher_count = data.ids.len();
@@ -1632,7 +1618,7 @@ async fn move_cipher_selected(
.await;
} else {
// Multi move actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
}
if cipher_count != accessible_ciphers_count {
@@ -1684,7 +1670,7 @@ async fn purge_org_vault(
match Membership::find_confirmed_by_user_and_org(&user.uuid, &organization.org_id, &conn).await {
Some(member) if member.atype == MembershipType::Owner => {
Cipher::delete_all_by_organization(&organization.org_id, &conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
log_event(
EventType::OrganizationPurgedVault as i32,
@@ -1724,41 +1710,11 @@ async fn purge_personal_vault(
}
user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
Ok(())
}
#[put("/ciphers/<cipher_id>/archive")]
async fn archive_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
archive_cipher(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/archive", data = "<data>")]
async fn archive_cipher_selected(
data: Json<CipherIdsData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
archive_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[put("/ciphers/<cipher_id>/unarchive")]
async fn unarchive_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
unarchive_cipher(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/unarchive", data = "<data>")]
async fn unarchive_cipher_selected(
data: Json<CipherIdsData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
unarchive_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[derive(PartialEq)]
pub enum CipherDeleteOptions {
SoftSingle,
@@ -1767,7 +1723,7 @@ pub enum CipherDeleteOptions {
HardMulti,
}
async fn delete_cipher_by_uuid(
async fn _delete_cipher_by_uuid(
cipher_id: &CipherId,
headers: &Headers,
conn: &DbConn,
@@ -1833,7 +1789,7 @@ struct CipherIdsData {
ids: Vec<CipherId>,
}
async fn delete_multiple_ciphers(
async fn _delete_multiple_ciphers(
data: Json<CipherIdsData>,
headers: Headers,
conn: DbConn,
@@ -1843,18 +1799,18 @@ async fn delete_multiple_ciphers(
let data = data.into_inner();
for cipher_id in data.ids {
if let error @ Err(_) = delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
return error;
}
};
}
// Multi delete actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
Ok(())
}
async fn restore_cipher_by_uuid(
async fn _restore_cipher_by_uuid(
cipher_id: &CipherId,
headers: &Headers,
multi_restore: bool,
@@ -1900,7 +1856,7 @@ async fn restore_cipher_by_uuid(
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
}
async fn restore_multiple_ciphers(
async fn _restore_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &DbConn,
@@ -1910,14 +1866,14 @@ async fn restore_multiple_ciphers(
let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids {
match restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await {
match _restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()),
err => return err,
}
}
// Multi move actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, conn).await;
Ok(Json(json!({
"data": ciphers,
@@ -1926,7 +1882,7 @@ async fn restore_multiple_ciphers(
})))
}
async fn delete_cipher_attachment_by_id(
async fn _delete_cipher_attachment_by_id(
cipher_id: &CipherId,
attachment_id: &AttachmentId,
headers: &Headers,
@@ -1977,122 +1933,6 @@ async fn delete_cipher_attachment_by_id(
Ok(Json(json!({"cipher":cipher_json})))
}
async fn archive_cipher(
cipher_id: &CipherId,
headers: &Headers,
multi_archive: bool,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_accessible_to_user(&headers.user.uuid, conn).await {
err!("Cipher is not accessible for the current user")
}
cipher.set_archived_at(Utc::now().naive_utc(), &headers.user.uuid, conn).await?;
if !multi_archive {
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(conn).await,
&headers.device,
None,
conn,
)
.await;
}
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
}
async fn unarchive_cipher(
cipher_id: &CipherId,
headers: &Headers,
multi_unarchive: bool,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_accessible_to_user(&headers.user.uuid, conn).await {
err!("Cipher is not accessible for the current user")
}
cipher.unarchive(&headers.user.uuid, conn).await?;
if !multi_unarchive {
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(conn).await,
&headers.device,
None,
conn,
)
.await;
}
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
}
async fn archive_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids {
match archive_cipher(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()),
err => return err,
}
}
// Multi archive does not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), conn).await;
Ok(Json(json!({
"data": ciphers,
"object": "list",
"continuationToken": null
})))
}
async fn unarchive_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids {
match unarchive_cipher(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()),
err => return err,
}
}
// Multi unarchive does not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), conn).await;
Ok(Json(json!({
"data": ciphers,
"object": "list",
"continuationToken": null
})))
}
/// This will hold all the necessary data to improve a full sync of all the ciphers
/// It can be used during the `Cipher::to_json()` call.
/// It will prevent the so called N+1 SQL issue by running just a few queries which will hold all the data needed.
@@ -2102,7 +1942,6 @@ pub struct CipherSyncData {
pub cipher_folders: HashMap<CipherId, FolderId>,
pub cipher_favorites: HashSet<CipherId>,
pub cipher_collections: HashMap<CipherId, Vec<CollectionId>>,
pub cipher_archives: HashMap<CipherId, NaiveDateTime>,
pub members: HashMap<OrganizationId, Membership>,
pub user_collections: HashMap<CollectionId, CollectionUser>,
pub user_collections_groups: HashMap<CollectionId, CollectionGroup>,
@@ -2119,25 +1958,20 @@ impl CipherSyncData {
pub async fn new(user_id: &UserId, sync_type: CipherSyncType, conn: &DbConn) -> Self {
let cipher_folders: HashMap<CipherId, FolderId>;
let cipher_favorites: HashSet<CipherId>;
let cipher_archives: HashMap<CipherId, NaiveDateTime>;
match sync_type {
// User Sync supports Folders, Favorites, and Archives
// User Sync supports Folders and Favorites
CipherSyncType::User => {
// Generate a HashMap with the Cipher UUID as key and the Folder UUID as value
cipher_folders = FolderCipher::find_by_user(user_id, conn).await.into_iter().collect();
// Generate a HashSet of all the Cipher UUID's which are marked as favorite
cipher_favorites = Favorite::get_all_cipher_uuid_by_user(user_id, conn).await.into_iter().collect();
// Generate a HashMap with the Cipher UUID as key and the archived date time as value
cipher_archives = Archive::find_by_user(user_id, conn).await.into_iter().collect();
}
// Organization Sync does not support Folders, Favorites, or Archives.
// Organization Sync does not support Folders and Favorites.
// If these are set, it will cause issues in the web-vault.
CipherSyncType::Organization => {
cipher_folders = HashMap::with_capacity(0);
cipher_favorites = HashSet::with_capacity(0);
cipher_archives = HashMap::with_capacity(0);
}
}
@@ -2204,7 +2038,6 @@ impl CipherSyncData {
cipher_folders,
cipher_favorites,
cipher_collections,
cipher_archives,
members,
user_collections,
user_collections_groups,
+24 -28
View File
@@ -1,23 +1,23 @@
use chrono::{TimeDelta, Utc};
use rocket::{Route, serde::json::Json};
use rocket::{serde::json::Json, Route};
use serde_json::Value;
use crate::{
CONFIG,
api::{
EmptyResult, JsonResult,
core::{CipherSyncData, CipherSyncType},
EmptyResult, JsonResult,
},
auth::{Headers, decode_emergency_access_invite},
auth::{decode_emergency_access_invite, Headers},
db::{
DbConn, DbPool,
models::{
Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation,
Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId,
},
DbConn, DbPool,
},
mail,
util::NumberOrString,
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -55,7 +55,7 @@ async fn get_contacts(headers: Headers, conn: DbConn) -> Json<Value> {
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list {
if let Some(grantee) = ea.to_json_grantee_details(&conn).await {
emergency_access_list_json.push(grantee);
emergency_access_list_json.push(grantee)
}
}
@@ -89,14 +89,11 @@ async fn get_grantees(headers: Headers, conn: DbConn) -> Json<Value> {
async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
if let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
{
Ok(Json(
match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await {
Some(emergency_access) => Ok(Json(
emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"),
))
} else {
err!("Emergency access not valid.")
)),
None => err!("Emergency access not valid."),
}
}
@@ -139,10 +136,9 @@ async fn post_emergency_access(
err!("Emergency access not valid.")
};
let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) {
new_type as i32
} else {
err!("Invalid emergency access type.")
let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) {
Some(new_type) => new_type as i32,
None => err!("Invalid emergency access type."),
};
emergency_access.atype = new_type;
@@ -209,10 +205,9 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, co
let emergency_access_status = EmergencyAccessStatus::Invited as i32;
let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) {
new_type as i32
} else {
err!("Invalid emergency access type.")
let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) {
Some(new_type) => new_type as i32,
None => err!("Invalid emergency access type."),
};
let grantor_user = headers.user;
@@ -347,11 +342,12 @@ async fn accept_invite(
err!("Claim email does not match current users email")
}
let grantee_user = if let Some(user) = User::find_by_mail(&claims.email, &conn).await {
Invitation::take(&claims.email, &conn).await;
user
} else {
err!("Invited user not found")
let grantee_user = match User::find_by_mail(&claims.email, &conn).await {
Some(user) => {
Invitation::take(&claims.email, &conn).await;
user
}
None => err!("Invited user not found"),
};
// We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database.
@@ -770,7 +766,7 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
}
}
} else {
error!("Failed to get DB connection while searching emergency request timed out");
error!("Failed to get DB connection while searching emergency request timed out")
}
}
@@ -829,6 +825,6 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
}
}
} else {
error!("Failed to get DB connection while searching emergency notification reminder");
error!("Failed to get DB connection while searching emergency notification reminder")
}
}
+60 -54
View File
@@ -1,18 +1,18 @@
use std::net::IpAddr;
use chrono::NaiveDateTime;
use rocket::{Route, form::FromForm, serde::json::Json};
use rocket::{form::FromForm, serde::json::Json, Route};
use serde_json::Value;
use crate::{
CONFIG,
api::{EmptyResult, JsonResult},
auth::{AdminHeaders, Headers},
db::{
DbConn, DbPool,
models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId},
DbConn, DbPool,
},
util::parse_date,
CONFIG,
};
/// ###############################################################################################################
@@ -38,7 +38,9 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
// Return an empty vec when we org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if CONFIG.org_events_enabled() {
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
@@ -49,10 +51,8 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn)
.await
.iter()
.map(Event::to_json)
.map(|e| e.to_json())
.collect()
} else {
Vec::with_capacity(0)
};
Ok(Json(json!({
@@ -64,21 +64,27 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
#[get("/ciphers/<cipher_id>/events?<data..>")]
async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult {
// Return an empty vec when org events are disabled.
// Return an empty vec when we org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if CONFIG.org_events_enabled()
&& Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await
{
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
} else {
parse_date(&data.end)
};
Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn).await.iter().map(Event::to_json).collect()
} else {
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let mut events_json = Vec::with_capacity(0);
if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
} else {
parse_date(&data.end)
};
events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
.collect()
}
events_json
};
Ok(Json(json!({
@@ -101,7 +107,9 @@ async fn get_user_events(
}
// Return an empty vec when we org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if CONFIG.org_events_enabled() {
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
@@ -112,10 +120,8 @@ async fn get_user_events(
Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn)
.await
.iter()
.map(Event::to_json)
.map(|e| e.to_json())
.collect()
} else {
Vec::with_capacity(0)
};
Ok(Json(json!({
@@ -128,8 +134,7 @@ async fn get_user_events(
fn get_continuation_token(events_json: &[Value]) -> Option<&str> {
// When the length of the vec equals the max page_size there probably is more data
// When it is less, then all events are loaded.
#[expect(clippy::cast_possible_truncation, reason = "PAGE_SIZE fits within usize")]
if events_json.len() == Event::PAGE_SIZE as usize {
if events_json.len() as i64 == Event::PAGE_SIZE {
if let Some(last_event) = events_json.last() {
last_event["date"].as_str()
} else {
@@ -171,7 +176,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
let event_date = parse_date(&event.date);
match event.r#type {
1000..=1099 => {
log_user_event_impl(
_log_user_event(
event.r#type,
&headers.user.uuid,
headers.device.atype,
@@ -183,7 +188,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
}
1600..=1699 => {
if let Some(org_id) = &event.organization_id {
log_event_impl(
_log_event(
event.r#type,
org_id,
org_id,
@@ -197,21 +202,22 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
}
}
_ => {
if let Some(cipher_uuid) = &event.cipher_id
&& let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await
&& let Some(org_id) = cipher.organization_uuid
{
log_event_impl(
event.r#type,
cipher_uuid,
&org_id,
&headers.user.uuid,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&conn,
)
.await;
if let Some(cipher_uuid) = &event.cipher_id {
if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await {
if let Some(org_id) = cipher.organization_uuid {
_log_event(
event.r#type,
cipher_uuid,
&org_id,
&headers.user.uuid,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&conn,
)
.await;
}
}
}
}
}
@@ -223,10 +229,10 @@ pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32,
if !CONFIG.org_events_enabled() {
return;
}
log_user_event_impl(event_type, user_id, device_type, None, ip, conn).await;
_log_user_event(event_type, user_id, device_type, None, ip, conn).await;
}
async fn log_user_event_impl(
async fn _log_user_event(
event_type: i32,
user_id: &UserId,
device_type: i32,
@@ -272,11 +278,11 @@ pub async fn log_event(
if !CONFIG.org_events_enabled() {
return;
}
log_event_impl(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await;
_log_event(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await;
}
#[expect(clippy::too_many_arguments)]
async fn log_event_impl(
#[allow(clippy::too_many_arguments)]
async fn _log_event(
event_type: i32,
source_uuid: &str,
org_id: &OrganizationId,
@@ -292,24 +298,24 @@ async fn log_event_impl(
// 1000..=1099 Are user events, they need to be logged via log_user_event()
// Cipher Events
1100..=1199 => {
event.cipher_uuid = Some(source_uuid.to_owned().into());
event.cipher_uuid = Some(source_uuid.to_string().into());
}
// Collection Events
1300..=1399 => {
event.collection_uuid = Some(source_uuid.to_owned().into());
event.collection_uuid = Some(source_uuid.to_string().into());
}
// Group Events
1400..=1499 => {
event.group_uuid = Some(source_uuid.to_owned().into());
event.group_uuid = Some(source_uuid.to_string().into());
}
// Org User Events
1500..=1599 => {
event.org_user_uuid = Some(source_uuid.to_owned().into());
event.org_user_uuid = Some(source_uuid.to_string().into());
}
// 1600..=1699 Are organizational events, and they do not need the source_uuid
// Policy Events
1700..=1799 => {
event.policy_uuid = Some(source_uuid.to_owned().into());
event.policy_uuid = Some(source_uuid.to_string().into());
}
// Ignore others
_ => {}
@@ -332,6 +338,6 @@ pub async fn event_cleanup_job(pool: DbPool) {
if let Ok(conn) = pool.get().await {
Event::clean_events(&conn).await.ok();
} else {
error!("Failed to get DB connection while trying to cleanup the events table");
error!("Failed to get DB connection while trying to cleanup the events table")
}
}
+4 -5
View File
@@ -5,8 +5,8 @@ use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
DbConn,
models::{Folder, FolderId},
DbConn,
},
util::deser_opt_nonempty_str,
};
@@ -29,10 +29,9 @@ async fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
#[get("/folders/<folder_id>")]
async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult {
if let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Ok(Json(folder.to_json()))
} else {
err!("Invalid folder", "Folder does not exist or belongs to another user")
match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Some(folder) => Ok(Json(folder.to_json())),
_ => err!("Invalid folder", "Folder does not exist or belongs to another user"),
}
}
+40 -50
View File
@@ -1,6 +1,4 @@
pub mod accounts;
pub mod two_factor;
mod ciphers;
mod emergency_access;
mod events;
@@ -8,32 +6,17 @@ mod folders;
mod organizations;
mod public;
mod sends;
pub mod two_factor;
pub use accounts::purge_auth_requests;
pub use ciphers::{CipherData, CipherSyncData, CipherSyncType, purge_trashed_ciphers};
pub use ciphers::{purge_trashed_ciphers, CipherData, CipherSyncData, CipherSyncType};
pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job};
pub use events::{event_cleanup_job, log_event, log_user_event};
use reqwest::Method;
pub use sends::purge_sends;
use reqwest::Method;
use rocket::{Catcher, Route, serde::json::Json, serde::json::Value};
use crate::{
CONFIG,
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
DbConn,
models::{Membership, MembershipStatus, OrgPolicy, Organization, User},
},
error::Error,
http_client::make_http_request,
mail,
util::{FeatureFlagFilter, parse_experimental_client_feature_flags},
};
pub fn routes() -> Vec<Route> {
let mut eq_domains_routes = routes![get_settings_domains, post_settings_domains, put_settings_domains];
let mut eq_domains_routes = routes![get_eq_domains, post_eq_domains, put_eq_domains];
let mut hibp_routes = routes![hibp_breach];
let mut meta_routes = routes![alive, now, version, config, get_api_webauthn];
@@ -61,6 +44,25 @@ pub fn events_routes() -> Vec<Route> {
routes
}
//
// Move this somewhere else
//
use rocket::{serde::json::Json, serde::json::Value, Catcher, Route};
use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
models::{Membership, MembershipStatus, OrgPolicy, Organization, User},
DbConn,
},
error::Error,
http_client::make_http_request,
mail,
util::{parse_experimental_client_feature_flags, FeatureFlagFilter},
CONFIG,
};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GlobalDomain {
@@ -71,16 +73,14 @@ struct GlobalDomain {
const GLOBAL_DOMAINS: &str = include_str!("../../static/global_domains.json");
#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")]
#[get("/settings/domains")]
fn get_settings_domains(headers: Headers) -> Json<Value> {
get_eq_domains(&headers, false)
fn get_eq_domains(headers: Headers) -> Json<Value> {
_get_eq_domains(&headers, false)
}
fn get_eq_domains(headers: &Headers, no_excluded: bool) -> Json<Value> {
use serde_json::from_str;
fn _get_eq_domains(headers: &Headers, no_excluded: bool) -> Json<Value> {
let user = &headers.user;
use serde_json::from_str;
let equivalent_domains: Vec<Vec<String>> = from_str(&user.equivalent_domains).unwrap();
let excluded_globals: Vec<i32> = from_str(&user.excluded_globals).unwrap();
@@ -110,39 +110,28 @@ struct EquivDomainData {
}
#[post("/settings/domains", data = "<data>")]
async fn post_settings_domains(
data: Json<EquivDomainData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
use serde_json::to_string;
async fn post_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
let data: EquivDomainData = data.into_inner();
let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default();
let equivalent_domains = data.equivalent_domains.unwrap_or_default();
let mut user = headers.user;
use serde_json::to_string;
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_owned());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_owned());
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string());
user.save(&conn).await?;
nt.send_user_update(UpdateType::SyncSettings, &user, headers.device.push_uuid.as_ref(), &conn).await;
nt.send_user_update(UpdateType::SyncSettings, &user, &headers.device.push_uuid, &conn).await;
Ok(Json(json!({})))
}
#[put("/settings/domains", data = "<data>")]
async fn put_settings_domains(
data: Json<EquivDomainData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
post_settings_domains(data, headers, conn, nt).await
async fn put_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
post_eq_domains(data, headers, conn, nt).await
}
#[get("/hibp/breach?<username>")]
@@ -215,11 +204,11 @@ fn config() -> Json<Value> {
// Client (v2026.2.1): https://github.com/bitwarden/clients/blob/f96380c3138291a028bdd2c7a5fee540d5c98ba5/libs/common/src/enums/feature-flag.enum.ts#L12
// Android (v2026.2.1): https://github.com/bitwarden/android/blob/6902c19c0093fa476bbf74ccaa70c9f14afbb82f/core/src/main/kotlin/com/bitwarden/core/data/manager/model/FlagKey.kt#L31
// iOS (v2026.2.1): https://github.com/bitwarden/ios/blob/cdd9ba1770ca2ffc098d02d12cc3208e3a830454/BitwardenShared/Core/Platform/Models/Enum/FeatureFlag.swift#L7
let mut feature_states = parse_experimental_client_feature_flags(
let feature_states = parse_experimental_client_feature_flags(
&CONFIG.experimental_client_feature_flags(),
&FeatureFlagFilter::ValidOnly,
FeatureFlagFilter::ValidOnly,
);
feature_states.insert("pm-19148-innovation-archive".to_owned(), true);
// Add default feature_states here if needed, currently no features are needed by default.
Json(json!({
// Note: The clients use this version to handle backwards compatibility concerns
@@ -289,8 +278,9 @@ async fn accept_org_invite(
member.save(conn).await?;
if CONFIG.mail_enabled() {
let Some(org) = Organization::find_by_uuid(&member.org_uuid, conn).await else {
err!("Organization not found.")
let org = match Organization::find_by_uuid(&member.org_uuid, conn).await {
Some(org) => org,
None => err!("Organization not found."),
};
// User was invited to an organization, so they must be confirmed manually after acceptance
mail::send_invite_accepted(&user.email, &member.invited_by_email.unwrap_or(org.billing_email), &org.name)
+168 -187
View File
@@ -1,28 +1,27 @@
use num_traits::FromPrimitive;
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use num_traits::FromPrimitive;
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use crate::api::admin::FAKE_ADMIN_UUID;
use crate::{
CONFIG,
api::admin::FAKE_ADMIN_UUID,
api::{
core::{accept_org_invite, log_event, two_factor, CipherSyncData, CipherSyncType},
EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType,
core::{CipherSyncData, CipherSyncType, accept_org_invite, log_event, two_factor},
},
auth::{AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders, decode_invite},
auth::{decode_invite, AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders},
db::{
DbConn,
models::{
Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId, CollectionUser, EventType,
Group, GroupId, GroupUser, Invitation, Membership, MembershipId, MembershipStatus, MembershipType,
OrgPolicy, OrgPolicyType, Organization, OrganizationApiKey, OrganizationId, User, UserId,
},
DbConn,
},
mail,
sso::FAKE_SSO_IDENTIFIER,
util::{NumberOrString, convert_json_key_lcase_first},
util::{convert_json_key_lcase_first, get_uuid, NumberOrString},
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -65,7 +64,6 @@ pub fn routes() -> Vec<Route> {
post_org_import,
list_policies,
list_policies_token,
get_dummy_master_password_policy,
get_master_password_policy,
get_policy,
put_policy,
@@ -78,7 +76,6 @@ pub fn routes() -> Vec<Route> {
revoke_member,
bulk_revoke_members,
restore_member,
restore_member_vnext,
bulk_restore_members,
get_groups,
get_groups_details,
@@ -97,12 +94,11 @@ pub fn routes() -> Vec<Route> {
get_reset_password_details,
put_reset_password,
get_org_export,
post_api_key,
api_key,
rotate_api_key,
get_billing_metadata,
get_billing_warnings,
get_auto_enroll_status,
get_self_host_billing_metadata,
]
}
@@ -286,10 +282,9 @@ async fn get_organization(org_id: OrganizationId, headers: OwnerHeaders, conn: D
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
}
if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await {
Ok(Json(organization.to_json()))
} else {
err!("Can't find organization details")
match Organization::find_by_uuid(&org_id, &conn).await {
Some(organization) => Ok(Json(organization.to_json())),
None => err!("Can't find organization details"),
}
}
@@ -358,7 +353,7 @@ async fn get_user_collections(headers: Headers, conn: DbConn) -> Json<Value> {
// The returned `Id` will then be passed to `get_master_password_policy` which will mainly ignore it
#[get("/organizations/<identifier>/auto-enroll-status")]
async fn get_auto_enroll_status(identifier: &str, headers: Headers, conn: DbConn) -> JsonResult {
let org = if identifier == FAKE_SSO_IDENTIFIER {
let org = if identifier == crate::sso::FAKE_IDENTIFIER {
match Membership::find_main_user_org(&headers.user.uuid, &conn).await {
Some(member) => Organization::find_by_uuid(&member.org_uuid, &conn).await,
None => None,
@@ -368,7 +363,7 @@ async fn get_auto_enroll_status(identifier: &str, headers: Headers, conn: DbConn
};
let (id, identifier, rp_auto_enroll) = match org {
None => (identifier.to_owned(), identifier.to_owned(), false),
None => (get_uuid(), identifier.to_string(), false),
Some(org) => (
org.uuid.to_string(),
org.uuid.to_string(),
@@ -394,7 +389,7 @@ async fn get_org_collections(org_id: OrganizationId, headers: ManagerHeadersLoos
}
Ok(Json(json!({
"data": get_org_collections_impl(&org_id, &conn).await,
"data": _get_org_collections(&org_id, &conn).await,
"object": "list",
"continuationToken": null,
})))
@@ -466,7 +461,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
CollectionGroup::find_by_collection(&col.uuid, &conn)
.await
.iter()
.map(CollectionGroup::to_json_details_for_group)
.map(|collection_group| collection_group.to_json_details_for_group())
.collect()
} else {
Vec::with_capacity(0)
@@ -478,7 +473,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
json_object["groups"] = json!(groups);
json_object["object"] = json!("collectionAccessDetails");
json_object["unmanaged"] = json!(false);
data.push(json_object);
data.push(json_object)
}
Ok(Json(json!({
@@ -488,7 +483,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
})))
}
async fn get_org_collections_impl(org_id: &OrganizationId, conn: &DbConn) -> Value {
async fn _get_org_collections(org_id: &OrganizationId, conn: &DbConn) -> Value {
Collection::find_by_organization(org_id, conn).await.iter().map(Collection::to_json).collect::<Value>()
}
@@ -574,7 +569,7 @@ async fn post_bulk_access_collections(
if Organization::find_by_uuid(&org_id, &conn).await.is_none() {
err!("Can't find organization details")
}
};
for col_id in data.collection_ids {
let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else {
@@ -651,7 +646,7 @@ async fn post_organization_collection_update(
if Organization::find_by_uuid(&org_id, &conn).await.is_none() {
err!("Can't find organization details")
}
};
let Some(mut collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else {
err!("Collection not found")
@@ -702,7 +697,7 @@ async fn post_organization_collection_update(
Ok(Json(collection.to_json_details(&headers.user.uuid, None, &conn).await))
}
async fn delete_organization_collection_impl(
async fn _delete_organization_collection(
org_id: &OrganizationId,
col_id: &CollectionId,
headers: &ManagerHeaders,
@@ -734,7 +729,7 @@ async fn delete_organization_collection(
headers: ManagerHeaders,
conn: DbConn,
) -> EmptyResult {
delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await
}
#[post("/organizations/<org_id>/collections/<col_id>/delete")]
@@ -744,7 +739,7 @@ async fn post_organization_collection_delete(
headers: ManagerHeaders,
conn: DbConn,
) -> EmptyResult {
delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await
}
#[derive(Deserialize, Debug)]
@@ -770,7 +765,7 @@ async fn bulk_delete_organization_collections(
let headers = ManagerHeaders::from_loose(headers, &collections, &conn).await?;
for col_id in collections {
delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await?;
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await?
}
Ok(())
}
@@ -800,7 +795,7 @@ async fn get_org_collection_detail(
CollectionGroup::find_by_collection(&collection.uuid, &conn)
.await
.iter()
.map(CollectionGroup::to_json_details_for_group)
.map(|collection_group| collection_group.to_json_details_for_group())
.collect()
} else {
// The Bitwarden clients seem to call this API regardless of whether groups are enabled,
@@ -887,13 +882,13 @@ async fn get_org_details(data: OrgIdData, headers: ManagerHeadersLoose, conn: Db
}
Ok(Json(json!({
"data": get_org_details_impl(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?,
"data": _get_org_details(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?,
"object": "list",
"continuationToken": null,
})))
}
async fn get_org_details_impl(
async fn _get_org_details(
org_id: &OrganizationId,
host: &str,
user_id: &UserId,
@@ -909,21 +904,36 @@ async fn get_org_details_impl(
Ok(json!(ciphers_json))
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct OrgDomainDetails {
email: String,
}
// Returning a Domain/Organization here allow to prefill it and prevent prompting the user
// So we return a dummy value, since we only support a single SSO integration, and do not use the response anywhere
// So we either return an Org name associated to the user or a dummy value.
// In use since `v2025.6.0`, appears to use only the first `organizationIdentifier`
#[post("/organizations/domain/sso/verified")]
fn get_org_domain_sso_verified() -> JsonResult {
// Always return a dummy value, no matter if SSO is enabled or not
#[post("/organizations/domain/sso/verified", data = "<data>")]
async fn get_org_domain_sso_verified(data: Json<OrgDomainDetails>, conn: DbConn) -> JsonResult {
let data: OrgDomainDetails = data.into_inner();
let identifiers = match Organization::find_org_user_email(&data.email, &conn)
.await
.into_iter()
.map(|o| (o.name, o.uuid.to_string()))
.collect::<Vec<(String, String)>>()
{
v if !v.is_empty() => v,
_ => vec![(crate::sso::FAKE_IDENTIFIER.to_string(), crate::sso::FAKE_IDENTIFIER.to_string())],
};
Ok(Json(json!({
"object": "list",
"data": [{
"organizationIdentifier": FAKE_SSO_IDENTIFIER,
// These appear to be unused
"organizationName": FAKE_SSO_IDENTIFIER,
"domainName": CONFIG.domain()
}],
"continuationToken": null
"data": identifiers.into_iter().map(|(name, identifier)| json!({
"organizationName": name, // appear unused
"organizationIdentifier": identifier,
"domainName": CONFIG.domain(), // appear unused
})).collect::<Vec<Value>>()
})))
}
@@ -976,13 +986,14 @@ async fn post_org_keys(
}
let data: OrgKeyData = data.into_inner();
let mut org = if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await {
if organization.private_key.is_some() && organization.public_key.is_some() {
err!("Organization Keys already exist")
let mut org = match Organization::find_by_uuid(&org_id, &conn).await {
Some(organization) => {
if organization.private_key.is_some() && organization.public_key.is_some() {
err!("Organization Keys already exist")
}
organization
}
organization
} else {
err!("Can't find organization details")
None => err!("Can't find organization details"),
};
org.private_key = Some(data.encrypted_private_key);
@@ -1043,10 +1054,9 @@ async fn send_invite(
// The from_str() will convert the custom role type into a manager role type
let raw_type = &data.r#type.into_string();
// Membership::from_str will convert custom (4) to manager (3)
let new_type = if let Some(new_type) = MembershipType::from_str(raw_type) {
new_type as i32
} else {
err!("Invalid type")
let new_type = match MembershipType::from_str(raw_type) {
Some(new_type) => new_type as i32,
None => err!("Invalid type"),
};
if new_type != MembershipType::User && headers.membership_type != MembershipType::Owner {
@@ -1063,7 +1073,7 @@ async fn send_invite(
&& data.permissions.get("createNewCollections") == Some(&json!(true)));
let mut user_created: bool = false;
for email in &data.emails {
for email in data.emails.iter() {
let mut member_status = MembershipStatus::Invited as i32;
let user = match User::find_by_mail(email, &conn).await {
None => {
@@ -1087,13 +1097,13 @@ async fn send_invite(
Some(user) => {
if Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await.is_some() {
err!(format!("User already in organization: {email}"))
} else {
// automatically accept existing users if mail is disabled
if !CONFIG.mail_enabled() && !user.password_hash.is_empty() {
member_status = MembershipStatus::Accepted as i32;
}
user
}
// automatically accept existing users if mail is disabled
if !CONFIG.mail_enabled() && !user.password_hash.is_empty() {
member_status = MembershipStatus::Accepted as i32;
}
user
}
};
@@ -1104,10 +1114,9 @@ async fn send_invite(
new_member.save(&conn).await?;
if CONFIG.mail_enabled() {
let org_name = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await {
org.name
} else {
err!("Error looking up organization")
let org_name = match Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => org.name,
None => err!("Error looking up organization"),
};
if let Err(e) = mail::send_invite(
@@ -1161,7 +1170,7 @@ async fn send_invite(
}
}
for group_id in &data.groups {
for group_id in data.groups.iter() {
let mut group_entry = GroupUser::new(group_id.clone(), new_member.uuid.clone());
group_entry.save(&conn).await?;
}
@@ -1184,8 +1193,8 @@ async fn bulk_reinvite_members(
let mut bulk_response = Vec::new();
for member_id in data.ids {
let err_msg = match reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await {
Ok(()) => String::new(),
let err_msg = match _reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await {
Ok(_) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -1195,7 +1204,7 @@ async fn bulk_reinvite_members(
"id": member_id,
"error": err_msg
}
));
))
}
Ok(Json(json!({
@@ -1215,10 +1224,10 @@ async fn reinvite_member(
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
}
reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await
_reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await
}
async fn reinvite_member_impl(
async fn _reinvite_member(
org_id: &OrganizationId,
member_id: &MembershipId,
invited_by_email: &str,
@@ -1240,14 +1249,13 @@ async fn reinvite_member_impl(
err!("Invitations are not allowed.")
}
let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await {
org.name
} else {
err!("Error looking up organization.")
let org_name = match Organization::find_by_uuid(org_id, conn).await {
Some(org) => org.name,
None => err!("Error looking up organization."),
};
if CONFIG.mail_enabled() {
mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_owned())).await?;
mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_string())).await?;
} else if user.password_hash.is_empty() {
let invitation = Invitation::new(&user.email);
invitation.save(conn).await?;
@@ -1355,8 +1363,8 @@ async fn bulk_confirm_invite(
for invite in keys {
let member_id = invite.id.unwrap();
let user_key = invite.key.unwrap_or_default();
let err_msg = match confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await {
Ok(()) => String::new(),
let err_msg = match _confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await {
Ok(_) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -1390,10 +1398,10 @@ async fn confirm_invite(
) -> EmptyResult {
let data = data.into_inner();
let user_key = data.key.unwrap_or_default();
confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await
_confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await
}
async fn confirm_invite_impl(
async fn _confirm_invite(
org_id: &OrganizationId,
member_id: &MembershipId,
key: &str,
@@ -1421,7 +1429,7 @@ async fn confirm_invite_impl(
}
member_to_confirm.status = MembershipStatus::Confirmed as i32;
member_to_confirm.akey = key.to_owned();
member_to_confirm.akey = key.to_string();
// This check is also done at accept_invite, _confirm_invite, _activate_member, edit_member, admin::update_membership_type
OrgPolicy::check_user_allowed(&member_to_confirm, "confirm", conn).await?;
@@ -1438,15 +1446,13 @@ async fn confirm_invite_impl(
.await;
if CONFIG.mail_enabled() {
let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await {
org.name
} else {
err!("Error looking up organization.")
let org_name = match Organization::find_by_uuid(org_id, conn).await {
Some(org) => org.name,
None => err!("Error looking up organization."),
};
let address = if let Some(user) = User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
user.email
} else {
err!("Error looking up user.")
let address = match User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
Some(user) => user.email,
None => err!("Error looking up user."),
};
mail::send_invite_confirmed(&address, &org_name).await?;
}
@@ -1454,7 +1460,7 @@ async fn confirm_invite_impl(
let save_result = member_to_confirm.save(conn).await;
if let Some(user) = User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
nt.send_user_update(UpdateType::SyncOrgKeys, &user, headers.device.push_uuid.as_ref(), conn).await;
nt.send_user_update(UpdateType::SyncOrgKeys, &user, &headers.device.push_uuid, conn).await;
}
save_result
@@ -1642,8 +1648,8 @@ async fn bulk_delete_member(
let mut bulk_response = Vec::new();
for member_id in data.ids {
let err_msg = match delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await {
Ok(()) => String::new(),
let err_msg = match _delete_member(&org_id, &member_id, &headers, &conn, &nt).await {
Ok(_) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -1653,7 +1659,7 @@ async fn bulk_delete_member(
"id": member_id,
"error": err_msg
}
));
))
}
Ok(Json(json!({
@@ -1671,10 +1677,10 @@ async fn delete_member(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await
_delete_member(&org_id, &member_id, &headers, &conn, &nt).await
}
async fn delete_member_impl(
async fn _delete_member(
org_id: &OrganizationId,
member_id: &MembershipId,
headers: &AdminHeaders,
@@ -1712,7 +1718,7 @@ async fn delete_member_impl(
.await;
if let Some(user) = User::find_by_uuid(&member_to_delete.user_uuid, conn).await {
nt.send_user_update(UpdateType::SyncOrgKeys, &user, headers.device.push_uuid.as_ref(), conn).await;
nt.send_user_update(UpdateType::SyncOrgKeys, &user, &headers.device.push_uuid, conn).await;
}
member_to_delete.delete(conn).await
@@ -1758,8 +1764,8 @@ async fn bulk_public_keys(
})))
}
use super::ciphers::CipherData;
use super::ciphers::update_cipher_from_data;
use super::ciphers::CipherData;
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -1907,24 +1913,24 @@ async fn post_bulk_collections(data: Json<BulkCollectionsData>, headers: Headers
}
}
for cipher_id in &data.cipher_ids {
for cipher_id in data.cipher_ids.iter() {
// Only act on existing cipher uuid's
// Do not abort the operation just ignore it, it could be a cipher was just deleted for example
if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await
&& cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await
{
// When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection
// In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting.
if data.remove_collections {
for collection in &data.collection_ids {
CollectionCipher::delete(&cipher.uuid, collection, &conn).await?;
}
} else {
for collection in &data.collection_ids {
CollectionCipher::save(&cipher.uuid, collection, &conn).await?;
if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await {
if cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
// When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection
// In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting.
if data.remove_collections {
for collection in &data.collection_ids {
CollectionCipher::delete(&cipher.uuid, collection, &conn).await?;
}
} else {
for collection in &data.collection_ids {
CollectionCipher::save(&cipher.uuid, collection, &conn).await?;
}
}
}
}
};
}
Ok(())
@@ -1969,25 +1975,15 @@ async fn list_policies_token(org_id: OrganizationId, token: &str, conn: DbConn)
})))
}
// Called during the SSO enrollment return the default policy
#[get("/organizations/00000000-01DC-01DC-01DC-000000000000/policies/master-password", rank = 1)]
fn get_dummy_master_password_policy() -> JsonResult {
let (enabled, data) = match CONFIG.sso_master_password_policy_value() {
Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()),
_ => (false, "null".to_owned()),
};
let policy = OrgPolicy::new(FAKE_SSO_IDENTIFIER.into(), OrgPolicyType::MasterPassword, enabled, data);
Ok(Json(policy.to_json()))
}
// Called during the SSO enrollment return the org policy if it exists
#[get("/organizations/<org_id>/policies/master-password", rank = 2)]
// Called during the SSO enrollment.
// Return the org policy if it exists, otherwise use the default one.
#[get("/organizations/<org_id>/policies/master-password", rank = 1)]
async fn get_master_password_policy(org_id: OrganizationId, _headers: OrgMemberHeaders, conn: DbConn) -> JsonResult {
let policy =
OrgPolicy::find_by_org_and_type(&org_id, OrgPolicyType::MasterPassword, &conn).await.unwrap_or_else(|| {
let (enabled, data) = match CONFIG.sso_master_password_policy_value() {
Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()),
_ => (false, "null".to_owned()),
_ => (false, "null".to_string()),
};
OrgPolicy::new(org_id, OrgPolicyType::MasterPassword, enabled, data)
@@ -1996,7 +1992,7 @@ async fn get_master_password_policy(org_id: OrganizationId, _headers: OrgMemberH
Ok(Json(policy.to_json()))
}
#[get("/organizations/<org_id>/policies/<pol_type>", rank = 3)]
#[get("/organizations/<org_id>/policies/<pol_type>", rank = 2)]
async fn get_policy(org_id: OrganizationId, pol_type: i32, headers: AdminHeaders, conn: DbConn) -> JsonResult {
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
@@ -2008,7 +2004,7 @@ async fn get_policy(org_id: OrganizationId, pol_type: i32, headers: AdminHeaders
let policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await {
Some(p) => p,
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_owned()),
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_string()),
};
Ok(Json(policy.to_json()))
@@ -2083,7 +2079,7 @@ async fn put_policy(
// When enabling the SingleOrg policy, remove this org's members that are members of other orgs
if pol_type_enum == OrgPolicyType::SingleOrg && data.enabled {
for mut member in Membership::find_by_org(&org_id, &conn).await {
for mut member in Membership::find_by_org(&org_id, &conn).await.into_iter() {
// Policy only applies to non-Owner/non-Admin members who have accepted joining the org
// Exclude invited and revoked users when checking for this policy.
// Those users will not be allowed to accept or be activated because of the policy checks done there.
@@ -2118,7 +2114,7 @@ async fn put_policy(
let mut policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await {
Some(p) => p,
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_owned()),
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_string()),
};
policy.enabled = data.enabled;
@@ -2192,7 +2188,7 @@ fn get_plans() -> Json<Value> {
#[get("/organizations/<_org_id>/billing/metadata")]
fn get_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHeaders) -> Json<Value> {
// Prevent a 404 error, which also causes Javascript errors.
Json(empty_data_json())
Json(_empty_data_json())
}
#[get("/organizations/<_org_id>/billing/vnext/warnings")]
@@ -2205,16 +2201,7 @@ fn get_billing_warnings(_org_id: OrganizationId, _headers: OrgMemberHeaders) ->
}))
}
#[get("/organizations/<_org_id>/billing/vnext/self-host/metadata")]
fn get_self_host_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHeaders) -> Json<Value> {
// Prevent a 404 error, which also causes Javascript errors.
Json(json!({
"isOnSecretsManagerStandalone": false, // Secrets Manager is not supported by Vaultwarden
"organizationOccupiedSeats": 0 // Vaultwarden does not count seats
}))
}
fn empty_data_json() -> Value {
fn _empty_data_json() -> Value {
json!({
"object": "list",
"data": [],
@@ -2235,7 +2222,7 @@ async fn revoke_member(
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
revoke_member_impl(&org_id, &member_id, &headers, &conn).await
_revoke_member(&org_id, &member_id, &headers, &conn).await
}
#[put("/organizations/<org_id>/users/revoke", data = "<data>")]
@@ -2254,8 +2241,8 @@ async fn bulk_revoke_members(
match data.ids {
Some(members) => {
for member_id in members {
let err_msg = match revoke_member_impl(&org_id, &member_id, &headers, &conn).await {
Ok(()) => String::new(),
let err_msg = match _revoke_member(&org_id, &member_id, &headers, &conn).await {
Ok(_) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -2278,7 +2265,7 @@ async fn bulk_revoke_members(
})))
}
async fn revoke_member_impl(
async fn _revoke_member(
org_id: &OrganizationId,
member_id: &MembershipId,
headers: &AdminHeaders,
@@ -2321,18 +2308,6 @@ async fn revoke_member_impl(
Ok(())
}
#[put("/organizations/<org_id>/users/<member_id>/restore/vnext")]
async fn restore_member_vnext(
org_id: OrganizationId,
member_id: MembershipId,
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
// Vaultwarden does not (yet) support the per User Collection linked to the `Enforce organization data ownership` policy.
// Therefor we ignore the `defaultUserCollectionName` data sent and just call restore_member
restore_member_impl(&org_id, &member_id, &headers, &conn).await
}
#[put("/organizations/<org_id>/users/<member_id>/restore")]
async fn restore_member(
org_id: OrganizationId,
@@ -2340,7 +2315,7 @@ async fn restore_member(
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
restore_member_impl(&org_id, &member_id, &headers, &conn).await
_restore_member(&org_id, &member_id, &headers, &conn).await
}
#[put("/organizations/<org_id>/users/restore", data = "<data>")]
@@ -2357,8 +2332,8 @@ async fn bulk_restore_members(
let mut bulk_response = Vec::new();
for member_id in data.ids {
let err_msg = match restore_member_impl(&org_id, &member_id, &headers, &conn).await {
Ok(()) => String::new(),
let err_msg = match _restore_member(&org_id, &member_id, &headers, &conn).await {
Ok(_) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -2378,7 +2353,7 @@ async fn bulk_restore_members(
})))
}
async fn restore_member_impl(
async fn _restore_member(
org_id: &OrganizationId,
member_id: &MembershipId,
headers: &AdminHeaders,
@@ -2434,11 +2409,11 @@ async fn get_groups_data(
if details {
for g in groups {
groups_json.push(g.to_json_details(&conn).await);
groups_json.push(g.to_json_details(&conn).await)
}
} else {
for g in groups {
groups_json.push(g.to_json());
groups_json.push(g.to_json())
}
}
groups_json
@@ -2677,15 +2652,15 @@ async fn post_delete_group(
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
delete_group_impl(&org_id, &group_id, &headers, &conn).await
_delete_group(&org_id, &group_id, &headers, &conn).await
}
#[delete("/organizations/<org_id>/groups/<group_id>")]
async fn delete_group(org_id: OrganizationId, group_id: GroupId, headers: AdminHeaders, conn: DbConn) -> EmptyResult {
delete_group_impl(&org_id, &group_id, &headers, &conn).await
_delete_group(&org_id, &group_id, &headers, &conn).await
}
async fn delete_group_impl(
async fn _delete_group(
org_id: &OrganizationId,
group_id: &GroupId,
headers: &AdminHeaders,
@@ -2733,7 +2708,7 @@ async fn bulk_delete_groups(
let data: BulkGroupIds = data.into_inner();
for group_id in data.ids {
delete_group_impl(&org_id, &group_id, &headers, &conn).await?;
_delete_group(&org_id, &group_id, &headers, &conn).await?
}
Ok(())
}
@@ -2770,7 +2745,7 @@ async fn get_group_members(
if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() {
err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization")
}
};
let group_members: Vec<MembershipId> = GroupUser::find_by_group(&group_id, &org_id, &conn)
.await
@@ -2798,7 +2773,7 @@ async fn put_group_members(
if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() {
err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization")
}
};
let assigned_members = data.into_inner();
@@ -3052,7 +3027,10 @@ async fn put_reset_password_enrollment(
err!("User to enroll isn't member of required organization", "The user_id and acting user do not match");
}
let mut membership = headers.membership;
let Some(mut membership) = Membership::find_confirmed_by_user_and_org(&headers.user.uuid, &org_id, &conn).await
else {
err!("User to enroll isn't member of required organization")
};
check_reset_password_applicable(&org_id, &conn).await?;
@@ -3105,12 +3083,12 @@ async fn get_org_export(org_id: OrganizationId, headers: AdminHeaders, conn: DbC
}
Ok(Json(json!({
"collections": convert_json_key_lcase_first(get_org_collections_impl(&org_id, &conn).await),
"ciphers": convert_json_key_lcase_first(get_org_details_impl(&org_id, &headers.host, &headers.user.uuid, &conn).await?),
"collections": convert_json_key_lcase_first(_get_org_collections(&org_id, &conn).await),
"ciphers": convert_json_key_lcase_first(_get_org_details(&org_id, &headers.host, &headers.user.uuid, &conn).await?),
})))
}
async fn api_key(
async fn _api_key(
org_id: &OrganizationId,
data: Json<PasswordOrOtpData>,
rotate: bool,
@@ -3126,18 +3104,21 @@ async fn api_key(
// Validate the admin users password/otp
data.validate(&user, true, &conn).await?;
let org_api_key = if let Some(mut org_api_key) = OrganizationApiKey::find_by_org_uuid(org_id, &conn).await {
if rotate {
org_api_key.api_key = crate::crypto::generate_api_key();
org_api_key.revision_date = chrono::Utc::now().naive_utc();
org_api_key.save(&conn).await.expect("Error rotating organization API Key");
let org_api_key = match OrganizationApiKey::find_by_org_uuid(org_id, &conn).await {
Some(mut org_api_key) => {
if rotate {
org_api_key.api_key = crate::crypto::generate_api_key();
org_api_key.revision_date = chrono::Utc::now().naive_utc();
org_api_key.save(&conn).await.expect("Error rotating organization API Key");
}
org_api_key
}
None => {
let api_key = crate::crypto::generate_api_key();
let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key);
new_org_api_key.save(&conn).await.expect("Error creating organization API Key");
new_org_api_key
}
org_api_key
} else {
let api_key = crate::crypto::generate_api_key();
let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key);
new_org_api_key.save(&conn).await.expect("Error creating organization API Key");
new_org_api_key
};
Ok(Json(json!({
@@ -3148,13 +3129,13 @@ async fn api_key(
}
#[post("/organizations/<org_id>/api-key", data = "<data>")]
async fn post_api_key(
async fn api_key(
org_id: OrganizationId,
data: Json<PasswordOrOtpData>,
headers: AdminHeaders,
conn: DbConn,
) -> JsonResult {
api_key(&org_id, data, false, headers, conn).await
_api_key(&org_id, data, false, headers, conn).await
}
#[post("/organizations/<org_id>/rotate-api-key", data = "<data>")]
@@ -3164,5 +3145,5 @@ async fn rotate_api_key(
headers: AdminHeaders,
conn: DbConn,
) -> JsonResult {
api_key(&org_id, data, true, headers, conn).await
_api_key(&org_id, data, true, headers, conn).await
}
+62 -59
View File
@@ -1,24 +1,23 @@
use std::collections::HashSet;
use chrono::Utc;
use rocket::{
Request, Route,
request::{FromRequest, Outcome},
serde::json::Json,
Request, Route,
};
use std::collections::HashSet;
use crate::{
CONFIG,
api::EmptyResult,
auth,
db::{
DbConn,
models::{
Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization,
OrganizationApiKey, OrganizationId, User,
},
DbConn,
},
mail,
mail, CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -91,18 +90,19 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
}
} else {
// If user is not part of the organization
let user = if let Some(user) = User::find_by_mail(&user_data.email, &conn).await {
user
} else {
// User does not exist yet
let mut new_user = User::new(&user_data.email, None);
new_user.save(&conn).await?;
let user = match User::find_by_mail(&user_data.email, &conn).await {
Some(user) => user, // exists in vaultwarden
None => {
// User does not exist yet
let mut new_user = User::new(&user_data.email, None);
new_user.save(&conn).await?;
if !CONFIG.mail_enabled() {
Invitation::new(&new_user.email).save(&conn).await?;
if !CONFIG.mail_enabled() {
Invitation::new(&new_user.email).save(&conn).await?;
}
user_created = true;
new_user
}
user_created = true;
new_user
};
let member_status = if CONFIG.mail_enabled() || user.password_hash.is_empty() {
MembershipStatus::Invited as i32
@@ -110,10 +110,9 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites
};
let (org_name, org_email) = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await {
(org.name, org.billing_email)
} else {
err!("Error looking up organization")
let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => (org.name, org.billing_email),
None => err!("Error looking up organization"),
};
let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(org_email.clone()));
@@ -124,33 +123,37 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
new_member.save(&conn).await?;
if CONFIG.mail_enabled()
&& let Err(e) =
if CONFIG.mail_enabled() {
if let Err(e) =
mail::send_invite(&user, org_id.clone(), new_member.uuid.clone(), &org_name, Some(org_email)).await
{
// Upon error delete the user, invite and org member records when needed
if user_created {
user.delete(&conn).await?;
} else {
new_member.delete(&conn).await?;
}
{
// Upon error delete the user, invite and org member records when needed
if user_created {
user.delete(&conn).await?;
} else {
new_member.delete(&conn).await?;
}
err!(format!("Error sending invite: {e:?} "));
err!(format!("Error sending invite: {e:?} "));
}
}
}
}
if CONFIG.org_groups_enabled() {
for group_data in &data.groups {
let group_uuid = if let Some(group) =
Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await
{
group.uuid
} else {
let mut group =
Group::new(org_id.clone(), group_data.name.clone(), false, Some(group_data.external_id.clone()));
group.save(&conn).await?;
group.uuid
let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await {
Some(group) => group.uuid,
None => {
let mut group = Group::new(
org_id.clone(),
group_data.name.clone(),
false,
Some(group_data.external_id.clone()),
);
group.save(&conn).await?;
group.uuid
}
};
GroupUser::delete_all_by_group(&group_uuid, &org_id, &conn).await?;
@@ -171,17 +174,18 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
// Generate a HashSet to quickly verify if a member is listed or not.
let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect();
for member in Membership::find_by_org(&org_id, &conn).await {
if let Some(ref user_external_id) = member.external_id
&& !sync_members.contains(user_external_id)
{
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 {
warn!("Can't delete the last owner");
continue;
if let Some(ref user_external_id) = member.external_id {
if !sync_members.contains(user_external_id) {
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1
{
warn!("Can't delete the last owner");
continue;
}
}
member.delete(&conn).await?;
}
member.delete(&conn).await?;
}
}
}
@@ -198,14 +202,12 @@ impl<'r> FromRequest<'r> for PublicToken {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = request.headers();
// Get access_token
let access_token: &str = if let Some(a) = headers.get_one("Authorization") {
if let Some(split) = a.rsplit("Bearer ").next() {
split
} else {
err_handler!("No access token provided")
}
} else {
err_handler!("No access token provided")
let access_token: &str = match headers.get_one("Authorization") {
Some(a) => match a.rsplit("Bearer ").next() {
Some(split) => split,
None => err_handler!("No access token provided"),
},
None => err_handler!("No access token provided"),
};
// Check JWT token is valid and get device and user from it
let Ok(claims) = auth::decode_api_org(access_token) else {
@@ -227,13 +229,14 @@ impl<'r> FromRequest<'r> for PublicToken {
// Check if claims.sub is org_api_key.uuid
// Check if claims.client_sub is org_api_key.org_uuid
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
let Some(org_id) = claims.client_id.strip_prefix("organization.") else {
err_handler!("Malformed client_id")
};
let org_id: OrganizationId = org_id.to_owned().into();
let org_id: OrganizationId = org_id.to_string().into();
let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, &conn).await else {
err_handler!("Invalid client_id")
};
+34 -36
View File
@@ -10,15 +10,15 @@ use rocket::{
use serde_json::Value;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
auth::{ClientIp, Headers, Host},
config::PathType,
db::{
DbConn, DbPool,
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
DbConn, DbPool,
},
util::{NumberOrString, save_temp_file},
util::{save_temp_file, NumberOrString},
CONFIG,
};
const SEND_INACCESSIBLE_MSG: &str = "Send does not exist or is no longer available";
@@ -63,7 +63,7 @@ pub async fn purge_sends(pool: DbPool) {
if let Ok(conn) = pool.get().await {
Send::purge(&conn).await;
} else {
error!("Failed to get DB connection while purging sends");
error!("Failed to get DB connection while purging sends")
}
}
@@ -168,7 +168,7 @@ fn create_send(data: SendData, user_id: UserId) -> ApiResult<Send> {
#[get("/sends")]
async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &conn);
let sends_json: Vec<Value> = sends.await.iter().map(Send::to_json).collect();
let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect();
Json(json!({
"data": sends_json,
@@ -179,10 +179,9 @@ async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
#[get("/sends/<send_id>")]
async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult {
if let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Ok(Json(send.to_json()))
} else {
err!("Send not found", "Invalid send uuid or does not belong to user")
match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Some(send) => Ok(Json(send.to_json())),
None => err!("Send not found", "Invalid send uuid or does not belong to user"),
}
}
@@ -311,10 +310,9 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, conn: DbConn)
enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let file_length = if let Some(m) = &data.file_length {
m.into_i64()?
} else {
err!("Invalid send length")
let file_length = match &data.file_length {
Some(m) => m.into_i64()?,
_ => err!("Invalid send length"),
};
if file_length < 0 {
err!("Send size can't be negative")
@@ -459,16 +457,16 @@ async fn post_access(
err_code!(SEND_INACCESSIBLE_MSG, 404)
};
if let Some(max_access_count) = send.max_access_count
&& send.access_count >= max_access_count
{
err_code!(SEND_INACCESSIBLE_MSG, 404);
if let Some(max_access_count) = send.max_access_count {
if send.access_count >= max_access_count {
err_code!(SEND_INACCESSIBLE_MSG, 404);
}
}
if let Some(expiration) = send.expiration_date
&& Utc::now().naive_utc() >= expiration
{
err_code!(SEND_INACCESSIBLE_MSG, 404)
if let Some(expiration) = send.expiration_date {
if Utc::now().naive_utc() >= expiration {
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
}
if Utc::now().naive_utc() >= send.deletion_date {
@@ -519,16 +517,16 @@ async fn post_access_file(
err_code!(SEND_INACCESSIBLE_MSG, 404)
};
if let Some(max_access_count) = send.max_access_count
&& send.access_count >= max_access_count
{
err_code!(SEND_INACCESSIBLE_MSG, 404)
if let Some(max_access_count) = send.max_access_count {
if send.access_count >= max_access_count {
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
}
if let Some(expiration) = send.expiration_date
&& Utc::now().naive_utc() >= expiration
{
err_code!(SEND_INACCESSIBLE_MSG, 404)
if let Some(expiration) = send.expiration_date {
if Utc::now().naive_utc() >= expiration {
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
}
if Utc::now().naive_utc() >= send.deletion_date {
@@ -570,22 +568,22 @@ async fn post_access_file(
async fn download_url(host: &Host, send_id: &SendId, file_id: &SendFileId) -> Result<String, crate::Error> {
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Sends)?;
if crate::storage::is_fs_operator(&operator) {
if operator.info().scheme() == <&'static str>::from(opendal::Scheme::Fs) {
let token_claims = crate::auth::generate_send_claims(send_id, file_id);
let token = crate::auth::encode_jwt(&token_claims);
Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", host.host))
Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", &host.host))
} else {
Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_mins(5)).await?.uri().to_string())
Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_secs(5 * 60)).await?.uri().to_string())
}
}
#[get("/sends/<send_id>/<file_id>?<t>")]
async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(t)
&& claims.sub == format!("{send_id}/{file_id}")
{
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
if let Ok(claims) = crate::auth::decode_send(t) {
if claims.sub == format!("{send_id}/{file_id}") {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
}
}
None
}
+11 -11
View File
@@ -1,13 +1,14 @@
use data_encoding::BASE32;
use rocket::{Route, serde::json::Json};
use rocket::serde::json::Json;
use rocket::Route;
use crate::{
api::{EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event, core::two_factor::generate_recover_code},
api::{core::log_user_event, core::two_factor::_generate_recover_code, EmptyResult, JsonResult, PasswordOrOtpData},
auth::{ClientIp, Headers},
crypto,
db::{
DbConn,
models::{EventType, TwoFactor, TwoFactorType, UserId},
DbConn,
},
util::NumberOrString,
};
@@ -69,10 +70,9 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
.await?;
// Validate key as base32 and 20 bytes length
let decoded_key: Vec<u8> = if let Ok(decoded) = BASE32.decode(key.as_bytes()) {
decoded
} else {
err!("Invalid totp secret")
let decoded_key: Vec<u8> = match BASE32.decode(key.as_bytes()) {
Ok(decoded) => decoded,
_ => err!("Invalid totp secret"),
};
if decoded_key.len() != 20 {
@@ -82,7 +82,7 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
// Validate the token provided with the key, and save new twofactor
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?;
generate_recover_code(&mut user, &conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -119,7 +119,7 @@ pub async fn validate_totp_code(
ip: &ClientIp,
conn: &DbConn,
) -> EmptyResult {
use totp_lite::{Sha1, totp_custom};
use totp_lite::{totp_custom, Sha1};
let Ok(decoded_secret) = BASE32.decode(secret.as_bytes()) else {
err!("Invalid TOTP secret")
@@ -128,7 +128,7 @@ pub async fn validate_totp_code(
let mut twofactor = match TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Authenticator as i32, conn).await
{
Some(tf) => tf,
_ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_owned()),
_ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_string()),
};
// The amount of steps back and forward in time
@@ -145,7 +145,7 @@ pub async fn validate_totp_code(
// We need to calculate the time offsite and cast it as an u64.
// Since we only have times into the future and the totp generator needs an u64 instead of the default i64.
let time: u64 = (current_timestamp + step * 30i64).cast_unsigned();
let time = (current_timestamp + step * 30i64) as u64;
let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, time);
// Check the given code equals the generated and if the time_step is larger then the one last used.
+17 -16
View File
@@ -1,21 +1,22 @@
use chrono::Utc;
use data_encoding::BASE64;
use rocket::{Route, serde::json::Json};
use rocket::serde::json::Json;
use rocket::Route;
use crate::{
CONFIG,
api::{
ApiResult, EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event,
core::two_factor::generate_recover_code,
core::log_user_event, core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult,
PasswordOrOtpData,
},
auth::Headers,
crypto,
db::{
DbConn,
models::{EventType, TwoFactor, TwoFactorType, User, UserId},
DbConn,
},
error::MapResult,
http_client::make_http_request,
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -81,7 +82,8 @@ enum DuoStatus {
impl DuoStatus {
fn data(self) -> Option<DuoData> {
match self {
DuoStatus::Global(data) | DuoStatus::User(data) => Some(data),
DuoStatus::Global(data) => Some(data),
DuoStatus::User(data) => Some(data),
DuoStatus::Disabled(_) => None,
}
}
@@ -180,7 +182,7 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, conn: DbConn)
let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str);
twofactor.save(&conn).await?;
generate_recover_code(&mut user, &conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -199,14 +201,14 @@ async fn activate_duo_put(data: Json<EnableDuoData>, headers: Headers, conn: DbC
}
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
use reqwest::{Method, header};
use reqwest::{header, Method};
use std::str::FromStr;
// https://duo.com/docs/authapi#api-details
let url = format!("https://{}{path}", data.host);
let dt = Utc::now().to_rfc2822();
let url = format!("https://{}{path}", &data.host);
let date = Utc::now().to_rfc2822();
let username = &data.ik;
let fields = [&dt, method, &data.host, path, params];
let fields = [&date, method, &data.host, path, params];
let password = crypto::hmac_sign(&data.sk, &fields.join("\n"));
let m = Method::from_str(method).unwrap_or_default();
@@ -214,7 +216,7 @@ async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData)
make_http_request(m, &url)?
.basic_auth(username, Some(password))
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
.header(header::DATE, dt)
.header(header::DATE, date)
.send()
.await?
.error_for_status()?;
@@ -354,10 +356,9 @@ fn parse_duo_values(key: &str, val: &str, ikey: &str, prefix: &str, time: i64) -
err!("Invalid ikey")
}
let expire: i64 = if let Ok(e) = expire.parse() {
e
} else {
err!("Invalid expire time")
let expire: i64 = match expire.parse() {
Ok(e) => e,
Err(_) => err!("Invalid expire time"),
};
if time >= expire {
+23 -21
View File
@@ -1,24 +1,23 @@
use std::collections::HashMap;
use chrono::Utc;
use data_encoding::HEXLOWER;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use reqwest::{StatusCode, header};
use ring::digest::{Digest, SHA512_256, digest};
use reqwest::{header, StatusCode};
use ring::digest::{digest, Digest, SHA512_256};
use serde::Serialize;
use url::Url;
use std::collections::HashMap;
use crate::{
CONFIG,
api::{EmptyResult, core::two_factor::duo::get_duo_keys_email},
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult},
crypto,
db::{
DbConn, DbPool,
models::{DeviceId, EventType, TwoFactorDuoContext},
DbConn, DbPool,
},
error::Error,
http_client::make_http_request,
CONFIG,
};
use url::Url;
// The location on this service that Duo should redirect users to. For us, this is a bridge
// built in to the Bitwarden clients.
@@ -125,7 +124,7 @@ impl DuoClient {
ClientAssertion {
iss: self.client_id.clone(),
sub: self.client_id.clone(),
aud: url.to_owned(),
aud: url.to_string(),
exp: now + JWT_VALIDITY_SECS,
jti: jwt_id,
iat: now,
@@ -303,7 +302,7 @@ impl DuoClient {
if !(matching_nonces && matching_usernames) {
err!("Error validating Duo authorization, nonce or username mismatch.")
}
};
Ok(())
}
@@ -348,7 +347,7 @@ pub async fn purge_duo_contexts(pool: DbPool) {
if let Ok(conn) = pool.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await;
} else {
error!("Failed to get DB connection while purging expired Duo authentications");
error!("Failed to get DB connection while purging expired Duo authentications")
}
}
@@ -395,7 +394,7 @@ pub async fn get_duo_auth_url(
match client.health_check().await {
Ok(()) => {}
Err(e) => return Err(e),
}
};
// Generate random OAuth2 state and OIDC Nonce
let state: String = crypto::get_random_string_alphanum(STATE_LENGTH);
@@ -439,13 +438,16 @@ pub async fn validate_duo_login(
// Get the context by the state reported by the client. If we don't have one,
// it means the context is either missing or expired.
let Some(ctx) = extract_context(state, conn).await else {
err!(
"Error validating duo authentication",
ErrorEvent {
event: EventType::UserFailedLogIn2fa
}
)
let ctx = match extract_context(state, conn).await {
Some(c) => c,
None => {
err!(
"Error validating duo authentication",
ErrorEvent {
event: EventType::UserFailedLogIn2fa
}
)
}
};
// Context validation steps
@@ -474,13 +476,13 @@ pub async fn validate_duo_login(
match client.health_check().await {
Ok(()) => {}
Err(e) => return Err(e),
}
};
let d: Digest = digest(&SHA512_256, format!("{}{device_identifier}", ctx.nonce).as_bytes());
let hash: String = HEXLOWER.encode(d.as_ref());
match client.exchange_authz_code_for_result(code, email, hash.as_str()).await {
Ok(()) => Ok(()),
Ok(_) => Ok(()),
Err(_) => {
err!(
"Error validating duo authentication",
+22 -25
View File
@@ -1,20 +1,20 @@
use chrono::{DateTime, TimeDelta, Utc};
use rocket::{Route, serde::json::Json};
use rocket::serde::json::Json;
use rocket::Route;
use crate::{
CONFIG,
api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
},
auth::{ClientHeaders, Headers},
crypto,
db::{
DbConn,
models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId},
DbConn,
},
error::{Error, MapResult},
mail,
mail, CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -25,7 +25,7 @@ pub fn routes() -> Vec<Route> {
#[serde(rename_all = "camelCase")]
struct SendEmailLoginData {
#[serde(alias = "DeviceIdentifier")]
device_identifier: Option<DeviceId>,
device_identifier: DeviceId,
#[serde(alias = "Email")]
email: Option<String>,
#[serde(alias = "MasterPasswordHash")]
@@ -91,11 +91,8 @@ async fn send_email_login(data: Json<SendEmailLoginData>, client_headers: Client
user
} else {
let Some(device_identifier) = &data.device_identifier else {
err!("No device identifier has been submitted.")
};
// SSO login only sends device id, so we get the user by the most recently used device
let Some(user) = User::find_by_device_for_email2fa(device_identifier, &conn).await else {
let Some(user) = User::find_by_device_for_email2fa(&data.device_identifier, &conn).await else {
err!("Username or password is incorrect. Try again.")
};
@@ -232,7 +229,7 @@ async fn email(data: Json<EmailData>, headers: Headers, conn: DbConn) -> JsonRes
twofactor.data = email_data.to_json();
twofactor.save(&conn).await?;
generate_recover_code(&mut user, &conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -284,9 +281,9 @@ pub async fn validate_email_code_str(
twofactor.data = email_data.to_json();
twofactor.save(conn).await?;
let dt = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc();
let max_time = CONFIG.email_expiration_time().cast_signed();
if dt + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() {
let date = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc();
let max_time = CONFIG.email_expiration_time() as i64;
if date + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() {
err!(
"Token has expired",
ErrorEvent {
@@ -342,10 +339,9 @@ impl EmailTokenData {
pub fn from_json(string: &str) -> Result<EmailTokenData, Error> {
let res: Result<EmailTokenData, serde_json::Error> = serde_json::from_str(string);
if let Ok(x) = res {
Ok(x)
} else {
err!("Could not decode EmailTokenData from string")
match res {
Ok(x) => Ok(x),
Err(_) => err!("Could not decode EmailTokenData from string"),
}
}
}
@@ -363,17 +359,18 @@ pub async fn activate_email_2fa(user: &User, conn: &DbConn) -> EmptyResult {
pub fn obscure_email(email: &str) -> String {
let split: Vec<&str> = email.rsplitn(2, '@').collect();
let mut name = split[1].to_owned();
let mut name = split[1].to_string();
let domain = &split[0];
let name_size = name.chars().count();
let new_name = if let 1..=3 = name_size {
"*".repeat(name_size)
} else {
let stars = "*".repeat(name_size - 2);
name.truncate(2);
format!("{name}{stars}")
let new_name = match name_size {
1..=3 => "*".repeat(name_size),
_ => {
let stars = "*".repeat(name_size - 2);
name.truncate(2);
format!("{name}{stars}")
}
};
format!("{new_name}@{domain}")
+21 -14
View File
@@ -1,27 +1,28 @@
use chrono::{TimeDelta, Utc};
use data_encoding::BASE32;
use num_traits::FromPrimitive;
use rocket::{Route, serde::json::Json};
use rocket::serde::json::Json;
use rocket::Route;
use serde::Deserialize;
use serde_json::Value;
use crate::{
CONFIG,
api::{
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_event, log_user_event},
EmptyResult, JsonResult, PasswordOrOtpData,
},
auth::Headers,
crypto,
db::{
DbConn, DbPool,
models::{
DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor,
TwoFactorIncomplete, TwoFactorType, User, UserId,
},
DbConn, DbPool,
},
mail,
util::NumberOrString,
CONFIG,
};
pub mod authenticator;
@@ -36,7 +37,7 @@ fn has_global_duo_credentials() -> bool {
CONFIG._enable_duo() && CONFIG.duo_host().is_some() && CONFIG.duo_ikey().is_some() && CONFIG.duo_skey().is_some()
}
pub fn is_twofactor_provider_usable(provider_type: &TwoFactorType, provider_data: Option<&str>) -> bool {
pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data: Option<&str>) -> bool {
#[derive(Deserialize)]
struct DuoProviderData {
host: String,
@@ -45,7 +46,7 @@ pub fn is_twofactor_provider_usable(provider_type: &TwoFactorType, provider_data
}
match provider_type {
TwoFactorType::Authenticator | TwoFactorType::RecoveryCode => true,
TwoFactorType::Authenticator => true,
TwoFactorType::Email => CONFIG._enable_email_2fa(),
TwoFactorType::Duo | TwoFactorType::OrganizationDuo => {
provider_data
@@ -58,6 +59,7 @@ pub fn is_twofactor_provider_usable(provider_type: &TwoFactorType, provider_data
}
TwoFactorType::Webauthn => CONFIG.is_webauthn_2fa_supported(),
TwoFactorType::Remember => !CONFIG.disable_2fa_remember(),
TwoFactorType::RecoveryCode => true,
TwoFactorType::U2f
| TwoFactorType::U2fRegisterChallenge
| TwoFactorType::U2fLoginChallenge
@@ -94,7 +96,7 @@ async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> {
.iter()
.filter_map(|tf| {
let provider_type = TwoFactorType::from_i32(tf.atype)?;
is_twofactor_provider_usable(&provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf))
is_twofactor_provider_usable(provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf))
})
.collect();
@@ -118,7 +120,7 @@ async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbCo
})))
}
async fn generate_recover_code(user: &mut User, conn: &DbConn) {
async fn _generate_recover_code(user: &mut User, conn: &DbConn) {
if user.totp_recover.is_none() {
let totp_recover = crypto::encode_random_bytes::<20>(&BASE32);
user.totp_recover = Some(totp_recover);
@@ -178,7 +180,9 @@ pub async fn enforce_2fa_policy(
ip: &std::net::IpAddr,
conn: &DbConn,
) -> EmptyResult {
for member in Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await {
for member in
Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter()
{
// Policy only applies to non-Owner/non-Admin members who have accepted joining the org
if member.atype < MembershipType::Admin {
if CONFIG.mail_enabled() {
@@ -213,7 +217,7 @@ pub async fn enforce_2fa_policy_for_org(
conn: &DbConn,
) -> EmptyResult {
let org = Organization::find_by_uuid(org_id, conn).await.unwrap();
for member in Membership::find_confirmed_by_org(org_id, conn).await {
for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() {
// Don't enforce the policy for Admins and Owners.
if member.atype < MembershipType::Admin && TwoFactor::find_by_user(&member.user_uuid, conn).await.is_empty() {
if CONFIG.mail_enabled() {
@@ -247,9 +251,12 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
return;
}
let Ok(conn) = pool.get().await else {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
return;
let conn = match pool.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
return;
}
};
let now = Utc::now().naive_utc();
@@ -271,7 +278,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
)
.await
{
Ok(()) => {
Ok(_) => {
if let Err(e) = login.delete(&conn).await {
error!("Error deleting incomplete 2FA record: {e:#?}");
}
+10 -16
View File
@@ -1,17 +1,16 @@
use chrono::{NaiveDateTime, TimeDelta, Utc, naive::serde::ts_seconds};
use rocket::{Route, serde::json::Json};
use chrono::{naive::serde::ts_seconds, NaiveDateTime, TimeDelta, Utc};
use rocket::{serde::json::Json, Route};
use crate::{
CONFIG,
api::EmptyResult,
auth::Headers,
crypto,
db::{
DbConn,
models::{TwoFactor, TwoFactorType, UserId},
DbConn,
},
error::{Error, MapResult},
mail,
mail, CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -45,10 +44,9 @@ impl ProtectedActionData {
pub fn from_json(string: &str) -> Result<Self, Error> {
let res: Result<Self, serde_json::Error> = serde_json::from_str(string);
if let Ok(x) = res {
Ok(x)
} else {
err!("Could not decode ProtectedActionData from string")
match res {
Ok(x) => Ok(x),
Err(_) => err!("Could not decode ProtectedActionData from string"),
}
}
@@ -64,9 +62,7 @@ impl ProtectedActionData {
#[post("/accounts/request-otp")]
async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() {
err!(
"Email is disabled for this server. Either enable email or login using your master password instead of login via device."
);
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
}
let user = headers.user;
@@ -106,9 +102,7 @@ struct ProtectedActionVerify {
#[post("/accounts/verify-otp", data = "<data>")]
async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() {
err!(
"Email is disabled for this server. Either enable email or login using your master password instead of login via device."
);
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
}
let user = headers.user;
@@ -139,7 +133,7 @@ pub async fn validate_protected_action_otp(
}
// Check if the token has expired (Using the email 2fa expiration time)
let max_time = CONFIG.email_expiration_time().cast_signed();
let max_time = CONFIG.email_expiration_time() as i64;
if pa_data.time_since_sent().num_seconds() > max_time {
pa.delete(conn).await?;
err!("Token has expired")
+41 -41
View File
@@ -1,33 +1,32 @@
use std::{str::FromStr, sync::LazyLock, time::Duration};
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use url::Url;
use uuid::Uuid;
use webauthn_rs::{
Webauthn, WebauthnBuilder,
prelude::{Base64UrlSafeData, Credential, Passkey, PasskeyAuthentication, PasskeyRegistration},
};
use webauthn_rs_proto::{
AuthenticationExtensionsClientOutputs, AuthenticatorAssertionResponseRaw, AuthenticatorAttestationResponseRaw,
PublicKeyCredential, RegisterPublicKeyCredential, RegistrationExtensionsClientOutputs,
RequestAuthenticationExtensions, UserVerificationPolicy,
};
use crate::{
CONFIG,
api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
},
auth::Headers,
crypto::ct_eq,
db::{
DbConn,
models::{EventType, TwoFactor, TwoFactorType, UserId},
DbConn,
},
error::Error,
util::NumberOrString,
CONFIG,
};
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use std::str::FromStr;
use std::sync::LazyLock;
use std::time::Duration;
use url::Url;
use uuid::Uuid;
use webauthn_rs::prelude::{Base64UrlSafeData, Credential, Passkey, PasskeyAuthentication, PasskeyRegistration};
use webauthn_rs::{Webauthn, WebauthnBuilder};
use webauthn_rs_proto::{
AuthenticationExtensionsClientOutputs, AuthenticatorAssertionResponseRaw, AuthenticatorAttestationResponseRaw,
PublicKeyCredential, RegisterPublicKeyCredential, RegistrationExtensionsClientOutputs,
RequestAuthenticationExtensions, UserVerificationPolicy,
};
static WEBAUTHN: LazyLock<Webauthn> = LazyLock::new(|| {
@@ -39,7 +38,7 @@ static WEBAUTHN: LazyLock<Webauthn> = LazyLock::new(|| {
let webauthn = WebauthnBuilder::new(&rp_id, &rp_origin)
.expect("Creating WebauthnBuilder failed")
.rp_name(&domain)
.timeout(Duration::from_mins(1));
.timeout(Duration::from_millis(60000));
webauthn.build().expect("Building Webauthn failed")
});
@@ -150,7 +149,7 @@ async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Hea
)?;
let mut state = serde_json::to_value(&state)?;
state["rs"]["policy"] = Value::String("discouraged".to_owned());
state["rs"]["policy"] = Value::String("discouraged".to_string());
state["rs"]["extensions"].as_object_mut().unwrap().clear();
let type_ = TwoFactorType::WebauthnRegisterChallenge;
@@ -266,12 +265,13 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, con
// Retrieve and delete the saved challenge state
let type_ = TwoFactorType::WebauthnRegisterChallenge as i32;
let state = if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&conn).await?;
state
} else {
err!("Can't recover challenge")
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
Some(tf) => {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&conn).await?;
state
}
None => err!("Can't recover challenge"),
};
// Verify the credentials with the saved state
@@ -291,7 +291,7 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, con
TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
.save(&conn)
.await?;
generate_recover_code(&mut user, &conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -342,10 +342,9 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, conn: DbCo
// If entry is migrated from u2f, delete the u2f entry as well
if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await
{
let mut data: Vec<U2FRegistration> = if let Ok(d) = serde_json::from_str(&u2f.data) {
d
} else {
err!("Error parsing U2F data")
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) {
Ok(d) => d,
Err(_) => err!("Error parsing U2F data"),
};
data.retain(|r| r.reg.key_handle != removed_item.credential.cred_id().as_slice());
@@ -389,10 +388,10 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes
// Modify to discourage user verification
let mut state = serde_json::to_value(&state)?;
state["ast"]["policy"] = Value::String("discouraged".to_owned());
state["ast"]["policy"] = Value::String("discouraged".to_string());
// Add appid, this is only needed for U2F compatibility, so maybe it can be removed as well
let app_id = format!("{}/app-id.json", CONFIG.domain());
let app_id = format!("{}/app-id.json", &CONFIG.domain());
state["ast"]["appid"] = Value::String(app_id.clone());
response.public_key.user_verification = UserVerificationPolicy::Discouraged_DO_NOT_USE;
@@ -417,17 +416,18 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes
pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::WebauthnLoginChallenge as i32;
let mut state = if let Some(tf) = TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?;
tf.delete(conn).await?;
state
} else {
err!(
let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
Some(tf) => {
let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?;
tf.delete(conn).await?;
state
}
None => err!(
"Can't recover login challenge",
ErrorEvent {
event: EventType::UserFailedLogIn2fa
}
)
),
};
let rsp: PublicKeyCredentialCopy = serde_json::from_str(response)?;
+10 -10
View File
@@ -1,19 +1,20 @@
use rocket::{Route, serde::json::Json};
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use yubico::{config::Config, verify_async};
use crate::{
CONFIG,
api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
},
auth::Headers,
db::{
DbConn,
models::{EventType, TwoFactor, TwoFactorType},
DbConn,
},
error::{Error, MapResult},
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -45,7 +46,7 @@ pub struct YubikeyMetadata {
fn parse_yubikeys(data: &EnableYubikeyData) -> Vec<String> {
let data_keys = [&data.key1, &data.key2, &data.key3, &data.key4, &data.key5];
data_keys.into_iter().flatten().cloned().collect()
data_keys.iter().filter_map(|e| e.as_ref().cloned()).collect()
}
fn jsonify_yubikeys(yubikeys: Vec<String>) -> Value {
@@ -63,10 +64,9 @@ fn get_yubico_credentials() -> Result<(String, String), Error> {
err!("Yubico support is disabled");
}
if let (Some(id), Some(secret)) = (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) {
Ok((id, secret))
} else {
err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled")
match (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) {
(Some(id), Some(secret)) => Ok((id, secret)),
_ => err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled"),
}
}
@@ -162,7 +162,7 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, conn:
yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap();
yubikey_data.save(&conn).await?;
generate_recover_code(&mut user, &conn).await;
_generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
+110 -89
View File
@@ -6,29 +6,28 @@ use std::{
};
use bytes::{Bytes, BytesMut};
use futures::{TryFutureExt, stream::StreamExt};
use futures::{stream::StreamExt, TryFutureExt};
use html5gum::{Emitter, HtmlString, Readable, StringReader, Tokenizer};
use regex::Regex;
use reqwest::{
Client, Response,
header::{self, HeaderMap, HeaderValue},
Client, Response,
};
use rocket::{Route, http::ContentType, response::Redirect};
use svg_hush::{Filter, data_url_filter};
use rocket::{http::ContentType, response::Redirect, Route};
use svg_hush::{data_url_filter, Filter};
use crate::{
CONFIG,
config::PathType,
error::Error,
http_client::{CustomHttpClientError, get_reqwest_client_builder, get_valid_host, should_block_host},
http_client::{get_reqwest_client_builder, should_block_address, CustomHttpClientError},
util::Cached,
CONFIG,
};
pub fn routes() -> Vec<Route> {
if CONFIG.icon_service().as_str() == "internal" {
routes![icon_internal]
} else {
routes![icon_external]
match CONFIG.icon_service().as_str() {
"internal" => routes![icon_internal],
_ => routes![icon_external],
}
}
@@ -82,19 +81,19 @@ static ICON_SIZE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?x)(\d+
// The function name `icon_external` is checked in the `on_response` function in `AppHeaders`
// It is used to prevent sending a specific header which breaks icon downloads.
// If this function needs to be renamed, also adjust the code in `util.rs`
#[get("/<host>/icon.png")]
fn icon_external(host: &str) -> Cached<Option<Redirect>> {
let Ok(host) = get_valid_host(host) else {
warn!("Invalid host: {host}");
return Cached::ttl(None, CONFIG.icon_cache_negttl(), true);
};
if should_block_host(&host).is_err() {
warn!("Blocked address: {host}");
#[get("/<domain>/icon.png")]
fn icon_external(domain: &str) -> Cached<Option<Redirect>> {
if !is_valid_domain(domain) {
warn!("Invalid domain: {domain}");
return Cached::ttl(None, CONFIG.icon_cache_negttl(), true);
}
let url = CONFIG._icon_service_url().replace("{}", &host.to_string());
if should_block_address(domain) {
warn!("Blocked address: {domain}");
return Cached::ttl(None, CONFIG.icon_cache_negttl(), true);
}
let url = CONFIG._icon_service_url().replace("{}", domain);
let redir = match CONFIG.icon_redirect_code() {
301 => Some(Redirect::moved(url)), // legacy permanent redirect
302 => Some(Redirect::found(url)), // legacy temporary redirect
@@ -108,21 +107,12 @@ fn icon_external(host: &str) -> Cached<Option<Redirect>> {
Cached::ttl(redir, CONFIG.icon_cache_ttl(), true)
}
#[get("/<host>/icon.png")]
async fn icon_internal(host: &str) -> Cached<(ContentType, Vec<u8>)> {
#[get("/<domain>/icon.png")]
async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
let Ok(host) = get_valid_host(host) else {
warn!("Invalid host: {host}");
return Cached::ttl(
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
);
};
if should_block_host(&host).is_err() {
warn!("Blocked address: {host}");
if !is_valid_domain(domain) {
warn!("Invalid domain: {domain}");
return Cached::ttl(
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
@@ -130,7 +120,16 @@ async fn icon_internal(host: &str) -> Cached<(ContentType, Vec<u8>)> {
);
}
match get_icon(&host.to_string()).await {
if should_block_address(domain) {
warn!("Blocked address: {domain}");
return Cached::ttl(
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
);
}
match get_icon(domain).await {
Some((icon, icon_type)) => {
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
}
@@ -138,6 +137,42 @@ async fn icon_internal(host: &str) -> Cached<(ContentType, Vec<u8>)> {
}
}
/// Returns if the domain provided is valid or not.
///
/// This does some manual checks and makes use of Url to do some basic checking.
/// domains can't be larger then 63 characters (not counting multiple subdomains) according to the RFC's, but we limit the total size to 255.
fn is_valid_domain(domain: &str) -> bool {
const ALLOWED_CHARS: &str = "-.";
// If parsing the domain fails using Url, it will not work with reqwest.
if let Err(parse_error) = url::Url::parse(format!("https://{domain}").as_str()) {
debug!("Domain parse error: '{domain}' - {parse_error:?}");
return false;
} else if domain.is_empty()
|| domain.contains("..")
|| domain.starts_with('.')
|| domain.starts_with('-')
|| domain.ends_with('-')
{
debug!(
"Domain validation error: '{domain}' is either empty, contains '..', starts with an '.', starts or ends with a '-'"
);
return false;
} else if domain.len() > 255 {
debug!("Domain validation error: '{domain}' exceeds 255 characters");
return false;
}
for c in domain.chars() {
if !c.is_alphanumeric() && !ALLOWED_CHARS.contains(c) {
debug!("Domain validation error: '{domain}' contains an invalid character '{c}'");
return false;
}
}
true
}
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
let path = format!("{domain}.png");
@@ -148,7 +183,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
if let Some(icon) = get_cached_icon(&path).await {
let icon_type = get_icon_type(&icon).unwrap_or("x-icon");
return Some((icon, icon_type.to_owned()));
return Some((icon, icon_type.to_string()));
}
if CONFIG.disable_icon_download() {
@@ -159,7 +194,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
match download_icon(domain).await {
Ok((icon, icon_type)) => {
save_icon(&path, icon.to_vec()).await;
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_owned()))
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
}
Err(e) => {
// If this error comes from the custom resolver, this means this is a blocked domain
@@ -184,10 +219,10 @@ async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
}
// Try to read the cached icon, and return it if it exists
if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache)
&& let Ok(buf) = operator.read(path).await
{
return Some(buf.to_vec());
if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache) {
if let Ok(buf) = operator.read(path).await {
return Some(buf.to_vec());
}
}
None
@@ -281,17 +316,17 @@ fn get_favicons_node(dom: Tokenizer<StringReader<'_>, FaviconEmitter>, icons: &m
}
for icon_tag in icon_tags {
if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF)
&& let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default())
{
let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) {
std::str::from_utf8(v).unwrap_or_default()
} else {
""
};
let priority = get_icon_priority(full_href.as_str(), sizes);
icons.push(Icon::new(priority, full_href.to_string()));
}
if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF) {
if let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default()) {
let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) {
std::str::from_utf8(v).unwrap_or_default()
} else {
""
};
let priority = get_icon_priority(full_href.as_str(), sizes);
icons.push(Icon::new(priority, full_href.to_string()));
}
};
}
}
@@ -332,7 +367,7 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
tld = domain_parts.next_back().unwrap(),
base = domain_parts.next_back().unwrap()
);
if get_valid_host(&base_domain).is_ok() {
if is_valid_domain(&base_domain) {
let sslbase = format!("https://{base_domain}");
let httpbase = format!("http://{base_domain}");
debug!("[get_icon_url]: Trying without subdomains '{base_domain}'");
@@ -343,7 +378,7 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
} else if is_ip.is_err() && domain.matches('.').count() < 2 {
let www_domain = format!("www.{domain}");
if get_valid_host(&www_domain).is_ok() {
if is_valid_domain(&www_domain) {
let sslwww = format!("https://{www_domain}");
let httpwww = format!("http://{www_domain}");
debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'");
@@ -407,7 +442,7 @@ async fn get_page(url: &str) -> Result<Response, Error> {
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
let mut client = CLIENT.get(url);
if !referer.is_empty() {
client = client.header("Referer", referer);
client = client.header("Referer", referer)
}
Ok(client.send().await?.error_for_status()?)
@@ -495,10 +530,11 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
let mut buffer = Bytes::new();
let mut icon_type: Option<&str> = None;
let mut icons = icon_result.iconlist.iter().take(5).peekable();
while let Some(icon) = icons.next() {
use data_url::DataUrl;
for icon in icon_result.iconlist.iter().take(5) {
if icon.href.starts_with("data:image") {
let Ok(datauri) = data_url::DataUrl::process(&icon.href) else {
let Ok(datauri) = DataUrl::process(&icon.href) else {
continue;
};
// Check if we are able to decode the data uri
@@ -522,25 +558,13 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
}
}
_ => debug!("Extracted icon from data:image uri is invalid"),
}
} else {
debug!("Trying {}", icon.href);
// Make sure all icons are checked before returning error
let res = match get_page_with_referer(&icon.href, &icon_result.referer).await {
Ok(r) => r,
Err(e) if icons.peek().is_none() => return Err(e),
Err(e) if CustomHttpClientError::downcast_ref(&e).is_some() => return Err(e), // If blacklisted stop immediately instead of checking the rest of the icons. see explanation and actual handling inside get_icon()
Err(e) => {
warn!("Unable to download icon: {e:?}");
// Continue to next icon
continue;
}
};
} else {
let res = get_page_with_referer(&icon.href, &icon_result.referer).await?;
buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
// Check if the icon type is allowed, else try another icon from the list.
// Check if the icon type is allowed, else try an icon from the list.
icon_type = get_icon_type(&buffer);
if icon_type.is_none() {
buffer.clear();
@@ -586,25 +610,22 @@ async fn save_icon(path: &str, icon: Vec<u8>) {
fn get_icon_type(bytes: &[u8]) -> Option<&'static str> {
fn check_svg_after_xml_declaration(bytes: &[u8]) -> Option<&'static str> {
// Look for SVG tag within the first 1KB
if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)])
&& (content.contains("<svg") || content.contains("<SVG"))
{
return Some("svg+xml");
if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)]) {
if content.contains("<svg") || content.contains("<SVG") {
return Some("svg+xml");
}
}
None
}
// Some details can be found here:
// - https://www.garykessler.net/library/file_sigs_GCK_latest.html
// - https://en.wikipedia.org/wiki/List_of_file_signatures
match bytes {
[137, 80, 78, 71, 13, 10, 26, 10, ..] => Some("png"),
[0, 0, 1, 0, n1, n2, ..] if u16::from_le_bytes([*n1, *n2]) > 0 => Some("x-icon"), // https://en.wikipedia.org/wiki/ICO_(file_format)
[82, 73, 70, 70, _, _, _, _, 87, 69, 66, 80, ..] => Some("webp"), // Only match WebP Images
[255, 216, 255, b, ..] if *b >= 0xC0 => Some("jpeg"),
[71, 73, 70, 56, 55 | 57, 97, ..] => Some("gif"),
[66, 77, _, _, _, _, 0, 0, 0, 0, ..] => Some("bmp"), // https://en.wikipedia.org/wiki/BMP_file_format
[60, 115, 118, 103, ..] => Some("svg+xml"), // Normal svg
[137, 80, 78, 71, ..] => Some("png"),
[0, 0, 1, 0, ..] => Some("x-icon"),
[82, 73, 70, 70, ..] => Some("webp"),
[255, 216, 255, ..] => Some("jpeg"),
[71, 73, 70, 56, ..] => Some("gif"),
[66, 77, ..] => Some("bmp"),
[60, 115, 118, 103, ..] => Some("svg+xml"), // Normal svg
[60, 63, 120, 109, 108, ..] => check_svg_after_xml_declaration(bytes), // An svg starting with <?xml
_ => None,
}
@@ -732,7 +753,7 @@ impl FaviconEmitter {
let rel_value =
std::str::from_utf8(token.tag.attributes.get(ATTR_REL).unwrap()).unwrap_or_default();
if rel_value.contains("icon") && !rel_value.contains("mask-icon") {
self.emit_token = true;
self.emit_token = true
}
}
_ => (),
@@ -805,13 +826,13 @@ impl Emitter for FaviconEmitter {
fn push_attribute_name(&mut self, s: &[u8]) {
if let Some(attr) = &mut self.current_attribute {
attr.0.extend(s);
attr.0.extend(s)
}
}
fn push_attribute_value(&mut self, s: &[u8]) {
if let Some(attr) = &mut self.current_attribute {
attr.1.extend(s);
attr.1.extend(s)
}
}
+167 -240
View File
@@ -1,20 +1,18 @@
use chrono::Utc;
use num_traits::FromPrimitive;
use rocket::{
Route,
form::{Form, FromForm},
http::{Cookie, CookieJar, SameSite},
http::Status,
response::Redirect,
serde::json::Json,
Route,
};
use serde_json::Value;
use crate::{
CONFIG,
api::{
ApiResult, EmptyResult, JsonResult,
core::{
accounts::{PreloginData, RegisterData, kdf_upgrade, prelogin, register},
accounts::{PreloginData, RegisterData, _prelogin, _register, kdf_upgrade},
log_user_event,
two_factor::{
authenticator, duo, duo_oidc, email, enforce_2fa_policy, is_twofactor_provider_usable, webauthn,
@@ -23,29 +21,27 @@ use crate::{
},
master_password_policy,
push::register_push_device,
ApiResult, EmptyResult, JsonResult,
},
auth,
auth::{AuthMethod, ClientHeaders, ClientIp, ClientVersion, Secure, generate_organization_api_key_login_claims},
crypto,
auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion},
db::{
DbConn,
models::{
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeResponseError,
OrganizationApiKey, OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User,
UserId,
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeWrapper, OrganizationApiKey,
OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, UserId,
},
DbConn,
},
error::MapResult,
mail, sso,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
util,
util, CONFIG,
};
pub fn routes() -> Vec<Route> {
routes![
login,
post_prelogin,
prelogin_password,
prelogin,
identity_register,
register_verification_email,
register_finish,
@@ -69,43 +65,43 @@ async fn login(
let login_result = match data.grant_type.as_ref() {
"refresh_token" => {
check_is_some(data.refresh_token.as_ref(), "refresh_token cannot be blank")?;
refresh_login(data, &conn, &client_header.ip).await
_check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
_refresh_login(data, &conn, &client_header.ip).await
}
"password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"),
"password" => {
check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
check_is_some(data.password.as_ref(), "password cannot be blank")?;
check_is_some(data.scope.as_ref(), "scope cannot be blank")?;
check_is_some(data.username.as_ref(), "username cannot be blank")?;
_check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.password, "password cannot be blank")?;
_check_is_some(&data.scope, "scope cannot be blank")?;
_check_is_some(&data.username, "username cannot be blank")?;
check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
password_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await
_password_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
}
"client_credentials" => {
check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
check_is_some(data.client_secret.as_ref(), "client_secret cannot be blank")?;
check_is_some(data.scope.as_ref(), "scope cannot be blank")?;
_check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.client_secret, "client_secret cannot be blank")?;
_check_is_some(&data.scope, "scope cannot be blank")?;
check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
api_key_login(data, &mut user_id, &conn, &client_header.ip).await
_api_key_login(data, &mut user_id, &conn, &client_header.ip).await
}
"authorization_code" if CONFIG.sso_enabled() => {
check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
check_is_some(data.code.as_ref(), "code cannot be blank")?;
check_is_some(data.code_verifier.as_ref(), "code verifier cannot be blank")?;
_check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.code, "code cannot be blank")?;
_check_is_some(&data.code_verifier, "code verifier cannot be blank")?;
check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
sso_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await
_sso_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
}
"authorization_code" => err!("SSO sign-in is not available"),
t => err!("Invalid type", t),
@@ -126,7 +122,7 @@ async fn login(
Err(e) => {
if let Some(ev) = e.get_event() {
log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn)
.await;
.await
}
}
}
@@ -135,14 +131,12 @@ async fn login(
login_result
}
async fn refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// When a refresh token is invalid or missing we need to respond with an HTTP BadRequest (400)
// It also needs to return a json which holds at least a key `error` with the value `invalid_grant`
// See the link below for details
// https://github.com/bitwarden/clients/blob/2ee158e720a5e7dbe3641caf80b569e97a1dd91b/libs/common/src/services/api.service.ts#L1786-L1797
let Some(refresh_token) = data.refresh_token else {
err_json!(json!({"error": "invalid_grant"}), "Missing refresh_token")
// Return Status::Unauthorized to trigger logout
async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Extract token
let refresh_token = match data.refresh_token {
Some(token) => token,
None => err_code!("Missing refresh_token", Status::Unauthorized.code),
};
// ---
@@ -153,10 +147,7 @@ async fn refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonR
// let members = Membership::find_confirmed_by_user(&user.uuid, conn).await;
match auth::refresh_tokens(ip, &refresh_token, data.client_id, conn).await {
Err(err) => {
err_json!(
json!({"error": "invalid_grant"}),
format!("Unable to refresh login credentials: {}", err.message())
)
err_code!(format!("Unable to refresh login credentials: {}", err.message()), Status::Unauthorized.code)
}
Ok((mut device, auth_tokens)) => {
// Save to update `device.updated_at` to track usage and toggle new status
@@ -176,19 +167,19 @@ async fn refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonR
}
// After exchanging the code we need to check first if 2FA is needed before continuing
async fn sso_login(
async fn _sso_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &DbConn,
ip: &ClientIp,
client_version: Option<&ClientVersion>,
client_version: &Option<ClientVersion>,
) -> JsonResult {
AuthMethod::Sso.check_scope(data.scope.as_ref())?;
// Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?;
let (code, code_verifier) = match (data.code.as_ref(), data.code_verifier.as_ref()) {
let (state, code_verifier) = match (data.code.as_ref(), data.code_verifier.as_ref()) {
(None, _) => err!(
"Got no code in OIDC data",
ErrorEvent {
@@ -204,7 +195,7 @@ async fn sso_login(
(Some(code), Some(code_verifier)) => (code, code_verifier.clone()),
};
let (sso_auth, user_infos) = sso::exchange_code(code, code_verifier, conn).await?;
let (sso_auth, user_infos) = sso::exchange_code(state, code_verifier, conn).await?;
let user_with_sso = match SsoUser::find_by_identifier(&user_infos.identifier, conn).await {
None => match SsoUser::find_by_mail(&user_infos.email, conn).await {
None => None,
@@ -232,33 +223,7 @@ async fn sso_login(
}
)
}
Some((user, None)) => match user_infos.email_verified {
None if !CONFIG.sso_allow_unknown_email_verification() => {
error!(
"Login failure ({}), existing non SSO user ({}) with same email ({}) and email verification status is unknown",
user_infos.identifier, user.uuid, user.email
);
err_silent!(
"Email verification status is unknown",
ErrorEvent {
event: EventType::UserFailedLogIn
}
)
}
Some(false) => {
error!(
"Login failure ({}), existing non SSO user ({}) with same email ({}) and email is not verified",
user_infos.identifier, user.uuid, user.email
);
err_silent!(
"Email is not verified by the SSO provider",
ErrorEvent {
event: EventType::UserFailedLogIn
}
)
}
_ => Some((user, None)),
},
Some((user, None)) => Some((user, None)),
},
Some((user, sso_user)) => Some((user, Some(sso_user))),
};
@@ -345,12 +310,12 @@ async fn sso_login(
authenticated_response(&user, &mut device, auth_tokens, twofactor_token, conn, ip).await
}
async fn password_login(
async fn _password_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &DbConn,
ip: &ClientIp,
client_version: Option<&ClientVersion>,
client_version: &Option<ClientVersion>,
) -> JsonResult {
// Validate scope
AuthMethod::Password.check_scope(data.scope.as_ref())?;
@@ -429,9 +394,9 @@ async fn password_login(
if user.verified_at.is_none() && CONFIG.mail_enabled() && CONFIG.signups_verify() {
if user.last_verifying_at.is_none()
|| now.signed_duration_since(user.last_verifying_at.unwrap()).num_seconds()
> CONFIG.signups_verify_resend_time().cast_signed()
> CONFIG.signups_verify_resend_time() as i64
{
let resend_limit = CONFIG.signups_verify_resend_limit().cast_signed();
let resend_limit = CONFIG.signups_verify_resend_limit() as i32;
if resend_limit == 0 || user.login_verify_count < resend_limit {
// We want to send another email verification if we require signups to verify
// their email address, and we haven't sent them a reminder in a while...
@@ -567,19 +532,19 @@ async fn authenticated_response(
Ok(Json(result))
}
async fn api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
async fn _api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?;
// Validate scope
match data.scope.as_ref() {
Some(scope) if scope == &AuthMethod::UserApiKey.scope() => user_api_key_login(data, user_id, conn, ip).await,
Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => organization_api_key_login(data, conn, ip).await,
Some(scope) if scope == &AuthMethod::UserApiKey.scope() => _user_api_key_login(data, user_id, conn, ip).await,
Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => _organization_api_key_login(data, conn, ip).await,
_ => err!("Scope not supported"),
}
}
async fn user_api_key_login(
async fn _user_api_key_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &DbConn,
@@ -711,13 +676,13 @@ async fn user_api_key_login(
Ok(Json(result))
}
async fn organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
async fn _organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Get the org via the client_id
let client_id = data.client_id.as_ref().unwrap();
let Some(org_id) = client_id.strip_prefix("organization.") else {
err!("Malformed client_id", format!("IP: {}.", ip.ip))
};
let org_id: OrganizationId = org_id.to_owned().into();
let org_id: OrganizationId = org_id.to_string().into();
let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, conn).await else {
err!("Invalid client_id", format!("IP: {}.", ip.ip))
};
@@ -748,13 +713,14 @@ async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> ApiResult
let device_name = data.device_name.clone().expect("No device name provided");
// Find device or create new
if let Some(device) = Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await {
Ok(device)
} else {
let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type);
// save device without updating `device.updated_at`
device.save(false, conn).await?;
Ok(device)
match Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await {
Some(device) => Ok(device),
None => {
let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type);
// save device without updating `device.updated_at`
device.save(false, conn).await?;
Ok(device)
}
}
}
@@ -763,7 +729,7 @@ async fn twofactor_auth(
data: &ConnectData,
device: &mut Device,
ip: &ClientIp,
client_version: Option<&ClientVersion>,
client_version: &Option<ClientVersion>,
conn: &DbConn,
) -> ApiResult<Option<String>> {
let twofactors = TwoFactor::find_by_user(&user.uuid, conn).await;
@@ -780,7 +746,7 @@ async fn twofactor_auth(
.iter()
.filter_map(|tf| {
let provider_type = TwoFactorType::from_i32(tf.atype)?;
(tf.enabled && is_twofactor_provider_usable(&provider_type, Some(&tf.data))).then_some(tf.atype)
(tf.enabled && is_twofactor_provider_usable(provider_type, Some(&tf.data))).then_some(tf.atype)
})
.collect();
if twofactor_ids.is_empty() {
@@ -788,51 +754,56 @@ async fn twofactor_auth(
}
let selected_id = data.two_factor_provider.unwrap_or(twofactor_ids[0]); // If we aren't given a two factor provider, assume the first one
// Ignore Remember and RecoveryCode Types during this check, these are special
if ![TwoFactorType::Remember as i32, TwoFactorType::RecoveryCode as i32].contains(&selected_id)
&& !twofactor_ids.contains(&selected_id)
{
if !twofactor_ids.contains(&selected_id) {
err_json!(
json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"Invalid two factor provider"
)
}
let Some(ref twofactor_code) = data.two_factor_token else {
err_json!(
json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA token not provided"
)
let twofactor_code = match data.two_factor_token {
Some(ref code) => code,
None => {
err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA token not provided"
)
}
};
let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled);
let selected_data = selected_data(selected_twofactor);
use crate::crypto::ct_eq;
let selected_data = _selected_data(selected_twofactor);
match TwoFactorType::from_i32(selected_id) {
Some(TwoFactorType::Authenticator) => {
authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await?;
authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await?
}
Some(TwoFactorType::Webauthn) => webauthn::validate_webauthn_login(&user.uuid, twofactor_code, conn).await?,
Some(TwoFactorType::YubiKey) => yubikey::validate_yubikey_login(twofactor_code, &selected_data?).await?,
Some(TwoFactorType::Duo) => {
if CONFIG.duo_use_iframe() {
// Legacy iframe prompt flow
duo::validate_duo_login(&user.email, twofactor_code, conn).await?;
} else {
// OIDC based flow
duo_oidc::validate_duo_login(
&user.email,
twofactor_code,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?;
match CONFIG.duo_use_iframe() {
true => {
// Legacy iframe prompt flow
duo::validate_duo_login(&user.email, twofactor_code, conn).await?
}
false => {
// OIDC based flow
duo_oidc::validate_duo_login(
&user.email,
twofactor_code,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?
}
}
}
Some(TwoFactorType::Email) => {
email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await?;
email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await?
}
Some(TwoFactorType::Remember) => {
match device.twofactor_remember {
@@ -840,7 +811,7 @@ async fn twofactor_auth(
// If it is invalid we need to trigger the 2FA Login prompt
Some(ref token)
if !CONFIG.disable_2fa_remember()
&& (crypto::ct_eq(token, twofactor_code)
&& (ct_eq(token, twofactor_code)
&& auth::decode_2fa_remember(twofactor_code)
.is_ok_and(|t| t.sub == device.uuid && t.user_uuid == user.uuid)) => {}
_ => {
@@ -851,7 +822,7 @@ async fn twofactor_auth(
device.save(true, conn).await?;
}
err_json!(
json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA Remember token not provided or expired"
)
}
@@ -892,15 +863,15 @@ async fn twofactor_auth(
Ok(two_factor)
}
fn selected_data(tf: Option<TwoFactor>) -> ApiResult<String> {
fn _selected_data(tf: Option<TwoFactor>) -> ApiResult<String> {
tf.map(|t| t.data).map_res("Two factor doesn't exist")
}
async fn json_err_twofactor(
async fn _json_err_twofactor(
providers: &[i32],
user_id: &UserId,
data: &ConnectData,
client_version: Option<&ClientVersion>,
client_version: &Option<ClientVersion>,
conn: &DbConn,
) -> ApiResult<Value> {
let mut result = json!({
@@ -917,38 +888,42 @@ async fn json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = Value::Null;
match TwoFactorType::from_i32(*provider) {
Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ }
Some(TwoFactorType::Webauthn) if CONFIG.is_webauthn_2fa_supported() => {
let request = webauthn::generate_webauthn_login(user_id, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = request.0;
}
Some(TwoFactorType::Duo) => {
let email = if let Some(u) = User::find_by_uuid(user_id, conn).await {
u.email
} else {
err!("User does not exist")
let email = match User::find_by_uuid(user_id, conn).await {
Some(u) => u.email,
None => err!("User does not exist"),
};
if CONFIG.duo_use_iframe() {
// Legacy iframe prompt flow
let (signature, host) = duo::generate_duo_signature(&email, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Host": host,
"Signature": signature,
});
} else {
// OIDC based flow
let auth_url = duo_oidc::get_duo_auth_url(
&email,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?;
match CONFIG.duo_use_iframe() {
true => {
// Legacy iframe prompt flow
let (signature, host) = duo::generate_duo_signature(&email, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Host": host,
"Signature": signature,
})
}
false => {
// OIDC based flow
let auth_url = duo_oidc::get_duo_auth_url(
&email,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"AuthUrl": auth_url,
});
result["TwoFactorProviders2"][provider.to_string()] = json!({
"AuthUrl": auth_url,
})
}
}
}
@@ -961,7 +936,7 @@ async fn json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Nfc": yubikey_metadata.nfc,
});
})
}
Some(tf_type @ TwoFactorType::Email) => {
@@ -979,30 +954,16 @@ async fn json_err_twofactor(
// Send email immediately if email is the only 2FA option.
if providers.len() == 1 && !disabled_send {
email::send_token(user_id, conn).await?;
email::send_token(user_id, conn).await?
}
let email_data = email::EmailTokenData::from_json(&twofactor.data)?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Email": email::obscure_email(&email_data.email),
});
})
}
None
| Some(
TwoFactorType::Authenticator
| TwoFactorType::EmailVerificationChallenge
| TwoFactorType::OrganizationDuo
| TwoFactorType::ProtectedActions
| TwoFactorType::RecoveryCode
| TwoFactorType::Remember
| TwoFactorType::U2f
| TwoFactorType::U2fLoginChallenge
| TwoFactorType::U2fRegisterChallenge
| TwoFactorType::Webauthn
| TwoFactorType::WebauthnLoginChallenge
| TwoFactorType::WebauthnRegisterChallenge,
) => { /* Nothing special to do for these providers */ }
_ => {}
}
}
@@ -1010,18 +971,13 @@ async fn json_err_twofactor(
}
#[post("/accounts/prelogin", data = "<data>")]
async fn post_prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
prelogin(data, conn).await
}
#[post("/accounts/prelogin/password", data = "<data>")]
async fn prelogin_password(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
prelogin(data, conn).await
async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
}
#[post("/accounts/register", data = "<data>")]
async fn identity_register(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
register(data, false, conn).await
_register(data, false, conn).await
}
#[derive(Debug, Deserialize)]
@@ -1060,13 +1016,13 @@ async fn register_verification_email(
if should_send_mail {
let user = User::find_by_mail(&data.email, &conn).await;
if user.as_ref().is_some_and(|u| u.private_key.is_some()) {
if user.filter(|u| u.private_key.is_some()).is_some() {
// There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that don't.
// Add a randomized sleep to mitigate this somewhat.
use rand::{RngExt, rngs::SmallRng};
use rand::{rngs::SmallRng, RngExt};
let mut rng: SmallRng = rand::make_rng();
let sleep_ms: u64 = rng.random_range(900..=1100);
let sleep_ms = rng.random_range(900..=1100) as u64;
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
} else {
mail::send_register_verify_email(&data.email, &token).await?;
@@ -1082,7 +1038,7 @@ async fn register_verification_email(
#[post("/accounts/register/finish", data = "<data>")]
async fn register_finish(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
register(data, true, conn).await
_register(data, true, conn).await
}
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
@@ -1141,11 +1097,11 @@ struct ConnectData {
// Needed for authorization code
#[field(name = uncased("code"))]
code: Option<OIDCCode>,
code: Option<OIDCState>,
#[field(name = uncased("code_verifier"))]
code_verifier: Option<OIDCCodeVerifier>,
}
fn check_is_some<T>(value: Option<&T>, msg: &str) -> EmptyResult {
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
if value.is_none() {
err!(msg)
}
@@ -1164,32 +1120,33 @@ fn prevalidate() -> JsonResult {
}
}
const SSO_BINDING_COOKIE: &str = "VW_SSO_BINDING";
#[get("/connect/oidc-signin?<code>&<state>", rank = 1)]
async fn oidcsignin(code: OIDCCode, state: String, cookies: &CookieJar<'_>, mut conn: DbConn) -> ApiResult<Redirect> {
oidcsignin_redirect(state, code, None, cookies, &mut conn).await
async fn oidcsignin(code: OIDCCode, state: String, mut conn: DbConn) -> ApiResult<Redirect> {
_oidcsignin_redirect(
state,
OIDCCodeWrapper::Ok {
code,
},
&mut conn,
)
.await
}
// Bitwarden client appear to only care for code and state
// We save the error in the database and set the encoded state as the code to be able to retrieve them later on
// cf: https://github.com/bitwarden/clients/blob/afd36d290ce18fb0048e0575e7d5a8f78b5dbffc/libs/auth/src/angular/sso/sso.component.ts#L156
// Bitwarden client appear to only care for code and state so we pipe it through
// cf: https://github.com/bitwarden/clients/blob/80b74b3300e15b4ae414dc06044cc9b02b6c10a6/libs/auth/src/angular/sso/sso.component.ts#L141
#[get("/connect/oidc-signin?<state>&<error>&<error_description>", rank = 2)]
async fn oidcsignin_error(
state: String,
error: String,
error_description: Option<String>,
cookies: &CookieJar<'_>,
mut conn: DbConn,
) -> ApiResult<Redirect> {
oidcsignin_redirect(
state.clone(),
state.into(),
Some(OIDCCodeResponseError {
_oidcsignin_redirect(
state,
OIDCCodeWrapper::Error {
error,
error_description,
}),
cookies,
},
&mut conn,
)
.await
@@ -1198,32 +1155,18 @@ async fn oidcsignin_error(
// The state was encoded using Base64 to ensure no issue with providers.
// iss and scope parameters are needed for redirection to work on IOS.
// We pass the state as the code to get it back later on.
async fn oidcsignin_redirect(
async fn _oidcsignin_redirect(
base64_state: String,
code: OIDCCode,
error: Option<OIDCCodeResponseError>,
cookies: &CookieJar<'_>,
code_response: OIDCCodeWrapper,
conn: &mut DbConn,
) -> ApiResult<Redirect> {
let state = sso::decode_state(&base64_state)?;
let Some(mut sso_auth) = SsoAuth::find(&state, conn).await else {
err!(format!("Cannot retrieve sso_auth for {state}"))
let mut sso_auth = match SsoAuth::find(&state, conn).await {
None => err!(format!("Cannot retrieve sso_auth for {state}")),
Some(sso_auth) => sso_auth,
};
// Browser-binding check
// The cookie was set on /connect/authorize and must come from the same browser that initiated the flow.
let cookie_value = cookies.get(SSO_BINDING_COOKIE).map(|c| c.value().to_owned());
let provided_hash = cookie_value.as_deref().map(|v| crypto::sha256_hex(v.as_bytes()));
match (sso_auth.binding_hash.as_deref(), provided_hash.as_deref()) {
(Some(expected), Some(actual)) if crypto::ct_eq(expected, actual) => {}
_ => err!(format!("SSO session binding mismatch for {state}")),
}
cookies
.remove(Cookie::build(SSO_BINDING_COOKIE).path(format!("{}/identity/connect/", CONFIG.domain_path())).build());
sso_auth.code_response = Some(code.clone());
sso_auth.code_response_error = error;
sso_auth.code_response = Some(code_response);
sso_auth.updated_at = Utc::now().naive_utc();
sso_auth.save(conn).await?;
@@ -1233,7 +1176,7 @@ async fn oidcsignin_redirect(
};
url.query_pairs_mut()
.append_pair("code", &code)
.append_pair("code", &state)
.append_pair("state", &state)
.append_pair("scope", &AuthMethod::Sso.scope())
.append_pair("iss", &CONFIG.domain());
@@ -1269,7 +1212,7 @@ struct AuthorizeData {
// The `redirect_uri` will change depending of the client (web, android, ios ..)
#[get("/connect/authorize?<data..>")]
async fn authorize(data: AuthorizeData, cookies: &CookieJar<'_>, secure: Secure, conn: DbConn) -> ApiResult<Redirect> {
async fn authorize(data: AuthorizeData, conn: DbConn) -> ApiResult<Redirect> {
let AuthorizeData {
client_id,
redirect_uri,
@@ -1283,23 +1226,7 @@ async fn authorize(data: AuthorizeData, cookies: &CookieJar<'_>, secure: Secure,
err!("Unsupported code challenge method");
}
// Generate browser-binding token. Stored hashed in DB; raw value handed to the browser as a cookie.
// Validated on /connect/oidc-signin
let binding_token = data_encoding::BASE64URL_NOPAD.encode(&crypto::get_random_bytes::<32>());
let binding_hash = crypto::sha256_hex(binding_token.as_bytes());
let auth_url =
sso::authorize_url(state, code_challenge, &client_id, &redirect_uri, Some(binding_hash), conn).await?;
cookies.add(
Cookie::build((SSO_BINDING_COOKIE, binding_token))
.path(format!("{}/identity/connect/", CONFIG.domain_path()))
.max_age(time::Duration::seconds(sso::SSO_AUTH_EXPIRATION.num_seconds()))
.same_site(SameSite::Lax) // Lax is needed because the IdP runs on a different FQDN
.http_only(true)
.secure(secure.https)
.build(),
);
let auth_url = sso::authorize_url(state, code_challenge, &client_id, &redirect_uri, conn).await?;
Ok(Redirect::temporary(String::from(auth_url)))
}
+4 -7
View File
@@ -32,13 +32,11 @@ pub use crate::api::{
web::routes as web_routes,
web::static_files,
};
use crate::{
CONFIG,
db::{
DbConn,
models::{OrgPolicy, OrgPolicyType, User},
},
use crate::db::{
models::{OrgPolicy, OrgPolicyType, User},
DbConn,
};
use crate::CONFIG;
// Type aliases for API methods results
pub type ApiResult<T> = Result<T, crate::error::Error>;
@@ -76,7 +74,6 @@ impl PasswordOrOtpData {
}
}
#[expect(clippy::struct_excessive_bools, reason = "Bitwarden clients expect the data in this specific format")]
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct MasterPasswordPolicy {
+19 -20
View File
@@ -6,22 +6,17 @@ use std::{
use chrono::{NaiveDateTime, Utc};
use rmpv::Value;
use rocket::{Route, futures::StreamExt};
use rocket::{futures::StreamExt, Route};
use rocket_ws::{Message, WebSocket};
use tokio::sync::mpsc::Sender;
use crate::{
CONFIG, Error,
auth::{ClientIp, WsAccessTokenHeader},
db::{
DbConn,
models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId},
DbConn,
},
};
use super::{
push::push_auth_request, push::push_auth_response, push_cipher_update, push_folder_update, push_logout,
push_send_update, push_user_update,
Error, CONFIG,
};
pub static WS_USERS: LazyLock<Arc<WebSocketUsers>> = LazyLock::new(|| {
@@ -36,6 +31,11 @@ pub static WS_ANONYMOUS_SUBSCRIPTIONS: LazyLock<Arc<AnonymousWebSocketSubscripti
})
});
use super::{
push::push_auth_request, push::push_auth_response, push_cipher_update, push_folder_update, push_logout,
push_send_update, push_user_update,
};
static NOTIFICATIONS_DISABLED: LazyLock<bool> = LazyLock::new(|| !CONFIG.enable_websocket() && !CONFIG.push_enabled());
pub fn routes() -> Vec<Route> {
@@ -102,7 +102,7 @@ impl Drop for WSAnonymousEntryMapGuard {
}
}
#[expect(tail_expr_drop_order)]
#[allow(tail_expr_drop_order)]
#[get("/hub?<data..>")]
fn websockets_hub<'r>(
ws: WebSocket,
@@ -186,7 +186,7 @@ fn websockets_hub<'r>(
})
}
#[expect(tail_expr_drop_order)]
#[allow(tail_expr_drop_order)]
#[get("/anonymous-hub?<token..>")]
fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result<rocket_ws::Stream!['r], Error> {
info!("Accepting Anonymous Rocket WS connection from {}", ip.ip);
@@ -268,15 +268,14 @@ fn serialize(val: &Value) -> Vec<u8> {
let mut len_buf: Vec<u8> = Vec::new();
loop {
#[expect(clippy::cast_possible_truncation, reason = "masked to 7 bits, fits u8")]
let mut size_part = (size & 0x7f) as u8;
let mut size_part = size & 0x7f;
size >>= 7;
if size > 0 {
size_part |= 0x80;
}
len_buf.push(size_part);
len_buf.push(size_part as u8);
if size == 0 {
break;
@@ -330,7 +329,7 @@ pub struct WebSocketUsers {
impl WebSocketUsers {
async fn send_update(&self, user_id: &UserId, data: &[u8]) {
if let Some(user) = self.map.get(user_id.as_ref()).map(|v| v.clone()) {
for (_, sender) in &user {
for (_, sender) in user.iter() {
if let Err(e) = sender.send(Message::binary(data)).await {
error!("Error sending WS update {e}");
}
@@ -339,7 +338,7 @@ impl WebSocketUsers {
}
// NOTE: The last modified date needs to be updated before calling these methods
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: Option<&PushId>, conn: &DbConn) {
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
return;
@@ -539,10 +538,10 @@ pub struct AnonymousWebSocketSubscriptions {
impl AnonymousWebSocketSubscriptions {
async fn send_update(&self, token: &str, data: &[u8]) {
if let Some(sender) = self.map.get(token).map(|v| v.clone())
&& let Err(e) = sender.send(Message::binary(data)).await
{
error!("Error sending WS update {e}");
if let Some(sender) = self.map.get(token).map(|v| v.clone()) {
if let Err(e) = sender.send(Message::binary(data)).await {
error!("Error sending WS update {e}");
}
}
}
@@ -583,7 +582,7 @@ fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType, acting_device_id:
V::Nil,
"ReceiveMessage".into(),
V::Array(vec![V::Map(vec![
("ContextId".into(), acting_device_id.map_or(V::Nil, |v| v.to_string().into())),
("ContextId".into(), acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| V::Nil)),
("Type".into(), (ut as i32).into()),
("Payload".into(), payload.into()),
])]),
+26 -26
View File
@@ -4,21 +4,21 @@ use std::{
};
use reqwest::{
Method,
header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
Method,
};
use serde_json::Value;
use tokio::sync::RwLock;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult, UpdateType},
db::{
DbConn,
models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId},
DbConn,
},
http_client::make_http_request,
util::{format_date, get_uuid},
CONFIG,
};
#[derive(Deserialize)]
@@ -74,9 +74,9 @@ async fn get_auth_api_token() -> ApiResult<String> {
};
let mut api_token = API_TOKEN.write().await;
// Token valid for half the specified time
let half_expires_in = u64::from((json_pushtoken.expires_in / 2).max(0).cast_unsigned());
api_token.valid_until = Instant::now().checked_add(Duration::from_secs(half_expires_in)).unwrap();
api_token.valid_until = Instant::now()
.checked_add(Duration::new((json_pushtoken.expires_in / 2) as u64, 0)) // Token valid for half the specified time
.unwrap();
api_token.access_token = json_pushtoken.access_token;
@@ -135,7 +135,7 @@ pub async fn register_push_device(device: &mut Device, conn: &DbConn) -> EmptyRe
Ok(())
}
pub async fn unregister_push_device(push_id: Option<&PushId>) -> EmptyResult {
pub async fn unregister_push_device(push_id: &Option<PushId>) -> EmptyResult {
if !CONFIG.push_enabled() || push_id.is_none() {
return Ok(());
}
@@ -161,7 +161,7 @@ pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device
// We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too.
if cipher.organization_uuid.is_some() {
return;
}
};
let Some(user_id) = &cipher.user_uuid else {
debug!("Cipher has no uuid");
return;
@@ -206,7 +206,7 @@ pub async fn push_logout(user: &User, acting_device: Option<&Device>, conn: &DbC
}
}
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: Option<&PushId>, conn: &DbConn) {
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
if Device::check_user_has_push_device(&user.uuid, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": user.uuid,
@@ -244,23 +244,23 @@ pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device
}
pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) {
if let Some(s) = &send.user_uuid
&& Device::check_user_has_push_device(s, conn).await
{
tokio::task::spawn(send_to_push_relay(json!({
"userId": send.user_uuid,
"organizationId": null,
"deviceId": device.push_uuid, // Should be the records unique uuid of the acting device (unique uuid per user/device)
"identifier": device.uuid, // Should be the acting device id (aka uuid per device/app)
"type": ut as i32,
"payload": {
"id": send.uuid,
if let Some(s) = &send.user_uuid {
if Device::check_user_has_push_device(s, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": send.user_uuid,
"revisionDate": format_date(&send.revision_date)
},
"clientType": null,
"installationId": null
})));
"organizationId": null,
"deviceId": device.push_uuid, // Should be the records unique uuid of the acting device (unique uuid per user/device)
"identifier": device.uuid, // Should be the acting device id (aka uuid per device/app)
"type": ut as i32,
"payload": {
"id": send.uuid,
"userId": send.user_uuid,
"revisionDate": format_date(&send.revision_date)
},
"clientType": null,
"installationId": null
})));
}
}
}
@@ -296,7 +296,7 @@ async fn send_to_push_relay(notification_data: Value) {
.await
{
error!("An error occurred while sending a send update to the push relay: {e}");
}
};
}
pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) {
+10 -38
View File
@@ -1,24 +1,21 @@
use std::path::{Path, PathBuf};
use rocket::{
Catcher, Route,
fs::NamedFile,
http::ContentType,
response::{Redirect, content::RawCss as Css, content::RawHtml as Html},
response::{content::RawCss as Css, content::RawHtml as Html, Redirect},
serde::json::Json,
Catcher, Route,
};
use serde_json::Value;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult, core::now},
api::{core::now, ApiResult, EmptyResult},
auth::decode_file_download,
db::{
DbConn,
models::{AttachmentId, CipherId},
},
db::models::{AttachmentId, CipherId},
error::Error,
util::Cached,
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -26,20 +23,12 @@ pub fn routes() -> Vec<Route> {
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
let mut routes = routes![attachments, alive, alive_head, static_files];
if CONFIG.web_vault_enabled() {
routes.append(&mut routes![
web_index,
web_index_direct,
web_index_head,
app_id,
apple_app_site_association,
web_files,
vaultwarden_css
]);
routes.append(&mut routes![web_index, web_index_direct, web_index_head, app_id, web_files, vaultwarden_css]);
}
#[cfg(debug_assertions)]
if CONFIG.reload_templates() {
routes.append(&mut routes![static_files_dev]);
routes.append(&mut routes![_static_files_dev]);
}
routes
@@ -171,24 +160,6 @@ fn app_id() -> Cached<(ContentType, Json<Value>)> {
)
}
#[get("/.well-known/apple-app-site-association")]
fn apple_app_site_association() -> Cached<(ContentType, Json<Value>)> {
Cached::long(
(
ContentType::JSON,
Json(json!({
"webcredentials": {
"apps": [
"LTZ2PFU5D6.com.8bit.bitwarden",
"LTZ2PFU5D6.com.8bit.bitwarden.beta"
]
}
})),
),
true,
)
}
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
@@ -207,6 +178,7 @@ async fn attachments(cipher_id: CipherId, file_id: AttachmentId, token: String)
}
// We use DbConn here to let the alive healthcheck also verify the database connection.
use crate::db::DbConn;
#[get("/alive")]
fn alive(_conn: DbConn) -> Json<String> {
now()
@@ -225,7 +197,7 @@ fn alive_head(_conn: DbConn) -> EmptyResult {
// NOTE: Do not forget to add any new files added to the `static_files` function below!
#[cfg(debug_assertions)]
#[get("/vw_static/<filename>", rank = 1)]
pub async fn static_files_dev(filename: PathBuf) -> Option<NamedFile> {
pub async fn _static_files_dev(filename: PathBuf) -> Option<NamedFile> {
warn!("LOADING STATIC FILES FROM DISK");
let file = filename.to_str().unwrap_or_default();
let ext = filename.extension().unwrap_or_default();
@@ -238,7 +210,7 @@ pub async fn static_files_dev(filename: PathBuf) -> Option<NamedFile> {
if let Ok(path) = path {
return NamedFile::open(path).await.ok();
}
};
None
}
+81 -73
View File
@@ -5,30 +5,21 @@ use std::{
};
use chrono::{DateTime, TimeDelta, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, errors::ErrorKind};
use jsonwebtoken::{errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
use num_traits::FromPrimitive;
use openssl::rsa::Rsa;
use serde::{de::DeserializeOwned, ser::Serialize};
use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
use crate::{
CONFIG,
api::ApiResult,
config::PathType,
db::{
DbConn,
models::{
AttachmentId, CipherId, Collection, CollectionId, Device, DeviceId, DeviceType, EmergencyAccessId,
Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, OrganizationId, SendFileId,
SendId, User, UserId, UserStampException,
},
db::models::{
AttachmentId, CipherId, CollectionId, DeviceId, DeviceType, EmergencyAccessId, MembershipId, OrgApiKeyId,
OrganizationId, SendFileId, SendId, UserId,
},
error::Error,
sso,
sso, CONFIG,
};
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
@@ -61,12 +52,16 @@ static PRIVATE_RSA_KEY: OnceLock<EncodingKey> = OnceLock::new();
static PUBLIC_RSA_KEY: OnceLock<DecodingKey> = OnceLock::new();
pub async fn initialize_keys() -> Result<(), Error> {
use std::io::Error as IoError;
use std::io::Error;
let rsa_key_filename = crate::storage::file_name(&CONFIG.private_rsa_key())
.ok_or_else(|| IoError::other("Private RSA key path missing filename"))?;
let rsa_key_filename = std::path::PathBuf::from(CONFIG.private_rsa_key())
.file_name()
.ok_or_else(|| Error::other("Private RSA key path missing filename"))?
.to_str()
.ok_or_else(|| Error::other("Private RSA key path filename is not valid UTF-8"))?
.to_string();
let operator = CONFIG.opendal_operator_for_path_type(&PathType::RsaKey).map_err(IoError::other)?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::RsaKey).map_err(Error::other)?;
let priv_key_buffer = match operator.read(&rsa_key_filename).await {
Ok(buffer) => Some(buffer),
@@ -235,7 +230,7 @@ impl LoginJwtClaims {
// let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect();
if exp <= (now + *BW_EXPIRATION).timestamp() {
warn!("Raise access_token lifetime to more than 5min.");
warn!("Raise access_token lifetime to more than 5min.")
}
// Create the JWT claims struct, to send to the client
@@ -262,7 +257,7 @@ impl LoginJwtClaims {
sstamp: user.security_stamp.clone(),
device: device.uuid.clone(),
devicetype: DeviceType::from_i32(device.atype).to_string(),
client_id: client_id.unwrap_or("undefined".to_owned()),
client_id: client_id.unwrap_or("undefined".to_string()),
scope,
amr: vec!["Application".into()],
}
@@ -515,7 +510,7 @@ pub fn generate_admin_claims() -> BasicJwtClaims {
nbf: time_now.timestamp(),
exp: (time_now + TimeDelta::try_minutes(CONFIG.admin_session_lifetime()).unwrap()).timestamp(),
iss: JWT_ADMIN_ISSUER.to_string(),
sub: "admin_panel".to_owned(),
sub: "admin_panel".to_string(),
}
}
@@ -532,6 +527,16 @@ pub fn generate_send_claims(send_id: &SendId, file_id: &SendFileId) -> BasicJwtC
//
// Bearer token authentication
//
use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use crate::db::{
models::{Collection, Device, Membership, MembershipStatus, MembershipType, User, UserStampException},
DbConn,
};
pub struct Host {
pub host: String,
}
@@ -547,7 +552,7 @@ impl<'r> FromRequest<'r> for Host {
let host = if CONFIG.domain_set() {
CONFIG.domain()
} else if let Some(referer) = headers.get_one("Referer") {
referer.to_owned()
referer.to_string()
} else {
// Try to guess from the headers
let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") {
@@ -583,15 +588,13 @@ impl<'r> FromRequest<'r> for ClientHeaders {
type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let Outcome::Success(ip) = ClientIp::from_request(request).await else {
err_handler!("Error getting Client IP")
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
};
// When unknown or unable to parse, return 'UnknownBrowser'
let device_type: i32 = request
.headers()
.get_one("device-type")
.and_then(|d| d.parse().ok())
.unwrap_or(DeviceType::UnknownBrowser as i32);
// When unknown or unable to parse, return 14, which is 'Unknown Browser'
let device_type: i32 =
request.headers().get_one("device-type").map(|d| d.parse().unwrap_or(14)).unwrap_or_else(|| 14);
Outcome::Success(ClientHeaders {
device_type,
@@ -615,19 +618,18 @@ impl<'r> FromRequest<'r> for Headers {
let headers = request.headers();
let host = try_outcome!(Host::from_request(request).await).host;
let Outcome::Success(ip) = ClientIp::from_request(request).await else {
err_handler!("Error getting Client IP")
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
};
// Get access_token
let access_token: &str = if let Some(a) = headers.get_one("Authorization") {
if let Some(split) = a.rsplit("Bearer ").next() {
split
} else {
err_handler!("No access token provided")
}
} else {
err_handler!("No access token provided")
let access_token: &str = match headers.get_one("Authorization") {
Some(a) => match a.rsplit("Bearer ").next() {
Some(split) => split,
None => err_handler!("No access token provided"),
},
None => err_handler!("No access token provided"),
};
// Check JWT token is valid and get device and user from it
@@ -638,8 +640,9 @@ impl<'r> FromRequest<'r> for Headers {
let device_id = claims.device;
let user_id = claims.sub;
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else {
@@ -670,7 +673,7 @@ impl<'r> FromRequest<'r> for Headers {
error!("Error updating user: {e:#?}");
}
err_handler!("Stamp exception is expired")
} else if !stamp_exception.routes.contains(&current_route.to_owned()) {
} else if !stamp_exception.routes.contains(&current_route.to_string()) {
err_handler!("Invalid security stamp: Current route and exception route do not match")
} else if stamp_exception.security_stamp != claims.sstamp {
err_handler!("Invalid security stamp for matched stamp exception")
@@ -758,8 +761,9 @@ impl<'r> FromRequest<'r> for OrgHeaders {
match url_org_id {
Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => {
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
let user = headers.user;
@@ -831,16 +835,16 @@ impl<'r> FromRequest<'r> for AdminHeaders {
// but there could be cases where it is a query value.
// First check the path, if this is not a valid uuid, try the query values.
fn get_col_id(request: &Request<'_>) -> Option<CollectionId> {
if let Some(Ok(col_id)) = request.param::<String>(3)
&& uuid::Uuid::parse_str(&col_id).is_ok()
{
return Some(col_id.into());
if let Some(Ok(col_id)) = request.param::<String>(3) {
if uuid::Uuid::parse_str(&col_id).is_ok() {
return Some(col_id.into());
}
}
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId")
&& uuid::Uuid::parse_str(&col_id).is_ok()
{
return Some(col_id.into());
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
if uuid::Uuid::parse_str(&col_id).is_ok() {
return Some(col_id.into());
}
}
None
@@ -864,16 +868,18 @@ impl<'r> FromRequest<'r> for ManagerHeaders {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = try_outcome!(OrgHeaders::from_request(request).await);
if headers.is_confirmed_and_manager() {
if let Some(col_id) = get_col_id(request) {
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
};
match get_col_id(request) {
Some(col_id) => {
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await {
err_handler!("The current user isn't a manager for this collection")
if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await {
err_handler!("The current user isn't a manager for this collection")
}
}
} else {
err_handler!("Error getting the collection id")
_ => err_handler!("Error getting the collection id"),
}
Outcome::Success(Self {
@@ -1034,7 +1040,7 @@ impl From<OrgMemberHeaders> for Headers {
//
// Client IP address detection
//
#[derive(Copy, Clone)]
pub struct ClientIp {
pub ip: IpAddr,
}
@@ -1066,7 +1072,6 @@ impl<'r> FromRequest<'r> for ClientIp {
}
}
#[derive(Copy, Clone)]
pub struct Secure {
pub https: bool,
}
@@ -1152,14 +1157,15 @@ pub enum AuthMethod {
impl AuthMethod {
pub fn scope(&self) -> String {
match self {
AuthMethod::OrgApiKey => "api.organization".to_owned(),
AuthMethod::UserApiKey => "api".to_owned(),
AuthMethod::Password | AuthMethod::Sso => "api offline_access".to_owned(),
AuthMethod::OrgApiKey => "api.organization".to_string(),
AuthMethod::Password => "api offline_access".to_string(),
AuthMethod::Sso => "api offline_access".to_string(),
AuthMethod::UserApiKey => "api".to_string(),
}
}
pub fn scope_vec(&self) -> Vec<String> {
self.scope().split_whitespace().map(str::to_owned).collect()
self.scope().split_whitespace().map(str::to_string).collect()
}
pub fn check_scope(&self, scope: Option<&String>) -> ApiResult<String> {
@@ -1272,15 +1278,17 @@ pub async fn refresh_tokens(
};
// Get device by refresh token
let Some(mut device) = Device::find_by_refresh_token(&refresh_claims.device_token, conn).await else {
err!("Invalid refresh token")
let mut device = match Device::find_by_refresh_token(&refresh_claims.device_token, conn).await {
None => err!("Invalid refresh token"),
Some(device) => device,
};
// Save to update `updated_at`.
device.save(true, conn).await?;
let Some(user) = User::find_by_uuid(&device.user_uuid, conn).await else {
err!("Impossible to find user")
let user = match User::find_by_uuid(&device.user_uuid, conn).await {
None => err!("Impossible to find user"),
Some(user) => user,
};
let auth_tokens = match refresh_claims.sub {
+234 -157
View File
@@ -3,8 +3,8 @@ use std::{
fmt,
process::exit,
sync::{
LazyLock, RwLock,
atomic::{AtomicBool, Ordering},
LazyLock, RwLock,
},
};
@@ -14,23 +14,26 @@ use serde::de::{self, Deserialize, Deserializer, MapAccess, Visitor};
use crate::{
error::Error,
storage,
util::{
FeatureFlagFilter, get_active_web_release, get_env, get_env_bool, is_valid_email,
parse_experimental_client_feature_flags,
get_active_web_release, get_env, get_env_bool, is_valid_email, parse_experimental_client_feature_flags,
FeatureFlagFilter,
},
};
static CONFIG_FILE: LazyLock<String> = LazyLock::new(|| {
let data_folder = get_env("DATA_FOLDER").unwrap_or_else(|| String::from("data"));
get_env("CONFIG_FILE").unwrap_or_else(|| storage::join_path(&data_folder, "config.json"))
get_env("CONFIG_FILE").unwrap_or_else(|| format!("{data_folder}/config.json"))
});
static CONFIG_FILE_PARENT_DIR: LazyLock<String> =
LazyLock::new(|| storage::parent(&CONFIG_FILE).unwrap_or_else(|| "data".to_owned()));
static CONFIG_FILE_PARENT_DIR: LazyLock<String> = LazyLock::new(|| {
let path = std::path::PathBuf::from(&*CONFIG_FILE);
path.parent().unwrap_or(std::path::Path::new("data")).to_str().unwrap_or("data").to_string()
});
static CONFIG_FILENAME: LazyLock<String> =
LazyLock::new(|| storage::file_name(&CONFIG_FILE).unwrap_or_else(|| "config.json".to_owned()));
static CONFIG_FILENAME: LazyLock<String> = LazyLock::new(|| {
let path = std::path::PathBuf::from(&*CONFIG_FILE);
path.file_name().unwrap_or(std::ffi::OsStr::new("config.json")).to_str().unwrap_or("config.json").to_string()
});
pub static SKIP_CONFIG_VALIDATION: AtomicBool = AtomicBool::new(false);
@@ -260,7 +263,7 @@ macro_rules! make_config {
}
async fn from_file() -> Result<Self, Error> {
let operator = storage::operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let operator = opendal_operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let config_bytes = operator.read(&CONFIG_FILENAME).await?;
println!("[INFO] Using saved config from `{}` for configuration.\n", *CONFIG_FILE);
serde_json::from_slice(&config_bytes.to_vec()).map_err(Into::into)
@@ -360,7 +363,13 @@ macro_rules! make_config {
)+)+
pub fn prepare_json(&self) -> serde_json::Value {
fn get_form_type(rust_type: &'static str) -> &'static str {
let (def, cfg, overridden) = {
// Lock the inner as short as possible and clone what is needed to prevent deadlocks
let inner = &self.inner.read().unwrap();
(inner._env.build(), inner.config.clone(), inner._overrides.clone())
};
fn _get_form_type(rust_type: &'static str) -> &'static str {
match rust_type {
"Pass" => "password",
"String" => "text",
@@ -369,7 +378,7 @@ macro_rules! make_config {
}
}
fn get_doc(doc_str: &'static str) -> ElementDoc {
fn _get_doc(doc_str: &'static str) -> ElementDoc {
let mut split = doc_str.split("|>").map(str::trim);
ElementDoc {
name: split.next().unwrap_or_default(),
@@ -377,12 +386,6 @@ macro_rules! make_config {
}
}
let (def, cfg, overridden) = {
// Lock the inner as short as possible and clone what is needed to prevent deadlocks
let inner = &self.inner.read().unwrap();
(inner._env.build(), inner.config.clone(), inner._overrides.clone())
};
let data: Vec<GroupData> = vec![
$( // This repetition is for each group
GroupData {
@@ -397,8 +400,8 @@ macro_rules! make_config {
name: stringify!($name),
value: serde_json::to_value(&cfg.$name).unwrap_or_default(),
default: serde_json::to_value(&def.$name).unwrap_or_default(),
r#type: get_form_type(stringify!($ty)),
doc: get_doc(concat!($($doc),+)),
r#type: _get_form_type(stringify!($ty)),
doc: _get_doc(concat!($($doc),+)),
overridden: overridden.contains(&pastey::paste!(stringify!([<$name:upper>]))),
},
)+], // End of elements repetition
@@ -408,31 +411,9 @@ macro_rules! make_config {
}
pub fn get_support_json(&self) -> serde_json::Value {
/// We map over the string and remove all alphanumeric, _ and - characters.
/// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds)
fn privacy_mask(value: &str) -> String {
let mut n: u16 = 0;
let mut colon_match = false;
value
.chars()
.map(|c| {
n += 1;
match c {
':' if n <= 11 => {
colon_match = true;
c
}
'/' if n <= 13 && colon_match => c,
',' => c,
_ => '*',
}
})
.collect::<String>()
}
// Define which config keys need to be masked.
// Pass types will always be masked and no need to put them in the list.
// Besides Pass, only String types will be masked via privacy_mask.
// Besides Pass, only String types will be masked via _privacy_mask.
const PRIVACY_CONFIG: &[&str] = &[
"allowed_connect_src",
"allowed_iframe_ancestors",
@@ -459,6 +440,28 @@ macro_rules! make_config {
inner.config.clone()
};
/// We map over the string and remove all alphanumeric, _ and - characters.
/// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds)
fn _privacy_mask(value: &str) -> String {
let mut n: u16 = 0;
let mut colon_match = false;
value
.chars()
.map(|c| {
n += 1;
match c {
':' if n <= 11 => {
colon_match = true;
c
}
'/' if n <= 13 && colon_match => c,
',' => c,
_ => '*',
}
})
.collect::<String>()
}
serde_json::Value::Object({
let mut json = serde_json::Map::new();
$($(
@@ -468,7 +471,7 @@ macro_rules! make_config {
for mask_key in PRIVACY_CONFIG {
if let Some(value) = json.get_mut(*mask_key) {
if let Some(s) = value.as_str() {
*value = privacy_mask(s).into();
*value = _privacy_mask(s).into();
}
}
}
@@ -502,23 +505,23 @@ macro_rules! make_config {
make_config! {
folders {
/// Data folder |> Main data folder
data_folder: String, false, def, "data".to_owned();
data_folder: String, false, def, "data".to_string();
/// Database URL
database_url: String, false, auto, |c| format!("sqlite://{}", storage::join_path(&c.data_folder, "db.sqlite3"));
database_url: String, false, auto, |c| format!("{}/db.sqlite3", c.data_folder);
/// Icon cache folder
icon_cache_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "icon_cache");
icon_cache_folder: String, false, auto, |c| format!("{}/icon_cache", c.data_folder);
/// Attachments folder
attachments_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "attachments");
attachments_folder: String, false, auto, |c| format!("{}/attachments", c.data_folder);
/// Sends folder
sends_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "sends");
sends_folder: String, false, auto, |c| format!("{}/sends", c.data_folder);
/// Temp folder |> Used for storing temporary file uploads
tmp_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "tmp");
tmp_folder: String, false, auto, |c| format!("{}/tmp", c.data_folder);
/// Templates folder
templates_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "templates");
templates_folder: String, false, auto, |c| format!("{}/templates", c.data_folder);
/// Session JWT key
rsa_key_filename: String, false, auto, |c| storage::join_path(&c.data_folder, "rsa_key");
rsa_key_filename: String, false, auto, |c| format!("{}/rsa_key", c.data_folder);
/// Web vault folder
web_vault_folder: String, false, def, "web-vault/".to_owned();
web_vault_folder: String, false, def, "web-vault/".to_string();
},
ws {
/// Enable websocket notifications
@@ -528,9 +531,9 @@ make_config! {
/// Enable push notifications
push_enabled: bool, false, def, false;
/// Push relay uri
push_relay_uri: String, false, def, "https://push.bitwarden.com".to_owned();
push_relay_uri: String, false, def, "https://push.bitwarden.com".to_string();
/// Push identity uri
push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_owned();
push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_string();
/// Installation id |> The installation id from https://bitwarden.com/host
push_installation_id: Pass, false, def, String::new();
/// Installation key |> The installation key from https://bitwarden.com/host
@@ -542,38 +545,38 @@ make_config! {
job_poll_interval_ms: u64, false, def, 30_000;
/// Send purge schedule |> Cron schedule of the job that checks for Sends past their deletion date.
/// Defaults to hourly. Set blank to disable this job.
send_purge_schedule: String, false, def, "0 5 * * * *".to_owned();
send_purge_schedule: String, false, def, "0 5 * * * *".to_string();
/// Trash purge schedule |> Cron schedule of the job that checks for trashed items to delete permanently.
/// Defaults to daily. Set blank to disable this job.
trash_purge_schedule: String, false, def, "0 5 0 * * *".to_owned();
trash_purge_schedule: String, false, def, "0 5 0 * * *".to_string();
/// Incomplete 2FA login schedule |> Cron schedule of the job that checks for incomplete 2FA logins.
/// Defaults to once every minute. Set blank to disable this job.
incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_owned();
incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_string();
/// Emergency notification reminder schedule |> Cron schedule of the job that sends expiration reminders to emergency access grantors.
/// Defaults to hourly. (3 minutes after the hour) Set blank to disable this job.
emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_owned();
emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_string();
/// Emergency request timeout schedule |> Cron schedule of the job that grants emergency access requests that have met the required wait time.
/// Defaults to hourly. (7 minutes after the hour) Set blank to disable this job.
emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_owned();
emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_string();
/// Event cleanup schedule |> Cron schedule of the job that cleans old events from the event table.
/// Defaults to daily. Set blank to disable this job.
event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_owned();
event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_string();
/// Auth Request cleanup schedule |> Cron schedule of the job that cleans old auth requests from the auth request.
/// Defaults to every minute. Set blank to disable this job.
auth_request_purge_schedule: String, false, def, "30 * * * * *".to_owned();
auth_request_purge_schedule: String, false, def, "30 * * * * *".to_string();
/// Duo Auth context cleanup schedule |> Cron schedule of the job that cleans expired Duo contexts from the database. Does nothing if Duo MFA is disabled or set to use the legacy iframe prompt.
/// Defaults to once every minute. Set blank to disable this job.
duo_context_purge_schedule: String, false, def, "30 * * * * *".to_owned();
duo_context_purge_schedule: String, false, def, "30 * * * * *".to_string();
/// Purge incomplete SSO auth. |> Cron schedule of the job that cleans leftover auth in db due to incomplete SSO login.
/// Defaults to daily. Set blank to disable this job.
purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_owned();
purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_string();
},
/// General settings
settings {
/// Domain URL |> This needs to be set to the URL used to access the server, including 'http[s]://'
/// and port, if it's different than the default. Some server functions don't work correctly without this value
domain: String, true, def, "http://localhost".to_owned();
domain: String, true, def, "http://localhost".to_string();
/// Domain Set |> Indicates if the domain is set by the admin. Otherwise the default will be used.
domain_set: bool, false, def, false;
/// Domain origin |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin)
@@ -653,7 +656,7 @@ make_config! {
admin_token: Pass, true, option;
/// Invitation organization name |> Name shown in the invitation emails that don't come from a specific organization
invitation_org_name: String, true, def, "Vaultwarden".to_owned();
invitation_org_name: String, true, def, "Vaultwarden".to_string();
/// Events days retain |> Number of days to retain events stored in the database. If unset, events are kept indefinitely.
events_days_retain: i64, false, option;
@@ -663,7 +666,7 @@ make_config! {
advanced {
/// Client IP header |> If not present, the remote IP is used.
/// Set to the string "none" (without quotes), to disable any headers and just use the remote IP
ip_header: String, true, def, "X-Real-IP".to_owned();
ip_header: String, true, def, "X-Real-IP".to_string();
/// Internal IP header property, used to avoid recomputing each time
_ip_header_enabled: bool, false, generated, |c| &c.ip_header.trim().to_lowercase() != "none";
/// Icon service |> The predefined icon services are: internal, bitwarden, duckduckgo, google.
@@ -672,7 +675,7 @@ make_config! {
/// `internal` refers to Vaultwarden's built-in icon fetching implementation. If an external
/// service is set, an icon request to Vaultwarden will return an HTTP redirect to the
/// corresponding icon at the external service.
icon_service: String, false, def, "internal".to_owned();
icon_service: String, false, def, "internal".to_string();
/// _icon_service_url
_icon_service_url: String, false, generated, |c| generate_icon_service_url(&c.icon_service);
/// _icon_service_csp
@@ -723,14 +726,14 @@ make_config! {
/// Enable extended logging
extended_logging: bool, false, def, true;
/// Log timestamp format
log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_owned();
log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_string();
/// Enable the log to output to Syslog
use_syslog: bool, false, def, false;
/// Log file path
log_file: String, false, option;
/// Log level |> Valid values are "trace", "debug", "info", "warn", "error" and "off"
/// For a specific module append it as a comma separated value "info,path::to::module=debug"
log_level: String, false, def, "info".to_owned();
log_level: String, false, def, "info".to_string();
/// Enable DB WAL |> Turning this off might lead to worse performance, but might help if using vaultwarden on some exotic filesystems,
/// that do not support WAL. Please make sure you read project wiki on the topic before changing this setting.
@@ -812,7 +815,7 @@ make_config! {
/// Authority Server |> Base url of the OIDC provider discovery endpoint (without `/.well-known/openid-configuration`)
sso_authority: String, true, def, String::new();
/// Authorization request scopes |> List the of the needed scope (`openid` is implicit)
sso_scopes: String, true, def, "email profile".to_owned();
sso_scopes: String, true, def, "email profile".to_string();
/// Authorization request extra parameters
sso_authorize_extra_params: String, true, def, String::new();
/// Use PKCE during Authorization flow
@@ -880,7 +883,7 @@ make_config! {
/// From Address
smtp_from: String, true, def, String::new();
/// From Name
smtp_from_name: String, true, def, "Vaultwarden".to_owned();
smtp_from_name: String, true, def, "Vaultwarden".to_string();
/// Username
smtp_username: String, true, option;
/// Password
@@ -926,13 +929,10 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
{
use crate::db::DbConnType;
let url = &cfg.database_url;
if DbConnType::from_url(url)? == DbConnType::Sqlite {
let file_path = url.strip_prefix("sqlite://").unwrap_or(url);
if file_path.contains('/') {
let path = std::path::Path::new(file_path);
if let Some(parent) = path.parent()
&& !parent.is_dir()
{
if DbConnType::from_url(url)? == DbConnType::Sqlite && url.contains('/') {
let path = std::path::Path::new(&url);
if let Some(parent) = path.parent() {
if !parent.is_dir() {
err!(format!(
"SQLite database directory `{}` does not exist or is not a directory",
parent.display()
@@ -959,10 +959,10 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
err!(format!("`DATABASE_MIN_CONNS` must be smaller than or equal to `DATABASE_MAX_CONNS`.",));
}
if let Some(log_file) = &cfg.log_file
&& std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err()
{
err!("Unable to write to log file", log_file);
if let Some(log_file) = &cfg.log_file {
if std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err() {
err!("Unable to write to log file", log_file);
}
}
let dom = cfg.domain.to_lowercase();
@@ -975,9 +975,7 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
let connect_src = cfg.allowed_connect_src.to_lowercase();
for url in connect_src.split_whitespace() {
if !url.starts_with("https://") || Url::parse(url).is_err() {
err!(
"ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed"
);
err!("ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed");
}
}
@@ -993,12 +991,11 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
err!("`ORG_CREATION_USERS` contains invalid email addresses");
}
if let Some(ref token) = cfg.admin_token
&& token.trim().is_empty()
&& !cfg.disable_admin_token
{
println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled.");
println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`.");
if let Some(ref token) = cfg.admin_token {
if token.trim().is_empty() && !cfg.disable_admin_token {
println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled.");
println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`.");
}
}
if cfg.push_enabled && (cfg.push_installation_id == String::new() || cfg.push_installation_key == String::new()) {
@@ -1032,41 +1029,37 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
}
}
let invalid_flags = parse_experimental_client_feature_flags(
&cfg.experimental_client_feature_flags,
&FeatureFlagFilter::InvalidOnly,
);
let invalid_flags =
parse_experimental_client_feature_flags(&cfg.experimental_client_feature_flags, FeatureFlagFilter::InvalidOnly);
if !invalid_flags.is_empty() {
let feature_flags_error = format!(
"Unrecognized experimental client feature flags: {invalid_flags:?}.\n\
let feature_flags_error = format!("Unrecognized experimental client feature flags: {:?}.\n\
Please ensure all feature flags are spelled correctly and that they are supported in this version.\n\
Supported flags: {SUPPORTED_FEATURE_FLAGS:?}\n"
);
Supported flags: {:?}\n", invalid_flags, SUPPORTED_FEATURE_FLAGS);
if on_update {
err!(feature_flags_error);
} else {
println!("[WARNING] {feature_flags_error}");
}
println!("[WARNING] {feature_flags_error}");
}
#[expect(clippy::items_after_statements, reason = "Keep this close to where it is used")]
const MAX_FILESIZE_KB: i64 = i64::MAX >> 10;
if let Some(limit) = cfg.user_attachment_limit
&& !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`USER_ATTACHMENT_LIMIT` is out of bounds");
if let Some(limit) = cfg.user_attachment_limit {
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) {
err!("`USER_ATTACHMENT_LIMIT` is out of bounds");
}
}
if let Some(limit) = cfg.org_attachment_limit
&& !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`ORG_ATTACHMENT_LIMIT` is out of bounds");
if let Some(limit) = cfg.org_attachment_limit {
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) {
err!("`ORG_ATTACHMENT_LIMIT` is out of bounds");
}
}
if let Some(limit) = cfg.user_send_limit
&& !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`USER_SEND_LIMIT` is out of bounds");
if let Some(limit) = cfg.user_send_limit {
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) {
err!("`USER_SEND_LIMIT` is out of bounds");
}
}
if cfg._enable_duo
@@ -1083,7 +1076,7 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
validate_internal_sso_issuer_url(&cfg.sso_authority)?;
validate_internal_sso_redirect_url(&cfg.sso_callback_path)?;
validate_sso_master_password_policy(cfg.sso_master_password_policy.as_ref())?;
validate_sso_master_password_policy(&cfg.sso_master_password_policy)?;
}
if cfg._enable_yubico {
@@ -1094,9 +1087,7 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
if let Some(yubico_server) = &cfg.yubico_server {
let yubico_server = yubico_server.to_lowercase();
if !yubico_server.starts_with("https://") {
err!(
"`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL."
)
err!("`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL.")
}
}
}
@@ -1148,9 +1139,7 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
}
if cfg.smtp_username.is_some() != cfg.smtp_password.is_some() {
err!(
"Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`"
)
err!("Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`")
}
}
@@ -1282,7 +1271,7 @@ fn validate_internal_sso_redirect_url(sso_callback_path: &String) -> Result<open
}
fn validate_sso_master_password_policy(
sso_master_password_policy: Option<&String>,
sso_master_password_policy: &Option<String>,
) -> Result<Option<serde_json::Value>, Error> {
let policy = sso_master_password_policy.as_ref().map(|mpp| serde_json::from_str::<serde_json::Value>(mpp));
@@ -1311,7 +1300,7 @@ fn extract_url_origin(url: &str) -> String {
/// All trailing '/' chars are trimmed, even if the path is a lone '/'.
fn extract_url_path(url: &str) -> String {
match Url::parse(url) {
Ok(u) => u.path().trim_end_matches('/').to_owned(),
Ok(u) => u.path().trim_end_matches('/').to_string(),
Err(_) => {
// We already print it in the method above, no need to do it again
String::new()
@@ -1321,7 +1310,7 @@ fn extract_url_path(url: &str) -> String {
fn generate_smtp_img_src(embed_images: bool, domain: &str) -> String {
if embed_images {
"cid:".to_owned()
"cid:".to_string()
} else {
// normalize base_url
let base_url = domain.trim_end_matches('/');
@@ -1340,10 +1329,10 @@ fn generate_sso_callback_path(domain: &str) -> String {
fn generate_icon_service_url(icon_service: &str) -> String {
match icon_service {
"internal" => String::new(),
"bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_owned(),
"duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_owned(),
"google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_owned(),
_ => icon_service.to_owned(),
"bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_string(),
"duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_string(),
"google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_string(),
_ => icon_service.to_string(),
}
}
@@ -1352,7 +1341,7 @@ fn generate_icon_service_csp(icon_service: &str, icon_service_url: &str) -> Stri
// We split on the first '{', since that is the variable delimiter for an icon service URL.
// Everything up until the first '{' should be fixed and can be used as an CSP string.
let csp_string = match icon_service_url.split_once('{') {
Some((c, _)) => c.to_owned(),
Some((c, _)) => c.to_string(),
None => String::new(),
};
@@ -1369,12 +1358,96 @@ fn smtp_convert_deprecated_ssl_options(smtp_ssl: Option<bool>, smtp_explicit_tls
println!("[DEPRECATED]: `SMTP_SSL` or `SMTP_EXPLICIT_TLS` is set. Please use `SMTP_SECURITY` instead.");
}
if smtp_explicit_tls.is_some() && smtp_explicit_tls.unwrap() {
return "force_tls".to_owned();
return "force_tls".to_string();
} else if smtp_ssl.is_some() && !smtp_ssl.unwrap() {
return "off".to_owned();
return "off".to_string();
}
// Return the default `starttls` in all other cases
"starttls".to_owned()
"starttls".to_string()
}
fn opendal_operator_for_path(path: &str) -> Result<opendal::Operator, Error> {
// Cache of previously built operators by path
static OPERATORS_BY_PATH: LazyLock<dashmap::DashMap<String, opendal::Operator>> =
LazyLock::new(dashmap::DashMap::new);
if let Some(operator) = OPERATORS_BY_PATH.get(path) {
return Ok(operator.clone());
}
let operator = if path.starts_with("s3://") {
#[cfg(not(s3))]
return Err(opendal::Error::new(opendal::ErrorKind::ConfigInvalid, "S3 support is not enabled").into());
#[cfg(s3)]
opendal_s3_operator_for_path(path)?
} else {
let builder = opendal::services::Fs::default().root(path);
opendal::Operator::new(builder)?.finish()
};
OPERATORS_BY_PATH.insert(path.to_string(), operator.clone());
Ok(operator)
}
#[cfg(s3)]
fn opendal_s3_operator_for_path(path: &str) -> Result<opendal::Operator, Error> {
use crate::http_client::aws::AwsReqwestConnector;
use aws_config::{default_provider::credentials::DefaultCredentialsChain, provider_config::ProviderConfig};
// This is a custom AWS credential loader that uses the official AWS Rust
// SDK config crate to load credentials. This ensures maximum compatibility
// with AWS credential configurations. For example, OpenDAL doesn't support
// AWS SSO temporary credentials yet.
struct OpenDALS3CredentialLoader {}
#[async_trait]
impl reqsign::AwsCredentialLoad for OpenDALS3CredentialLoader {
async fn load_credential(&self, _client: reqwest::Client) -> anyhow::Result<Option<reqsign::AwsCredential>> {
use aws_credential_types::provider::ProvideCredentials as _;
use tokio::sync::OnceCell;
static DEFAULT_CREDENTIAL_CHAIN: OnceCell<DefaultCredentialsChain> = OnceCell::const_new();
let chain = DEFAULT_CREDENTIAL_CHAIN
.get_or_init(|| {
let reqwest_client = reqwest::Client::builder().build().unwrap();
let connector = AwsReqwestConnector {
client: reqwest_client,
};
let conf = ProviderConfig::default().with_http_client(connector);
DefaultCredentialsChain::builder().configure(conf).build()
})
.await;
let creds = chain.provide_credentials().await?;
Ok(Some(reqsign::AwsCredential {
access_key_id: creds.access_key_id().to_string(),
secret_access_key: creds.secret_access_key().to_string(),
session_token: creds.session_token().map(|s| s.to_string()),
expires_in: creds.expiry().map(|expiration| expiration.into()),
}))
}
}
const OPEN_DAL_S3_CREDENTIAL_LOADER: OpenDALS3CredentialLoader = OpenDALS3CredentialLoader {};
let url = Url::parse(path).map_err(|e| format!("Invalid path S3 URL path {path:?}: {e}"))?;
let bucket = url.host_str().ok_or_else(|| format!("Missing Bucket name in data folder S3 URL {path:?}"))?;
let builder = opendal::services::S3::default()
.customized_credential_load(Box::new(OPEN_DAL_S3_CREDENTIAL_LOADER))
.enable_virtual_host_style()
.bucket(bucket)
.root(url.path())
.default_storage_class("INTELLIGENT_TIERING");
Ok(opendal::Operator::new(builder)?.finish())
}
pub enum PathType {
@@ -1417,12 +1490,12 @@ pub const SUPPORTED_FEATURE_FLAGS: &[&str] = &[
impl Config {
pub async fn load() -> Result<Self, Error> {
// Loading from env and file
let env = ConfigBuilder::from_env();
let usr = ConfigBuilder::from_file().await.unwrap_or_default();
let _env = ConfigBuilder::from_env();
let _usr = ConfigBuilder::from_file().await.unwrap_or_default();
// Create merged config, config file overwrites env
let mut overrides = Vec::new();
let builder = env.merge(&usr, true, &mut overrides);
let mut _overrides = Vec::new();
let builder = _env.merge(&_usr, true, &mut _overrides);
// Fill any missing with defaults
let config = builder.build();
@@ -1435,9 +1508,9 @@ impl Config {
rocket_shutdown_handle: None,
templates: load_templates(&config.templates_folder),
config,
_env: env,
_usr: usr,
_overrides: overrides,
_env,
_usr,
_overrides,
}),
})
}
@@ -1474,7 +1547,7 @@ impl Config {
}
//Save to file
let operator = storage::operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let operator = opendal_operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
operator.write(&CONFIG_FILENAME, config_str).await?;
Ok(())
@@ -1483,8 +1556,8 @@ impl Config {
async fn update_config_partial(&self, other: ConfigBuilder) -> Result<(), Error> {
let builder = {
let usr = &self.inner.read().unwrap()._usr;
let mut overrides = Vec::new();
usr.merge(&other, false, &mut overrides)
let mut _overrides = Vec::new();
usr.merge(&other, false, &mut _overrides)
};
self.update_config(builder, false).await
}
@@ -1507,11 +1580,11 @@ impl Config {
/// Tests whether signup is allowed for an email address, taking into
/// account the signups_allowed and signups_domains_whitelist settings.
pub fn is_signup_allowed(&self, email: &str) -> bool {
if self.signups_domains_whitelist().is_empty() {
self.signups_allowed()
} else {
if !self.signups_domains_whitelist().is_empty() {
// The whitelist setting overrides the signups_allowed setting.
self.is_email_domain_allowed(email)
} else {
self.signups_allowed()
}
}
@@ -1539,7 +1612,7 @@ impl Config {
}
pub async fn delete_user_config(&self) -> Result<(), Error> {
let operator = storage::operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let operator = opendal_operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
operator.delete(&CONFIG_FILENAME).await?;
// Empty user config
@@ -1563,7 +1636,7 @@ impl Config {
}
pub fn private_rsa_key(&self) -> String {
storage::with_extension(&self.rsa_key_filename(), "pem")
format!("{}.pem", self.rsa_key_filename())
}
pub fn mail_enabled(&self) -> bool {
let inner = &self.inner.read().unwrap().config;
@@ -1604,11 +1677,15 @@ impl Config {
PathType::IconCache => self.icon_cache_folder(),
PathType::Attachments => self.attachments_folder(),
PathType::Sends => self.sends_folder(),
PathType::RsaKey => storage::parent(&self.private_rsa_key())
.ok_or_else(|| std::io::Error::other("Failed to get directory of RSA key file"))?,
PathType::RsaKey => std::path::Path::new(&self.rsa_key_filename())
.parent()
.ok_or_else(|| std::io::Error::other("Failed to get directory of RSA key file"))?
.to_str()
.ok_or_else(|| std::io::Error::other("Failed to convert RSA key file directory to UTF-8 string"))?
.to_string(),
};
storage::operator_for_path(&path)
opendal_operator_for_path(&path)
}
pub fn render_template<T: serde::ser::Serialize>(&self, name: &str, data: &T) -> Result<String, Error> {
@@ -1632,10 +1709,10 @@ impl Config {
}
pub fn shutdown(&self) {
if let Ok(mut c) = self.inner.write()
&& let Some(handle) = c.rocket_shutdown_handle.take()
{
handle.notify();
if let Ok(mut c) = self.inner.write() {
if let Some(handle) = c.rocket_shutdown_handle.take() {
handle.notify();
}
}
}
@@ -1648,11 +1725,11 @@ impl Config {
}
pub fn sso_master_password_policy_value(&self) -> Option<serde_json::Value> {
validate_sso_master_password_policy(self.sso_master_password_policy().as_ref()).ok().flatten()
validate_sso_master_password_policy(&self.sso_master_password_policy()).ok().flatten()
}
pub fn sso_scopes_vec(&self) -> Vec<String> {
self.sso_scopes().split_whitespace().map(str::to_owned).collect()
self.sso_scopes().split_whitespace().map(str::to_string).collect()
}
pub fn sso_authorize_extra_params_vec(&self) -> Vec<(String, String)> {
@@ -1762,7 +1839,7 @@ fn case_helper<'reg, 'rc>(
let value = param.value().clone();
if h.params().iter().skip(1).any(|x| x.value() == &value) {
h.template().map_or(Ok(()), |t| t.render(r, ctx, rc, out))
h.template().map(|t| t.render(r, ctx, rc, out)).unwrap_or_else(|| Ok(()))
} else {
Ok(())
}
-7
View File
@@ -113,10 +113,3 @@ pub fn ct_eq<T: AsRef<[u8]>, U: AsRef<[u8]>>(a: T, b: U) -> bool {
use subtle::ConstantTimeEq;
a.as_ref().ct_eq(b.as_ref()).into()
}
//
// SHA256
//
pub fn sha256_hex(data: &[u8]) -> String {
HEXLOWER.encode(digest::digest(&digest::SHA256, data).as_ref())
}
+21 -38
View File
@@ -6,23 +6,25 @@ use std::{
};
use diesel::{
Connection, RunQueryDsl,
connection::SimpleConnection,
r2d2::{CustomizeConnection, Pool, PooledConnection},
Connection, RunQueryDsl,
};
use rocket::{
Request,
http::Status,
request::{FromRequest, Outcome},
Request,
};
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::timeout,
};
use crate::{
CONFIG,
error::{Error, MapResult},
CONFIG,
};
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
@@ -60,7 +62,7 @@ pub struct DbConnManager {
impl DbConnManager {
pub fn new(database_url: &str) -> Self {
Self {
database_url: database_url.to_owned(),
database_url: database_url.to_string(),
}
}
@@ -222,7 +224,7 @@ impl DbPool {
// Set a global to determine the database more easily throughout the rest of the code
if ACTIVE_DB_TYPE.set(conn_type).is_err() {
error!("Tried to set the active database connection type more than once.");
error!("Tried to set the active database connection type more than once.")
}
Ok(DbPool {
@@ -270,40 +272,22 @@ impl DbConnType {
#[cfg(not(postgresql))]
err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled")
// Sqlite (explicit)
} else if url.len() > 7 && &url[..7] == "sqlite:" {
//Sqlite
} else {
#[cfg(sqlite)]
return Ok(DbConnType::Sqlite);
#[cfg(not(sqlite))]
err!("`DATABASE_URL` is a SQLite URL, but the 'sqlite' feature is not enabled")
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
}
// No recognized scheme — assume legacy bare-path SQLite, but the database file must already exist.
// This prevents misconfigured URLs (typos, quoted strings) from silently creating a new empty SQLite database.
#[cfg(sqlite)]
{
if std::path::Path::new(url).exists() {
return Ok(DbConnType::Sqlite);
}
err!(format!(
"`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://) \
and no existing SQLite database was found at '{url}'. \
If you intend to use SQLite, use an explicit `sqlite://` scheme in your `DATABASE_URL`. \
Otherwise, check your DATABASE_URL for typos or quoting issues."
))
}
#[cfg(not(sqlite))]
err!("`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://)")
}
pub fn get_init_stmts(&self) -> String {
let init_stmts = CONFIG.database_conn_init();
if init_stmts.is_empty() {
self.default_init_stmts()
} else {
if !init_stmts.is_empty() {
init_stmts
} else {
self.default_init_stmts()
}
}
@@ -314,7 +298,7 @@ impl DbConnType {
#[cfg(postgresql)]
Self::Postgresql => String::new(),
#[cfg(sqlite)]
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_owned(),
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
}
}
}
@@ -405,13 +389,12 @@ pub fn backup_sqlite() -> Result<String, Error> {
use diesel::Connection;
let db_url = CONFIG.database_url();
if DbConnType::from_url(&CONFIG.database_url()).is_ok_and(|t| t == DbConnType::Sqlite) {
// Strip the sqlite:// prefix if present to get the raw file path
let file_path = db_url.strip_prefix("sqlite://").unwrap_or(&db_url);
// Open a read-only connection for the backup
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{file_path}?mode=ro"))?;
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) {
// Since we do not allow any schema for sqlite database_url's like `file:` or `sqlite:` to be set, we can assume here it isn't
// This way we can set a readonly flag on the opening mode without issues.
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{db_url}?mode=ro"))?;
let db_path = std::path::Path::new(file_path).parent().unwrap();
let db_path = std::path::Path::new(&db_url).parent().unwrap();
let backup_file = db_path
.join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S")))
.to_string_lossy()
@@ -440,12 +423,12 @@ pub async fn get_sql_server_version(conn: &DbConn) -> String {
postgresql,mysql {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_owned())
.unwrap_or_else(|_| "Unknown".to_string())
}
sqlite {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_owned())
.unwrap_or_else(|_| "Unknown".to_string())
}
}
}
-95
View File
@@ -1,95 +0,0 @@
use chrono::NaiveDateTime;
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::archives},
error::MapResult,
};
use super::{CipherId, User, UserId};
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = archives)]
#[diesel(primary_key(user_uuid, cipher_uuid))]
pub struct Archive {
pub user_uuid: UserId,
pub cipher_uuid: CipherId,
pub archived_at: NaiveDateTime,
}
impl Archive {
// Returns the date the specified cipher was archived
pub async fn get_archived_at(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> Option<NaiveDateTime> {
conn.run(move |conn| {
archives::table
.filter(archives::cipher_uuid.eq(cipher_uuid))
.filter(archives::user_uuid.eq(user_uuid))
.select(archives::archived_at)
.first::<NaiveDateTime>(conn)
.ok()
})
.await
}
// Saves (inserts or updates) an archive record with the provided timestamp
pub async fn save(
user_uuid: &UserId,
cipher_uuid: &CipherId,
archived_at: NaiveDateTime,
conn: &DbConn,
) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn:
sqlite, mysql {
diesel::replace_into(archives::table)
.values((
archives::user_uuid.eq(user_uuid),
archives::cipher_uuid.eq(cipher_uuid),
archives::archived_at.eq(archived_at),
))
.execute(conn)
.map_res("Error saving archive")
}
postgresql {
diesel::insert_into(archives::table)
.values((
archives::user_uuid.eq(user_uuid),
archives::cipher_uuid.eq(cipher_uuid),
archives::archived_at.eq(archived_at),
))
.on_conflict((archives::user_uuid, archives::cipher_uuid))
.do_update()
.set(archives::archived_at.eq(archived_at))
.execute(conn)
.map_res("Error saving archive")
}
}
}
// Deletes an archive record for a specific cipher
pub async fn delete_by_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await;
conn.run(move |conn| {
diesel::delete(
archives::table.filter(archives::user_uuid.eq(user_uuid)).filter(archives::cipher_uuid.eq(cipher_uuid)),
)
.execute(conn)
.map_res("Error deleting archive")
})
.await
}
/// Return a vec with (cipher_uuid, archived_at)
/// This is used during a full sync so we only need one query for all archive matches
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, NaiveDateTime)> {
conn.run(move |conn| {
archives::table
.filter(archives::user_uuid.eq(user_uuid))
.select((archives::cipher_uuid, archives::archived_at))
.load::<(CipherId, NaiveDateTime)>(conn)
.unwrap_or_default()
})
.await
}
}
+37 -44
View File
@@ -1,24 +1,13 @@
use std::time::Duration;
use bigdecimal::{BigDecimal, ToPrimitive};
use derive_more::{AsRef, Deref, Display};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
auth::{encode_jwt, generate_file_download_claims},
config::PathType,
db::{
DbConn,
schema::{attachments, ciphers},
},
error::MapResult,
};
use macros::IdFromParam;
use std::time::Duration;
use super::{CipherId, OrganizationId, UserId};
use crate::db::schema::{attachments, ciphers};
use crate::{config::PathType, CONFIG};
use macros::IdFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = attachments)]
@@ -57,11 +46,11 @@ impl Attachment {
pub async fn get_url(&self, host: &str) -> Result<String, crate::Error> {
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Attachments)?;
if crate::storage::is_fs_operator(&operator) {
if operator.info().scheme() == <&'static str>::from(opendal::Scheme::Fs) {
let token = encode_jwt(&generate_file_download_claims(self.cipher_uuid.clone(), self.id.clone()));
Ok(format!("{host}/attachments/{}/{}?token={token}", self.cipher_uuid, self.id))
} else {
Ok(operator.presign_read(&self.get_file_path(), Duration::from_mins(5)).await?.uri().to_string())
Ok(operator.presign_read(&self.get_file_path(), Duration::from_secs(5 * 60)).await?.uri().to_string())
}
}
@@ -78,6 +67,12 @@ impl Attachment {
}
}
use crate::auth::{encode_jwt, generate_file_download_claims};
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Attachment {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@@ -112,15 +107,15 @@ impl Attachment {
}
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
crate::util::retry(
|| diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
db_run! { conn: {
crate::util::retry(||
diesel::delete(attachments::table.filter(attachments::id.eq(&self.id)))
.execute(conn),
10,
)
.map(|_| ())
.map_res("Error deleting attachment")
})
.await?;
}}?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Attachments)?;
let file_path = self.get_file_path();
@@ -144,22 +139,25 @@ impl Attachment {
}
pub async fn find_by_id(id: &AttachmentId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| attachments::table.filter(attachments::id.eq(id.to_lowercase())).first::<Self>(conn).ok())
.await
db_run! { conn: {
attachments::table
.filter(attachments::id.eq(id.to_lowercase()))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
attachments::table
.filter(attachments::cipher_uuid.eq(cipher_uuid))
.load::<Self>(conn)
.expect("Error loading attachments")
})
.await
}}
}
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid))
@@ -170,26 +168,24 @@ impl Attachment {
match result.map(|r| r.to_i64()) {
Some(Some(r)) => r,
Some(None) => i64::MAX,
None => 0,
None => 0
}
})
.await
}}
}
pub async fn count_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid))
.count()
.first(conn)
.unwrap_or(0)
})
.await
}}
}
pub async fn size_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::organization_uuid.eq(org_uuid))
@@ -200,22 +196,20 @@ impl Attachment {
match result.map(|r| r.to_i64()) {
Some(Some(r)) => r,
Some(None) => i64::MAX,
None => 0,
None => 0
}
})
.await
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::organization_uuid.eq(org_uuid))
.count()
.first(conn)
.unwrap_or(0)
})
.await
}}
}
// This will return all attachments linked to the user or org
@@ -226,7 +220,7 @@ impl Attachment {
org_uuids: &Vec<OrganizationId>,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid))
@@ -234,8 +228,7 @@ impl Attachment {
.select(attachments::all_columns)
.load::<Self>(conn)
.expect("Error loading attachments")
})
.await
}}
}
}
+25 -27
View File
@@ -1,18 +1,11 @@
use super::{DeviceId, OrganizationId, UserId};
use crate::db::schema::auth_requests;
use crate::{crypto::ct_eq, util::format_date};
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
api::EmptyResult,
crypto::ct_eq,
db::{DbConn, schema::auth_requests},
error::MapResult,
util::format_date,
};
use macros::UuidFromParam;
use super::{DeviceId, OrganizationId, UserId};
use serde_json::Value;
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)]
#[diesel(table_name = auth_requests)]
@@ -81,6 +74,11 @@ impl AuthRequest {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
impl AuthRequest {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
@@ -114,28 +112,31 @@ impl AuthRequest {
}
pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| auth_requests::table.filter(auth_requests::uuid.eq(uuid)).first::<Self>(conn).ok()).await
db_run! { conn: {
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.filter(auth_requests::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading auth_requests")
})
.await
}}
}
pub async fn find_by_user_and_requested_device(
@@ -143,7 +144,7 @@ impl AuthRequest {
device_uuid: &DeviceId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid))
.filter(auth_requests::request_device_identifier.eq(device_uuid))
@@ -151,27 +152,24 @@ impl AuthRequest {
.order_by(auth_requests::creation_date.desc())
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_created_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::creation_date.lt(dt))
.load::<Self>(conn)
.expect("Error loading auth_requests")
})
.await
}}
}
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting auth request")
})
.await
}}
}
pub fn check_access_code(&self, access_code: &str) -> bool {
+341 -364
View File
File diff suppressed because it is too large Load Diff
+309 -402
View File
@@ -1,25 +1,16 @@
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections,
users_organizations,
},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{
CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId,
User, UserId,
};
use crate::db::schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections, users_organizations,
};
use crate::CONFIG;
use diesel::prelude::*;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = collections)]
@@ -83,7 +74,7 @@ impl Collection {
if external_id.is_empty() {
self.external_id = None;
} else {
self.external_id = Some(external_id);
self.external_id = Some(external_id)
}
}
None => self.external_id = None,
@@ -156,6 +147,11 @@ impl Collection {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Collection {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@@ -197,12 +193,11 @@ impl Collection {
CollectionUser::delete_all_by_collection(&self.uuid, conn).await?;
CollectionGroup::delete_all_by_collection(&self.uuid, &self.org_uuid, conn).await?;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting collection")
})
.await
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -213,90 +208,90 @@ impl Collection {
}
pub async fn update_users_revision(&self, conn: &DbConn) {
for member in &Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await {
for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() {
User::update_uuid_revision(&member.user_uuid, conn).await;
}
}
pub async fn find_by_uuid(uuid: &CollectionId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| collections::table.filter(collections::uuid.eq(uuid)).first::<Self>(conn).ok()).await
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_user_uuid(user_uuid: UserId, conn: &DbConn) -> Vec<Self> {
if CONFIG.org_groups_enabled() {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid.clone())
)
.left_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
)
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(
users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true) // access_all in Organization
).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null()
)
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(
users_collections::user_uuid
.eq(user_uuid)
.or(
// Directly accessed collection
users_organizations::access_all.eq(true), // access_all in Organization
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid
.eq(users_organizations::uuid)
.and(collections_groups::collections_uuid.is_not_null()),
),
)
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
})
.await
)
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
}}
} else {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid.clone())
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(users_collections::user_uuid.eq(user_uuid).or(
// Directly accessed collection
users_organizations::access_all.eq(true), // access_all in Organization
))
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
})
.await
))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(
users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true) // access_all in Organization
)
)
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
}}
}
}
@@ -313,311 +308,256 @@ impl Collection {
}
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.filter(collections::org_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading collections")
})
.await
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
collections::table.filter(collections::org_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
})
.await
db_run! { conn: {
collections::table
.filter(collections::org_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
}
pub async fn find_by_uuid_and_org(uuid: &CollectionId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(uuid))
.filter(collections::org_uuid.eq(org_uuid))
.select(collections::all_columns)
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &DbConn) -> Option<Self> {
if CONFIG.org_groups_enabled() {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
)
.left_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
)
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
))
.filter(collections::uuid.eq(uuid))
.filter(
users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null()
)
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(collections::uuid.eq(uuid))
.filter(
users_collections::collection_uuid
.eq(uuid)
.or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid
.eq(users_organizations::uuid)
.and(collections_groups::collections_uuid.is_not_null()),
),
)
.select(collections::all_columns)
.first::<Self>(conn)
.ok()
})
.await
).select(collections::all_columns)
.first::<Self>(conn)
.ok()
}}
} else {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
)
.filter(collections::uuid.eq(uuid))
.filter(users_collections::collection_uuid.eq(uuid).or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
))
.filter(collections::uuid.eq(uuid))
.filter(
users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
))
.select(collections::all_columns)
.first::<Self>(conn)
.ok()
})
.await
).select(collections::all_columns)
.first::<Self>(conn)
.ok()
}}
}
}
pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
if CONFIG.org_groups_enabled() {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(&self.uuid))
.inner_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))),
)
.left_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(
users_organizations::atype
.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid
.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // access_all via group
.or(collections_groups::collections_uuid
.is_not_null() // write access given via group
.and(collections_groups::read_only.eq(false))),
.inner_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))
))
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))
))
.filter(users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // access_all via group
.or(collections_groups::collections_uuid.is_not_null() // write access given via group
.and(collections_groups::read_only.eq(false)))
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
.unwrap_or(0) != 0
}}
} else {
conn.run(move |conn| {
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(&self.uuid))
.inner_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))),
)
.filter(
users_organizations::atype
.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid
.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false))),
.inner_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))
))
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))
))
.filter(users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false)))
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
.unwrap_or(0) != 0
}}
}
}
pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
conn.run(move |conn| {
db_run! { conn: {
collections::table
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
)
.left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
))
.filter(collections::uuid.eq(&self.uuid))
.filter(
users_collections::collection_uuid.eq(&self.uuid).and(users_collections::hide_passwords.eq(true)).or(// Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null().and(
collections_groups::hide_passwords.eq(true))
)
)
.filter(collections::uuid.eq(&self.uuid))
.filter(
users_collections::collection_uuid
.eq(&self.uuid)
.and(users_collections::hide_passwords.eq(true))
.or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid
.is_not_null()
.and(collections_groups::hide_passwords.eq(true)),
),
),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
}
pub async fn is_coll_manageable_by_user(uuid: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
let uuid = uuid.to_string();
let user_uuid = user_uuid.to_string();
conn.run(move |conn| {
db_run! { conn: {
collections::table
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
)
.left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
))
.filter(collections::uuid.eq(&uuid))
.filter(
users_collections::collection_uuid.eq(&uuid).and(users_collections::manage.eq(true)).or(// Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null().and(
collections_groups::manage.eq(true))
)
)
.filter(collections::uuid.eq(&uuid))
.filter(
users_collections::collection_uuid
.eq(&uuid)
.and(users_collections::manage.eq(true))
.or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid
.is_not_null()
.and(collections_groups::manage.eq(true)),
),
),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
}
pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
@@ -632,7 +572,7 @@ impl CollectionUser {
user_uuid: &UserId,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid))
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
@@ -640,35 +580,24 @@ impl CollectionUser {
.select(users_collections::all_columns)
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await
}}
}
pub async fn find_by_organization_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Vec<CollectionMembership> {
let col_users = conn
.run(move |conn| {
users_collections::table
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid))
.inner_join(
users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)),
)
.filter(users_organizations::org_uuid.eq(org_uuid))
.select((
users_organizations::uuid,
users_collections::collection_uuid,
users_collections::read_only,
users_collections::hide_passwords,
users_collections::manage,
))
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await;
col_users.into_iter().map(Into::into).collect()
let col_users = db_run! { conn: {
users_collections::table
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.filter(users_organizations::org_uuid.eq(org_uuid))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<Self>(conn)
.expect("Error loading users_collections")
}};
col_users.into_iter().map(|c| c.into()).collect()
}
pub async fn save(
@@ -737,7 +666,7 @@ impl CollectionUser {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(
users_collections::table
.filter(users_collections::user_uuid.eq(&self.user_uuid))
@@ -745,19 +674,17 @@ impl CollectionUser {
)
.execute(conn)
.map_res("Error removing user from collection")
})
.await
}}
}
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.select(users_collections::all_columns)
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await
}}
}
pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid(
@@ -765,26 +692,16 @@ impl CollectionUser {
collection_uuid: &CollectionId,
conn: &DbConn,
) -> Vec<CollectionMembership> {
let col_users = conn
.run(move |conn| {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(
users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)),
)
.select((
users_organizations::uuid,
users_collections::collection_uuid,
users_collections::read_only,
users_collections::hide_passwords,
users_collections::manage,
))
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await;
col_users.into_iter().map(Into::into).collect()
let col_users = db_run! { conn: {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<Self>(conn)
.expect("Error loading users_collections")
}};
col_users.into_iter().map(|c| c.into()).collect()
}
pub async fn find_by_collection_and_user(
@@ -792,39 +709,36 @@ impl CollectionUser {
user_uuid: &UserId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await
}}
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
for collection in &CollectionUser::find_by_collection(collection_uuid, conn).await {
for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() {
User::update_uuid_revision(&collection.user_uuid, conn).await;
}
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid)))
.execute(conn)
.map_res("Error deleting users from collection")
})
.await
}}
}
pub async fn delete_all_by_user_and_org(
@@ -834,21 +748,17 @@ impl CollectionUser {
) -> EmptyResult {
let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await;
conn.run(move |conn| {
db_run! { conn: {
for user in collectionusers {
let _: () = diesel::delete(
users_collections::table.filter(
users_collections::user_uuid
.eq(user_uuid)
.and(users_collections::collection_uuid.eq(user.collection_uuid)),
),
)
.execute(conn)
.map_res("Error removing user from collections")?;
let _: () = diesel::delete(users_collections::table.filter(
users_collections::user_uuid.eq(user_uuid)
.and(users_collections::collection_uuid.eq(user.collection_uuid))
))
.execute(conn)
.map_res("Error removing user from collections")?;
}
Ok(())
})
.await
}}
}
pub async fn has_access_to_collection_by_user(col_id: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
@@ -891,7 +801,7 @@ impl CollectionCipher {
pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(
ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))
@@ -899,26 +809,23 @@ impl CollectionCipher {
)
.execute(conn)
.map_res("Error deleting cipher from collection")
})
.await
}}
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error removing cipher from collections")
})
.await
}}
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid)))
.execute(conn)
.map_res("Error removing ciphers from collection")
})
.await
}}
}
pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &DbConn) {
+45 -43
View File
@@ -1,20 +1,18 @@
use chrono::{NaiveDateTime, Utc};
use data_encoding::BASE64URL;
use derive_more::{Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{AuthRequest, UserId};
use crate::db::schema::devices;
use crate::{
api::EmptyResult,
crypto,
db::{DbConn, schema::devices},
error::MapResult,
util::{format_date, get_uuid},
};
use diesel::prelude::*;
use macros::{IdFromParam, UuidFromParam};
use super::{AuthRequest, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = devices)]
#[diesel(treat_none_as_null = true)]
@@ -27,7 +25,7 @@ pub struct Device {
pub user_uuid: UserId,
pub name: String,
pub atype: i32, // https://github.com/bitwarden/server/blob/8d547dcc280babab70dd4a3c94ced6a34b12dfbf/src/Core/Enums/DeviceType.cs
pub atype: i32, // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/Enums/DeviceType.cs
pub push_uuid: Option<PushId>,
pub push_token: Option<String>,
@@ -137,6 +135,10 @@ impl DeviceWithAuthRequest {
}
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Device {
@@ -169,23 +171,21 @@ impl Device {
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error removing devices for user")
})
.await
}}
}
pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
devices::table
.filter(devices::uuid.eq(uuid))
.filter(devices::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<DeviceWithAuthRequest> {
@@ -199,65 +199,71 @@ impl Device {
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
devices::table.filter(devices::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading devices")
})
.await
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading devices")
}}
}
pub async fn find_by_uuid(uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| devices::table.filter(devices::uuid.eq(uuid)).first::<Self>(conn).ok()).await
db_run! { conn: {
devices::table
.filter(devices::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::update(devices::table)
.filter(devices::uuid.eq(uuid))
.set(devices::push_token.eq::<Option<String>>(None))
.execute(conn)
.map_res("Error removing push token")
})
.await
}}
}
pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| devices::table.filter(devices::refresh_token.eq(refresh_token)).first::<Self>(conn).ok())
.await
db_run! { conn: {
devices::table
.filter(devices::refresh_token.eq(refresh_token))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.order(devices::updated_at.desc())
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.load::<Self>(conn)
.expect("Error loading push devices")
})
.await
}}
}
pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &DbConn) -> bool {
conn.run(move |conn| {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
}
pub async fn rotate_refresh_tokens_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@@ -326,12 +332,9 @@ pub enum DeviceType {
MacOsCLI = 24,
#[display("Linux CLI")]
LinuxCLI = 25,
#[display("DuckDuckGo")]
DuckDuckGoBrowser = 26,
}
impl DeviceType {
#[expect(clippy::match_same_arms, reason = "Specifically define 14 and have a fallback for new types")]
pub fn from_i32(value: i32) -> DeviceType {
match value {
0 => DeviceType::Android,
@@ -360,7 +363,6 @@ impl DeviceType {
23 => DeviceType::WindowsCLI,
24 => DeviceType::MacOsCLI,
25 => DeviceType::LinuxCLI,
26 => DeviceType::DuckDuckGoBrowser,
_ => DeviceType::UnknownBrowser,
}
}
+50 -71
View File
@@ -1,16 +1,12 @@
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
api::EmptyResult,
db::{DbConn, schema::emergency_access},
error::MapResult,
};
use macros::UuidFromParam;
use super::{User, UserId};
use crate::db::schema::emergency_access;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = emergency_access)]
@@ -89,15 +85,17 @@ impl EmergencyAccess {
pub async fn to_json_grantee_details(&self, conn: &DbConn) -> Option<Value> {
let grantee_user = if let Some(grantee_uuid) = &self.grantee_uuid {
User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.")
} else {
let email = self.email.as_deref()?;
if let Some(user) = User::find_by_mail(email, conn).await {
user
} else {
// remove outstanding invitations which should not exist
Self::delete_all_by_grantee_email(email, conn).await.ok();
return None;
} else if let Some(email) = self.email.as_deref() {
match User::find_by_mail(email, conn).await {
Some(user) => user,
None => {
// remove outstanding invitations which should not exist
Self::delete_all_by_grantee_email(email, conn).await.ok();
return None;
}
}
} else {
return None;
};
Some(json!({
@@ -186,36 +184,28 @@ impl EmergencyAccess {
self.status = status;
date.clone_into(&mut self.updated_at);
conn.run(move |conn| {
crate::util::retry(
|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::status.eq(status), emergency_access::updated_at.eq(date)))
.execute(conn)
},
10,
)
db_run! { conn: {
crate::util::retry(|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::status.eq(status), emergency_access::updated_at.eq(date)))
.execute(conn)
}, 10)
.map_res("Error updating emergency access status")
})
.await
}}
}
pub async fn update_last_notification_date_and_save(&mut self, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
self.last_notification_at = Some(date.to_owned());
date.clone_into(&mut self.updated_at);
conn.run(move |conn| {
crate::util::retry(
|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::last_notification_at.eq(date), emergency_access::updated_at.eq(date)))
.execute(conn)
},
10,
)
db_run! { conn: {
crate::util::retry(|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::last_notification_at.eq(date), emergency_access::updated_at.eq(date)))
.execute(conn)
}, 10)
.map_res("Error updating emergency access status")
})
.await
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@@ -238,12 +228,11 @@ impl EmergencyAccess {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(emergency_access::table.filter(emergency_access::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error removing user from emergency access")
})
.await
}}
}
pub async fn find_by_grantor_uuid_and_grantee_uuid_or_email(
@@ -252,25 +241,23 @@ impl EmergencyAccess {
email: &str,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email)))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_all_recoveries_initiated(conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32))
.filter(emergency_access::recovery_initiated_at.is_not_null())
.load::<Self>(conn)
.expect("Error loading emergency_access")
})
.await
}}
}
pub async fn find_by_uuid_and_grantor_uuid(
@@ -278,14 +265,13 @@ impl EmergencyAccess {
grantor_uuid: &UserId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_uuid_and_grantee_uuid(
@@ -293,14 +279,13 @@ impl EmergencyAccess {
grantee_uuid: &UserId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_uuid_and_grantee_email(
@@ -308,67 +293,61 @@ impl EmergencyAccess {
grantee_email: &str,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::email.eq(grantee_email))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.load::<Self>(conn)
.expect("Error loading emergency_access")
})
.await
}}
}
pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.load::<Self>(conn)
.expect("Error loading emergency_access")
})
.await
}}
}
pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.load::<Self>(conn)
.expect("Error loading emergency_access")
})
.await
}}
}
pub async fn find_all_confirmed_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::status.ge(EmergencyAccessStatus::Confirmed as i32))
.load::<Self>(conn)
.expect("Error loading emergency_access")
})
.await
}}
}
pub async fn accept_invite(&mut self, grantee_uuid: &UserId, grantee_email: &str, conn: &DbConn) -> EmptyResult {
+28 -38
View File
@@ -1,18 +1,11 @@
use chrono::{NaiveDateTime, TimeDelta, Utc};
use diesel::prelude::*;
//use derive_more::{AsRef, Deref, Display, From};
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{event, users_organizations},
},
error::MapResult,
};
use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId};
use crate::db::schema::{event, users_organizations};
use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG};
use diesel::prelude::*;
// https://bitwarden.com/help/event-logs/
@@ -256,10 +249,11 @@ impl Event {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid))).execute(conn).map_res("Error deleting event")
})
.await
db_run! { conn: {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting event")
}}
}
/// ##############
@@ -270,7 +264,7 @@ impl Event {
end: &NaiveDateTime,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
event::table
.filter(event::org_uuid.eq(org_uuid))
.filter(event::event_date.between(start, end))
@@ -278,15 +272,18 @@ impl Event {
.limit(Self::PAGE_SIZE)
.load::<Self>(conn)
.expect("Error filtering events")
})
.await
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
event::table.filter(event::org_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
})
.await
db_run! { conn: {
event::table
.filter(event::org_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
}
pub async fn find_by_org_and_member(
@@ -296,23 +293,18 @@ impl Event {
end: &NaiveDateTime,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
event::table
.inner_join(users_organizations::table.on(users_organizations::uuid.eq(member_uuid)))
.filter(event::org_uuid.eq(org_uuid))
.filter(event::event_date.between(start, end))
.filter(
event::user_uuid
.eq(users_organizations::user_uuid.nullable())
.or(event::act_user_uuid.eq(users_organizations::user_uuid.nullable())),
)
.filter(event::user_uuid.eq(users_organizations::user_uuid.nullable()).or(event::act_user_uuid.eq(users_organizations::user_uuid.nullable())))
.select(event::all_columns)
.order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE)
.load::<Self>(conn)
.expect("Error filtering events")
})
.await
}}
}
pub async fn find_by_cipher_uuid(
@@ -321,7 +313,7 @@ impl Event {
end: &NaiveDateTime,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
event::table
.filter(event::cipher_uuid.eq(cipher_uuid))
.filter(event::event_date.between(start, end))
@@ -329,19 +321,17 @@ impl Event {
.limit(Self::PAGE_SIZE)
.load::<Self>(conn)
.expect("Error filtering events")
})
.await
}}
}
pub async fn clean_events(conn: &DbConn) -> EmptyResult {
if let Some(days_to_retain) = CONFIG.events_days_retain() {
let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap();
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(event::table.filter(event::event_date.lt(dt)))
.execute(conn)
.map_res("Error cleaning old events")
})
.await
.execute(conn)
.map_res("Error cleaning old events")
}}
} else {
Ok(())
}
+30 -32
View File
@@ -1,12 +1,6 @@
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::favorites},
error::MapResult,
};
use super::{CipherId, User, UserId};
use crate::db::schema::favorites;
use diesel::prelude::*;
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = favorites)]
@@ -16,18 +10,24 @@ pub struct Favorite {
pub cipher_uuid: CipherId,
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
impl Favorite {
// Returns whether the specified cipher is a favorite of the specified user.
pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> bool {
conn.run(move |conn| {
db_run! { conn: {
let query = favorites::table
.filter(favorites::cipher_uuid.eq(cipher_uuid))
.filter(favorites::user_uuid.eq(user_uuid))
.count();
query.first::<i64>(conn).ok().unwrap_or(0) != 0
})
.await
query.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
}
// Sets whether the specified cipher is a favorite of the specified user.
@@ -41,26 +41,27 @@ impl Favorite {
match (old, new) {
(false, true) => {
User::update_uuid_revision(user_uuid, conn).await;
conn.run(move |conn| {
diesel::insert_into(favorites::table)
.values((favorites::user_uuid.eq(user_uuid), favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error adding favorite")
})
.await
db_run! { conn: {
diesel::insert_into(favorites::table)
.values((
favorites::user_uuid.eq(user_uuid),
favorites::cipher_uuid.eq(cipher_uuid),
))
.execute(conn)
.map_res("Error adding favorite")
}}
}
(true, false) => {
User::update_uuid_revision(user_uuid, conn).await;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(
favorites::table
.filter(favorites::user_uuid.eq(user_uuid))
.filter(favorites::cipher_uuid.eq(cipher_uuid)),
.filter(favorites::cipher_uuid.eq(cipher_uuid))
)
.execute(conn)
.map_res("Error removing favorite")
})
.await
}}
}
// Otherwise, the favorite status is already what it should be.
_ => Ok(()),
@@ -69,34 +70,31 @@ impl Favorite {
// Delete all favorite entries associated with the specified cipher.
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error removing favorites by cipher")
})
.await
}}
}
// Delete all favorite entries associated with the specified user.
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error removing favorites by user")
})
.await
}}
}
/// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers
/// This is used during a full sync so we only need one query for all favorite cipher matches.
pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<CipherId> {
conn.run(move |conn| {
db_run! { conn: {
favorites::table
.filter(favorites::user_uuid.eq(user_uuid))
.select(favorites::cipher_uuid)
.load::<CipherId>(conn)
.unwrap_or_default()
})
.await
}}
}
}
+31 -40
View File
@@ -1,19 +1,11 @@
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
api::EmptyResult,
db::{
DbConn,
schema::{folders, folders_ciphers},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{CipherId, User, UserId};
use crate::db::schema::{folders, folders_ciphers};
use diesel::prelude::*;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = folders)]
@@ -64,12 +56,17 @@ impl Folder {
impl FolderCipher {
pub fn new(folder_uuid: FolderId, cipher_uuid: CipherId) -> Self {
Self {
cipher_uuid,
folder_uuid,
cipher_uuid,
}
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Folder {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
@@ -110,12 +107,11 @@ impl Folder {
User::update_uuid_revision(&self.user_uuid, conn).await;
FolderCipher::delete_all_by_folder(&self.uuid, conn).await?;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting folder")
})
.await
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@@ -126,21 +122,22 @@ impl Folder {
}
pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
folders::table
.filter(folders::uuid.eq(uuid))
.filter(folders::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
folders::table.filter(folders::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading folders")
})
.await
db_run! { conn: {
folders::table
.filter(folders::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading folders")
}}
}
}
@@ -168,7 +165,7 @@ impl FolderCipher {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(
folders_ciphers::table
.filter(folders_ciphers::cipher_uuid.eq(self.cipher_uuid))
@@ -176,26 +173,23 @@ impl FolderCipher {
)
.execute(conn)
.map_res("Error removing cipher from folder")
})
.await
}}
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error removing cipher from folders")
})
.await
}}
}
pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid)))
.execute(conn)
.map_res("Error removing ciphers from folder")
})
.await
}}
}
pub async fn find_by_folder_and_cipher(
@@ -203,38 +197,35 @@ impl FolderCipher {
cipher_uuid: &CipherId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.load::<Self>(conn)
.expect("Error loading folders")
})
.await
}}
}
/// Return a vec with (cipher_uuid, folder_uuid)
/// This is used during a full sync so we only need one query for all folder matches.
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, FolderId)> {
conn.run(move |conn| {
db_run! { conn: {
folders_ciphers::table
.inner_join(folders::table)
.filter(folders::user_uuid.eq(user_uuid))
.select(folders_ciphers::all_columns)
.load::<(CipherId, FolderId)>(conn)
.unwrap_or_default()
})
.await
}}
}
}
+122 -143
View File
@@ -1,19 +1,13 @@
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
use crate::api::EmptyResult;
use crate::db::schema::{collections, collections_groups, groups, groups_users, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
api::EmptyResult,
db::{
DbConn,
schema::{collections, collections_groups, groups, groups_users, users_organizations},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
use serde_json::Value;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = groups)]
@@ -203,31 +197,33 @@ impl Group {
}
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
groups::table
.filter(groups::organizations_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading groups")
})
.await
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
groups::table.filter(groups::organizations_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
})
.await
db_run! { conn: {
groups::table
.filter(groups::organizations_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
}
pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
groups::table
.filter(groups::uuid.eq(uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_external_id_and_org(
@@ -235,85 +231,77 @@ impl Group {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
groups::table
.filter(groups::external_id.eq(external_id))
.filter(groups::organizations_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
//Returns all organizations the user has full access to
pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
conn.run(move |conn| {
db_run! { conn: {
groups_users::table
.inner_join(
users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)
.inner_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(groups::access_all.eq(true))
.select(groups::organizations_uuid)
.distinct()
.load::<OrganizationId>(conn)
.expect("Error loading organization group full access information for user")
})
.await
}}
}
pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> bool {
conn.run(move |conn| {
db_run! { conn: {
groups::table
.inner_join(groups_users::table.on(groups_users::groups_uuid.eq(groups::uuid)))
.inner_join(
users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)
.inner_join(groups_users::table.on(
groups_users::groups_uuid.eq(groups::uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.filter(groups::access_all.eq(true))
.select(groups::access_all)
.first::<bool>(conn)
.unwrap_or_default()
})
.await
}}
}
pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
CollectionGroup::delete_all_by_group(&self.uuid, org_uuid, conn).await?;
GroupUser::delete_all_by_group(&self.uuid, org_uuid, conn).await?;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(groups::table.filter(groups::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting group")
})
.await
}}
}
pub async fn update_revision(uuid: &GroupId, conn: &DbConn) {
if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}");
}
}
async fn update_revision_impl(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
crate::util::retry(
|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
.set(groups::revision_date.eq(date))
.execute(conn)
},
10,
)
async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
crate::util::retry(|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
.set(groups::revision_date.eq(date))
.execute(conn)
}, 10)
.map_res("Error updating group revision")
})
.await
}}
}
}
@@ -378,63 +366,60 @@ impl CollectionGroup {
}
pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
collections_groups::table
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)))
.inner_join(
collections::table.on(collections::uuid
.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.inner_join(groups::table.on(
groups::uuid.eq(collections_groups::groups_uuid)
))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.filter(collections_groups::groups_uuid.eq(group_uuid))
.filter(collections::org_uuid.eq(org_uuid))
.select(collections_groups::all_columns)
.load::<Self>(conn)
.expect("Error loading collection groups")
})
.await
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
collections_groups::table
.inner_join(groups_users::table.on(groups_users::groups_uuid.eq(collections_groups::groups_uuid)))
.inner_join(
users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)
.inner_join(
groups::table.on(groups::uuid
.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.inner_join(
collections::table.on(collections::uuid
.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.inner_join(groups_users::table.on(
groups_users::groups_uuid.eq(collections_groups::groups_uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.filter(users_organizations::user_uuid.eq(user_uuid))
.select(collections_groups::all_columns)
.load::<Self>(conn)
.expect("Error loading user collection groups")
})
.await
}}
}
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
collections_groups::table
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.inner_join(collections::table.on(collections::uuid.eq(collections_groups::collections_uuid)))
.inner_join(
groups::table.on(groups::uuid
.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(collections::org_uuid))),
)
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(collections::org_uuid))
))
.select(collections_groups::all_columns)
.load::<Self>(conn)
.expect("Error loading collection groups")
})
.await
}}
}
pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -443,14 +428,13 @@ impl CollectionGroup {
group_user.update_user_revision(conn).await;
}
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(collections_groups::table)
.filter(collections_groups::collections_uuid.eq(&self.collections_uuid))
.filter(collections_groups::groups_uuid.eq(&self.groups_uuid))
.execute(conn)
.map_res("Error deleting collection group")
})
.await
}}
}
pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -459,13 +443,12 @@ impl CollectionGroup {
group_user.update_user_revision(conn).await;
}
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(collections_groups::table)
.filter(collections_groups::groups_uuid.eq(group_uuid))
.execute(conn)
.map_res("Error deleting collection group")
})
.await
}}
}
pub async fn delete_all_by_collection(
@@ -481,13 +464,12 @@ impl CollectionGroup {
}
}
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(collections_groups::table)
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.execute(conn)
.map_res("Error deleting collection group")
})
.await
}}
}
}
@@ -539,31 +521,30 @@ impl GroupUser {
}
pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
groups_users::table
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
.inner_join(
users_organizations::table.on(users_organizations::uuid
.eq(groups_users::users_organizations_uuid)
.and(users_organizations::org_uuid.eq(groups::organizations_uuid))),
)
.inner_join(groups::table.on(
groups::uuid.eq(groups_users::groups_uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
.and(users_organizations::org_uuid.eq(groups::organizations_uuid))
))
.filter(groups_users::groups_uuid.eq(group_uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.select(groups_users::all_columns)
.load::<Self>(conn)
.expect("Error loading group users")
})
.await
}}
}
pub async fn find_by_member(member_uuid: &MembershipId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
groups_users::table
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.load::<Self>(conn)
.expect("Error loading groups for user")
})
.await
}}
}
pub async fn has_access_to_collection_by_member(
@@ -571,23 +552,24 @@ impl GroupUser {
member_uuid: &MembershipId,
conn: &DbConn,
) -> bool {
conn.run(move |conn| {
db_run! { conn: {
groups_users::table
.inner_join(collections_groups::table.on(collections_groups::groups_uuid.eq(groups_users::groups_uuid)))
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
.inner_join(
collections::table.on(collections::uuid
.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.inner_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
))
.inner_join(groups::table.on(
groups::uuid.eq(groups_users::groups_uuid)
))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.count()
.first::<i64>(conn)
.unwrap_or(0)
!= 0
})
.await
.unwrap_or(0) != 0
}}
}
pub async fn has_full_access_by_member(
@@ -595,18 +577,18 @@ impl GroupUser {
member_uuid: &MembershipId,
conn: &DbConn,
) -> bool {
conn.run(move |conn| {
db_run! { conn: {
groups_users::table
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
.inner_join(groups::table.on(
groups::uuid.eq(groups_users::groups_uuid)
))
.filter(groups::organizations_uuid.eq(org_uuid))
.filter(groups::access_all.eq(true))
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.count()
.first::<i64>(conn)
.unwrap_or(0)
!= 0
})
.await
.unwrap_or(0) != 0
}}
}
pub async fn update_user_revision(&self, conn: &DbConn) {
@@ -624,16 +606,15 @@ impl GroupUser {
match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"),
}
};
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(groups_users::table)
.filter(groups_users::groups_uuid.eq(group_uuid))
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.execute(conn)
.map_res("Error deleting group users")
})
.await
}}
}
pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -642,13 +623,12 @@ impl GroupUser {
group_user.update_user_revision(conn).await;
}
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(groups_users::table)
.filter(groups_users::groups_uuid.eq(group_uuid))
.execute(conn)
.map_res("Error deleting group users")
})
.await
}}
}
pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &DbConn) -> EmptyResult {
@@ -657,13 +637,12 @@ impl GroupUser {
None => warn!("Member could not be found!"),
}
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(groups_users::table)
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.execute(conn)
.map_res("Error deleting user groups")
})
.await
}}
}
}
+3 -5
View File
@@ -1,4 +1,3 @@
mod archive;
mod attachment;
mod auth_request;
mod cipher;
@@ -18,12 +17,11 @@ mod two_factor_duo_context;
mod two_factor_incomplete;
mod user;
pub use self::archive::Archive;
pub use self::attachment::{Attachment, AttachmentId};
pub use self::auth_request::{AuthRequest, AuthRequestId};
pub use self::cipher::{Cipher, CipherId, RepromptType};
pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser};
pub use self::device::{Device, DeviceId, DeviceType, DeviceWithAuthRequest, PushId};
pub use self::device::{Device, DeviceId, DeviceType, PushId};
pub use self::emergency_access::{EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType};
pub use self::event::{Event, EventType};
pub use self::favorite::Favorite;
@@ -35,10 +33,10 @@ pub use self::organization::{
OrganizationId,
};
pub use self::send::{
Send, SendType,
id::{SendFileId, SendId},
Send, SendType,
};
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeResponseError, SsoAuth};
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeWrapper, SsoAuth};
pub use self::two_factor::{TwoFactor, TwoFactorType};
pub use self::two_factor_duo_context::TwoFactorDuoContext;
pub use self::two_factor_incomplete::TwoFactorIncomplete;
+70 -74
View File
@@ -1,17 +1,14 @@
use derive_more::{AsRef, From};
use diesel::prelude::*;
use serde::Deserialize;
use serde_json::Value;
use crate::{
CONFIG,
api::{EmptyResult, core::two_factor},
db::{
DbConn,
schema::{org_policies, users_organizations},
},
error::MapResult,
};
use crate::api::core::two_factor;
use crate::api::EmptyResult;
use crate::db::schema::{org_policies, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use crate::CONFIG;
use diesel::prelude::*;
use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId};
@@ -151,38 +148,37 @@ impl OrgPolicy {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting org_policy")
})
.await
}}
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading org_policy")
})
.await
}}
}
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
org_policies::table
.inner_join(
users_organizations::table.on(users_organizations::org_uuid
.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
users_organizations::table.on(
users_organizations::org_uuid.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.select(org_policies::all_columns)
.load::<Self>(conn)
.expect("Error loading org_policy")
})
.await
}}
}
pub async fn find_by_org_and_type(
@@ -190,23 +186,21 @@ impl OrgPolicy {
policy_type: OrgPolicyType,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid))
.filter(org_policies::atype.eq(policy_type as i32))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid)))
.execute(conn)
.map_res("Error deleting org_policy")
})
.await
}}
}
pub async fn find_accepted_and_confirmed_by_user_and_active_policy(
@@ -214,22 +208,25 @@ impl OrgPolicy {
policy_type: OrgPolicyType,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
org_policies::table
.inner_join(
users_organizations::table.on(users_organizations::org_uuid
.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
users_organizations::table.on(
users_organizations::org_uuid.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Accepted as i32)
)
.or_filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(users_organizations::status.eq(MembershipStatus::Accepted as i32))
.or_filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns)
.load::<Self>(conn)
.expect("Error loading org_policy")
})
.await
}}
}
pub async fn find_confirmed_by_user_and_active_policy(
@@ -237,21 +234,22 @@ impl OrgPolicy {
policy_type: OrgPolicyType,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
org_policies::table
.inner_join(
users_organizations::table.on(users_organizations::org_uuid
.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
users_organizations::table.on(
users_organizations::org_uuid.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns)
.load::<Self>(conn)
.expect("Error loading org_policy")
})
.await
}}
}
/// Returns true if the user belongs to an org that has enabled the specified policy type,
@@ -271,10 +269,10 @@ impl OrgPolicy {
continue;
}
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await
&& user.atype < MembershipType::Admin
{
return true;
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await {
if user.atype < MembershipType::Admin {
return true;
}
}
}
false
@@ -284,13 +282,13 @@ impl OrgPolicy {
if m.atype < MembershipType::Admin && m.status > (MembershipStatus::Invited as i32) {
// Enforce TwoFactor/TwoStep login
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::TwoFactorAuthentication, conn).await
&& p.enabled
&& TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty()
{
if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?;
} else {
err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid));
if p.enabled && TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty() {
if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?;
} else {
err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid));
}
}
}
@@ -302,14 +300,12 @@ impl OrgPolicy {
));
}
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await
&& p.enabled
&& Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0
{
err!(format!(
"Cannot {} because the organization policy forbids being part of other organization (membership {})",
action, m.uuid
));
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await {
if p.enabled
&& Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0
{
err!(format!("Cannot {} because the organization policy forbids being part of other organization (membership {})", action, m.uuid));
}
}
}
@@ -336,16 +332,16 @@ impl OrgPolicy {
for policy in
OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await
{
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await
&& user.atype < MembershipType::Admin
{
match serde_json::from_str::<SendOptionsPolicyData>(&policy.data) {
Ok(opts) => {
if opts.disable_hide_email {
return true;
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await {
if user.atype < MembershipType::Admin {
match serde_json::from_str::<SendOptionsPolicyData>(&policy.data) {
Ok(opts) => {
if opts.disable_hide_email {
return true;
}
}
_ => error!("Failed to deserialize SendOptionsPolicyData: {}", policy.data),
}
_ => error!("Failed to deserialize SendOptionsPolicyData: {}", policy.data),
}
}
}
@@ -353,10 +349,10 @@ impl OrgPolicy {
}
pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool {
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await
&& let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await
{
return policy.enabled;
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await {
if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await {
return policy.enabled;
}
}
false
}
+183 -200
View File
@@ -1,32 +1,23 @@
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use num_traits::FromPrimitive;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations,
},
},
error::MapResult,
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};
use macros::UuidFromParam;
use super::{
Cipher, CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
OrgPolicyType, TwoFactor, User, UserId,
};
use crate::db::schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations,
};
use crate::CONFIG;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = organizations)]
@@ -102,10 +93,6 @@ pub enum MembershipType {
impl MembershipType {
pub fn from_str(s: &str) -> Option<Self> {
#[expect(
clippy::match_same_arms,
reason = "Specifically define `4|Custom` since this is a hack, not a default"
)]
match s {
"0" | "Owner" => Some(MembershipType::Owner),
"1" | "Admin" => Some(MembershipType::Admin),
@@ -334,6 +321,11 @@ impl OrganizationApiKey {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Organization {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@@ -341,7 +333,7 @@ impl Organization {
err!(format!("BillingEmail {} is not a valid email address", self.billing_email))
}
for member in &Membership::find_by_org(&self.uuid, conn).await {
for member in Membership::find_by_org(&self.uuid, conn).await.iter() {
User::update_uuid_revision(&member.user_uuid, conn).await;
}
@@ -377,6 +369,8 @@ impl Organization {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
use super::{Cipher, Collection};
Cipher::delete_all_by_organization(&self.uuid, conn).await?;
Collection::delete_all_by_organization(&self.uuid, conn).await?;
Membership::delete_all_by_organization(&self.uuid, conn).await?;
@@ -384,30 +378,43 @@ impl Organization {
Group::delete_all_by_organization(&self.uuid, conn).await?;
OrganizationApiKey::delete_all_by_organization(&self.uuid, conn).await?;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error saving organization")
})
.await
}}
}
pub async fn find_by_uuid(uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| organizations::table.filter(organizations::uuid.eq(uuid)).first::<Self>(conn).ok()).await
db_run! { conn: {
organizations::table
.filter(organizations::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_name(name: &str, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| organizations::table.filter(organizations::name.eq(name)).first::<Self>(conn).ok()).await
db_run! { conn: {
organizations::table
.filter(organizations::name.eq(name))
.first::<Self>(conn)
.ok()
}}
}
pub async fn get_all(conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| organizations::table.load::<Self>(conn).expect("Error loading organizations")).await
db_run! { conn: {
organizations::table
.load::<Self>(conn)
.expect("Error loading organizations")
}}
}
pub async fn find_main_org_user_email(user_email: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = user_email.to_lowercase();
conn.run(move |conn| {
db_run! { conn: {
organizations::table
.inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid)))
.inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid)))
@@ -417,14 +424,13 @@ impl Organization {
.select(organizations::all_columns)
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_org_user_email(user_email: &str, conn: &DbConn) -> Vec<Self> {
let lower_mail = user_email.to_lowercase();
conn.run(move |conn| {
db_run! { conn: {
organizations::table
.inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid)))
.inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid)))
@@ -434,8 +440,7 @@ impl Organization {
.select(organizations::all_columns)
.load::<Self>(conn)
.expect("Error loading user orgs")
})
.await
}}
}
}
@@ -775,12 +780,11 @@ impl Membership {
CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?;
GroupUser::delete_all_by_member(&self.uuid, conn).await?;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error removing user from organization")
})
.await
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -798,10 +802,10 @@ impl Membership {
}
pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Membership> {
if let Some(user) = User::find_by_mail(email, conn).await
&& let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await
{
return Some(member);
if let Some(user) = User::find_by_mail(email, conn).await {
if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await {
return Some(member);
}
}
None
@@ -820,67 +824,64 @@ impl Membership {
}
pub async fn find_by_uuid(uuid: &MembershipId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
users_organizations::table.filter(users_organizations::uuid.eq(uuid)).first::<Self>(conn).ok()
})
.await
db_run! { conn: {
users_organizations::table
.filter(users_organizations::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid_and_org(uuid: &MembershipId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::uuid.eq(uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<Self>(conn)
.unwrap_or_default()
})
.await
}}
}
pub async fn find_invited_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.load::<Self>(conn)
.unwrap_or_default()
})
.await
}}
}
// Should be used only when email are disabled.
// In Organizations::send_invite status is set to Accepted only if the user has a password.
pub async fn accept_user_invitations(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::update(users_organizations::table)
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.set(users_organizations::status.eq(MembershipStatus::Accepted as i32))
.execute(conn)
.map_res("Error confirming invitations")
})
.await
}}
}
pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.unwrap_or_default()
})
.await
}}
}
pub async fn count_accepted_and_confirmed_by_user(
@@ -888,83 +889,70 @@ impl Membership {
excluded_org: &OrganizationId,
conn: &DbConn,
) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.ne(excluded_org))
.filter(
users_organizations::status
.eq(MembershipStatus::Accepted as i32)
.or(users_organizations::status.eq(MembershipStatus::Confirmed as i32)),
)
.filter(users_organizations::status.eq(MembershipStatus::Accepted as i32).or(users_organizations::status.eq(MembershipStatus::Confirmed as i32)))
.count()
.first::<i64>(conn)
.unwrap_or(0)
})
.await
}}
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
}}
}
pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<Self>(conn)
.unwrap_or_default()
})
.await
}}
}
// Get all users which are either owner or admin, or a manager which can manage/access all
pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(
users_organizations::atype
.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype
.eq(MembershipType::Manager as i32)
.and(users_organizations::access_all.eq(true))),
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype.eq(MembershipType::Manager as i32).and(users_organizations::access_all.eq(true)))
)
.load::<Self>(conn)
.unwrap_or_default()
})
.await
}}
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
})
.await
}}
}
pub async fn find_by_org_and_type(org_uuid: &OrganizationId, atype: MembershipType, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32))
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
}}
}
pub async fn count_confirmed_by_org_and_type(
@@ -972,7 +960,7 @@ impl Membership {
atype: MembershipType,
conn: &DbConn,
) -> i64 {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32))
@@ -980,19 +968,17 @@ impl Membership {
.count()
.first::<i64>(conn)
.unwrap_or(0)
})
.await
}}
}
pub async fn find_by_user_and_org(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_confirmed_by_user_and_org(
@@ -1000,76 +986,78 @@ impl Membership {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
}}
}
pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.select(users_organizations::org_uuid)
.load::<OrganizationId>(conn)
.unwrap_or_default()
})
.await
}}
}
pub async fn find_by_user_and_policy(user_uuid: &UserId, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.inner_join(
org_policies::table.on(org_policies::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))
.and(org_policies::atype.eq(policy_type as i32))
.and(org_policies::enabled.eq(true))),
org_policies::table.on(
org_policies::org_uuid.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))
.and(org_policies::atype.eq(policy_type as i32))
.and(org_policies::enabled.eq(true)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.select(users_organizations::all_columns)
.load::<Self>(conn)
.unwrap_or_default()
})
.await
}}
}
pub async fn find_by_cipher_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(users_collections::user_uuid.eq(users_organizations::user_uuid)))
.left_join(
ciphers_collections::table.on(ciphers_collections::collection_uuid
.eq(users_collections::collection_uuid)
.and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))),
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(
users_collections::user_uuid.eq(users_organizations::user_uuid)
))
.left_join(ciphers_collections::table.on(
ciphers_collections::collection_uuid.eq(users_collections::collection_uuid).and(
ciphers_collections::cipher_uuid.eq(&cipher_uuid)
)
.filter(users_organizations::access_all.eq(true).or(
// AccessAll..
ciphers_collections::cipher_uuid.eq(&cipher_uuid), // ..or access to collection with cipher
))
.select(users_organizations::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
))
.filter(
users_organizations::access_all.eq(true).or( // AccessAll..
ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection with cipher
)
)
.select(users_organizations::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn find_by_cipher_and_org_with_group(
@@ -1077,54 +1065,45 @@ impl Membership {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)
.left_join(collections_groups::table.on(collections_groups::groups_uuid.eq(groups_users::groups_uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.left_join(
ciphers_collections::table.on(ciphers_collections::collection_uuid
.eq(collections_groups::collections_uuid)
.and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))),
)
.filter(groups::access_all.eq(true).or(
// AccessAll via groups
ciphers_collections::cipher_uuid.eq(&cipher_uuid), // ..or access to collection via group
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(ciphers_collections::table.on(
ciphers_collections::collection_uuid.eq(collections_groups::collections_uuid).and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))
))
.filter(
groups::access_all.eq(true).or( // AccessAll via groups
ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection via group
)
)
.select(users_organizations::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading user organizations with groups")
})
.await
.load::<Self>(conn)
.expect("Error loading user organizations with groups")
}}
}
pub async fn user_has_ge_admin_access_to_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> bool {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.inner_join(
ciphers::table.on(ciphers::uuid
.eq(cipher_uuid)
.and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))),
)
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
.inner_join(ciphers::table.on(ciphers::uuid.eq(cipher_uuid).and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))))
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
}
pub async fn find_by_collection_and_org(
@@ -1132,41 +1111,44 @@ impl Membership {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(users_collections::user_uuid.eq(users_organizations::user_uuid)))
.filter(users_organizations::access_all.eq(true).or(
// AccessAll..
users_collections::collection_uuid.eq(&collection_uuid), // ..or access to collection with cipher
))
.select(users_organizations::all_columns)
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(
users_collections::user_uuid.eq(users_organizations::user_uuid)
))
.filter(
users_organizations::access_all.eq(true).or( // AccessAll..
users_collections::collection_uuid.eq(&collection_uuid) // ..or access to collection with cipher
)
)
.select(users_organizations::all_columns)
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
}
pub async fn find_by_external_id_and_org(ext_id: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::external_id.eq(ext_id).and(users_organizations::org_uuid.eq(org_uuid)))
.first::<Self>(conn)
.ok()
})
.await
.filter(
users_organizations::external_id.eq(ext_id)
.and(users_organizations::org_uuid.eq(org_uuid))
)
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_main_user_org(user_uuid: &str, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc())
.first::<Self>(conn)
.ok()
})
.await
}}
}
}
@@ -1204,19 +1186,20 @@ impl OrganizationApiKey {
}
pub async fn find_by_org_uuid(org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)).first::<Self>(conn).ok()
})
.await
db_run! { conn: {
organization_api_key::table
.filter(organization_api_key::org_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)))
.execute(conn)
.map_res("Error removing organization api key from organization")
})
.await
}}
}
}
+82 -55
View File
@@ -1,19 +1,11 @@
use chrono::{NaiveDateTime, Utc};
use data_encoding::BASE64URL_NOPAD;
use diesel::prelude::*;
use serde_json::Value;
use uuid::Uuid;
use crate::{
CONFIG,
api::EmptyResult,
config::PathType,
db::{DbConn, schema::sends},
error::MapResult,
util::{LowerCase, NumberOrString, format_date},
};
use crate::{config::PathType, util::LowerCase, CONFIG};
use super::{OrganizationId, User, UserId};
use crate::db::schema::sends;
use diesel::prelude::*;
use id::SendId;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
@@ -115,33 +107,37 @@ impl Send {
pub fn check_password(&self, password: &str) -> bool {
match (&self.password_hash, &self.password_salt, self.password_iter) {
(Some(hash), Some(salt), Some(iter)) => {
crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter.cast_unsigned())
crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter as u32)
}
_ => false,
}
}
pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> {
if let Some(hide_email) = self.hide_email
&& hide_email
{
return None;
if let Some(hide_email) = self.hide_email {
if hide_email {
return None;
}
}
if let Some(user_uuid) = &self.user_uuid
&& let Some(user) = User::find_by_uuid(user_uuid, conn).await
{
return Some(user.email);
if let Some(user_uuid) = &self.user_uuid {
if let Some(user) = User::find_by_uuid(user_uuid, conn).await {
return Some(user.email);
}
}
None
}
pub fn to_json(&self) -> Value {
use crate::util::format_date;
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
// Mobile clients expect size to be a string instead of a number
if let Some(size) = data.get("size").and_then(Value::as_i64) {
if let Some(size) = data.get("size").and_then(|v| v.as_i64()) {
data["size"] = Value::String(size.to_string());
}
@@ -171,10 +167,12 @@ impl Send {
}
pub async fn to_json_access(&self, conn: &DbConn) -> Value {
use crate::util::format_date;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
// Mobile clients expect size to be a string instead of a number
if let Some(size) = data.get("size").and_then(Value::as_i64) {
if let Some(size) = data.get("size").and_then(|v| v.as_i64()) {
data["size"] = Value::String(size.to_string());
}
@@ -193,6 +191,12 @@ impl Send {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
use crate::util::NumberOrString;
impl Send {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
@@ -233,13 +237,14 @@ impl Send {
if self.atype == SendType::File as i32 {
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Sends)?;
operator.delete_with(&self.uuid).recursive(true).await.ok();
operator.remove_all(&self.uuid).await.ok();
}
conn.run(move |conn| {
diesel::delete(sends::table.filter(sends::uuid.eq(&self.uuid))).execute(conn).map_res("Error deleting send")
})
.await
db_run! { conn: {
diesel::delete(sends::table.filter(sends::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting send")
}}
}
/// Purge all sends that are past their deletion date.
@@ -251,12 +256,15 @@ impl Send {
pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new();
if let Some(user_uuid) = &self.user_uuid {
User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone());
} else {
// Belongs to Organization, not implemented
}
match &self.user_uuid {
Some(user_uuid) => {
User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone())
}
None => {
// Belongs to Organization, not implemented
}
};
user_uuids
}
@@ -268,6 +276,9 @@ impl Send {
}
pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> {
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
let Ok(uuid_vec) = BASE64URL_NOPAD.decode(access_id.as_bytes()) else {
return None;
};
@@ -281,38 +292,50 @@ impl Send {
}
pub async fn find_by_uuid(uuid: &SendId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| sends::table.filter(sends::uuid.eq(uuid)).first::<Self>(conn).ok()).await
db_run! { conn: {
sends::table
.filter(sends::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
sends::table.filter(sends::uuid.eq(uuid)).filter(sends::user_uuid.eq(user_uuid)).first::<Self>(conn).ok()
})
.await
db_run! { conn: {
sends::table
.filter(sends::uuid.eq(uuid))
.filter(sends::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
sends::table.filter(sends::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading sends")
})
.await
db_run! { conn: {
sends::table
.filter(sends::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading sends")
}}
}
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<i64> {
let sends = Self::find_by_user(user_uuid, conn).await;
#[derive(serde::Deserialize)]
struct FileData {
#[serde(rename = "size", alias = "Size")]
size: NumberOrString,
}
let sends = Self::find_by_user(user_uuid, conn).await;
let mut total: i64 = 0;
for send in sends {
if send.atype == SendType::File as i32
&& let Ok(size) =
if send.atype == SendType::File as i32 {
if let Ok(size) =
serde_json::from_str::<FileData>(&send.data).map_err(Into::into).and_then(|d| d.size.into_i64())
{
total = total.checked_add(size)?;
{
total = total.checked_add(size)?;
};
}
}
@@ -320,18 +343,22 @@ impl Send {
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
sends::table.filter(sends::organization_uuid.eq(org_uuid)).load::<Self>(conn).expect("Error loading sends")
})
.await
db_run! { conn: {
sends::table
.filter(sends::organization_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading sends")
}}
}
pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().naive_utc();
conn.run(move |conn| {
sends::table.filter(sends::deletion_date.lt(now)).load::<Self>(conn).expect("Error loading sends")
})
.await
db_run! { conn: {
sends::table
.filter(sends::deletion_date.lt(now))
.load::<Self>(conn)
.expect("Error loading sends")
}}
}
}
+27 -49
View File
@@ -1,29 +1,31 @@
use chrono::{NaiveDateTime, Utc};
use std::time::Duration;
use chrono::{NaiveDateTime, Utc};
use diesel::{
deserialize::FromSql,
expression::AsExpression,
prelude::*,
serialize::{Output, ToSql},
sql_types::Text,
};
use crate::api::EmptyResult;
use crate::db::schema::sso_auth;
use crate::db::{DbConn, DbPool};
use crate::error::MapResult;
use crate::sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION};
use crate::{
api::EmptyResult,
db::{DbConn, DbPool, schema::sso_auth},
error::MapResult,
sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION},
};
use diesel::deserialize::FromSql;
use diesel::expression::AsExpression;
use diesel::prelude::*;
use diesel::serialize::{Output, ToSql};
use diesel::sql_types::Text;
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)]
pub struct OIDCCodeResponseError {
pub error: String,
pub error_description: Option<String>,
pub enum OIDCCodeWrapper {
Ok {
code: OIDCCode,
},
Error {
error: String,
error_description: Option<String>,
},
}
impl_FromToSqlText!(OIDCCodeResponseError);
impl_FromToSqlText!(OIDCCodeWrapper);
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)]
@@ -48,23 +50,15 @@ pub struct SsoAuth {
pub client_challenge: OIDCCodeChallenge,
pub nonce: String,
pub redirect_uri: String,
pub code_response: Option<OIDCCode>,
pub code_response_error: Option<OIDCCodeResponseError>,
pub code_response: Option<OIDCCodeWrapper>,
pub auth_response: Option<OIDCAuthenticatedUser>,
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
pub binding_hash: Option<String>,
}
/// Local methods
impl SsoAuth {
pub fn new(
state: OIDCState,
client_challenge: OIDCCodeChallenge,
nonce: String,
redirect_uri: String,
binding_hash: Option<String>,
) -> Self {
pub fn new(state: OIDCState, client_challenge: OIDCCodeChallenge, nonce: String, redirect_uri: String) -> Self {
let now = Utc::now().naive_utc();
SsoAuth {
@@ -75,9 +69,7 @@ impl SsoAuth {
created_at: now,
updated_at: now,
code_response: None,
code_response_error: None,
auth_response: None,
binding_hash,
}
}
}
@@ -108,22 +100,10 @@ impl SsoAuth {
}
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
conn.run(move |conn| {
sso_auth::table
.filter(sso_auth::state.eq(state))
.filter(sso_auth::created_at.ge(oldest))
.first::<Self>(conn)
.ok()
})
.await
}
pub async fn find_by_code(code: &OIDCCode, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: {
sso_auth::table
.filter(sso_auth::code_response.eq(code))
.filter(sso_auth::state.eq(state))
.filter(sso_auth::created_at.ge(oldest))
.first::<Self>(conn)
.ok()
@@ -131,24 +111,22 @@ impl SsoAuth {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! {conn: {
diesel::delete(sso_auth::table.filter(sso_auth::state.eq(self.state)))
.execute(conn)
.map_res("Error deleting sso_auth")
})
.await
}}
}
pub async fn delete_expired(pool: DbPool) -> EmptyResult {
debug!("Purging expired sso_auth");
if let Ok(conn) = pool.get().await {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(sso_auth::table.filter(sso_auth::created_at.lt(oldest)))
.execute(conn)
.map_res("Error deleting expired SSO nonce")
})
.await
}}
} else {
err!("Failed to get DB connection while purging expired sso_auth")
}
+28 -39
View File
@@ -1,17 +1,13 @@
use super::UserId;
use crate::api::core::two_factor::webauthn::WebauthnRegistration;
use crate::db::schema::twofactor;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use serde_json::Value;
use webauthn_rs::prelude::{Credential, ParsedAttestation};
use webauthn_rs_core::proto::CredentialV3;
use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions};
use crate::{
api::{EmptyResult, core::two_factor::webauthn::WebauthnRegistration},
db::{DbConn, schema::twofactor},
error::MapResult,
};
use super::UserId;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor)]
#[diesel(primary_key(uuid))]
@@ -118,59 +114,54 @@ impl TwoFactor {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting twofactor")
})
.await
}}
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.lt(1000)) // Filter implementation types
.load::<Self>(conn)
.expect("Error loading twofactor")
})
.await
}}
}
pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.eq(atype))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error deleting twofactors")
})
.await
}}
}
pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult {
use crate::api::core::two_factor::webauthn::{U2FRegistration, get_webauthn_registrations};
let u2f_factors = db_run! { conn: {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
}};
use crate::api::core::two_factor::webauthn::U2FRegistration;
use crate::api::core::two_factor::webauthn::{get_webauthn_registrations, WebauthnRegistration};
use webauthn_rs::prelude::{COSEEC2Key, COSEKey, COSEKeyType, ECDSACurve};
use webauthn_rs_proto::{COSEAlgorithm, UserVerificationPolicy};
let u2f_factors = conn
.run(move |conn| {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
})
.await;
for mut u2f in u2f_factors {
let mut regs: Vec<U2FRegistration> = serde_json::from_str(&u2f.data)?;
// If there are no registrations or they are migrated (we do the migration in batch so we can consider them all migrated when the first one is)
@@ -236,14 +227,12 @@ impl TwoFactor {
}
pub async fn migrate_credential_to_passkey(conn: &DbConn) -> EmptyResult {
let webauthn_factors = conn
.run(move |conn| {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
})
.await;
let webauthn_factors = db_run! { conn: {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
}};
for webauthn_factor in webauthn_factors {
// assume that a failure to parse into the old struct, means that it was already converted
@@ -252,7 +241,7 @@ impl TwoFactor {
continue;
};
let regs = regs.into_iter().map(Into::into).collect::<Vec<WebauthnRegistration>>();
let regs = regs.into_iter().map(|r| r.into()).collect::<Vec<WebauthnRegistration>>();
TwoFactor::new(webauthn_factor.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&regs)?)
.save(conn)
+23 -25
View File
@@ -1,11 +1,8 @@
use chrono::Utc;
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::twofactor_duo_ctx},
error::MapResult,
};
use crate::db::schema::twofactor_duo_ctx;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_duo_ctx)]
@@ -19,10 +16,12 @@ pub struct TwoFactorDuoContext {
impl TwoFactorDuoContext {
pub async fn find_by_state(state: &str, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
twofactor_duo_ctx::table.filter(twofactor_duo_ctx::state.eq(state)).first::<Self>(conn).ok()
})
.await
db_run! { conn: {
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(state))
.first::<Self>(conn)
.ok()
}}
}
pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &DbConn) -> EmptyResult {
@@ -30,42 +29,41 @@ impl TwoFactorDuoContext {
let exists = Self::find_by_state(state, conn).await;
if exists.is_some() {
return Ok(());
}
};
let exp = Utc::now().timestamp() + ttl;
conn.run(move |conn| {
db_run! { conn: {
diesel::insert_into(twofactor_duo_ctx::table)
.values((
twofactor_duo_ctx::state.eq(state),
twofactor_duo_ctx::user_email.eq(user_email),
twofactor_duo_ctx::nonce.eq(nonce),
twofactor_duo_ctx::exp.eq(exp),
))
.execute(conn)
.map_res("Error saving context to twofactor_duo_ctx")
})
.await
twofactor_duo_ctx::exp.eq(exp)
))
.execute(conn)
.map_res("Error saving context to twofactor_duo_ctx")
}}
}
pub async fn find_expired(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().timestamp();
conn.run(move |conn| {
db_run! { conn: {
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::exp.lt(now))
.load::<Self>(conn)
.expect("Error finding expired contexts in twofactor_duo_ctx")
})
.await
}}
}
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
diesel::delete(twofactor_duo_ctx::table.filter(twofactor_duo_ctx::state.eq(&self.state)))
db_run! { conn: {
diesel::delete(
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(&self.state)))
.execute(conn)
.map_res("Error deleting from twofactor_duo_ctx")
})
.await
}}
}
pub async fn purge_expired_duo_contexts(conn: &DbConn) {
+19 -26
View File
@@ -1,17 +1,17 @@
use chrono::{NaiveDateTime, Utc};
use diesel::prelude::*;
use crate::db::schema::twofactor_incomplete;
use crate::{
CONFIG,
api::EmptyResult,
auth::ClientIp,
db::{
DbConn,
models::{DeviceId, UserId},
schema::twofactor_incomplete,
DbConn,
},
error::MapResult,
CONFIG,
};
use diesel::prelude::*;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_incomplete)]
@@ -49,7 +49,7 @@ impl TwoFactorIncomplete {
return Ok(());
}
conn.run(move |conn| {
db_run! { conn: {
diesel::insert_into(twofactor_incomplete::table)
.values((
twofactor_incomplete::user_uuid.eq(user_uuid),
@@ -61,8 +61,7 @@ impl TwoFactorIncomplete {
))
.execute(conn)
.map_res("Error adding twofactor_incomplete record")
})
.await
}}
}
pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
@@ -74,24 +73,22 @@ impl TwoFactorIncomplete {
}
pub async fn find_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| {
db_run! { conn: {
twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.first::<Self>(conn)
.ok()
})
.await
}}
}
pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
conn.run(move |conn| {
db_run! { conn: {
twofactor_incomplete::table
.filter(twofactor_incomplete::login_time.lt(dt))
.load::<Self>(conn)
.expect("Error loading twofactor_incomplete")
})
.await
}}
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
@@ -99,24 +96,20 @@ impl TwoFactorIncomplete {
}
pub async fn delete_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
diesel::delete(
twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)),
)
.execute(conn)
.map_res("Error in twofactor_incomplete::delete_by_user_and_device()")
})
.await
db_run! { conn: {
diesel::delete(twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)))
.execute(conn)
.map_res("Error in twofactor_incomplete::delete_by_user_and_device()")
}}
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error in twofactor_incomplete::delete_all_by_user()")
})
.await
}}
}
}
+77 -76
View File
@@ -1,26 +1,22 @@
use crate::db::schema::{invitations, sso_users, twofactor_incomplete, users};
use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
crypto,
db::{
DbConn,
models::DeviceId,
schema::{invitations, sso_users, twofactor_incomplete, users},
},
error::MapResult,
sso::OIDCIdentifier,
util::{format_date, get_uuid, retry},
};
use macros::UuidFromParam;
use super::{
Cipher, Device, EmergencyAccess, Favorite, Folder, Membership, MembershipType, TwoFactor, TwoFactorIncomplete,
};
use crate::{
api::EmptyResult,
crypto,
db::{models::DeviceId, DbConn},
error::MapResult,
sso::OIDCIdentifier,
util::{format_date, get_uuid, retry},
CONFIG,
};
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
#[diesel(table_name = users)]
@@ -141,8 +137,8 @@ impl User {
_totp_secret: None,
totp_recover: None,
equivalent_domains: "[]".to_owned(),
excluded_globals: "[]".to_owned(),
equivalent_domains: "[]".to_string(),
excluded_globals: "[]".to_string(),
client_kdf_type: Self::CLIENT_KDF_TYPE_DEFAULT,
client_kdf_iter: Self::CLIENT_KDF_ITER_DEFAULT,
@@ -162,7 +158,7 @@ impl User {
password.as_bytes(),
&self.salt,
&self.password_hash,
self.password_iterations.cast_unsigned(),
self.password_iterations as u32,
)
}
@@ -197,8 +193,7 @@ impl User {
allow_next_route: Option<Vec<String>>,
conn: &DbConn,
) -> EmptyResult {
self.password_hash =
crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations.cast_unsigned());
self.password_hash = crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations as u32);
if let Some(route) = allow_next_route {
self.set_stamp_exception(route);
@@ -243,10 +238,10 @@ impl User {
pub fn display_name(&self) -> &str {
// default to email if name is empty
if self.name.is_empty() {
&self.email
} else {
if !&self.name.is_empty() {
&self.name
} else {
&self.email
}
}
}
@@ -342,14 +337,15 @@ impl User {
TwoFactorIncomplete::delete_all_by_user(&self.uuid, conn).await?;
Invitation::take(&self.email, conn).await; // Delete invitation if any
conn.run(move |conn| {
diesel::delete(users::table.filter(users::uuid.eq(self.uuid))).execute(conn).map_res("Error deleting user")
})
.await
db_run! { conn: {
diesel::delete(users::table.filter(users::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting user")
}}
}
pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) {
if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}");
}
}
@@ -357,62 +353,68 @@ impl User {
pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult {
let updated_at = Utc::now().naive_utc();
conn.run(move |conn| {
retry(|| diesel::update(users::table).set(users::updated_at.eq(updated_at)).execute(conn), 10)
.map_res("Error updating revision date for all users")
})
.await
db_run! { conn: {
retry(|| {
diesel::update(users::table)
.set(users::updated_at.eq(updated_at))
.execute(conn)
}, 10)
.map_res("Error updating revision date for all users")
}}
}
pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc();
Self::update_revision_impl(&self.uuid, &self.updated_at, conn).await
Self::_update_revision(&self.uuid, &self.updated_at, conn).await
}
async fn update_revision_impl(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
retry(
|| {
diesel::update(users::table.filter(users::uuid.eq(uuid)))
.set(users::updated_at.eq(date))
.execute(conn)
},
10,
)
async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
retry(|| {
diesel::update(users::table.filter(users::uuid.eq(uuid)))
.set(users::updated_at.eq(date))
.execute(conn)
}, 10)
.map_res("Error updating user revision")
})
.await
}}
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase();
conn.run(move |conn| users::table.filter(users::email.eq(lower_mail)).first::<Self>(conn).ok()).await
db_run! { conn: {
users::table
.filter(users::email.eq(lower_mail))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_uuid(uuid: &UserId, conn: &DbConn) -> Option<Self> {
conn.run(move |conn| users::table.filter(users::uuid.eq(uuid)).first::<Self>(conn).ok()).await
db_run! { conn: {
users::table
.filter(users::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
}
pub async fn find_by_device_for_email2fa(device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
if let Some(user_uuid) = conn
.run(move |conn| {
twofactor_incomplete::table
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.order_by(twofactor_incomplete::login_time.desc())
.select(twofactor_incomplete::user_uuid)
.first::<UserId>(conn)
.ok()
})
.await
{
if let Some(user_uuid) = db_run! ( conn: {
twofactor_incomplete::table
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.order_by(twofactor_incomplete::login_time.desc())
.select(twofactor_incomplete::user_uuid)
.first::<UserId>(conn)
.ok()
}) {
return Self::find_by_uuid(&user_uuid, conn).await;
}
None
}
pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> {
conn.run(move |conn| {
db_run! { conn: {
users::table
.left_join(sso_users::table)
.select(<(Self, Option<SsoUser>)>::as_select())
@@ -420,8 +422,7 @@ impl User {
.expect("Error loading groups for user")
.into_iter()
.collect()
})
.await
}}
}
pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> {
@@ -466,18 +467,21 @@ impl Invitation {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(invitations::table.filter(invitations::email.eq(self.email)))
.execute(conn)
.map_res("Error deleting invitation")
})
.await
}}
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase();
conn.run(move |conn| invitations::table.filter(invitations::email.eq(lower_mail)).first::<Self>(conn).ok())
.await
db_run! { conn: {
invitations::table
.filter(invitations::email.eq(lower_mail))
.first::<Self>(conn)
.ok()
}}
}
pub async fn take(mail: &str, conn: &DbConn) -> bool {
@@ -527,37 +531,34 @@ impl SsoUser {
}
pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, Self)> {
conn.run(move |conn| {
db_run! { conn: {
users::table
.inner_join(sso_users::table)
.select(<(User, Self)>::as_select())
.filter(sso_users::identifier.eq(identifier))
.first::<(User, Self)>(conn)
.ok()
})
.await
}}
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<Self>)> {
let lower_mail = mail.to_lowercase();
conn.run(move |conn| {
db_run! { conn: {
users::table
.left_join(sso_users::table)
.select(<(User, Option<Self>)>::as_select())
.filter(users::email.eq(lower_mail))
.first::<(User, Option<Self>)>(conn)
.ok()
})
.await
}}
}
pub async fn delete(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
db_run! { conn: {
diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error deleting sso user")
})
.await
}}
}
}
+5 -6
View File
@@ -1,6 +1,5 @@
use std::{cell::RefCell, collections::HashMap, time::Instant};
use diesel::connection::{Instrumentation, InstrumentationEvent};
use std::{cell::RefCell, collections::HashMap, time::Instant};
thread_local! {
static QUERY_PERF_TRACKER: RefCell<HashMap<String, Instant>> = RefCell::new(HashMap::new());
@@ -12,7 +11,7 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
url,
..
} => {
debug!("Establishing connection: {url}");
debug!("Establishing connection: {url}")
}
InstrumentationEvent::FinishEstablishConnection {
url,
@@ -20,9 +19,9 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
..
} => {
if let Some(e) = error {
error!("Error during establishing a connection with {url}: {e:?}");
error!("Error during establishing a connection with {url}: {e:?}")
} else {
debug!("Connection established: {url}");
debug!("Connection established: {url}")
}
}
InstrumentationEvent::StartQuery {
@@ -48,7 +47,7 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
} else if duration.as_secs() >= 1 {
info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else {
debug!("QUERY [{duration:?}]: {query_string}");
debug!("QUERY [{:?}]: {}", duration, query_string);
}
}
});
-13
View File
@@ -262,11 +262,9 @@ table! {
nonce -> Text,
redirect_uri -> Text,
code_response -> Nullable<Text>,
code_response_error -> Nullable<Text>,
auth_response -> Nullable<Text>,
created_at -> Timestamp,
updated_at -> Timestamp,
binding_hash -> Nullable<Text>,
}
}
@@ -343,16 +341,6 @@ table! {
}
}
table! {
archives (user_uuid, cipher_uuid) {
user_uuid -> Text,
cipher_uuid -> Text,
archived_at -> Timestamp,
}
}
joinable!(archives -> users (user_uuid));
joinable!(archives -> ciphers (cipher_uuid));
joinable!(attachments -> ciphers (cipher_uuid));
joinable!(ciphers -> organizations (organization_uuid));
joinable!(ciphers -> users (user_uuid));
@@ -384,7 +372,6 @@ joinable!(auth_requests -> users (user_uuid));
joinable!(sso_users -> users (user_uuid));
allow_tables_to_appear_in_same_query!(
archives,
attachments,
ciphers,
ciphers_collections,
+44 -48
View File
@@ -1,11 +1,10 @@
//
// Error generator macro
//
use std::error::Error as StdError;
use crate::db::models::EventType;
use crate::http_client::CustomHttpClientError;
use serde::ser::{Serialize, SerializeStruct, Serializer};
use std::error::Error as StdError;
macro_rules! make_error {
( $( $name:ident ( $ty:ty ): $src_fn:expr, $usr_msg_fun:expr ),+ $(,)? ) => {
@@ -15,24 +14,24 @@ macro_rules! make_error {
#[derive(Debug)]
pub struct ErrorEvent { pub event: EventType }
pub struct Error { message: String, kind: ErrorKind, code: u16, event: Option<ErrorEvent> }
pub struct Error { message: String, error: ErrorKind, error_code: u16, event: Option<ErrorEvent> }
$(impl From<$ty> for Error {
fn from(err: $ty) -> Self { Error::from((stringify!($name), err)) }
})+
$(impl<S: Into<String>> From<(S, $ty)> for Error {
fn from(val: (S, $ty)) -> Self {
Error { message: val.0.into(), kind: ErrorKind::$name(val.1), code: BAD_REQUEST, event: None }
Error { message: val.0.into(), error: ErrorKind::$name(val.1), error_code: BAD_REQUEST, event: None }
}
})+
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match &self.kind {$( ErrorKind::$name(e) => $src_fn(e), )+}
match &self.error {$( ErrorKind::$name(e) => $src_fn(e), )+}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {$(
match &self.error {$(
ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)),
)+}
}
@@ -40,10 +39,10 @@ macro_rules! make_error {
};
}
use diesel::ConnectionError as DieselConErr;
use diesel::r2d2::Error as R2d2Err;
use diesel::r2d2::PoolError as R2d2PoolErr;
use diesel::result::Error as DieselErr;
use diesel::ConnectionError as DieselConErr;
use handlebars::RenderError as HbErr;
use jsonwebtoken::errors::Error as JwtErr;
use lettre::address::AddressError as AddrErr;
@@ -72,46 +71,46 @@ pub struct Compact {}
// The second one contains the function used to obtain the response sent to the client
make_error! {
// Just an empty error
Empty(Empty): no_source, serialize,
Empty(Empty): _no_source, _serialize,
// Used to represent err! calls
Simple(String): no_source, api_error,
Compact(Compact): no_source, compact_api_error,
Simple(String): _no_source, _api_error,
Compact(Compact): _no_source, _compact_api_error,
// Used in our custom http client to handle non-global IPs and blocked domains
CustomHttpClient(CustomHttpClientError): has_source, api_error,
CustomHttpClient(CustomHttpClientError): _has_source, _api_error,
// Used for special return values, like 2FA errors
Json(Value): no_source, serialize,
Db(DieselErr): has_source, api_error,
R2d2(R2d2Err): has_source, api_error,
R2d2Pool(R2d2PoolErr): has_source, api_error,
Serde(SerdeErr): has_source, api_error,
JWt(JwtErr): has_source, api_error,
Handlebars(HbErr): has_source, api_error,
Json(Value): _no_source, _serialize,
Db(DieselErr): _has_source, _api_error,
R2d2(R2d2Err): _has_source, _api_error,
R2d2Pool(R2d2PoolErr): _has_source, _api_error,
Serde(SerdeErr): _has_source, _api_error,
JWt(JwtErr): _has_source, _api_error,
Handlebars(HbErr): _has_source, _api_error,
Io(IoErr): has_source, api_error,
Time(TimeErr): has_source, api_error,
Req(ReqErr): has_source, api_error,
Regex(RegexErr): has_source, api_error,
Yubico(YubiErr): has_source, api_error,
Io(IoErr): _has_source, _api_error,
Time(TimeErr): _has_source, _api_error,
Req(ReqErr): _has_source, _api_error,
Regex(RegexErr): _has_source, _api_error,
Yubico(YubiErr): _has_source, _api_error,
Lettre(LettreErr): has_source, api_error,
Address(AddrErr): has_source, api_error,
Smtp(SmtpErr): has_source, api_error,
OpenSSL(SSLErr): has_source, api_error,
Rocket(RocketErr): has_source, api_error,
Lettre(LettreErr): _has_source, _api_error,
Address(AddrErr): _has_source, _api_error,
Smtp(SmtpErr): _has_source, _api_error,
OpenSSL(SSLErr): _has_source, _api_error,
Rocket(RocketErr): _has_source, _api_error,
DieselCon(DieselConErr): has_source, api_error,
Webauthn(WebauthnErr): has_source, api_error,
DieselCon(DieselConErr): _has_source, _api_error,
Webauthn(WebauthnErr): _has_source, _api_error,
OpenDAL(OpenDALErr): has_source, api_error,
OpenDAL(OpenDALErr): _has_source, _api_error,
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.source() {
Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e),
None => match self.kind {
None => match self.error {
ErrorKind::Empty(_) => Ok(()),
ErrorKind::Simple(ref s) => {
if &self.message == s {
@@ -136,7 +135,6 @@ impl Error {
(usr_msg.clone(), usr_msg.into()).into()
}
#[must_use]
pub fn empty() -> Self {
Empty {}.into()
}
@@ -149,13 +147,13 @@ impl Error {
#[must_use]
pub fn with_kind(mut self, kind: ErrorKind) -> Self {
self.kind = kind;
self.error = kind;
self
}
#[must_use]
pub const fn with_code(mut self, code: u16) -> Self {
self.code = code;
self.error_code = code;
self
}
@@ -196,14 +194,14 @@ impl<S> MapResult<S> for Option<S> {
}
}
const fn has_source<T>(e: T) -> Option<T> {
const fn _has_source<T>(e: T) -> Option<T> {
Some(e)
}
fn no_source<T, S>(_: T) -> Option<S> {
fn _no_source<T, S>(_: T) -> Option<S> {
None
}
fn serialize(e: &impl Serialize, _msg: &str) -> String {
fn _serialize(e: &impl Serialize, _msg: &str) -> String {
serde_json::to_string(e).unwrap()
}
@@ -282,14 +280,14 @@ struct ApiErrorResponse<'a>(ApiErrorMsg<'a>);
/// The custom serialization adds all other needed fields
struct CompactApiErrorResponse<'a>(ApiErrorMsg<'a>);
fn api_error(_: &impl std::any::Any, msg: &str) -> String {
fn _api_error(_: &impl std::any::Any, msg: &str) -> String {
let response = ApiErrorMsg {
message: msg,
};
serde_json::to_string(&ApiErrorResponse(response)).unwrap()
}
fn compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
fn _compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
let response = ApiErrorMsg {
message: msg,
};
@@ -301,20 +299,18 @@ fn compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
//
use std::io::Cursor;
use rocket::{
http::{ContentType, Status},
request::Request,
response::{self, Responder, Response},
};
use rocket::http::{ContentType, Status};
use rocket::request::Request;
use rocket::response::{self, Responder, Response};
impl Responder<'_, 'static> for Error {
fn respond_to(self, _: &Request<'_>) -> response::Result<'static> {
match self.kind {
match self.error {
ErrorKind::Empty(_) | ErrorKind::Simple(_) | ErrorKind::Compact(_) => {} // Don't print the error in this situation
_ => error!(target: "error", "{self:#?}"),
}
};
let code = Status::from_code(self.code).unwrap_or(Status::BadRequest);
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
let body = self.to_string();
Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
}
+57 -311
View File
@@ -1,25 +1,22 @@
use std::{
fmt,
net::{IpAddr, SocketAddr},
str::FromStr,
sync::{Arc, LazyLock, Mutex},
time::Duration,
};
use hickory_resolver::{TokioResolver, net::runtime::TokioRuntimeProvider};
use hickory_resolver::{name_server::TokioConnectionProvider, TokioResolver};
use regex::Regex;
use reqwest::{
Client, ClientBuilder,
dns::{Name, Resolve, Resolving},
header,
header, Client, ClientBuilder,
};
use url::Host;
use crate::{CONFIG, util::is_global};
use crate::{util::is_global, CONFIG};
pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
let Ok(url) = url::Url::parse(url) else {
err!("Invalid URL");
};
@@ -29,6 +26,9 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::
should_block_host(&host)?;
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
Ok(INSTANCE.request(method, url))
}
@@ -59,6 +59,16 @@ pub fn get_reqwest_client_builder() -> ClientBuilder {
.timeout(Duration::from_secs(10))
}
pub fn should_block_address(domain_or_ip: &str) -> bool {
if let Ok(ip) = IpAddr::from_str(domain_or_ip) {
if should_block_ip(ip) {
return true;
}
}
should_block_address_regex(domain_or_ip)
}
fn should_block_ip(ip: IpAddr) -> bool {
if !CONFIG.http_request_block_non_global_ips() {
return false;
@@ -68,19 +78,18 @@ fn should_block_ip(ip: IpAddr) -> bool {
}
fn should_block_address_regex(domain_or_ip: &str) -> bool {
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
let Some(block_regex) = CONFIG.http_request_block_regex() else {
return false;
};
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
let mut guard = COMPILED_REGEX.lock().unwrap();
// If the stored regex is up to date, use it
if let Some((value, regex)) = &*guard
&& value == &block_regex
{
return regex.is_match(domain_or_ip);
if let Some((value, regex)) = &*guard {
if value == &block_regex {
return regex.is_match(domain_or_ip);
}
}
// If we don't have a regex stored, or it's not up to date, recreate it
@@ -91,63 +100,20 @@ fn should_block_address_regex(domain_or_ip: &str) -> bool {
is_match
}
pub fn get_valid_host(host: &str) -> Result<Host, CustomHttpClientError> {
let Ok(host) = Host::parse(host) else {
return Err(CustomHttpClientError::Invalid {
domain: host.to_owned(),
});
};
// Some extra checks to validate hosts
match host {
Host::Domain(ref domain) => {
// Host::parse() does not verify length or all possible invalid characters
// We do some extra checks here to prevent issues
if domain.len() > 253 {
debug!("Domain validation error: '{domain}' exceeds 253 characters");
return Err(CustomHttpClientError::Invalid {
domain: host.to_string(),
});
}
if !domain.split('.').all(|label| {
!label.is_empty()
// Labels can't be longer than 63 chars
&& label.len() <= 63
// Labels are not allowed to start or end with a hyphen `-`
&& !label.starts_with('-')
&& !label.ends_with('-')
// Only ASCII Alphanumeric characters are allowed
// We already received a punycoded domain back, so no unicode should exists here
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
}) {
debug!(
"Domain validation error: '{domain}' labels contain invalid characters or exceed the maximum length"
);
return Err(CustomHttpClientError::Invalid {
domain: host.to_string(),
});
}
}
Host::Ipv4(_) | Host::Ipv6(_) => {}
}
Ok(host)
}
pub fn should_block_host<S: AsRef<str>>(host: &Host<S>) -> Result<(), CustomHttpClientError> {
fn should_block_host(host: &Host<&str>) -> Result<(), CustomHttpClientError> {
let (ip, host_str): (Option<IpAddr>, String) = match host {
Host::Ipv4(ip) => (Some(IpAddr::V4(*ip)), ip.to_string()),
Host::Ipv6(ip) => (Some(IpAddr::V6(*ip)), ip.to_string()),
Host::Domain(d) => (None, d.as_ref().to_owned()),
Host::Domain(d) => (None, (*d).to_string()),
};
if let Some(ip) = ip
&& should_block_ip(ip)
{
return Err(CustomHttpClientError::NonGlobalIp {
domain: None,
ip,
});
if let Some(ip) = ip {
if should_block_ip(ip) {
return Err(CustomHttpClientError::NonGlobalIp {
domain: None,
ip,
});
}
}
if should_block_address_regex(&host_str) {
@@ -168,9 +134,6 @@ pub enum CustomHttpClientError {
domain: Option<String>,
ip: IpAddr,
},
Invalid {
domain: String,
},
}
impl CustomHttpClientError {
@@ -192,7 +155,7 @@ impl fmt::Display for CustomHttpClientError {
match self {
Self::Blocked {
domain,
} => write!(f, "Blocked domain: '{domain}' matched HTTP_REQUEST_BLOCK_REGEX"),
} => write!(f, "Blocked domain: {domain} matched HTTP_REQUEST_BLOCK_REGEX"),
Self::NonGlobalIp {
domain: Some(domain),
ip,
@@ -200,10 +163,7 @@ impl fmt::Display for CustomHttpClientError {
Self::NonGlobalIp {
domain: None,
ip,
} => write!(f, "IP '{ip}' is not a global IP!"),
Self::Invalid {
domain,
} => write!(f, "Invalid host: '{domain}' contains invalid characters or exceeds the maximum length"),
} => write!(f, "IP {ip} is not a global IP!"),
}
}
}
@@ -224,47 +184,42 @@ impl CustomDnsResolver {
}
fn new() -> Arc<Self> {
TokioResolver::builder(TokioRuntimeProvider::default())
.and_then(|mut builder| {
// Hickory's default since v0.26 is `Ipv6AndIpv4`, which sorts IPv6 first
// This might cause issues on IPv4 only systems or containers
// Unless someone enabled DNS_PREFER_IPV6, use Ipv4AndIpv6, which returns IPv4 first which was our previous default
if !CONFIG.dns_prefer_ipv6() {
builder.options_mut().ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6;
match TokioResolver::builder(TokioConnectionProvider::default()) {
Ok(mut builder) => {
if CONFIG.dns_prefer_ipv6() {
builder.options_mut().ip_strategy = hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4;
}
builder.build()
})
.inspect_err(|e| warn!("Error creating Hickory resolver, falling back to default: {e:?}"))
.map_or_else(|_| Arc::new(Self::Default()), |resolver| Arc::new(Self::Hickory(Arc::new(resolver))))
let resolver = builder.build();
Arc::new(Self::Hickory(Arc::new(resolver)))
}
Err(e) => {
warn!("Error creating Hickory resolver, falling back to default: {e:?}");
Arc::new(Self::Default())
}
}
}
// Note that we get an iterator of addresses, but we only grab the first one for convenience
async fn resolve_domain(&self, name: &str) -> Result<Vec<SocketAddr>, BoxError> {
async fn resolve_domain(&self, name: &str) -> Result<Option<SocketAddr>, BoxError> {
pre_resolve(name)?;
let results: Vec<SocketAddr> = match self {
Self::Default() => tokio::net::lookup_host((name, 0)).await?.collect(),
Self::Hickory(r) => r.lookup_ip(name).await?.iter().map(|i| SocketAddr::new(i, 0)).collect(),
let result = match self {
Self::Default() => tokio::net::lookup_host(name).await?.next(),
Self::Hickory(r) => r.lookup_ip(name).await?.iter().next().map(|a| SocketAddr::new(a, 0)),
};
for addr in &results {
if let Some(addr) = &result {
post_resolve(name, addr.ip())?;
}
Ok(results)
Ok(result)
}
}
fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
let Ok(host) = get_valid_host(name) else {
return Err(CustomHttpClientError::Invalid {
domain: name.to_owned(),
});
};
if should_block_host(&host).is_err() {
if should_block_address(name) {
return Err(CustomHttpClientError::Blocked {
domain: name.to_owned(),
domain: name.to_string(),
});
}
@@ -274,7 +229,7 @@ fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
if should_block_ip(ip) {
Err(CustomHttpClientError::NonGlobalIp {
domain: Some(name.to_owned()),
domain: Some(name.to_string()),
ip,
})
} else {
@@ -287,11 +242,8 @@ impl Resolve for CustomDnsResolver {
let this = self.clone();
Box::pin(async move {
let name = name.as_str();
let results = this.resolve_domain(name).await?;
if results.is_empty() {
warn!("Unable to resolve {name} to any valid IP address");
}
Ok::<reqwest::dns::Addrs, _>(Box::new(results.into_iter()))
let result = this.resolve_domain(name).await?;
Ok::<reqwest::dns::Addrs, _>(Box::new(result.into_iter()))
})
}
}
@@ -319,7 +271,7 @@ pub(crate) mod aws {
let future = async move {
let method = reqwest::Method::from_bytes(request.method().as_bytes())
.map_err(|e| ConnectorError::user(Box::new(e)))?;
let mut req_builder = client.request(method, request.uri().to_owned());
let mut req_builder = client.request(method, request.uri().to_string());
for (name, value) in request.headers() {
req_builder = req_builder.header(name, value);
@@ -353,209 +305,3 @@ pub(crate) mod aws {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::is_global_hardcoded;
use std::net::Ipv4Addr;
use url::Host;
// ===
// IPv4 numeric-format normalization
fn parse_to_ip(s: &str) -> Option<IpAddr> {
match Host::parse(s).ok()? {
Host::Ipv4(v4) => Some(IpAddr::V4(v4)),
Host::Ipv6(v6) => Some(IpAddr::V6(v6)),
Host::Domain(_) => None,
}
}
#[test]
fn dotted_decimal_loopback_normalizes() {
let ip = parse_to_ip("127.0.0.1").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn single_decimal_loopback_normalizes() {
// 127.0.0.1 == 2130706433
let ip = parse_to_ip("2130706433").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn hex_loopback_normalizes() {
let ip = parse_to_ip("0x7f000001").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn dotted_hex_loopback_normalizes() {
let ip = parse_to_ip("0x7f.0.0.1").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn octal_loopback_normalizes() {
// 017700000001 == 127.0.0.1
let ip = parse_to_ip("017700000001").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn dotted_octal_loopback_normalizes() {
let ip = parse_to_ip("0177.0.0.01").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn aws_metadata_decimal_blocked() {
// 169.254.169.254 == 2852039166 (link-local, AWS IMDS)
let ip = parse_to_ip("2852039166").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn rfc1918_hex_blocked() {
// 10.0.0.1
let ip = parse_to_ip("0x0a000001").unwrap();
assert!(!is_global_hardcoded(ip));
}
#[test]
fn public_ip_decimal_allowed() {
// 8.8.8.8 == 134744072
let ip = parse_to_ip("134744072").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
assert!(is_global_hardcoded(ip));
}
// ===
// get_valid_host integration: numeric forms become Host::Ipv4
#[test]
fn get_valid_host_normalizes_decimal_int() {
let h = get_valid_host("2130706433").expect("valid");
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
}
#[test]
fn get_valid_host_normalizes_hex() {
let h = get_valid_host("0x7f000001").expect("valid");
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
}
#[test]
fn get_valid_host_normalizes_octal() {
let h = get_valid_host("017700000001").expect("valid");
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
}
// ===
// IPv6 formats
#[test]
fn ipv6_loopback_blocked() {
let h = get_valid_host("[::1]").expect("valid");
let Host::Ipv6(ip) = h else {
panic!("expected v6")
};
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
}
#[test]
fn ipv4_mapped_in_ipv6_loopback_blocked() {
// ::ffff:127.0.0.1 — v4-mapped form; is_global_hardcoded blocks via ::ffff:0:0/96
let h = get_valid_host("[::ffff:127.0.0.1]").expect("valid");
let Host::Ipv6(ip) = h else {
panic!("expected v6")
};
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
}
#[test]
fn ipv6_unique_local_blocked() {
let h = get_valid_host("[fc00::1]").expect("valid");
let Host::Ipv6(ip) = h else {
panic!("expected v6")
};
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
}
// ===
// Punycode / IDN
#[test]
fn punycode_passthrough() {
let h = get_valid_host("xn--deadbeafcaf-lbb.test").expect("valid");
match h {
Host::Domain(d) => assert_eq!(d, "xn--deadbeafcaf-lbb.test"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_unicode_gets_punycoded() {
let h = get_valid_host("deadbeafcafé.test").expect("valid");
match h {
Host::Domain(d) => assert_eq!(d, "xn--deadbeafcaf-lbb.test"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_unicode_gets_punycoded_tld() {
let h = get_valid_host("deadbeaf.café").expect("valid");
match h {
Host::Domain(d) => assert_eq!(d, "deadbeaf.xn--caf-dma"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_emoji_gets_punycoded() {
let h = get_valid_host("xn--t88h.test").expect("valid"); // 🛡️.test
match h {
Host::Domain(d) => assert_eq!(d, "xn--t88h.test"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_unicode_to_punycode_roundtrip() {
let from_unicode = get_valid_host("🛡️.test").expect("valid");
let from_puny = get_valid_host("xn--t88h.test").expect("valid");
match (from_unicode, from_puny) {
(Host::Domain(a), Host::Domain(b)) => assert_eq!(a, b),
_ => panic!("expected domains"),
}
}
#[test]
fn invalid_punycode_rejected() {
// bare invalid punycode
assert!(get_valid_host("xn--").is_err());
}
#[test]
fn underscore_in_label_rejected() {
assert!(get_valid_host("dead_beaf.cafe").is_err());
}
#[test]
fn label_too_long_rejected() {
let label = "a".repeat(64);
assert!(get_valid_host(&format!("{label}.test")).is_err());
}
#[test]
fn domain_too_long_rejected() {
let big = "a.".repeat(130) + "test"; // > 253
assert!(get_valid_host(&big).is_err());
}
}
+30 -26
View File
@@ -1,17 +1,16 @@
use chrono::NaiveDateTime;
use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
use std::{env::consts::EXE_SUFFIX, str::FromStr};
use chrono::NaiveDateTime;
use lettre::{
Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
message::{Attachment, Body, Mailbox, Message, MultiPart, SinglePart},
transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism},
transport::smtp::client::{Tls, TlsParameters},
transport::smtp::extension::ClientId,
Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
};
use percent_encoding::{NON_ALPHANUMERIC, percent_encode};
use crate::{
CONFIG,
api::EmptyResult,
auth::{
encode_jwt, generate_delete_claims, generate_emergency_access_invite_claims, generate_invite_claims,
@@ -19,7 +18,7 @@ use crate::{
},
db::models::{Device, DeviceType, EmergencyAccessId, MembershipId, OrganizationId, User, UserId},
error::Error,
util::upcase_first,
CONFIG,
};
fn sendmail_transport() -> AsyncSendmailTransport<Tokio1Executor> {
@@ -39,9 +38,7 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
.timeout(Some(Duration::from_secs(CONFIG.smtp_timeout())));
// Determine security
let smtp_client = if CONFIG.smtp_security() == *"off" {
smtp_client
} else {
let smtp_client = if CONFIG.smtp_security() != *"off" {
let mut tls_parameters = TlsParameters::builder(host);
if CONFIG.smtp_accept_invalid_hostnames() {
tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true);
@@ -56,6 +53,8 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
} else {
smtp_client.tls(Tls::Required(tls_parameters))
}
} else {
smtp_client
};
let smtp_client = match (CONFIG.smtp_username(), CONFIG.smtp_password()) {
@@ -82,12 +81,12 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
}
}
if selected_mechanisms.is_empty() {
if !selected_mechanisms.is_empty() {
smtp_client.authentication(selected_mechanisms)
} else {
// Only show a warning, and return without setting an actual authentication mechanism
warn!("No valid SMTP Auth mechanism found for '{mechanism}', using default values");
smtp_client
} else {
smtp_client.authentication(selected_mechanisms)
}
}
_ => smtp_client,
@@ -130,16 +129,14 @@ fn get_template(template_name: &str, data: &serde_json::Value) -> Result<(String
let text = CONFIG.render_template(template_name, data)?;
let mut text_split = text.split("<!---------------->");
let subject = if let Some(s) = text_split.next() {
s.trim().to_owned()
} else {
err!("Template doesn't contain subject")
let subject = match text_split.next() {
Some(s) => s.trim().to_string(),
None => err!("Template doesn't contain subject"),
};
let body = if let Some(s) = text_split.next() {
s.trim().to_owned()
} else {
err!("Template doesn't contain body")
let body = match text_split.next() {
Some(s) => s.trim().to_string(),
None => err!("Template doesn't contain body"),
};
if text_split.next().is_some() {
@@ -207,8 +204,9 @@ pub async fn send_verify_email(address: &str, user_id: &UserId) -> EmptyResult {
pub async fn send_register_verify_email(email: &str, token: &str) -> EmptyResult {
let mut query = url::Url::parse("https://query.builder").unwrap();
query.query_pairs_mut().append_pair("email", email).append_pair("token", token);
let Some(query_string) = query.query() else {
err!("Failed to build verify URL query parameters")
let query_string = match query.query() {
None => err!("Failed to build verify URL query parameters"),
Some(query) => query,
};
let (subject, body_html, body_text) = get_text(
@@ -506,6 +504,8 @@ pub async fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult
}
pub async fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, device: &Device) -> EmptyResult {
use crate::util::upcase_first;
let fmt = "%A, %B %_d, %Y at %r %Z";
let (subject, body_html, body_text) = get_text(
"email/new_device_logged_in",
@@ -529,6 +529,8 @@ pub async fn send_incomplete_2fa_login(
device_name: &str,
device_type: &str,
) -> EmptyResult {
use crate::util::upcase_first;
let fmt = "%A, %B %_d, %Y at %r %Z";
let (subject, body_html, body_text) = get_text(
"email/incomplete_2fa_login",
@@ -653,7 +655,7 @@ pub async fn send_protected_action_token(address: &str, token: &str) -> EmptyRes
async fn send_with_selected_transport(email: Message) -> EmptyResult {
if CONFIG.use_sendmail() {
match sendmail_transport().send(email).await {
Ok(()) => Ok(()),
Ok(_) => Ok(()),
// Match some common errors and make them more user friendly
Err(e) => {
if e.is_client() {
@@ -662,9 +664,10 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult {
} else if e.is_response() {
debug!("Sendmail response error: {e:?}");
err!(format!("Sendmail response error: {e}"));
} else {
debug!("Sendmail error: {e:?}");
err!(format!("Sendmail error: {e}"));
}
debug!("Sendmail error: {e:?}");
err!(format!("Sendmail error: {e}"));
}
}
} else {
@@ -692,9 +695,10 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult {
} else if e.is_tls() {
debug!("SMTP encryption error: {e:#?}");
err!(format!("SMTP encryption error: {e}"));
} else {
debug!("SMTP error: {e:#?}");
err!(format!("SMTP error: {e}"));
}
debug!("SMTP error: {e:#?}");
err!(format!("SMTP error: {e}"));
}
}
}
+31 -39
View File
@@ -33,7 +33,6 @@ use std::{
path::Path,
process::exit,
str::FromStr,
sync::{Arc, atomic::Ordering},
thread,
};
@@ -45,8 +44,6 @@ use tokio::{
#[cfg(unix)]
use tokio::signal::unix::SignalKind;
use rocket::data::{Limits, ToByteUnit};
#[macro_use]
mod error;
mod api;
@@ -60,19 +57,19 @@ mod mail;
mod ratelimit;
mod sso;
mod sso_client;
mod storage;
mod util;
use crate::api::{
WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS, core::two_factor::duo_oidc::purge_duo_contexts, purge_auth_requests,
};
pub use config::{CONFIG, PathType};
use crate::api::core::two_factor::duo_oidc::purge_duo_contexts;
use crate::api::purge_auth_requests;
use crate::api::{WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS};
pub use config::{PathType, CONFIG};
pub use error::{Error, MapResult};
use rocket::data::{Limits, ToByteUnit};
use std::sync::{atomic::Ordering, Arc};
pub use util::is_running_in_container;
#[rocket::main]
async fn main() -> Result<(), Error> {
install_rustls_crypto_provider();
parse_args();
launch_info();
@@ -138,23 +135,26 @@ fn parse_args() {
if let Some(command) = pargs.subcommand().unwrap_or_default() {
if command == "hash" {
use argon2::{
Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13, password_hash::SaltString,
password_hash::SaltString, Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13,
};
let mut argon2_params = ParamsBuilder::new();
let preset: Option<String> = pargs.opt_value_from_str(["-p", "--preset"]).unwrap_or_default();
let selected_preset;
if preset.as_deref() == Some("owasp") {
selected_preset = "owasp";
argon2_params.m_cost(19456);
argon2_params.t_cost(2);
argon2_params.p_cost(1);
} else {
// Bitwarden preset is the default
selected_preset = "bitwarden";
argon2_params.m_cost(65540);
argon2_params.t_cost(3);
argon2_params.p_cost(4);
match preset.as_deref() {
Some("owasp") => {
selected_preset = "owasp";
argon2_params.m_cost(19456);
argon2_params.t_cost(2);
argon2_params.p_cost(1);
}
_ => {
// Bitwarden preset is the default
selected_preset = "bitwarden";
argon2_params.m_cost(65540);
argon2_params.t_cost(3);
argon2_params.p_cost(4);
}
}
println!("Generate an Argon2id PHC string using the '{selected_preset}' preset:\n");
@@ -202,14 +202,6 @@ fn parse_args() {
}
}
fn install_rustls_crypto_provider() {
if rustls::crypto::CryptoProvider::get_default().is_none() {
rustls::crypto::ring::default_provider()
.install_default()
.expect("failed to install rustls ring crypto provider");
}
}
fn launch_info() {
println!(
"\
@@ -245,7 +237,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
let level = caps
.get(1)
.and_then(|m| log::LevelFilter::from_str(m.as_str()).ok())
.ok_or(Error::new("Failed to parse global log level".to_owned(), ""))?;
.ok_or(Error::new("Failed to parse global log level".to_string(), ""))?;
let levels_override: Vec<(&str, log::LevelFilter)> = caps
.get(2)
@@ -254,13 +246,13 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
.split(',')
.collect::<Vec<&str>>()
.into_iter()
.filter_map(|s| match s.split_once('=') {
.flat_map(|s| match s.split_once('=') {
Some((log, lvl_str)) => log::LevelFilter::from_str(lvl_str).ok().map(|lvl| (log, lvl)),
_ => None,
})
.collect()
})
.ok_or(Error::new("Failed to parse overrides".to_owned(), ""))?;
.ok_or(Error::new("Failed to parse overrides".to_string(), ""))?;
(level, levels_override)
} else {
@@ -336,7 +328,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
("vaultwarden::db::query_logger", log::LevelFilter::Off),
]);
for (path, level) in levels_override {
for (path, level) in levels_override.into_iter() {
let _ = default_levels.insert(path, level);
}
@@ -350,7 +342,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
let mut logger = fern::Dispatch::new().level(level).chain(std::io::stdout());
for (path, level) in default_levels {
logger = logger.level_for(path.to_owned(), level);
logger = logger.level_for(path.to_string(), level);
}
if CONFIG.extended_logging() {
@@ -361,7 +353,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
record.target(),
record.level(),
message
));
))
});
} else {
logger = logger.format(|out, message, _| out.finish(format_args!("{message}")));
@@ -607,7 +599,9 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
#[cfg(all(unix, sqlite))]
{
if db::ACTIVE_DB_TYPE.get() == Some(&db::DbConnType::Sqlite) {
if db::ACTIVE_DB_TYPE.get() != Some(&db::DbConnType::Sqlite) {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
} else {
tokio::spawn(async move {
let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap();
loop {
@@ -620,8 +614,6 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
}
}
});
} else {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
}
}
@@ -669,7 +661,7 @@ fn schedule_jobs(pool: db::DbPool) {
let runtime = tokio::runtime::Runtime::new().unwrap();
thread::Builder::new()
.name("job-scheduler".to_owned())
.name("job-scheduler".to_string())
.spawn(move || {
use job_scheduler_ng::{Job, JobScheduler};
let _runtime_guard = runtime.enter();
+4 -4
View File
@@ -1,8 +1,8 @@
use std::{net::IpAddr, num::NonZeroU32, sync::LazyLock, time::Duration};
use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DashMapStateStore};
use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
use crate::{CONFIG, Error};
use crate::{Error, CONFIG};
type Limiter<T = IpAddr> = RateLimiter<T, DashMapStateStore<T>, DefaultClock>;
@@ -20,7 +20,7 @@ static LIMITER_ADMIN: LazyLock<Limiter> = LazyLock::new(|| {
pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> {
match LIMITER_LOGIN.check_key(ip) {
Ok(()) => Ok(()),
Ok(_) => Ok(()),
Err(_e) => {
err_code!("Too many login requests", 429);
}
@@ -29,7 +29,7 @@ pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> {
pub fn check_limit_admin(ip: &IpAddr) -> Result<(), Error> {
match LIMITER_ADMIN.check_key(ip) {
Ok(()) => Ok(()),
Ok(_) => Ok(()),
Err(_e) => {
err_code!("Too many admin requests", 429);
}
+43 -48
View File
@@ -6,18 +6,18 @@ use regex::Regex;
use url::Url;
use crate::{
CONFIG,
api::ApiResult,
auth,
auth::{AuthMethod, AuthTokens, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY, TokenWrapper},
auth::{AuthMethod, AuthTokens, TokenWrapper, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY},
db::{
models::{Device, OIDCAuthenticatedUser, OIDCCodeWrapper, SsoAuth, SsoUser, User},
DbConn,
models::{Device, OIDCAuthenticatedUser, SsoAuth, SsoUser, User},
},
sso_client::Client,
CONFIG,
};
pub static FAKE_SSO_IDENTIFIER: &str = "00000000-01DC-01DC-01DC-000000000000";
pub static FAKE_IDENTIFIER: &str = "VW_DUMMY_IDENTIFIER_FOR_OIDC";
static SSO_JWT_ISSUER: LazyLock<String> = LazyLock::new(|| format!("{}|sso", CONFIG.domain_origin()));
@@ -123,7 +123,7 @@ pub fn encode_ssotoken_claims() -> String {
nbf: time_now.timestamp(),
exp: (time_now + chrono::TimeDelta::try_minutes(2).unwrap()).timestamp(),
iss: SSO_JWT_ISSUER.to_string(),
sub: "vaultwarden".to_owned(),
sub: "vaultwarden".to_string(),
};
auth::encode_jwt(&claims)
@@ -171,14 +171,12 @@ fn decode_token_claims(token_name: &str, token: &str) -> ApiResult<BasicTokenCla
}
pub fn decode_state(base64_state: &str) -> ApiResult<OIDCState> {
let state = if let Ok(vec) = data_encoding::BASE64.decode(base64_state.as_bytes()) {
if let Ok(valid) = String::from_utf8(vec) {
OIDCState(valid)
} else {
err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding"))
}
} else {
err!(format!("Failed to decode {base64_state} using base64"))
let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) {
Ok(vec) => match String::from_utf8(vec) {
Ok(valid) => OIDCState(valid),
Err(_) => err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")),
},
Err(_) => err!(format!("Failed to decode {base64_state} using base64")),
};
Ok(state)
@@ -190,26 +188,22 @@ pub async fn authorize_url(
client_challenge: OIDCCodeChallenge,
client_id: &str,
raw_redirect_uri: &str,
binding_hash: Option<String>,
conn: DbConn,
) -> ApiResult<Url> {
let redirect_uri = match client_id {
"web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()),
"desktop" | "mobile" => "bitwarden://sso-callback".to_owned(),
"desktop" | "mobile" => "bitwarden://sso-callback".to_string(),
"cli" => {
let port_regex = Regex::new(r"^http://localhost:([0-9]{4})$").unwrap();
if let Some(port) =
port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str()))
{
format!("http://localhost:{port}")
} else {
err!("Failed to extract port number")
match port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str())) {
Some(port) => format!("http://localhost:{port}"),
None => err!("Failed to extract port number"),
}
}
_ => err!(format!("Unsupported client {client_id}")),
};
let (auth_url, sso_auth) = Client::authorize_url(state, client_challenge, redirect_uri, binding_hash).await?;
let (auth_url, sso_auth) = Client::authorize_url(state, client_challenge, redirect_uri).await?;
sso_auth.save(&conn).await?;
Ok(auth_url)
}
@@ -245,32 +239,33 @@ impl OIDCIdentifier {
// - second time we will rely on `SsoAuth.auth_response` since the `code` has already been exchanged.
// The `SsoAuth` will ensure that the user is authorized only once.
pub async fn exchange_code(
code: &OIDCCode,
state: &OIDCState,
client_verifier: OIDCCodeVerifier,
conn: &DbConn,
) -> ApiResult<(SsoAuth, OIDCAuthenticatedUser)> {
use openidconnect::OAuth2TokenResponse;
let Some(mut sso_auth) = SsoAuth::find_by_code(code, conn).await else {
err!("Invalid code cannot retrieve sso auth")
let mut sso_auth = match SsoAuth::find(state, conn).await {
None => err!(format!("Invalid state cannot retrieve sso auth")),
Some(sso_auth) => sso_auth,
};
if let Some(authenticated_user) = sso_auth.auth_response.clone() {
return Ok((sso_auth, authenticated_user));
}
let code = match (sso_auth.code_response.clone(), sso_auth.code_response_error.as_ref()) {
(Some(code), None) => code,
(_, Some(re)) => {
let error_msg = format!(
"SSO authorization failed: {}, {}",
re.error,
re.error_description.as_ref().unwrap_or(&String::new())
);
let code = match sso_auth.code_response.clone() {
Some(OIDCCodeWrapper::Ok {
code,
}) => code.clone(),
Some(OIDCCodeWrapper::Error {
error,
error_description,
}) => {
sso_auth.delete(conn).await?;
err!(error_msg);
err!(format!("SSO authorization failed: {error}, {}", error_description.as_ref().unwrap_or(&String::new())))
}
(None, _) => {
None => {
sso_auth.delete(conn).await?;
err!("Missing authorization provider return");
}
@@ -288,10 +283,10 @@ pub async fn exchange_code(
let email_verified = id_claims.email_verified().or(user_info.email_verified());
let user_name = id_claims.preferred_username().or(user_info.preferred_username()).map(|un| un.to_string());
let user_name = id_claims.preferred_username().map(|un| un.to_string());
let refresh_token = token_response.refresh_token().map(openidconnect::RefreshToken::secret);
if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_owned()) {
let refresh_token = token_response.refresh_token().map(|t| t.secret());
if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_string()) {
error!("Scope offline_access is present but response contain no refresh_token");
}
@@ -335,9 +330,7 @@ pub async fn redeem(
user_sso.save(conn).await?;
}
if CONFIG.sso_auth_only_not_session() {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} else {
if !CONFIG.sso_auth_only_not_session() {
let now = Utc::now();
let (ap_nbf, ap_exp) =
@@ -350,7 +343,9 @@ pub async fn redeem(
let access_claims =
auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now);
create_auth_tokens_impl(device, auth_user.refresh_token, access_claims, auth_user.access_token)
_create_auth_tokens(device, auth_user.refresh_token, access_claims, auth_user.access_token)
} else {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
}
}
@@ -364,9 +359,7 @@ pub fn create_auth_tokens(
access_token: String,
expires_in: Option<Duration>,
) -> ApiResult<AuthTokens> {
if CONFIG.sso_auth_only_not_session() {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} else {
if !CONFIG.sso_auth_only_not_session() {
let now = Utc::now();
let (ap_nbf, ap_exp) = match (decode_token_claims("access_token", &access_token), expires_in) {
@@ -378,11 +371,13 @@ pub fn create_auth_tokens(
let access_claims =
auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now);
create_auth_tokens_impl(device, refresh_token, access_claims, access_token)
_create_auth_tokens(device, refresh_token, access_claims, access_token)
} else {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
}
}
fn create_auth_tokens_impl(
fn _create_auth_tokens(
device: &Device,
refresh_token: Option<String>,
access_claims: auth::LoginJwtClaims,
@@ -466,7 +461,7 @@ pub async fn exchange_refresh_token(
now,
);
create_auth_tokens_impl(device, None, access_claims, access_token)
_create_auth_tokens(device, None, access_claims, access_token)
}
None => err!("No token present while in SSO"),
}
+22 -70
View File
@@ -1,31 +1,17 @@
use std::{borrow::Cow, future::Future, pin::Pin, sync::LazyLock, time::Duration};
use std::{borrow::Cow, sync::LazyLock, time::Duration};
use openidconnect::{
AccessToken, AsyncHttpClient, AuthDisplay, AuthPrompt, AuthenticationFlow, AuthorizationCode, AuthorizationRequest,
ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EmptyExtraTokenFields, EndpointNotSet, EndpointSet,
HttpClientError, HttpRequest, HttpResponse, IdTokenClaims, IdTokenFields, Nonce, OAuth2TokenResponse,
PkceCodeChallenge, PkceCodeVerifier, RefreshToken, ResponseType, Scope, StandardErrorResponse,
StandardTokenResponse,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreIdTokenVerifier,
CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
http, url,
};
use openidconnect::{core::*, reqwest, *};
use regex::Regex;
use url::Url;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult},
db::models::SsoAuth,
http_client::get_reqwest_client_builder,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
CONFIG,
};
static CLIENT_CACHE_KEY: LazyLock<String> = LazyLock::new(|| "sso-client".to_owned());
static CLIENT_CACHE_KEY: LazyLock<String> = LazyLock::new(|| "sso-client".to_string());
static CLIENT_CACHE: LazyLock<moka::sync::Cache<String, Client>> = LazyLock::new(|| {
moka::sync::Cache::builder()
.max_capacity(1)
@@ -60,51 +46,19 @@ pub type RefreshTokenResponse = (Option<String>, String, Option<Duration>);
#[derive(Clone)]
pub struct Client {
pub http_client: OidcHttpClient,
pub http_client: reqwest::Client,
pub core_client: CustomClient,
}
#[derive(Clone)]
pub struct OidcHttpClient {
client: reqwest::Client,
}
impl OidcHttpClient {
fn new() -> Result<Self, reqwest::Error> {
get_reqwest_client_builder().redirect(reqwest::redirect::Policy::none()).build().map(|client| Self {
client,
})
}
}
impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
type Error = HttpClientError<reqwest::Error>;
type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(Box::new)?;
let mut builder = http::Response::builder().status(response.status()).version(response.version());
for (name, value) in response.headers() {
builder = builder.header(name, value);
}
builder.body(response.bytes().await.map_err(Box::new)?.to_vec()).map_err(HttpClientError::Http)
})
}
}
impl Client {
// Call the OpenId discovery endpoint to retrieve configuration
async fn get_client() -> ApiResult<Self> {
async fn _get_client() -> ApiResult<Self> {
let client_id = ClientId::new(CONFIG.sso_client_id());
let client_secret = ClientSecret::new(CONFIG.sso_client_secret());
let issuer_url = CONFIG.sso_issuer_url()?;
let http_client = match OidcHttpClient::new() {
let http_client = match reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()).build() {
Err(err) => err!(format!("Failed to build http client: {err}")),
Ok(client) => client,
};
@@ -116,16 +70,14 @@ impl Client {
let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret));
let token_uri = if let Some(uri) = base_client.token_uri() {
uri.clone()
} else {
err!("Failed to discover token_url, cannot proceed")
let token_uri = match base_client.token_uri() {
Some(uri) => uri.clone(),
None => err!("Failed to discover token_url, cannot proceed"),
};
let user_info_url = if let Some(url) = base_client.user_info_url() {
url.clone()
} else {
err!("Failed to discover user_info url, cannot proceed")
let user_info_url = match base_client.user_info_url() {
Some(url) => url.clone(),
None => err!("Failed to discover user_info url, cannot proceed"),
};
let core_client = base_client
@@ -144,13 +96,13 @@ impl Client {
if CONFIG.sso_client_cache_expiration() > 0 {
match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) {
Some(client) => Ok(client),
None => Self::get_client().await.inspect(|client| {
None => Self::_get_client().await.inspect(|client| {
debug!("Inserting new client in cache");
CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone());
}),
}
} else {
Self::get_client().await
Self::_get_client().await
}
}
@@ -165,7 +117,6 @@ impl Client {
state: OIDCState,
client_challenge: OIDCCodeChallenge,
redirect_uri: String,
binding_hash: Option<String>,
) -> ApiResult<(Url, SsoAuth)> {
let scopes = CONFIG.sso_scopes_vec().into_iter().map(Scope::new);
let base64_state = data_encoding::BASE64.encode(state.to_string().as_bytes());
@@ -188,7 +139,7 @@ impl Client {
}
let (auth_url, _, nonce) = auth_req.url();
Ok((auth_url, SsoAuth::new(state, client_challenge, nonce.secret().clone(), redirect_uri, binding_hash)))
Ok((auth_url, SsoAuth::new(state, client_challenge, nonce.secret().clone(), redirect_uri)))
}
pub async fn exchange_code(
@@ -229,14 +180,15 @@ impl Client {
Ok(token_response) => {
let oidc_nonce = Nonce::new(sso_auth.nonce.clone());
let Some(id_token) = token_response.extra_fields().id_token() else {
err!("Token response did not contain an id_token")
let id_token = match token_response.extra_fields().id_token() {
None => err!("Token response did not contain an id_token"),
Some(token) => token,
};
if CONFIG.sso_debug_tokens() {
debug!("Id token: {}", id_token.to_string());
debug!("Access token: {}", token_response.access_token().secret());
debug!("Refresh token: {:?}", token_response.refresh_token().map(RefreshToken::secret));
debug!("Refresh token: {:?}", token_response.refresh_token().map(|t| t.secret()));
debug!("Expiration time: {:?}", token_response.expires_in());
}
@@ -289,12 +241,12 @@ impl Client {
let client = Client::cached().await?;
REFRESH_CACHE
.get_with(refresh_token.clone(), async move { client.exchange_refresh_token_impl(refresh_token).await })
.get_with(refresh_token.clone(), async move { client._exchange_refresh_token(refresh_token).await })
.await
.map_err(Into::into)
}
async fn exchange_refresh_token_impl(&self, refresh_token: String) -> Result<RefreshTokenResponse, String> {
async fn _exchange_refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse, String> {
let rt = RefreshToken::new(refresh_token);
match self.core_client.exchange_refresh_token(&rt).request_async(&self.http_client).await {
+2 -13
View File
@@ -1,17 +1,6 @@
body {
padding-top: 75px;
}
/* Some extra width's for the main layout */
@media (min-width: 1600px) {
.container-xxl {
max-width: 1520px;
}
}
@media (min-width: 1800px) {
.container-xxl {
max-width: 1720px;
}
}
img {
width: 48px;
height: 48px;
@@ -49,8 +38,8 @@ img {
max-width: 130px;
}
#users-table .vw-actions, #orgs-table .vw-actions {
min-width: 170px;
max-width: 180px;
min-width: 155px;
max-width: 160px;
}
#users-table .vw-org-cell {
max-height: 120px;
+2 -2
View File
@@ -4,10 +4,10 @@
*
* To rebuild or modify this file with the latest versions of the included
* software please visit:
* https://datatables.net/download/#bs5/dt-2.3.8
* https://datatables.net/download/#bs5/dt-2.3.7
*
* Included libraries:
* DataTables 2.3.8
* DataTables 2.3.7
*/
:root {
+11 -41
View File
@@ -4,13 +4,13 @@
*
* To rebuild or modify this file with the latest versions of the included
* software please visit:
* https://datatables.net/download/#bs5/dt-2.3.8
* https://datatables.net/download/#bs5/dt-2.3.7
*
* Included libraries:
* DataTables 2.3.8
* DataTables 2.3.7
*/
/*! DataTables 2.3.8
/*! DataTables 2.3.7
* © SpryMedia Ltd - datatables.net/license
*/
@@ -525,7 +525,7 @@
*
* @type string
*/
builder: "bs5/dt-2.3.8",
builder: "bs5/dt-2.3.7",
/**
* Buttons. For use with the Buttons extension for DataTables. This is
@@ -3607,11 +3607,6 @@
if ( holdPosition !== true ) {
settings._iDisplayStart = 0;
}
else {
// Keep position, but make sure that there is actually data to display,
// otherwise we need to rewind a bit (e.g. if rows were deleted)
_fnLengthOverflow(settings);
}
// Let any modules know about the draw hold position state (used by
// scrolling internally)
@@ -4925,12 +4920,6 @@
var args = [settings, settings.json];
// If the footer element is empty after initialisation, then remove it
let tfoot = $(settings.tfoot);
if (tfoot.children().length === 0) {
tfoot.remove();
}
settings._bInitComplete = true;
// Table is fully set up and we have data, so calculate the
@@ -5387,12 +5376,12 @@
// the content of the cell so that the width applied to the header and body
// both match, but we want to hide it completely.
$('th, td', headerCopy).each(function () {
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing" />');
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing">');
});
if ( footer ) {
$('th, td', footerCopy).each(function () {
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing" />');
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing">');
});
}
@@ -5420,10 +5409,6 @@
// Correct DOM ordering for colgroup - comes before the thead
table.children('colgroup').prependTo(table);
// Remove tabindex from the hidden row elements
table.find('thead, tfoot').find('[tabindex]').removeAttr('tabindex');
table.find('thead, tfoot').find('role').removeAttr('role');
// Adjust the position of the header in case we loose the y-scrollbar
divBody.trigger('scroll');
@@ -5747,12 +5732,8 @@
.replace(/id=".*?"/g, '')
.replace(/name=".*?"/g, '');
// Don't want script, dialog or template tags in the width
// calculations as they are hidden content
cellString = cellString
.replace(/<script[\s\S]*?<\/script>/gi, ' ')
.replace(/<dialog[\s\S]*?<\/dialog>/gi, ' ')
.replace(/<template[\s\S]*?<\/template>/gi, ' ');
// Don't want Javascript at all in these calculation cells.
cellString = cellString.replace(/<script.*?<\/script>/gi, ' ');
var noHtml = _stripHtml(cellString, ' ')
.replace( /&nbsp;/g, ' ' );
@@ -10323,7 +10304,7 @@
* @type string
* @default Version number
*/
DataTable.version = "2.3.8";
DataTable.version = "2.3.7";
/**
* Private data store, containing all of the settings objects that are
@@ -12605,7 +12586,6 @@
var __mlWarning = false;
var __luxon; // Can be assigned in DateTable.use()
var __moment; // Can be assigned in DateTable.use()
var __reIsoTimezone = /[T\s]\d{2}.*?(Z|[+-]\d{2}(?::?\d{2})?)$/;
/**
*
@@ -12626,7 +12606,7 @@
resolveWindowLibs();
if (__moment) {
dt = __moment( d, format, locale, true );
dt = __moment.utc( d, format, locale, true );
if (! dt.isValid()) {
return null;
@@ -12736,16 +12716,6 @@
return d;
}
// Determine if there is a timezone. If there is, we want to reuse
// it for the output, so the timezone doesn't change between the
// input and output.
let options = {};
let tzMatch = typeof d === 'string' ? d.match(__reIsoTimezone) : null;
if (tzMatch) {
options.timeZone = tzMatch[1] === 'Z' ? 'UTC' : tzMatch[1];
}
var dt = __mldObj(d, from, locale);
if (dt === null) {
@@ -12759,7 +12729,7 @@
var formatted = to === null
? __mld(dt, 'toDate', 'toJSDate', '')[localeString](
navigator.language,
options
{ timeZone: "UTC" }
)
: __mld(dt, 'format', 'toFormat', 'toISOString', to);
+1 -1
View File
@@ -27,7 +27,7 @@
</symbol>
</svg>
<nav class="navbar navbar-expand-md navbar-dark bg-dark mb-4 shadow fixed-top">
<div class="container-xxl">
<div class="container-xl">
<a class="navbar-brand" href="{{urlpath}}/admin"><img class="vaultwarden-icon" src="{{urlpath}}/vw_static/vaultwarden-icon.png" alt="V">aultwarden Admin</a>
<button class="navbar-toggler" type="button" data-bs-toggle="collapse" data-bs-target="#navbarCollapse"
aria-controls="navbarCollapse" aria-expanded="false" aria-label="Toggle navigation">
+1 -1
View File
@@ -1,4 +1,4 @@
<main class="container-xxl">
<main class="container-xl">
<div id="diagnostics-block" class="my-3 p-3 rounded shadow">
<h6 class="border-bottom pb-2 mb-2">Diagnostics</h6>

Some files were not shown because too many files have changed in this diff Show More