Compare commits

..

27 Commits

Author SHA1 Message Date
user71424q d626ea81ab Serve Apple app site association file (#7191) 2026-05-17 21:46:10 +02:00
Mathijs van Veluw 1ba2c6a26c Switch to Edition 2024, more clippy lints, and less macro calls (#7200)
* Update to Rust 2024 Edition

Updated to the Rust 2024 Edition and added and fixed several lint checks.
This is a large change which, because of the extra lints, added some possible fixes for issues.

Signed-off-by: BlackDex <black.dex@gmail.com>

* Reorder and merge imports

Signed-off-by: BlackDex <black.dex@gmail.com>

* Remove "db_run!" macro calls where possible

Signed-off-by: BlackDex <black.dex@gmail.com>

---------

Signed-off-by: BlackDex <black.dex@gmail.com>
2026-05-17 19:38:49 +02:00
Mathijs van Veluw 22f5e0496c Updates and fixes (#7235)
* Update crates and gha

Updated all the crates
Updated GitHub Actions

Signed-off-by: BlackDex <black.dex@gmail.com>

* Fix restoring revoked user

A new endpoint is used to restore a revoked user.
This commit fixes that.

Fixes #7224

Signed-off-by: BlackDex <black.dex@gmail.com>

* Update datatables

Signed-off-by: BlackDex <black.dex@gmail.com>

---------

Signed-off-by: BlackDex <black.dex@gmail.com>
2026-05-17 00:43:58 +02:00
Daniel 70f9dfbe8b Switch to xx-cargo (#6640)
- removes a lot of additional configuration lines from the Dockerfile
- includes the workaround for the `openssl-sys` build issues with improper `pkg-config` setup
- for reference: https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
2026-05-16 21:19:01 +02:00
mfw78 54895ad4be Reject unrecognised DATABASE_URL instead of silent SQLite fallback (#7061)
* Panic on unrecognised DATABASE_URL instead of silent SQLite fallback

Previously, any DATABASE_URL that did not match the mysql: or postgresql:
prefix was silently treated as a SQLite file path. This caused data loss
in containerised environments when the URL was misconfigured (typos,
quoting issues), as vaultwarden would create an ephemeral SQLite database
that was wiped on restart.

Now, an explicit sqlite:// prefix is supported and used as the default.
Bare paths without a recognised scheme are still accepted for backwards
compatibility, but only if the database file already exists. If not, the
process panics with a clear error message.

Relates to #2835, #1910, #860.

* Use err!() instead of panic!() for unrecognised DATABASE_URL

Follow the established codebase convention where configuration
validation errors use err!() to propagate gracefully, rather than
panic!(). The error propagates through from_config() and is caught
by create_db_pool() which logs and calls exit(1).

* Use 'scheme' instead of 'prefix' in DATABASE_URL messages

Per review feedback, 'scheme' is the more accurate term for the
sqlite:// portion of the URL.
2026-05-16 21:18:53 +02:00
Timshel a057c7deae sso_auth improvements (#7197)
Co-authored-by: Timshel <timshel@users.noreply.github.com>
2026-05-16 21:18:46 +02:00
Stefan Melmuk 2f85b62d2f fix email 2fa for bw cli (#7225) 2026-05-16 21:18:39 +02:00
Mathijs van Veluw 9bc14e6e77 Fix SSO Cookie path (#7187)
Signed-off-by: BlackDex <black.dex@gmail.com>
2026-05-15 20:31:58 +02:00
Chase Douglas cdf711bb30 OpenDAL S3 parameter support (#6127)
* deps: upgrade the reqwest stack to 0.13

The reqwest 0.13 rustls feature selects the aws-lc provider. Use
rustls-no-provider instead, add rustls 0.23 with the ring provider, and
install that provider at process startup. This keeps Vaultwarden on the
existing ring crypto provider while giving reqwest, OpenDAL and lettre a
process-wide rustls provider.

Disable openidconnect default features and provide a small
AsyncHttpClient wrapper around Vaultwarden's shared reqwest client
builder. This preserves custom DNS, request blocking, timeouts and the
no-redirect OIDC behavior without openidconnect enabling its own reqwest
stack.

Upgrade yubico_ng to 0.15.0 and OpenDAL to 0.56.0. OpenDAL 0.56 also
moves S3 signing to reqsign 3, so switch the optional S3 dependencies
from reqsign/anyhow to reqsign-core and reqsign-aws-v4 and adapt the AWS
SDK credential bridge to the new ProvideCredential API.

Adjust the local OpenDAL call sites for the 0.56 API: use the FS_SCHEME
constant for filesystem checks and replace deprecated remove_all() with
delete_with(...).recursive(true) for Send file cleanup.

* storage: add OpenDAL S3 URI options

OpenDAL S3 storage accepts bucket and root path data today, but
serverless deployments also need URI query parameters to describe provider
behavior in one DATA_FOLDER value.

Update OpenDAL to 0.56.0 and build S3 operators with
S3Config::from_uri(). Keep Vaultwarden's AWS SDK credential chain by
installing a reqsign provider when the URI does not explicitly request
OpenDAL-native credential handling.

Move path handling and operator construction into storage.rs so S3-specific
parsing, credential setup, and URI path manipulation stay out of
configuration handling. Local filesystem behavior is unchanged, and S3
child paths are derived before query strings.
2026-05-15 20:30:31 +02:00
Mathijs van Veluw f21a3adae2 Update hickory (#7175)
Signed-off-by: BlackDex <black.dex@gmail.com>
2026-05-02 18:56:15 +02:00
Mathijs van Veluw 07aa377af7 Update crates and web-vault (#7171)
- Update crates including fixing a regression of Diesel
- Update web-vault to v2026.4.1
- Adjusted the README to address the secure context and needing HTTPS

Fixes #7132
Closes #7137

Signed-off-by: BlackDex <black.dex@gmail.com>
2026-04-30 21:45:45 +02:00
Eldred Habert 14258caec9 Allow SQLite to be linked against dynamically (#7057)
Keeping the default behaviour of SQLite being built statically,
so as not to break anyone's workflow, but allowing for downstream
packagers to link dynamically against SQLite (where it's fine because
that's the point of package managers).

Note that SQLite is still *not* enabled by default, thanks to the `?` operator.

Co-authored-by: Daniel García <dani-garcia@users.noreply.github.com>
2026-04-29 22:59:18 +02:00
eason cb46fcb948 fix: return Err instead of panic on unknown cipher atype in to_json() (#7068)
`Cipher::to_json()` returns `Result<Value, Error>` but its match arm for
unknown `atype` values called `panic!("Wrong type")` instead of
propagating an error. This means if a cipher with an invalid/unknown type
ends up in the database (via direct DB edits, data migration issues, or
future type additions in the upstream Bitwarden protocol), the entire
server process would crash on the next sync request.

Replace the `panic!` with `err!()` so callers receive a proper `Err` and
can handle or log it gracefully without taking down the server.

Co-authored-by: easonysliu <easonysliu@tencent.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Daniel García <dani-garcia@users.noreply.github.com>
2026-04-29 22:58:50 +02:00
Johny Jiménez b89648a136 Replace organization_uuid unwrap with proper error handling (#6936)
The collection update endpoints (post_collections_update and
post_collections_admin) call .unwrap() on cipher.organization_uuid
in four places. If a user-owned cipher without an organization
somehow reaches these code paths, the server would panic.

Extract the organization UUID early with a descriptive error message
instead of relying on .unwrap(), preventing potential panics and
providing a clear API error response.

Co-authored-by: Daniel García <dani-garcia@users.noreply.github.com>
2026-04-29 22:58:39 +02:00
Daniel García c3bd1eb565 Fix merge conflict (#7164) 2026-04-29 22:47:42 +02:00
Shocker 8c3c969938 Fix favicon fetching to check all icon links instead of just the first one (#6880)
* Fix favicon fetching to check all icon links instead of just the first one

* revert max icons limit removal

* optimize code

* code formatting
2026-04-29 22:32:48 +02:00
Matt Aaron 38a6850b8d Add support for archiving items (#6916)
* Add archiving

* Update Diesel macros and remove unnecessary SUPPORTED_FEATURE_FLAG

* Add IF EXISTS to down.sql migratinos

* Rename migration folders, separate logic based on PR threads
2026-04-29 22:29:42 +02:00
Mathijs van Veluw d297e274a3 Several SSO Fixes (#7163)
* Ensure SSO token is only usable on the same client

This commit adds an extra check via cookies to ensure the same browser/client is used to request and provide the SSO token.
Previously it would be able to provide a custom link which attackers could use to steal data.
While an attacker would still need the Master Password to be able to decrypt or execute specific actions, they were able to fetch encrypted data.

Solved with some help of Claude Code.

Signed-off-by: BlackDex <black.dex@gmail.com>

* Check email-verified on SSO login/create

This commit prevents possible account takeover via SSO which doesn't check/validate or provide validated status of the email.
It was checked at other locations, but was skipped here.

Signed-off-by: BlackDex <black.dex@gmail.com>

* Prevent data disclosure via SSO endpoints

This commit prevents some data disclosure and user enumeration by only returning the fake SSO identifier.
Since we do not check the identifier anywhere useful, returning the fake one is just fine.

During an invite to an org, that link contains the correct UUID and will be used for the master password requirements.
For anything else, server admins should set the `SSO_MASTER_PASSWORD_POLICY` env variable.

Signed-off-by: BlackDex <black.dex@gmail.com>

* Adjust admin layout to fix issues when SSO is enabled

Signed-off-by: BlackDex <black.dex@gmail.com>

---------

Signed-off-by: BlackDex <black.dex@gmail.com>
2026-04-29 22:25:36 +02:00
Mathijs van Veluw a354e57659 Fix Host/IP resolving (#7162)
IPv4 addresses can also be in decimal or hex formats.
These were not checked during the Global IP check, and could bypass it.

We now convert everything to the right format before running this check and it will catch these formats.

Also updated the `is_global()` function to match Rust's still unstable version.
And updated the Image Magic checks to be more precise and filter out any possible broken or invalid formats.

While at it, also added several checks to ensure these special formatted IPv4 addresses are still blocked and punycode domains are also correctly resolved.

Signed-off-by: BlackDex <black.dex@gmail.com>
2026-04-29 22:20:59 +02:00
Mathijs van Veluw 5cc7360816 Update crates and fix a nightly lint (#7161)
Updated all the crates including two which reported a possible CVE
Updated Typos

Signed-off-by: BlackDex <black.dex@gmail.com>
2026-04-29 22:10:26 +02:00
Timshel 62748100f0 Fix hardcoded sso identifier (#7157)
Co-authored-by: Timshel <timshel@users.noreply.github.com>
2026-04-28 19:09:47 +02:00
Daniel fcbdebd6d7 Apply ref_option lint findings (#7143)
Quote from the lint description:
"More flexibility, better memory optimization, and more idiomatic Rust code.

&Option<T> in a function signature breaks encapsulation because the caller must own T and move it into an Option to call with it. When returned, the owner must internally store it as Option<T> in order to return it. At a lower level, &Option<T> points to memory with the presence bit flag plus the T value, whereas Option<&T> is usually optimized to a single pointer, so it may be more optimal."
2026-04-28 18:34:40 +02:00
Daniel 454b8e2a35 Apply duration_suboptimal_units lint findings (#7144)
Quote from lint description:
"Using a smaller unit for a duration that is evenly divisible by a larger unit reduces readability. Readers have to mentally convert values, which can be error-prone and makes the code less clear."
2026-04-28 18:34:15 +02:00
Daniel 7883da554e Add DuckDuckGo browser device type (#7147)
- sync with upstream
2026-04-28 18:34:03 +02:00
Stefan Melmuk fd2b6528a9 add new /identity/accounts/prelogin/password (#7156) 2026-04-28 18:33:52 +02:00
Timshel cc57e60886 Dummy identifier need to pass for a guid (#7154)
Co-authored-by: Timshel <timshel@users.noreply.github.com>
2026-04-28 18:33:49 +02:00
Timshel e5681258f0 SSO fallback to UserInfo preferred_username (#7128)
Co-authored-by: Timshel <timshel@users.noreply.github.com>
2026-04-28 18:33:45 +02:00
102 changed files with 5051 additions and 3969 deletions
+5 -4
View File
@@ -50,10 +50,11 @@
#########################
## Database URL
## When using SQLite, this is the path to the DB file, and it defaults to
## %DATA_FOLDER%/db.sqlite3. If DATA_FOLDER is set to an external location, this
## must be set to a local sqlite3 file path.
# DATABASE_URL=data/db.sqlite3
## When using SQLite, this should use the sqlite:// scheme followed by the path
## to the DB file. It defaults to sqlite://%DATA_FOLDER%/db.sqlite3.
## Bare paths without the sqlite:// scheme are supported for backwards compatibility,
## but only if the database file already exists.
# DATABASE_URL=sqlite://data/db.sqlite3
## When using MySQL, specify an appropriate connection URI.
## Details: https://docs.diesel.rs/2.1.x/diesel/mysql/struct.MysqlConnection.html
# DATABASE_URL=mysql://user:password@host[:port]/database_name
+1 -1
View File
@@ -50,6 +50,6 @@ jobs:
severity: CRITICAL,HIGH
- name: Upload Trivy scan results to GitHub Security tab
uses: github/codeql-action/upload-sarif@95e58e9a2cdfd71adc6e0353d5c52f41a045d225 # v4.35.2
uses: github/codeql-action/upload-sarif@9e0d7b8d25671d64c341c19c0152d693099fb5ba # v4.35.5
with:
sarif_file: 'trivy-results.sarif'
+1 -1
View File
@@ -23,4 +23,4 @@ jobs:
# When this version is updated, do not forget to update this in `.pre-commit-config.yaml` too
- name: Spell Check Repo
uses: crate-ci/typos@cf5f1c29a8ac336af8568821ec41919923b05a83 # v1.45.1
uses: crate-ci/typos@5374cbf686e897b15713110e233094e2874de7ef # v1.46.1
+1 -1
View File
@@ -24,7 +24,7 @@ jobs:
persist-credentials: false
- name: Run zizmor
uses: zizmorcore/zizmor-action@b1d7e1fb5de872772f31590499237e7cce841e8e # v0.5.3
uses: zizmorcore/zizmor-action@b572f7b1a1c2d41efaab43d504f68d215c3cd727 # v0.5.4
with:
# intentionally not scanning the entire repository,
# since it contains integration tests.
+1 -1
View File
@@ -18,7 +18,7 @@ repos:
# When this version is updated, do not forget to update this in `.github/workflows/typos.yaml` too
- repo: https://github.com/crate-ci/typos
rev: cf5f1c29a8ac336af8568821ec41919923b05a83 # v1.45.1
rev: 5374cbf686e897b15713110e233094e2874de7ef # v1.46.1
hooks:
- id: typos
+2
View File
@@ -23,4 +23,6 @@ extend-ignore-re = [
# https://github.com/bitwarden/server/blob/dff9f1cf538198819911cf2c20f8cda3307701c5/src/Notifications/HubHelpers.cs#L86
# https://github.com/bitwarden/clients/blob/9612a4ac45063e372a6fbe87eb253c7cb3c588fb/libs/common/src/auth/services/anonymous-hub.service.ts#L45
"AuthRequestResponseRecieved",
# Ignore Punycode/IDN tests
"xn--.+"
]
Generated
+291 -447
View File
File diff suppressed because it is too large Load Diff
+130 -66
View File
@@ -1,5 +1,5 @@
[workspace.package]
edition = "2021"
edition = "2024"
rust-version = "1.93.0"
license = "AGPL-3.0-only"
repository = "https://github.com/dani-garcia/vaultwarden"
@@ -24,20 +24,31 @@ publish.workspace = true
[features]
default = [
# "sqlite",
# "sqlite_system",
# "mysql",
# "postgresql",
]
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
enable_syslog = []
# Please enable at least one of these DB backends.
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
postgresql = ["diesel/postgres", "diesel_migrations/postgres"]
sqlite = ["diesel/sqlite", "diesel_migrations/sqlite", "dep:libsqlite3-sys"]
sqlite_system = ["diesel/sqlite", "diesel_migrations/sqlite"] # Dynamically link SQLite
sqlite = ["sqlite_system", "libsqlite3-sys/bundled"] # Statically link SQLite into the binary instead of dynamically.
# Enable to use a vendored and statically linked openssl
vendored_openssl = ["openssl/vendored"]
# Enable MiMalloc memory allocator to replace the default malloc
# This can improve performance for Alpine builds
enable_mimalloc = ["dep:mimalloc"]
s3 = ["opendal/services-s3", "dep:aws-config", "dep:aws-credential-types", "dep:aws-smithy-runtime-api", "dep:anyhow", "dep:http", "dep:reqsign"]
s3 = [
"opendal/services-s3",
"dep:aws-config",
"dep:aws-credential-types",
"dep:aws-smithy-runtime-api",
"dep:http",
"dep:reqsign-aws-v4",
"dep:reqsign-core",
]
# OIDC specific features
oidc-accept-rfc3339-timestamps = ["openidconnect/accept-rfc3339-timestamps"]
@@ -57,7 +68,8 @@ macros = { path = "./macros" }
# Logging
log = "0.4.29"
fern = { version = "0.7.1", features = ["syslog-7", "reopen-1"] }
tracing = { version = "0.1.44", features = ["log"] } # Needed to have lettre and webauthn-rs trace logging to work
# We need the `log` feature for `tracing` to enable logging for several crates to work, like lettre or webauthn-rs
tracing = { version = "0.1.44", features = ["log"] }
# A `dotenv` implementation for Rust
dotenvy = { version = "0.15.7", default-features = false }
@@ -68,8 +80,8 @@ num-derive = "0.4.2"
bigdecimal = "0.4.10"
# Web framework
rocket = { version = "0.5.1", features = ["tls", "json"], default-features = false }
rocket_ws = { version ="0.1.1" }
rocket = { version = "0.5.1", default-features = false, features = ["json", "tls"] }
rocket_ws = { version = "0.1.1" }
# WebSockets libraries
rmpv = "1.3.1" # MessagePack library
@@ -79,34 +91,48 @@ dashmap = "6.1.0"
# Async futures
futures = "0.3.32"
tokio = { version = "1.52.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot", "time", "signal", "net"] }
tokio-util = { version = "0.7.18", features = ["compat"]}
tokio = { version = "1.52.3", features = [
"fs",
"io-util",
"net",
"parking_lot",
"rt-multi-thread",
"signal",
"time",
] }
tokio-util = { version = "0.7.18", features = ["compat"] }
# A generic serialization/deserialization framework
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
# A safe, extensible ORM and Query builder
# Currently pinned diesel to v2.3.3 as newer version break MySQL/MariaDB compatibility
diesel = { version = "2.3.7", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.3.1"
diesel = { version = "2.3.9", features = ["chrono", "r2d2", "numeric"] }
diesel_migrations = "2.3.2"
derive_more = { version = "2.1.1", features = ["from", "into", "as_ref", "deref", "display"] }
derive_more = { version = "2.1.1", features = [
"as_ref",
"deref",
"display",
"from",
"into",
] }
diesel-derive-newtype = "2.1.2"
# Bundled/Static SQLite
libsqlite3-sys = { version = "0.36.0", features = ["bundled"], optional = true }
# SQLite, statically bundled unless the `sqlite_system` feature is enabled
libsqlite3-sys = { version = "0.37.0", optional = true }
# Crypto-related libraries
rand = "0.10.1"
ring = "0.17.14"
rustls = { version = "0.23.40", features = ["ring", "std"], default-features = false }
subtle = "2.6.1"
# UUID generation
uuid = { version = "1.23.1", features = ["v4"] }
# Date and time libraries
chrono = { version = "0.4.44", features = ["clock", "serde"], default-features = false }
chrono = { version = "0.4.44", default-features = false, features = ["clock", "serde"] }
chrono-tz = "0.10.4"
time = "0.3.47"
@@ -114,29 +140,42 @@ time = "0.3.47"
job_scheduler_ng = "2.4.0"
# Data encoding library Hex/Base32/Base64
data-encoding = "2.10.0"
data-encoding = "2.11.0"
# JWT library
jsonwebtoken = { version = "10.3.0", features = ["use_pem", "rust_crypto"], default-features = false }
jsonwebtoken = { version = "10.4.0", default-features = false, features = ["rust_crypto", "use_pem"] }
# TOTP library
totp-lite = "2.0.1"
# Yubico Library
yubico = { package = "yubico_ng", version = "0.14.1", features = ["online-tokio"], default-features = false }
yubico = { package = "yubico_ng", version = "0.15.0", default-features = false, features = ["online-tokio"] }
# WebAuthn libraries
# danger-allow-state-serialisation is needed to save the state in the db
# danger-credential-internals is needed to support U2F to Webauthn migration
webauthn-rs = { version = "0.5.4", features = ["danger-allow-state-serialisation", "danger-credential-internals"] }
webauthn-rs-proto = "0.5.4"
webauthn-rs-core = "0.5.4"
webauthn-rs = { version = "0.5.5", features = ["danger-allow-state-serialisation", "danger-credential-internals"] }
webauthn-rs-proto = "0.5.5"
webauthn-rs-core = "0.5.5"
# Handling of URL's for WebAuthn and favicons
url = "2.5.8"
# Email libraries
lettre = { version = "0.11.21", features = ["smtp-transport", "sendmail-transport", "builder", "serde", "hostname", "tracing", "tokio1-rustls", "ring", "rustls-native-certs"], default-features = false }
lettre = { version = "0.11.22", default-features = false, features = [
# Misc
"tracing",
"serde",
"builder",
"hostname",
# TLS/Security
"ring",
"rustls-native-certs",
"tokio1-rustls",
# Transport
"smtp-transport",
"sendmail-transport",
] }
percent-encoding = "2.3.2" # URL encoding library used for URL's in the emails
email_address = "0.2.9"
@@ -144,12 +183,33 @@ email_address = "0.2.9"
handlebars = { version = "6.4.0", features = ["dir_source"] }
# HTTP client (Used for favicons, version check, DUO and HIBP API)
reqwest = { version = "0.12.28", features = ["rustls-tls", "rustls-tls-native-roots", "stream", "json", "deflate", "gzip", "brotli", "zstd", "socks", "cookies", "charset", "http2", "system-proxy"], default-features = false}
hickory-resolver = "0.26.0"
reqwest = { version = "0.13.3", default-features = false, features = [
# Misc
"charset",
"cookies",
"http2",
"json",
"form",
"rustls-no-provider",
"stream",
# Compression
"brotli",
"deflate",
"gzip",
"zstd",
# Proxy
"socks",
"system-proxy",
] }
hickory-resolver = "0.26.1"
# Favicon extraction libraries
html5gum = "0.8.3"
regex = { version = "1.12.3", features = ["std", "perf", "unicode-perl"], default-features = false }
regex = { version = "1.12.3", default-features = false, features = [
"perf",
"std",
"unicode-perl",
] }
data-url = "0.3.2"
bytes = "1.11.1"
svg-hush = "0.9.6"
@@ -162,17 +222,17 @@ cookie = "0.18.1"
cookie_store = "0.22.1"
# Used by U2F, JWT and PostgreSQL
openssl = "0.10.78"
openssl = "0.10.79"
# CLI argument parsing
pico-args = "0.5.0"
# Macro ident concatenation
pastey = "0.2.1"
pastey = "0.2.2"
governor = "0.10.4"
# OIDC for SSO
openidconnect = { version = "4.0.1", features = ["reqwest", "rustls-tls"] }
openidconnect = { version = "4.0.1", default-features = false }
moka = { version = "0.12.15", features = ["future"] }
# Check client versions for specific features.
@@ -180,7 +240,7 @@ semver = "1.0.28"
# Allow overriding the default memory allocator
# Mainly used for the musl builds, since the default musl malloc is very slow
mimalloc = { version = "0.1.50", features = ["secure"], default-features = false, optional = true }
mimalloc = { version = "0.1.50", optional = true, default-features = false, features = ["secure"] }
which = "8.0.2"
@@ -188,21 +248,26 @@ which = "8.0.2"
argon2 = "0.5.3"
# Reading a password from the cli for generating the Argon2id ADMIN_TOKEN
rpassword = "7.4.0"
rpassword = "7.5.2"
# Loading a dynamic CSS Stylesheet
grass_compiler = { version = "0.13.4", default-features = false }
# File are accessed through Apache OpenDAL
opendal = { version = "0.55.0", features = ["services-fs"], default-features = false }
opendal = { version = "0.56.0", default-features = false, features = ["services-fs"] }
# For retrieving AWS credentials, including temporary SSO credentials
anyhow = { version = "1.0.102", optional = true }
aws-config = { version = "1.8.16", features = ["behavior-version-latest", "rt-tokio", "credentials-process", "sso"], default-features = false, optional = true }
aws-config = { version = "1.8.16", optional = true, default-features = false, features = [
"behavior-version-latest",
"credentials-process",
"rt-tokio",
"sso",
] }
aws-credential-types = { version = "1.2.14", optional = true }
aws-smithy-runtime-api = { version = "1.12.0", optional = true }
http = { version = "1.4.0", optional = true }
reqsign = { version = "0.16.5", optional = true }
reqsign-aws-v4 = { version = "3.0.0", optional = true }
reqsign-core = { version = "3.0.0", optional = true }
# Strip debuginfo from the release builds
# The debug symbols are to provide better panic traces
@@ -262,75 +327,74 @@ unsafe_code = "forbid"
non_ascii_idents = "forbid"
# Deny
deprecated_in_future = "deny"
warnings = "deny" # Explicitly deny all warnings since we deny all warnings in the end
# Deny lint groups
deprecated_safe = { level = "deny", priority = -1 }
future_incompatible = { level = "deny", priority = -1 }
keyword_idents = { level = "deny", priority = -1 }
let_underscore = { level = "deny", priority = -1 }
nonstandard_style = { level = "deny", priority = -1 }
noop_method_call = "deny"
refining_impl_trait = { level = "deny", priority = -1 }
rust_2018_idioms = { level = "deny", priority = -1 }
rust_2021_compatibility = { level = "deny", priority = -1 }
rust_2024_compatibility = { level = "deny", priority = -1 }
unused = { level = "deny", priority = -1 }
# Deny individual lints
closure_returning_async_block = "deny"
deprecated_in_future = "deny"
single_use_lifetimes = "deny"
trivial_casts = "deny"
trivial_numeric_casts = "deny"
unused = { level = "deny", priority = -1 }
unused_import_braces = "deny"
unused_lifetimes = "deny"
unused_qualifications = "deny"
variant_size_differences = "deny"
# Allow the following lints since these cause issues with Rust v1.84.0 or newer
# Building Vaultwarden with Rust v1.85.0 with edition 2024 also works without issues
edition_2024_expr_fragment_specifier = "allow" # Once changed to Rust 2024 this should be removed and macro's should be validated again
if_let_rescope = "allow"
tail_expr_drop_order = "allow"
# https://rust-lang.github.io/rust-clippy/stable/index.html
[workspace.lints.clippy]
# Warn
# Warn only so you can still use these during development, but not in the final code
dbg_macro = "warn"
todo = "warn"
# Ignore/Allow
result_large_err = "allow"
# Deny
# Warn on these lint group (Some might be warn by default already though)
# Will be denied during CI!
complexity = { level = "warn", priority = -1 }
pedantic = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
style = { level = "warn", priority = -1 }
suspicious = { level = "warn", priority = -1 }
# Deny individual lints
branches_sharing_code = "deny"
case_sensitive_file_extension_comparisons = "deny"
cast_lossless = "deny"
clone_on_ref_ptr = "deny"
equatable_if_let = "deny"
excessive_precision = "deny"
filter_map_next = "deny"
float_cmp_const = "deny"
implicit_clone = "deny"
inefficient_to_string = "deny"
iter_on_empty_collections = "deny"
iter_on_single_items = "deny"
linkedlist = "deny"
macro_use_imports = "deny"
manual_assert = "deny"
manual_instant_elapsed = "deny"
manual_string_new = "deny"
match_wildcard_for_single_variants = "deny"
mem_forget = "deny"
needless_borrow = "deny"
needless_collect = "deny"
needless_continue = "deny"
needless_lifetimes = "deny"
option_option = "deny"
redundant_clone = "deny"
string_add_assign = "deny"
unnecessary_join = "deny"
unnecessary_self_imports = "deny"
unnested_or_patterns = "deny"
unused_async = "deny"
unused_self = "deny"
useless_let_if_seq = "deny"
verbose_file_reads = "deny"
zero_sized_map_values = "deny"
str_to_string = "deny"
# Pedantic Opt-Outs
inline_always = "allow" # We use this sparsely
struct_field_names = "allow" # Noisy and some items are Bitwarden controlled
large_futures = "allow" # Causes a fail in some Rocket macro's, since we experience no issues, allow it
too_many_lines = "allow" # For now, allow this, good to enable in the future and see if we can refactor
unnecessary_wraps = "allow" # Too much false positives because of Rocket integrations
# We do not use these doc items
doc_link_with_quotes = "allow"
doc_markdown = "allow"
missing_errors_doc = "allow"
missing_panics_doc = "allow"
[lints]
workspace = true
+3 -2
View File
@@ -59,8 +59,9 @@ A nearly complete implementation of the Bitwarden Client API is provided, includ
## Usage
> [!IMPORTANT]
> The web-vault requires the use a secure context for the [Web Crypto API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Crypto_API).
> That means it will only work via `http://localhost:8000` (using the port from the example below) or if you [enable HTTPS](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-HTTPS).
> The web-vault requires the use of HTTPS and a secure context for the [Web Crypto API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Crypto_API). <br>
> That means it will only work if you [enable HTTPS](https://github.com/dani-garcia/vaultwarden/wiki/Enabling-HTTPS). <br>
> We also suggest to use a [reverse proxy](https://github.com/dani-garcia/vaultwarden/wiki/Proxy-examples).
The recommended way to install and use Vaultwarden is via our container images which are published to [ghcr.io](https://github.com/dani-garcia/vaultwarden/pkgs/container/vaultwarden), [docker.io](https://hub.docker.com/r/vaultwarden/server) and [quay.io](https://quay.io/repository/vaultwarden/server).
See [which container image to use](https://github.com/dani-garcia/vaultwarden/wiki/Which-container-image-to-use) for an explanation of the provided tags.
+10 -12
View File
@@ -1,22 +1,21 @@
use std::env;
use std::process::Command;
use std::{env, io::Error, process::Command};
fn main() {
// This allow using #[cfg(sqlite)] instead of #[cfg(feature = "sqlite")], which helps when trying to add them through macros
#[cfg(feature = "sqlite")]
// These allow using e.g. #[cfg(mysql)] instead of #[cfg(feature = "mysql")], which helps when trying to add them through macros
#[cfg(feature = "sqlite_system")] // The `sqlite` feature implies this one.
println!("cargo:rustc-cfg=sqlite");
#[cfg(feature = "mysql")]
println!("cargo:rustc-cfg=mysql");
#[cfg(feature = "postgresql")]
println!("cargo:rustc-cfg=postgresql");
#[cfg(feature = "s3")]
println!("cargo:rustc-cfg=s3");
#[cfg(not(any(feature = "sqlite", feature = "mysql", feature = "postgresql")))]
#[cfg(not(any(feature = "sqlite_system", feature = "mysql", feature = "postgresql")))]
compile_error!(
"You need to enable one DB backend. To build with previous defaults do: cargo build --features sqlite"
);
#[cfg(feature = "s3")]
println!("cargo:rustc-cfg=s3");
// Use check-cfg to let cargo know which cfg's we define,
// and avoid warnings when they are used in the code.
println!("cargo::rustc-check-cfg=cfg(sqlite)");
@@ -42,13 +41,12 @@ fn main() {
}
}
fn run(args: &[&str]) -> Result<String, std::io::Error> {
fn run(args: &[&str]) -> Result<String, Error> {
let out = Command::new(args[0]).args(&args[1..]).output()?;
if !out.status.success() {
use std::io::Error;
return Err(Error::other("Command not successful"));
}
Ok(String::from_utf8(out.stdout).unwrap().trim().to_string())
Ok(String::from_utf8(out.stdout).unwrap().trim().to_owned())
}
/// This method reads info from Git, namely tags, branch, and revision
@@ -58,7 +56,7 @@ fn run(args: &[&str]) -> Result<String, std::io::Error> {
/// - `env!("GIT_BRANCH")`
/// - `env!("GIT_REV")`
/// - `env!("VW_VERSION")`
fn version_from_git_info() -> Result<String, std::io::Error> {
fn version_from_git_info() -> Result<String, Error> {
// The exact tag for the current commit, can be empty when
// the current commit doesn't have an associated tag
let exact_tag = run(&["git", "describe", "--abbrev=0", "--tags", "--exact-match"]).ok();
+2 -2
View File
@@ -1,6 +1,6 @@
---
vault_version: "v2026.3.1"
vault_image_digest: "sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767"
vault_version: "v2026.4.1"
vault_image_digest: "sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe"
# Cross Compile Docker Helper Scripts v1.9.0
# We use the linux/amd64 platform shell scripts since there is no difference between the different platform scripts
# https://github.com/tonistiigi/xx | https://hub.docker.com/r/tonistiigi/xx/tags
+6 -7
View File
@@ -19,15 +19,15 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to.
# - From the command line:
# $ docker pull docker.io/vaultwarden/web-vault:v2026.3.1
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.3.1
# [docker.io/vaultwarden/web-vault@sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767]
# $ docker pull docker.io/vaultwarden/web-vault:v2026.4.1
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.4.1
# [docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe]
#
# - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767
# [docker.io/vaultwarden/web-vault:v2026.3.1]
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe
# [docker.io/vaultwarden/web-vault:v2026.4.1]
#
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767 AS vault
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe AS vault
########################## ALPINE BUILD IMAGES ##########################
## NOTE: The Alpine Base Images do not support other platforms then linux/amd64 and linux/arm64
@@ -57,7 +57,6 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq17/lib"
# Create CARGO_HOME folder and don't download rust docs
RUN mkdir -pv "${CARGO_HOME}" && \
rustup set profile minimal
+17 -39
View File
@@ -19,15 +19,15 @@
# - From https://hub.docker.com/r/vaultwarden/web-vault/tags,
# click the tag name to view the digest of the image it currently points to.
# - From the command line:
# $ docker pull docker.io/vaultwarden/web-vault:v2026.3.1
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.3.1
# [docker.io/vaultwarden/web-vault@sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767]
# $ docker pull docker.io/vaultwarden/web-vault:v2026.4.1
# $ docker image inspect --format "{{.RepoDigests}}" docker.io/vaultwarden/web-vault:v2026.4.1
# [docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe]
#
# - Conversely, to get the tag name from the digest:
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767
# [docker.io/vaultwarden/web-vault:v2026.3.1]
# $ docker image inspect --format "{{.RepoTags}}" docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe
# [docker.io/vaultwarden/web-vault:v2026.4.1]
#
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:c1b1f212333f95bff4ef8d00e8e3589c4ae8eda018691f28f8bddc7e971dd767 AS vault
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@sha256:ca2a4251c4e63c9ad428262b4dd452789a1b9f6fce71da351e93dceed0d2edbe AS vault
########################## Cross Compile Docker Helper Scripts ##########################
## We use the linux/amd64 no matter which Build Platform, since these are all bash scripts
@@ -51,7 +51,7 @@ ENV DEBIAN_FRONTEND=noninteractive \
TERM=xterm-256color \
CARGO_HOME="/root/.cargo" \
USER="root"
# Install clang to get `xx-cargo` working
# Install clang && xx-c-essentials to get `xx-cargo` working
# Install pkg-config to allow amd64 builds to find all libraries
# Install git so build.rs can determine the correct version
# Install the libc cross packages based upon the debian-arch
@@ -59,19 +59,16 @@ RUN apt-get update && \
apt-get install -y \
--no-install-recommends \
clang \
pkg-config \
git \
"libc6-$(xx-info debian-arch)-cross" \
"libc6-dev-$(xx-info debian-arch)-cross" \
"linux-libc-dev-$(xx-info debian-arch)-cross" && \
git && \
xx-apt-get install -y \
--no-install-recommends \
gcc \
libpq-dev \
libpq5 \
libssl-dev \
libmariadb-dev \
zlib1g-dev && \
pkg-config \
zlib1g-dev \
xx-c-essentials && \
# Run xx-cargo early, since it sometimes seems to break when run at a later stage
echo "export CARGO_TARGET=$(xx-cargo --print-target-triple)" >> /env-cargo
@@ -83,29 +80,6 @@ RUN mkdir -pv "${CARGO_HOME}" && \
RUN USER=root cargo new --bin /app
WORKDIR /app
# Environment variables for Cargo on Debian based builds
ARG TARGET_PKG_CONFIG_PATH
RUN source /env-cargo && \
if xx-info is-cross ; then \
# We can't use xx-cargo since that uses clang, which doesn't work for our libraries.
# Because of this we generate the needed environment variables here which we can load in the needed steps.
echo "export CC_$(echo "${CARGO_TARGET}" | tr '[:upper:]' '[:lower:]' | tr - _)=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CARGO_TARGET_$(echo "${CARGO_TARGET}" | tr '[:lower:]' '[:upper:]' | tr - _)_LINKER=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CROSS_COMPILE=1" >> /env-cargo && \
echo "export PKG_CONFIG_ALLOW_CROSS=1" >> /env-cargo && \
# For some architectures `xx-info` returns a triple which doesn't matches the path on disk
# In those cases you can override this by setting the `TARGET_PKG_CONFIG_PATH` build-arg
if [[ -n "${TARGET_PKG_CONFIG_PATH}" ]]; then \
echo "export TARGET_PKG_CONFIG_PATH=${TARGET_PKG_CONFIG_PATH}" >> /env-cargo ; \
else \
echo "export PKG_CONFIG_PATH=/usr/lib/$(xx-info)/pkgconfig" >> /env-cargo ; \
fi && \
echo "# End of env-cargo" >> /env-cargo ; \
fi && \
# Output the current contents of the file
cat /env-cargo
RUN source /env-cargo && \
rustup target add "${CARGO_TARGET}"
@@ -122,7 +96,9 @@ ARG DB=sqlite,mysql,postgresql
# dummy project, except the target folder
# This folder contains the compiled dependencies
RUN source /env-cargo && \
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
# Workaround for xx related build issues
# https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
PKG_CONFIG="$(command -v "$(xx-info)-pkg-config")" xx-cargo build --features ${DB} --profile "${CARGO_PROFILE}" && \
find . -not -path "./target*" -delete
# Copies the complete project
@@ -137,7 +113,9 @@ RUN source /env-cargo && \
# Also do this for build.rs to ensure the version is rechecked
touch build.rs src/main.rs && \
# Create a symlink to the binary target folder to easy copy the binary in the final stage
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
# Workaround for xx related build issues
# https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
PKG_CONFIG="$(command -v "$(xx-info)-pkg-config")" xx-cargo build --features ${DB} --profile "${CARGO_PROFILE}" && \
if [[ "${CARGO_PROFILE}" == "dev" ]] ; then \
ln -vfsr "/app/target/${CARGO_TARGET}/debug" /app/target/final ; \
else \
+20 -34
View File
@@ -27,6 +27,11 @@
# $ docker image inspect --format "{{ '{{' }}.RepoTags}}" docker.io/vaultwarden/web-vault@{{ vault_image_digest }}
# [docker.io/vaultwarden/web-vault:{{ vault_version | replace('+', '_') }}]
#
{% macro xx_cargo_config() -%}
# Workaround for xx related build issues
# https://github.com/tonistiigi/xx/pull/108#issuecomment-3700635977
PKG_CONFIG="$(command -v "$(xx-info)-pkg-config")" xx-cargo build --features ${DB} --profile "${CARGO_PROFILE}"
{%- endmacro %}
FROM --platform=linux/amd64 docker.io/vaultwarden/web-vault@{{ vault_image_digest }} AS vault
{% if base == "debian" %}
@@ -66,10 +71,10 @@ ENV DEBIAN_FRONTEND=noninteractive \
# Use PostgreSQL v17 during Alpine/MUSL builds instead of the default v16
# Debian Trixie uses libpq v17
PQ_LIB_DIR="/usr/local/musl/pq17/lib"
{% endif %}
{%- endif %}
{% if base == "debian" %}
# Install clang to get `xx-cargo` working
# Install clang && xx-c-essentials to get `xx-cargo` working
# Install pkg-config to allow amd64 builds to find all libraries
# Install git so build.rs can determine the correct version
# Install the libc cross packages based upon the debian-arch
@@ -77,19 +82,16 @@ RUN apt-get update && \
apt-get install -y \
--no-install-recommends \
clang \
pkg-config \
git \
"libc6-$(xx-info debian-arch)-cross" \
"libc6-dev-$(xx-info debian-arch)-cross" \
"linux-libc-dev-$(xx-info debian-arch)-cross" && \
git && \
xx-apt-get install -y \
--no-install-recommends \
gcc \
libpq-dev \
libpq5 \
libssl-dev \
libmariadb-dev \
zlib1g-dev && \
pkg-config \
zlib1g-dev \
xx-c-essentials && \
# Run xx-cargo early, since it sometimes seems to break when run at a later stage
echo "export CARGO_TARGET=$(xx-cargo --print-target-triple)" >> /env-cargo
{% endif %}
@@ -102,31 +104,7 @@ RUN mkdir -pv "${CARGO_HOME}" && \
RUN USER=root cargo new --bin /app
WORKDIR /app
{% if base == "debian" %}
# Environment variables for Cargo on Debian based builds
ARG TARGET_PKG_CONFIG_PATH
RUN source /env-cargo && \
if xx-info is-cross ; then \
# We can't use xx-cargo since that uses clang, which doesn't work for our libraries.
# Because of this we generate the needed environment variables here which we can load in the needed steps.
echo "export CC_$(echo "${CARGO_TARGET}" | tr '[:upper:]' '[:lower:]' | tr - _)=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CARGO_TARGET_$(echo "${CARGO_TARGET}" | tr '[:lower:]' '[:upper:]' | tr - _)_LINKER=/usr/bin/$(xx-info)-gcc" >> /env-cargo && \
echo "export CROSS_COMPILE=1" >> /env-cargo && \
echo "export PKG_CONFIG_ALLOW_CROSS=1" >> /env-cargo && \
# For some architectures `xx-info` returns a triple which doesn't matches the path on disk
# In those cases you can override this by setting the `TARGET_PKG_CONFIG_PATH` build-arg
if [[ -n "${TARGET_PKG_CONFIG_PATH}" ]]; then \
echo "export TARGET_PKG_CONFIG_PATH=${TARGET_PKG_CONFIG_PATH}" >> /env-cargo ; \
else \
echo "export PKG_CONFIG_PATH=/usr/lib/$(xx-info)/pkgconfig" >> /env-cargo ; \
fi && \
echo "# End of env-cargo" >> /env-cargo ; \
fi && \
# Output the current contents of the file
cat /env-cargo
{% elif base == "alpine" %}
{% if base == "alpine" %}
# Environment variables for Cargo on Alpine based builds
RUN echo "export CARGO_TARGET=${RUST_MUSL_CROSS_TARGET}" >> /env-cargo && \
# Output the current contents of the file
@@ -154,7 +132,11 @@ ARG DB=sqlite,mysql,postgresql,enable_mimalloc
# dummy project, except the target folder
# This folder contains the compiled dependencies
RUN source /env-cargo && \
{% if base == "debian" %}
{{ xx_cargo_config() }} && \
{% elif base == "alpine" %}
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
{% endif %}
find . -not -path "./target*" -delete
# Copies the complete project
@@ -169,7 +151,11 @@ RUN source /env-cargo && \
# Also do this for build.rs to ensure the version is rechecked
touch build.rs src/main.rs && \
# Create a symlink to the binary target folder to easy copy the binary in the final stage
{% if base == "debian" %}
{{ xx_cargo_config() }} && \
{% elif base == "alpine" %}
cargo build --features ${DB} --profile "${CARGO_PROFILE}" --target="${CARGO_TARGET}" && \
{% endif %}
if [[ "${CARGO_PROFILE}" == "dev" ]] ; then \
ln -vfsr "/app/target/${CARGO_TARGET}/debug" /app/target/final ; \
else \
+5 -4
View File
@@ -1,14 +1,15 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::{DeriveInput, parse_macro_input};
#[proc_macro_derive(UuidFromParam)]
pub fn derive_uuid_from_param(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
let ast = parse_macro_input!(input as DeriveInput);
impl_derive_uuid_macro(&ast)
}
fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream {
fn impl_derive_uuid_macro(ast: &DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen_derive = quote! {
#[automatically_derived]
@@ -30,12 +31,12 @@ fn impl_derive_uuid_macro(ast: &syn::DeriveInput) -> TokenStream {
#[proc_macro_derive(IdFromParam)]
pub fn derive_id_from_param(input: TokenStream) -> TokenStream {
let ast = syn::parse(input).unwrap();
let ast = parse_macro_input!(input as DeriveInput);
impl_derive_safestring_macro(&ast)
}
fn impl_derive_safestring_macro(ast: &syn::DeriveInput) -> TokenStream {
fn impl_derive_safestring_macro(ast: &DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen_derive = quote! {
#[automatically_derived]
@@ -0,0 +1 @@
DROP TABLE IF EXISTS archives;
@@ -0,0 +1,10 @@
DROP TABLE IF EXISTS archives;
CREATE TABLE archives (
user_uuid CHAR(36) NOT NULL,
cipher_uuid CHAR(36) NOT NULL,
archived_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_uuid, cipher_uuid),
FOREIGN KEY (user_uuid) REFERENCES users (uuid) ON DELETE CASCADE,
FOREIGN KEY (cipher_uuid) REFERENCES ciphers (uuid) ON DELETE CASCADE
);
@@ -0,0 +1 @@
ALTER TABLE sso_auth DROP COLUMN binding_hash;
@@ -0,0 +1 @@
ALTER TABLE sso_auth ADD COLUMN binding_hash TEXT;
@@ -0,0 +1 @@
ALTER TABLE sso_auth DROP COLUMN code_response_error;
@@ -0,0 +1 @@
ALTER TABLE sso_auth ADD COLUMN code_response_error TEXT;
@@ -0,0 +1 @@
DROP TABLE IF EXISTS archives;
@@ -0,0 +1,8 @@
DROP TABLE IF EXISTS archives;
CREATE TABLE archives (
user_uuid CHAR(36) NOT NULL REFERENCES users (uuid) ON DELETE CASCADE,
cipher_uuid CHAR(36) NOT NULL REFERENCES ciphers (uuid) ON DELETE CASCADE,
archived_at TIMESTAMP NOT NULL DEFAULT now(),
PRIMARY KEY (user_uuid, cipher_uuid)
);
@@ -0,0 +1 @@
ALTER TABLE sso_auth DROP COLUMN binding_hash;
@@ -0,0 +1 @@
ALTER TABLE sso_auth ADD COLUMN binding_hash TEXT;
@@ -0,0 +1 @@
ALTER TABLE sso_auth DROP COLUMN IF EXISTS code_response_error;
@@ -0,0 +1 @@
ALTER TABLE sso_auth ADD COLUMN IF NOT EXISTS code_response_error TEXT;
@@ -0,0 +1 @@
DROP TABLE IF EXISTS archives;
@@ -0,0 +1,8 @@
DROP TABLE IF EXISTS archives;
CREATE TABLE archives (
user_uuid CHAR(36) NOT NULL REFERENCES users (uuid) ON DELETE CASCADE,
cipher_uuid CHAR(36) NOT NULL REFERENCES ciphers (uuid) ON DELETE CASCADE,
archived_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (user_uuid, cipher_uuid)
);
@@ -0,0 +1 @@
ALTER TABLE sso_auth DROP COLUMN binding_hash;
@@ -0,0 +1 @@
ALTER TABLE sso_auth ADD COLUMN binding_hash TEXT;
@@ -0,0 +1 @@
ALTER TABLE sso_auth DROP COLUMN code_response_error;
@@ -0,0 +1 @@
ALTER TABLE sso_auth ADD COLUMN code_response_error TEXT;
+1 -1
View File
@@ -1,4 +1,4 @@
edition = "2021"
edition = "2024"
max_width = 120
newline_style = "Unix"
use_small_heuristics = "Off"
+65 -70
View File
@@ -2,40 +2,40 @@ use std::{env, sync::LazyLock};
use reqwest::Method;
use rocket::{
Catcher, Route,
form::Form,
http::{Cookie, CookieJar, MediaType, SameSite, Status},
request::{FromRequest, Outcome, Request},
response::{content::RawHtml as Html, Redirect},
response::{Redirect, content::RawHtml as Html},
serde::json::Json,
Catcher, Route,
};
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::{
CONFIG, VERSION,
api::{
ApiResult, EmptyResult, JsonResult, Notify,
core::{log_event, two_factor},
unregister_push_device, ApiResult, EmptyResult, JsonResult, Notify,
unregister_push_device,
},
auth::{decode_admin, encode_jwt, generate_admin_claims, ClientIp, Secure},
auth::{ClientIp, Secure, decode_admin, encode_jwt, generate_admin_claims},
config::ConfigBuilder,
db::{
backup_sqlite, get_sql_server_version,
ACTIVE_DB_TYPE, DbConn, DbConnType, backup_sqlite, get_sql_server_version,
models::{
Attachment, Cipher, Collection, Device, Event, EventType, Group, Invitation, Membership, MembershipId,
MembershipType, OrgPolicy, Organization, OrganizationId, SsoUser, TwoFactor, User, UserId,
},
DbConn, DbConnType, ACTIVE_DB_TYPE,
},
error::{Error, MapResult},
http_client::make_http_request,
mail,
sso::FAKE_SSO_IDENTIFIER,
util::{
container_base_image, format_naive_datetime_local, get_active_web_release, get_display_size,
is_running_in_container, parse_experimental_client_feature_flags, FeatureFlagFilter, NumberOrString,
FeatureFlagFilter, NumberOrString, container_base_image, format_naive_datetime_local, get_active_web_release,
get_display_size, is_running_in_container, parse_experimental_client_feature_flags,
},
CONFIG, VERSION,
};
pub fn routes() -> Vec<Route> {
@@ -93,8 +93,7 @@ static DB_TYPE: LazyLock<&str> = LazyLock::new(|| match ACTIVE_DB_TYPE.get() {
});
#[cfg(sqlite)]
static CAN_BACKUP: LazyLock<bool> =
LazyLock::new(|| ACTIVE_DB_TYPE.get().map(|t| *t == DbConnType::Sqlite).unwrap_or(false));
static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| ACTIVE_DB_TYPE.get().is_some_and(|t| *t == DbConnType::Sqlite));
#[cfg(not(sqlite))]
static CAN_BACKUP: LazyLock<bool> = LazyLock::new(|| false);
@@ -200,13 +199,7 @@ fn post_admin_login(
}
// If the token is invalid, redirect to login page
if !_validate_token(&data.token) {
error!("Invalid admin token. IP: {}", ip.ip);
Err(AdminResponse::Unauthorized(render_admin_login(
Some("Invalid admin token, please try again."),
redirect.as_deref(),
)))
} else {
if validate_token(&data.token) {
// If the token received is valid, generate JWT and save it as a cookie
let claims = generate_admin_claims();
let jwt = encode_jwt(&claims);
@@ -224,10 +217,16 @@ fn post_admin_login(
} else {
Err(AdminResponse::Ok(render_admin_page()))
}
} else {
error!("Invalid admin token. IP: {}", ip.ip);
Err(AdminResponse::Unauthorized(render_admin_login(
Some("Invalid admin token, please try again."),
redirect.as_deref(),
)))
}
}
fn _validate_token(token: &str) -> bool {
fn validate_token(token: &str) -> bool {
match CONFIG.admin_token().as_ref() {
None => false,
Some(t) if t.starts_with("$argon2") => {
@@ -307,21 +306,14 @@ async fn get_user_or_404(user_id: &UserId, conn: &DbConn) -> ApiResult<User> {
#[post("/invite", format = "application/json", data = "<data>")]
async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -> JsonResult {
let data: InviteData = data.into_inner();
if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code)
}
let mut user = User::new(&data.email, None);
async fn _generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
async fn generate_invite(user: &User, conn: &DbConn) -> EmptyResult {
if CONFIG.mail_enabled() {
let org_id: OrganizationId = if CONFIG.sso_enabled() {
FAKE_SSO_IDENTIFIER.into()
} else {
FAKE_ADMIN_UUID.into()
};
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into();
let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into();
mail::send_invite(user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else {
let invitation = Invitation::new(&user.email);
@@ -329,7 +321,14 @@ async fn invite_user(data: Json<InviteData>, _token: AdminToken, conn: DbConn) -
}
}
_generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
let data: InviteData = data.into_inner();
if User::find_by_mail(&data.email, &conn).await.is_some() {
err_code!("User already exists", Status::Conflict.code)
}
let mut user = User::new(&data.email, None);
generate_invite(&user, &conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
user.save(&conn).await.map_err(|e| e.with_code(Status::InternalServerError.code))?;
Ok(Json(user.to_json(&conn).await))
@@ -386,7 +385,7 @@ async fn users_overview(_token: AdminToken, conn: DbConn) -> ApiResult<Html<Stri
None => json!("Never"),
};
usr["sso_identifier"] = json!(sso_u.map(|u| u.identifier.to_string()).unwrap_or(String::new()));
usr["sso_identifier"] = json!(sso_u.map_or(String::new(), |u| u.identifier.to_string()));
users_json.push(usr);
}
@@ -469,10 +468,10 @@ async fn deauth_user(user_id: UserId, _token: AdminToken, conn: DbConn, nt: Noti
if CONFIG.push_enabled() {
for device in Device::find_push_devices_by_user(&user.uuid, &conn).await {
match unregister_push_device(&device.push_uuid).await {
match unregister_push_device(device.push_uuid.as_ref()).await {
Ok(r) => r,
Err(e) => error!("Unable to unregister devices from Bitwarden server: {e}"),
};
}
}
}
@@ -528,7 +527,7 @@ async fn resend_user_invite(user_id: UserId, _token: AdminToken, conn: DbConn) -
} else {
FAKE_ADMIN_UUID.into()
};
let member_id: MembershipId = FAKE_ADMIN_UUID.to_string().into();
let member_id: MembershipId = FAKE_ADMIN_UUID.to_owned().into();
mail::send_invite(&user, org_id, member_id, &CONFIG.invitation_org_name(), None).await
} else {
Ok(())
@@ -554,9 +553,10 @@ async fn update_membership_type(data: Json<MembershipTypeData>, token: AdminToke
err!("The specified user isn't member of the organization")
};
let new_type = match MembershipType::from_str(&data.user_type.into_string()) {
Some(new_type) => new_type as i32,
None => err!("Invalid type"),
let new_type = if let Some(new_type) = MembershipType::from_str(&data.user_type.into_string()) {
new_type as i32
} else {
err!("Invalid type")
};
if member_to_edit.atype == MembershipType::Owner && new_type != MembershipType::Owner {
@@ -656,42 +656,40 @@ async fn get_release_info(has_http_access: bool) -> (String, String, String) {
.await
{
Ok(r) => r.tag_name,
_ => "-".to_string(),
_ => "-".to_owned(),
},
match get_json_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await {
Ok(mut c) => {
c.sha.truncate(8);
c.sha
}
_ => "-".to_string(),
_ => "-".to_owned(),
},
// Do not fetch the web-vault version when running within a container
// The web-vault version is embedded within the container it self, and should not be updated manually
match get_json_api::<GitRelease>("https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest")
.await
{
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
_ => "-".to_string(),
Ok(r) => r.tag_name.trim_start_matches('v').to_owned(),
_ => "-".to_owned(),
},
)
} else {
("-".to_string(), "-".to_string(), "-".to_string())
("-".to_owned(), "-".to_owned(), "-".to_owned())
}
}
async fn get_ntp_time(has_http_access: bool) -> String {
if has_http_access {
if let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await {
for line in cf_trace.lines() {
if let Some((key, value)) = line.split_once('=') {
if key == "ts" {
let ts = value.split_once('.').map_or(value, |(s, _)| s);
if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") {
return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string();
}
break;
}
if has_http_access && let Ok(cf_trace) = get_text_api("https://cloudflare.com/cdn-cgi/trace").await {
for line in cf_trace.lines() {
if let Some((key, value)) = line.split_once('=')
&& key == "ts"
{
let ts = value.split_once('.').map_or(value, |(s, _)| s);
if let Ok(dt) = chrono::DateTime::parse_from_str(ts, "%s") {
return dt.format("%Y-%m-%d %H:%M:%S UTC").to_string();
}
break;
}
}
}
@@ -734,7 +732,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A
// Check if we are able to resolve DNS entries
let dns_resolved = match ("github.com", 0).to_socket_addrs().map(|mut i| i.next()) {
Ok(Some(a)) => a.ip().to_string(),
_ => "Unable to resolve domain name.".to_string(),
_ => "Unable to resolve domain name.".to_owned(),
};
let (latest_vw_release, latest_vw_commit, latest_web_release) = get_release_info(has_http_access).await;
@@ -745,7 +743,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> A
let invalid_feature_flags: Vec<String> = parse_experimental_client_feature_flags(
&CONFIG.experimental_client_feature_flags(),
FeatureFlagFilter::InvalidOnly,
&FeatureFlagFilter::InvalidOnly,
)
.into_keys()
.collect();
@@ -834,33 +832,30 @@ impl<'r> FromRequest<'r> for AdminToken {
type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
let Outcome::Success(ip) = ClientIp::from_request(request).await else {
err_handler!("Error getting Client IP")
};
if !CONFIG.disable_admin_token() {
let cookies = request.cookies();
let access_token = match cookies.get(COOKIE_NAME) {
Some(cookie) => cookie.value(),
None => {
let requested_page =
request.segments::<std::path::PathBuf>(0..).unwrap_or_default().display().to_string();
// When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page
// Else, return a 401 failure, which will be caught
if requested_page.is_empty() {
return Outcome::Forward(Status::Unauthorized);
} else {
return Outcome::Error((Status::Unauthorized, "Unauthorized"));
}
let access_token = if let Some(cookie) = cookies.get(COOKIE_NAME) {
cookie.value()
} else {
let requested_page =
request.segments::<std::path::PathBuf>(0..).unwrap_or_default().display().to_string();
// When the requested page is empty, it is `/admin`, in that case, Forward, so it will render the login page
// Else, return a 401 failure, which will be caught
if requested_page.is_empty() {
return Outcome::Forward(Status::Unauthorized);
}
return Outcome::Error((Status::Unauthorized, "Unauthorized"));
};
if decode_admin(access_token).is_err() {
// Remove admin cookie
cookies.remove(Cookie::build(COOKIE_NAME).path(admin_path()));
error!("Invalid or expired admin JWT. IP: {}.", &ip.ip);
error!("Invalid or expired admin JWT. IP: {}.", ip.ip);
return Outcome::Error((Status::Unauthorized, "Session expired"));
}
}
+93 -97
View File
@@ -1,34 +1,37 @@
use std::collections::HashSet;
use crate::db::DbPool;
use chrono::Utc;
use rocket::serde::json::Json;
use serde_json::Value;
use crate::{
api::{
core::{accept_org_invite, log_user_event, two_factor::email},
master_password_policy, register_push_device, unregister_push_device, AnonymousNotify, ApiResult, EmptyResult,
JsonResult, Notify, PasswordOrOtpData, UpdateType,
},
auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers},
crypto,
db::{
models::{
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, EmergencyAccess,
EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId, OrgPolicy,
OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
},
DbConn,
},
mail,
util::{deser_opt_nonempty_str, format_date, NumberOrString},
CONFIG,
};
use rocket::{
http::Status,
request::{FromRequest, Outcome, Request},
serde::json::Json,
};
use serde_json::Value;
use crate::{
CONFIG,
api::{
AnonymousNotify, ApiResult, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType,
core::{accept_org_invite, log_user_event, two_factor::email},
master_password_policy, register_push_device, unregister_push_device,
},
auth::{ClientHeaders, Headers, decode_delete, decode_invite, decode_verify_email},
crypto,
db::{
DbConn, DbPool,
models::{
AuthRequest, AuthRequestId, Cipher, CipherId, Device, DeviceId, DeviceType, DeviceWithAuthRequest,
EmergencyAccess, EmergencyAccessId, EventType, Folder, FolderId, Invitation, Membership, MembershipId,
OrgPolicy, OrgPolicyType, Organization, OrganizationId, Send, SendId, User, UserId, UserKdfType,
},
},
mail,
util::{NumberOrString, deser_opt_nonempty_str, format_date},
};
use super::{
ciphers::{CipherData, update_cipher_from_data},
sends::{SendData, update_send_from_data},
};
pub fn routes() -> Vec<rocket::Route> {
@@ -54,9 +57,9 @@ pub fn routes() -> Vec<rocket::Route> {
delete_account,
revision_date,
password_hint,
prelogin,
post_prelogin,
verify_password,
api_key,
post_api_key,
rotate_api_key,
get_known_device,
get_all_devices,
@@ -137,17 +140,17 @@ struct KeysData {
}
/// Trims whitespace from password hints, and converts blank password hints to `None`.
fn clean_password_hint(password_hint: &Option<String>) -> Option<String> {
fn clean_password_hint(password_hint: Option<&String>) -> Option<String> {
match password_hint {
None => None,
Some(h) => match h.trim() {
"" => None,
ht => Some(ht.to_string()),
ht => Some(ht.to_owned()),
},
}
}
fn enforce_password_hint_setting(password_hint: &Option<String>) -> EmptyResult {
fn enforce_password_hint_setting(password_hint: Option<&String>) -> EmptyResult {
if password_hint.is_some() && !CONFIG.password_hints_allowed() {
err!("Password hints have been disabled by the administrator. Remove the hint and try again.");
}
@@ -166,7 +169,7 @@ async fn is_email_2fa_required(member_id: Option<MembershipId>, conn: &DbConn) -
false
}
pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
pub async fn register(data: Json<RegisterData>, email_verification: bool, conn: DbConn) -> JsonResult {
let mut data: RegisterData = data.into_inner();
let email = data.email.to_lowercase();
@@ -237,16 +240,16 @@ pub async fn _register(data: Json<RegisterData>, email_verification: bool, conn:
// Check if the length of the username exceeds 50 characters (Same is Upstream Bitwarden)
// This also prevents issues with very long usernames causing to large JWT's. See #2419
if let Some(ref name) = data.name {
if name.len() > 50 {
err!("The field Name must be a string with a maximum length of 50.");
}
if let Some(ref name) = data.name
&& name.len() > 50
{
err!("The field Name must be a string with a maximum length of 50.");
}
// Check against the password hint setting here so if it fails, the user
// can retry without losing their invitation below.
let password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&password_hint)?;
let password_hint = clean_password_hint(data.master_password_hint.as_ref());
enforce_password_hint_setting(password_hint.as_ref())?;
let mut user = match User::find_by_mail(&email, &conn).await {
Some(user) => {
@@ -353,8 +356,8 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn:
// Check against the password hint setting here so if it fails,
// the user can retry without losing their invitation below.
let password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&password_hint)?;
let password_hint = clean_password_hint(data.master_password_hint.as_ref());
enforce_password_hint_setting(password_hint.as_ref())?;
set_kdf_data(&mut user, &data.kdf)?;
@@ -373,18 +376,19 @@ async fn post_set_password(data: Json<SetPasswordData>, headers: Headers, conn:
user.public_key = Some(keys.public_key);
}
if let Some(identifier) = data.org_identifier {
if identifier != crate::sso::FAKE_SSO_IDENTIFIER && identifier != crate::api::admin::FAKE_ADMIN_UUID {
let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else {
err!("Failed to retrieve the associated organization")
};
if let Some(identifier) = data.org_identifier
&& identifier != crate::sso::FAKE_SSO_IDENTIFIER
&& identifier != crate::api::admin::FAKE_ADMIN_UUID
{
let Some(org) = Organization::find_by_uuid(&identifier.into(), &conn).await else {
err!("Failed to retrieve the associated organization")
};
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await else {
err!("Failed to retrieve the invitation")
};
let Some(membership) = Membership::find_by_user_and_org(&user.uuid, &org.uuid, &conn).await else {
err!("Failed to retrieve the invitation")
};
accept_org_invite(&user, membership, None, &conn).await?;
}
accept_org_invite(&user, membership, None, &conn).await?;
}
if CONFIG.mail_enabled() {
@@ -451,10 +455,10 @@ async fn put_avatar(data: Json<AvatarData>, headers: Headers, conn: DbConn) -> J
// It looks like it only supports the 6 hex color format.
// If you try to add the short value it will not show that color.
// Check and force 7 chars, including the #.
if let Some(color) = &data.avatar_color {
if color.len() != 7 {
err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters")
}
if let Some(color) = &data.avatar_color
&& color.len() != 7
{
err!("The field AvatarColor must be a HTML/Hex color code with a length of 7 characters")
}
let mut user = headers.user;
@@ -515,8 +519,8 @@ async fn post_password(data: Json<ChangePassData>, headers: Headers, conn: DbCon
err!("Invalid password")
}
user.password_hint = clean_password_hint(&data.master_password_hint);
enforce_password_hint_setting(&user.password_hint)?;
user.password_hint = clean_password_hint(data.master_password_hint.as_ref());
enforce_password_hint_setting(user.password_hint.as_ref())?;
log_user_event(EventType::UserChangedPassword as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn)
.await;
@@ -668,9 +672,6 @@ struct UpdateResetPasswordData {
reset_password_key: String,
}
use super::ciphers::CipherData;
use super::sends::{update_send_from_data, SendData};
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct KeyData {
@@ -840,7 +841,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
};
saved_folder.name = folder_data.name;
saved_folder.save(&conn).await?
saved_folder.save(&conn).await?;
}
}
@@ -853,7 +854,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
};
saved_emergency_access.key_encrypted = Some(emergency_access_data.key_encrypted);
saved_emergency_access.save(&conn).await?
saved_emergency_access.save(&conn).await?;
}
// Update reset password data
@@ -865,7 +866,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
};
membership.reset_password_key = Some(reset_password_data.reset_password_key);
membership.save(&conn).await?
membership.save(&conn).await?;
}
// Update send data
@@ -878,8 +879,6 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
}
// Update cipher data
use super::ciphers::update_cipher_from_data;
for cipher_data in data.account_data.ciphers {
if cipher_data.organization_id.is_none() {
let Some(saved_cipher) = existing_ciphers.iter_mut().find(|c| &c.uuid == cipher_data.id.as_ref().unwrap())
@@ -890,7 +889,7 @@ async fn post_rotatekey(data: Json<KeyData>, headers: Headers, conn: DbConn, nt:
// Prevent triggering cipher updates via WebSockets by settings UpdateType::None
// The user sessions are invalidated because all the ciphers were re-encrypted and thus triggering an update could cause issues.
// We force the users to logout after the user has been saved to try and prevent these issues.
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?
update_cipher_from_data(saved_cipher, cipher_data, &headers, None, &conn, &nt, UpdateType::None).await?;
}
}
@@ -1020,24 +1019,22 @@ async fn post_email(data: Json<ChangeEmailData>, headers: Headers, conn: DbConn,
err!("Email already in use");
}
match user.email_new {
Some(ref val) => {
if val != &data.new_email {
err!("Email change mismatch");
}
if let Some(ref val) = user.email_new {
if val != &data.new_email {
err!("Email change mismatch");
}
None => err!("No email change pending"),
} else {
err!("No email change pending")
}
if CONFIG.mail_enabled() {
// Only check the token if we sent out an email...
match user.email_new_token {
Some(ref val) => {
if *val != data.token.into_string() {
err!("Token mismatch");
}
if let Some(ref val) = user.email_new_token {
if *val != data.token.into_string() {
err!("Token mismatch");
}
None => err!("No email change pending"),
} else {
err!("No email change pending")
}
user.verified_at = Some(Utc::now().naive_utc());
} else {
@@ -1114,10 +1111,10 @@ async fn post_delete_recover(data: Json<DeleteRecoverData>, conn: DbConn) -> Emp
let data: DeleteRecoverData = data.into_inner();
if CONFIG.mail_enabled() {
if let Some(user) = User::find_by_mail(&data.email, &conn).await {
if let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await {
error!("Error sending delete account email: {e:#?}");
}
if let Some(user) = User::find_by_mail(&data.email, &conn).await
&& let Err(e) = mail::send_delete_account(&user.email, &user.uuid).await
{
error!("Error sending delete account email: {e:#?}");
}
Ok(())
} else {
@@ -1169,6 +1166,7 @@ async fn delete_account(data: Json<PasswordOrOtpData>, headers: Headers, conn: D
user.delete(&conn).await
}
#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")]
#[get("/accounts/revision-date")]
fn revision_date(headers: Headers) -> JsonResult {
let revision_date = headers.user.updated_at.and_utc().timestamp_millis();
@@ -1183,12 +1181,12 @@ struct PasswordHintData {
#[post("/accounts/password-hint", data = "<data>")]
async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResult {
const NO_HINT: &str = "Sorry, you have no password hint...";
if !CONFIG.password_hints_allowed() || (!CONFIG.mail_enabled() && !CONFIG.show_password_hint()) {
err!("This server is not configured to provide password hints.");
}
const NO_HINT: &str = "Sorry, you have no password hint...";
let data: PasswordHintData = data.into_inner();
let email = &data.email;
@@ -1199,9 +1197,9 @@ async fn password_hint(data: Json<PasswordHintData>, conn: DbConn) -> EmptyResul
// There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that
// don't. Add a randomized sleep to mitigate this somewhat.
use rand::{rngs::SmallRng, RngExt};
use rand::{RngExt, rngs::SmallRng};
let mut rng: SmallRng = rand::make_rng();
let sleep_ms = rng.random_range(900..=1100) as u64;
let sleep_ms: u64 = rng.random_range(900..=1100);
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
Ok(())
} else {
@@ -1229,11 +1227,11 @@ pub struct PreloginData {
}
#[post("/accounts/prelogin", data = "<data>")]
async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
async fn post_prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
prelogin(data, conn).await
}
pub async fn _prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
pub async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
let data: PreloginData = data.into_inner();
let (kdf_type, kdf_iter, kdf_mem, kdf_para) = match User::find_by_mail(&data.email, &conn).await {
@@ -1283,9 +1281,7 @@ async fn verify_password(data: Json<SecretVerificationRequest>, headers: Headers
Ok(Json(master_password_policy(&user, &conn).await))
}
async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
use crate::util::format_date;
async fn update_api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers, conn: DbConn) -> JsonResult {
let data: PasswordOrOtpData = data.into_inner();
let mut user = headers.user;
@@ -1304,13 +1300,13 @@ async fn _api_key(data: Json<PasswordOrOtpData>, rotate: bool, headers: Headers,
}
#[post("/accounts/api-key", data = "<data>")]
async fn api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, false, headers, conn).await
async fn post_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
update_api_key(data, false, headers, conn).await
}
#[post("/accounts/rotate-api-key", data = "<data>")]
async fn rotate_api_key(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbConn) -> JsonResult {
_api_key(data, true, headers, conn).await
update_api_key(data, true, headers, conn).await
}
#[get("/devices/knowndevice")]
@@ -1353,7 +1349,7 @@ impl<'r> FromRequest<'r> for KnownDevice {
};
let uuid = if let Some(uuid) = req.headers().get_one("X-Device-Identifier") {
uuid.to_string().into()
uuid.to_owned().into()
} else {
return Outcome::Error((Status::BadRequest, "X-Device-Identifier value is required"));
};
@@ -1368,7 +1364,7 @@ impl<'r> FromRequest<'r> for KnownDevice {
#[get("/devices")]
async fn get_all_devices(headers: Headers, conn: DbConn) -> JsonResult {
let devices = Device::find_with_auth_request_by_user(&headers.user.uuid, &conn).await;
let devices = devices.iter().map(|device| device.to_json()).collect::<Vec<Value>>();
let devices = devices.iter().map(DeviceWithAuthRequest::to_json).collect::<Vec<Value>>();
Ok(Json(json!({
"data": devices,
@@ -1438,7 +1434,7 @@ async fn put_clear_device_token(device_id: DeviceId, conn: DbConn) -> EmptyResul
if let Some(device) = Device::find_by_uuid(&device_id, &conn).await {
Device::clear_push_token_by_uuid(&device_id, &conn).await?;
unregister_push_device(&device.push_uuid).await?;
unregister_push_device(device.push_uuid.as_ref()).await?;
}
Ok(())
@@ -1708,6 +1704,6 @@ pub async fn purge_auth_requests(pool: DbPool) {
if let Ok(conn) = pool.get().await {
AuthRequest::purge_expired_auth_requests(&conn).await;
} else {
error!("Failed to get DB connection while purging auth requests")
error!("Failed to get DB connection while purging auth requests");
}
}
+297 -130
View File
@@ -2,30 +2,30 @@ use std::collections::{HashMap, HashSet};
use chrono::{NaiveDateTime, Utc};
use num_traits::ToPrimitive;
use rocket::fs::TempFile;
use rocket::serde::json::Json;
use rocket::{
form::{Form, FromForm},
Route,
form::{Form, FromForm},
fs::TempFile,
serde::json::Json,
};
use serde_json::Value;
use crate::auth::ClientVersion;
use crate::util::{deser_opt_nonempty_str, save_temp_file, NumberOrString};
use crate::{
api::{self, core::log_event, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType},
CONFIG,
api::{self, EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType, core::log_event},
auth::ClientVersion,
auth::{Headers, OrgIdGuard, OwnerHeaders},
config::PathType,
crypto,
db::{
models::{
Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId,
CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership, MembershipType,
OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
},
DbConn, DbPool,
models::{
Archive, Attachment, AttachmentId, Cipher, CipherId, Collection, CollectionCipher, CollectionGroup,
CollectionId, CollectionUser, EventType, Favorite, Folder, FolderCipher, FolderId, Group, Membership,
MembershipType, OrgPolicy, OrgPolicyType, OrganizationId, RepromptType, Send, UserId,
},
},
CONFIG,
util::{NumberOrString, deser_opt_nonempty_str, save_temp_file},
};
use super::folders::FolderData;
@@ -96,6 +96,10 @@ pub fn routes() -> Vec<Route> {
post_collections_update,
post_collections_admin,
put_collections_admin,
archive_cipher_put,
archive_cipher_selected,
unarchive_cipher_put,
unarchive_cipher_selected,
]
}
@@ -104,7 +108,7 @@ pub async fn purge_trashed_ciphers(pool: DbPool) {
if let Ok(conn) = pool.get().await {
Cipher::purge_trash(&conn).await;
} else {
error!("Failed to get DB connection while purging trashed ciphers")
error!("Failed to get DB connection while purging trashed ciphers");
}
}
@@ -160,7 +164,7 @@ async fn sync(data: SyncData, headers: Headers, client_version: Option<ClientVer
let domains_json = if data.exclude_domains {
Value::Null
} else {
api::core::_get_eq_domains(&headers, true).into_inner()
api::core::get_eq_domains(&headers, true).into_inner()
};
// This is very similar to the the userDecryptionOptions sent in connect/token,
@@ -293,6 +297,7 @@ pub struct CipherData {
// when using older client versions, or if the operation doesn't involve
// updating an existing cipher.
last_known_revision_date: Option<String>,
archived_date: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -396,20 +401,34 @@ pub async fn update_cipher_from_data(
nt: &Notify<'_>,
ut: UpdateType,
) -> EmptyResult {
// Cleanup cipher data, like removing the 'Response' key.
// This key is somewhere generated during Javascript so no way for us this fix this.
// Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key.
// We do not mind which data is in it, the keep our model more flexible when there are upstream changes.
// But, we at least know we do not need to store and return this specific key.
fn clean_cipher_data(mut json_data: Value) -> Value {
if json_data.is_array() {
json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| {
f.as_object_mut().unwrap().remove("response");
});
}
json_data
}
enforce_personal_ownership_policy(Some(&data), headers, conn).await?;
// Check that the client isn't updating an existing cipher with stale data.
// And only perform this check when not importing ciphers, else the date/time check will fail.
if ut != UpdateType::None {
if let Some(dt) = data.last_known_revision_date {
match NaiveDateTime::parse_from_str(&dt, "%+") {
// ISO 8601 format
Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"),
Ok(dt) if cipher.updated_at.signed_duration_since(dt).num_seconds() > 1 => {
err!("The client copy of this cipher is out of date. Resync the client and try again.")
}
Ok(_) => (),
if ut != UpdateType::None
&& let Some(dt) = data.last_known_revision_date
{
match NaiveDateTime::parse_from_str(&dt, "%+") {
// ISO 8601 format
Err(err) => warn!("Error parsing LastKnownRevisionDate '{dt}': {err}"),
Ok(dt) if cipher.updated_at.signed_duration_since(dt).num_seconds() > 1 => {
err!("The client copy of this cipher is out of date. Resync the client and try again.")
}
Ok(_) => (),
}
}
@@ -451,25 +470,22 @@ pub async fn update_cipher_from_data(
cipher.user_uuid = Some(headers.user.uuid.clone());
}
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
if let Some(ref folder_id) = data.folder_id
&& Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
// Modify attachments name and keys when rotating
if let Some(attachments) = data.attachments2 {
for (id, attachment) in attachments {
let mut saved_att = match Attachment::find_by_id(&id, conn).await {
Some(att) => att,
None => {
// Warn and continue here.
// A missing attachment means it was removed via an other client.
// Also the Desktop Client supports removing attachments and save an update afterwards.
// Bitwarden it self ignores these mismatches server side.
warn!("Attachment {id} doesn't exist");
continue;
}
let Some(mut saved_att) = Attachment::find_by_id(&id, conn).await else {
// Warn and continue here.
// A missing attachment means it was removed via an other client.
// Also the Desktop Client supports removing attachments and save an update afterwards.
// Bitwarden it self ignores these mismatches server side.
warn!("Attachment {id} doesn't exist");
continue;
};
if saved_att.cipher_uuid != cipher.uuid {
@@ -486,20 +502,6 @@ pub async fn update_cipher_from_data(
}
}
// Cleanup cipher data, like removing the 'Response' key.
// This key is somewhere generated during Javascript so no way for us this fix this.
// Also, upstream only retrieves keys they actually want to store, and thus skip the 'Response' key.
// We do not mind which data is in it, the keep our model more flexible when there are upstream changes.
// But, we at least know we do not need to store and return this specific key.
fn _clean_cipher_data(mut json_data: Value) -> Value {
if json_data.is_array() {
json_data.as_array_mut().unwrap().iter_mut().for_each(|ref mut f| {
f.as_object_mut().unwrap().remove("response");
});
};
json_data
}
let type_data_opt = match data.r#type {
1 => data.login,
2 => data.secure_note,
@@ -509,23 +511,22 @@ pub async fn update_cipher_from_data(
_ => err!("Invalid type"),
};
let type_data = match type_data_opt {
Some(mut data) => {
// Remove the 'Response' key from the base object.
data.as_object_mut().unwrap().remove("response");
// Remove the 'Response' key from every Uri.
if data["uris"].is_array() {
data["uris"] = _clean_cipher_data(data["uris"].clone());
}
data
let type_data = if let Some(mut data) = type_data_opt {
// Remove the 'Response' key from the base object.
data.as_object_mut().unwrap().remove("response");
// Remove the 'Response' key from every Uri.
if data["uris"].is_array() {
data["uris"] = clean_cipher_data(data["uris"].clone());
}
None => err!("Data missing"),
data
} else {
err!("Data missing")
};
cipher.key = data.key;
cipher.name = data.name;
cipher.notes = data.notes;
cipher.fields = data.fields.map(|f| _clean_cipher_data(f).to_string());
cipher.fields = data.fields.map(|f| clean_cipher_data(f).to_string());
cipher.data = type_data.to_string();
cipher.password_history = data.password_history.map(|f| f.to_string());
cipher.reprompt = data.reprompt.filter(|r| *r == RepromptType::None as i32 || *r == RepromptType::Password as i32);
@@ -534,6 +535,13 @@ pub async fn update_cipher_from_data(
cipher.move_to_folder(data.folder_id, &headers.user.uuid, conn).await?;
cipher.set_favorite(data.favorite, &headers.user.uuid, conn).await?;
if let Some(dt_str) = data.archived_date {
match NaiveDateTime::parse_from_str(&dt_str, "%+") {
Ok(dt) => cipher.set_archived_at(dt, &headers.user.uuid, conn).await?,
Err(err) => warn!("Error parsing ArchivedDate '{dt_str}': {err}"),
}
}
if ut != UpdateType::None {
// Only log events for organizational ciphers
if let Some(org_id) = &cipher.organization_uuid {
@@ -600,7 +608,7 @@ async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbC
let existing_folders: HashSet<Option<FolderId>> =
Folder::find_by_user(&headers.user.uuid, &conn).await.into_iter().map(|f| Some(f.uuid)).collect();
let mut folders: Vec<FolderId> = Vec::with_capacity(data.folders.len());
for folder in data.folders.into_iter() {
for folder in data.folders {
let folder_id = if existing_folders.contains(&folder.id) {
folder.id.unwrap()
} else {
@@ -630,7 +638,7 @@ async fn post_ciphers_import(data: Json<ImportData>, headers: Headers, conn: DbC
let mut user = headers.user;
user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncVault, &user, headers.device.push_uuid.as_ref(), &conn).await;
Ok(())
}
@@ -725,10 +733,10 @@ async fn put_cipher_partial(
err!("Cipher does not exist", "Cipher is not accessible for the current user")
}
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
if let Some(ref folder_id) = data.folder_id
&& Folder::find_by_uuid_and_user(folder_id, &headers.user.uuid, &conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
// Move cipher
@@ -802,12 +810,16 @@ async fn post_collections_update(
err!("Collection cannot be changed")
}
let Some(ref org_uuid) = cipher.organization_uuid else {
err!("Cipher is not owned by an organization")
};
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
match Collection::find_by_uuid_and_org(collection, org_uuid, &conn).await {
None => err!("Invalid collection ID provided"),
Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
@@ -838,7 +850,7 @@ async fn post_collections_update(
log_event(
EventType::CipherUpdatedCollections as i32,
&cipher.uuid,
&cipher.organization_uuid.clone().unwrap(),
org_uuid,
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
@@ -878,12 +890,16 @@ async fn post_collections_admin(
err!("Collection cannot be changed")
}
let Some(ref org_uuid) = cipher.organization_uuid else {
err!("Cipher is not owned by an organization")
};
let posted_collections = HashSet::<CollectionId>::from_iter(data.collection_ids);
let current_collections =
HashSet::<CollectionId>::from_iter(cipher.get_admin_collections(headers.user.uuid.clone(), &conn).await);
for collection in posted_collections.symmetric_difference(&current_collections) {
match Collection::find_by_uuid_and_org(collection, cipher.organization_uuid.as_ref().unwrap(), &conn).await {
match Collection::find_by_uuid_and_org(collection, org_uuid, &conn).await {
None => err!("Invalid collection ID provided"),
Some(collection) => {
if collection.is_writable_by_user(&headers.user.uuid, &conn).await {
@@ -914,7 +930,7 @@ async fn post_collections_admin(
log_event(
EventType::CipherUpdatedCollections as i32,
&cipher.uuid,
&cipher.organization_uuid.unwrap(),
org_uuid,
&headers.user.uuid,
headers.device.atype,
&headers.ip.ip,
@@ -984,7 +1000,7 @@ async fn put_cipher_share_selected(
err!("You must select at least one collection.")
}
for cipher in data.ciphers.iter() {
for cipher in &data.ciphers {
if cipher.id.is_none() {
err!("Request missing ids field")
}
@@ -996,16 +1012,15 @@ async fn put_cipher_share_selected(
collection_ids: data.collection_ids.clone(),
};
match shared_cipher_data.cipher.id.take() {
Some(id) => {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
}
None => err!("Request missing ids field"),
if let Some(id) = shared_cipher_data.cipher.id.take() {
share_cipher_by_uuid(&id, shared_cipher_data, &headers, &conn, &nt, Some(UpdateType::None)).await?
} else {
err!("Request missing ids field")
};
}
// Multi share actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), &conn).await;
Ok(())
}
@@ -1018,15 +1033,14 @@ async fn share_cipher_by_uuid(
nt: &Notify<'_>,
override_ut: Option<UpdateType>,
) -> JsonResult {
let mut cipher = match Cipher::find_by_uuid(cipher_id, conn).await {
Some(cipher) => {
if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await {
cipher
} else {
err!("Cipher is not write accessible")
}
let mut cipher = if let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await {
if cipher.is_write_accessible_to_user(&headers.user.uuid, conn).await {
cipher
} else {
err!("Cipher is not write accessible")
}
None => err!("Cipher doesn't exist"),
} else {
err!("Cipher doesn't exist")
};
let mut shared_to_collections = vec![];
@@ -1045,7 +1059,7 @@ async fn share_cipher_by_uuid(
}
}
}
};
}
// When LastKnownRevisionDate is None, it is a new cipher, so send CipherCreate.
// If there is an override, like when handling multiple items, we want to prevent a push notification for every single item
@@ -1243,10 +1257,10 @@ async fn save_attachment(
err!("Cipher is neither owned by a user nor an organization");
};
if let Some(size_limit) = size_limit {
if size > size_limit {
err!("Attachment storage limit exceeded with this file");
}
if let Some(size_limit) = size_limit
&& size > size_limit
{
err!("Attachment storage limit exceeded with this file");
}
let file_id = match &attachment {
@@ -1388,7 +1402,7 @@ async fn post_attachment_share(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await?;
post_attachment(cipher_id, data, headers, conn, nt).await
}
@@ -1422,7 +1436,7 @@ async fn delete_attachment(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
}
#[delete("/ciphers/<cipher_id>/attachment/<attachment_id>/admin")]
@@ -1433,42 +1447,42 @@ async fn delete_attachment_admin(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
delete_cipher_attachment_by_id(&cipher_id, &attachment_id, &headers, &conn, &nt).await
}
#[post("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_post(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[post("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_post_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[put("/ciphers/<cipher_id>/delete")]
async fn delete_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete
}
#[put("/ciphers/<cipher_id>/delete-admin")]
async fn delete_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::SoftSingle, &nt).await
// soft delete
}
#[delete("/ciphers/<cipher_id>")]
async fn delete_cipher(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
#[delete("/ciphers/<cipher_id>/admin")]
async fn delete_cipher_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> EmptyResult {
_delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
delete_cipher_by_uuid(&cipher_id, &headers, &conn, &CipherDeleteOptions::HardSingle, &nt).await
// permanent delete
}
@@ -1479,7 +1493,7 @@ async fn delete_cipher_selected(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1490,7 +1504,7 @@ async fn delete_cipher_selected_post(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1501,7 +1515,7 @@ async fn delete_cipher_selected_put(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
// soft delete
}
@@ -1512,7 +1526,7 @@ async fn delete_cipher_selected_admin(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1523,7 +1537,7 @@ async fn delete_cipher_selected_post_admin(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::HardMulti, nt).await
// permanent delete
}
@@ -1534,18 +1548,18 @@ async fn delete_cipher_selected_put_admin(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
delete_multiple_ciphers(data, headers, conn, CipherDeleteOptions::SoftMulti, nt).await
// soft delete
}
#[put("/ciphers/<cipher_id>/restore")]
async fn restore_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/<cipher_id>/restore-admin")]
async fn restore_cipher_put_admin(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
_restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
restore_cipher_by_uuid(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/restore-admin", data = "<data>")]
@@ -1555,7 +1569,7 @@ async fn restore_cipher_selected_admin(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &conn, &nt).await
restore_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[put("/ciphers/restore", data = "<data>")]
@@ -1565,7 +1579,7 @@ async fn restore_cipher_selected(
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
_restore_multiple_ciphers(data, &headers, &conn, &nt).await
restore_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[derive(Deserialize)]
@@ -1586,10 +1600,10 @@ async fn move_cipher_selected(
let data = data.into_inner();
let user_id = &headers.user.uuid;
if let Some(ref folder_id) = data.folder_id {
if Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none() {
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
if let Some(ref folder_id) = data.folder_id
&& Folder::find_by_uuid_and_user(folder_id, user_id, &conn).await.is_none()
{
err!("Invalid folder", "Folder does not exist or belongs to another user");
}
let cipher_count = data.ids.len();
@@ -1618,7 +1632,7 @@ async fn move_cipher_selected(
.await;
} else {
// Multi move actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), &conn).await;
}
if cipher_count != accessible_ciphers_count {
@@ -1670,7 +1684,7 @@ async fn purge_org_vault(
match Membership::find_confirmed_by_user_and_org(&user.uuid, &organization.org_id, &conn).await {
Some(member) if member.atype == MembershipType::Owner => {
Cipher::delete_all_by_organization(&organization.org_id, &conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncVault, &user, headers.device.push_uuid.as_ref(), &conn).await;
log_event(
EventType::OrganizationPurgedVault as i32,
@@ -1710,11 +1724,41 @@ async fn purge_personal_vault(
}
user.update_revision(&conn).await?;
nt.send_user_update(UpdateType::SyncVault, &user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncVault, &user, headers.device.push_uuid.as_ref(), &conn).await;
Ok(())
}
#[put("/ciphers/<cipher_id>/archive")]
async fn archive_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
archive_cipher(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/archive", data = "<data>")]
async fn archive_cipher_selected(
data: Json<CipherIdsData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
archive_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[put("/ciphers/<cipher_id>/unarchive")]
async fn unarchive_cipher_put(cipher_id: CipherId, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
unarchive_cipher(&cipher_id, &headers, false, &conn, &nt).await
}
#[put("/ciphers/unarchive", data = "<data>")]
async fn unarchive_cipher_selected(
data: Json<CipherIdsData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
unarchive_multiple_ciphers(data, &headers, &conn, &nt).await
}
#[derive(PartialEq)]
pub enum CipherDeleteOptions {
SoftSingle,
@@ -1723,7 +1767,7 @@ pub enum CipherDeleteOptions {
HardMulti,
}
async fn _delete_cipher_by_uuid(
async fn delete_cipher_by_uuid(
cipher_id: &CipherId,
headers: &Headers,
conn: &DbConn,
@@ -1789,7 +1833,7 @@ struct CipherIdsData {
ids: Vec<CipherId>,
}
async fn _delete_multiple_ciphers(
async fn delete_multiple_ciphers(
data: Json<CipherIdsData>,
headers: Headers,
conn: DbConn,
@@ -1799,18 +1843,18 @@ async fn _delete_multiple_ciphers(
let data = data.into_inner();
for cipher_id in data.ids {
if let error @ Err(_) = _delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
if let error @ Err(_) = delete_cipher_by_uuid(&cipher_id, &headers, &conn, &delete_options, &nt).await {
return error;
};
}
}
// Multi delete actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), &conn).await;
Ok(())
}
async fn _restore_cipher_by_uuid(
async fn restore_cipher_by_uuid(
cipher_id: &CipherId,
headers: &Headers,
multi_restore: bool,
@@ -1856,7 +1900,7 @@ async fn _restore_cipher_by_uuid(
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
}
async fn _restore_multiple_ciphers(
async fn restore_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &DbConn,
@@ -1866,14 +1910,14 @@ async fn _restore_multiple_ciphers(
let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids {
match _restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await {
match restore_cipher_by_uuid(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()),
err => return err,
}
}
// Multi move actions do not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, &headers.device.push_uuid, conn).await;
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), conn).await;
Ok(Json(json!({
"data": ciphers,
@@ -1882,7 +1926,7 @@ async fn _restore_multiple_ciphers(
})))
}
async fn _delete_cipher_attachment_by_id(
async fn delete_cipher_attachment_by_id(
cipher_id: &CipherId,
attachment_id: &AttachmentId,
headers: &Headers,
@@ -1933,6 +1977,122 @@ async fn _delete_cipher_attachment_by_id(
Ok(Json(json!({"cipher":cipher_json})))
}
async fn archive_cipher(
cipher_id: &CipherId,
headers: &Headers,
multi_archive: bool,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_accessible_to_user(&headers.user.uuid, conn).await {
err!("Cipher is not accessible for the current user")
}
cipher.set_archived_at(Utc::now().naive_utc(), &headers.user.uuid, conn).await?;
if !multi_archive {
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(conn).await,
&headers.device,
None,
conn,
)
.await;
}
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
}
async fn unarchive_cipher(
cipher_id: &CipherId,
headers: &Headers,
multi_unarchive: bool,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let Some(cipher) = Cipher::find_by_uuid(cipher_id, conn).await else {
err!("Cipher doesn't exist")
};
if !cipher.is_accessible_to_user(&headers.user.uuid, conn).await {
err!("Cipher is not accessible for the current user")
}
cipher.unarchive(&headers.user.uuid, conn).await?;
if !multi_unarchive {
nt.send_cipher_update(
UpdateType::SyncCipherUpdate,
&cipher,
&cipher.update_users_revision(conn).await,
&headers.device,
None,
conn,
)
.await;
}
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await?))
}
async fn archive_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids {
match archive_cipher(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()),
err => return err,
}
}
// Multi archive does not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), conn).await;
Ok(Json(json!({
"data": ciphers,
"object": "list",
"continuationToken": null
})))
}
async fn unarchive_multiple_ciphers(
data: Json<CipherIdsData>,
headers: &Headers,
conn: &DbConn,
nt: &Notify<'_>,
) -> JsonResult {
let data = data.into_inner();
let mut ciphers: Vec<Value> = Vec::new();
for cipher_id in data.ids {
match unarchive_cipher(&cipher_id, headers, true, conn, nt).await {
Ok(json) => ciphers.push(json.into_inner()),
err => return err,
}
}
// Multi unarchive does not send out a push for each cipher, we need to send a general sync here
nt.send_user_update(UpdateType::SyncCiphers, &headers.user, headers.device.push_uuid.as_ref(), conn).await;
Ok(Json(json!({
"data": ciphers,
"object": "list",
"continuationToken": null
})))
}
/// This will hold all the necessary data to improve a full sync of all the ciphers
/// It can be used during the `Cipher::to_json()` call.
/// It will prevent the so called N+1 SQL issue by running just a few queries which will hold all the data needed.
@@ -1942,6 +2102,7 @@ pub struct CipherSyncData {
pub cipher_folders: HashMap<CipherId, FolderId>,
pub cipher_favorites: HashSet<CipherId>,
pub cipher_collections: HashMap<CipherId, Vec<CollectionId>>,
pub cipher_archives: HashMap<CipherId, NaiveDateTime>,
pub members: HashMap<OrganizationId, Membership>,
pub user_collections: HashMap<CollectionId, CollectionUser>,
pub user_collections_groups: HashMap<CollectionId, CollectionGroup>,
@@ -1958,20 +2119,25 @@ impl CipherSyncData {
pub async fn new(user_id: &UserId, sync_type: CipherSyncType, conn: &DbConn) -> Self {
let cipher_folders: HashMap<CipherId, FolderId>;
let cipher_favorites: HashSet<CipherId>;
let cipher_archives: HashMap<CipherId, NaiveDateTime>;
match sync_type {
// User Sync supports Folders and Favorites
// User Sync supports Folders, Favorites, and Archives
CipherSyncType::User => {
// Generate a HashMap with the Cipher UUID as key and the Folder UUID as value
cipher_folders = FolderCipher::find_by_user(user_id, conn).await.into_iter().collect();
// Generate a HashSet of all the Cipher UUID's which are marked as favorite
cipher_favorites = Favorite::get_all_cipher_uuid_by_user(user_id, conn).await.into_iter().collect();
// Generate a HashMap with the Cipher UUID as key and the archived date time as value
cipher_archives = Archive::find_by_user(user_id, conn).await.into_iter().collect();
}
// Organization Sync does not support Folders and Favorites.
// Organization Sync does not support Folders, Favorites, or Archives.
// If these are set, it will cause issues in the web-vault.
CipherSyncType::Organization => {
cipher_folders = HashMap::with_capacity(0);
cipher_favorites = HashSet::with_capacity(0);
cipher_archives = HashMap::with_capacity(0);
}
}
@@ -2038,6 +2204,7 @@ impl CipherSyncData {
cipher_folders,
cipher_favorites,
cipher_collections,
cipher_archives,
members,
user_collections,
user_collections_groups,
+28 -24
View File
@@ -1,23 +1,23 @@
use chrono::{TimeDelta, Utc};
use rocket::{serde::json::Json, Route};
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use crate::{
CONFIG,
api::{
core::{CipherSyncData, CipherSyncType},
EmptyResult, JsonResult,
core::{CipherSyncData, CipherSyncType},
},
auth::{decode_emergency_access_invite, Headers},
auth::{Headers, decode_emergency_access_invite},
db::{
DbConn, DbPool,
models::{
Cipher, EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType, Invitation,
Membership, MembershipType, OrgPolicy, TwoFactor, User, UserId,
},
DbConn, DbPool,
},
mail,
util::NumberOrString,
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -55,7 +55,7 @@ async fn get_contacts(headers: Headers, conn: DbConn) -> Json<Value> {
let mut emergency_access_list_json = Vec::with_capacity(emergency_access_list.len());
for ea in emergency_access_list {
if let Some(grantee) = ea.to_json_grantee_details(&conn).await {
emergency_access_list_json.push(grantee)
emergency_access_list_json.push(grantee);
}
}
@@ -89,11 +89,14 @@ async fn get_grantees(headers: Headers, conn: DbConn) -> Json<Value> {
async fn get_emergency_access(emer_id: EmergencyAccessId, headers: Headers, conn: DbConn) -> JsonResult {
check_emergency_access_enabled()?;
match EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await {
Some(emergency_access) => Ok(Json(
if let Some(emergency_access) =
EmergencyAccess::find_by_uuid_and_grantor_uuid(&emer_id, &headers.user.uuid, &conn).await
{
Ok(Json(
emergency_access.to_json_grantee_details(&conn).await.expect("Grantee user should exist but does not!"),
)),
None => err!("Emergency access not valid."),
))
} else {
err!("Emergency access not valid.")
}
}
@@ -136,9 +139,10 @@ async fn post_emergency_access(
err!("Emergency access not valid.")
};
let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) {
Some(new_type) => new_type as i32,
None => err!("Invalid emergency access type."),
let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) {
new_type as i32
} else {
err!("Invalid emergency access type.")
};
emergency_access.atype = new_type;
@@ -205,9 +209,10 @@ async fn send_invite(data: Json<EmergencyAccessInviteData>, headers: Headers, co
let emergency_access_status = EmergencyAccessStatus::Invited as i32;
let new_type = match EmergencyAccessType::from_str(&data.r#type.into_string()) {
Some(new_type) => new_type as i32,
None => err!("Invalid emergency access type."),
let new_type = if let Some(new_type) = EmergencyAccessType::from_str(&data.r#type.into_string()) {
new_type as i32
} else {
err!("Invalid emergency access type.")
};
let grantor_user = headers.user;
@@ -342,12 +347,11 @@ async fn accept_invite(
err!("Claim email does not match current users email")
}
let grantee_user = match User::find_by_mail(&claims.email, &conn).await {
Some(user) => {
Invitation::take(&claims.email, &conn).await;
user
}
None => err!("Invited user not found"),
let grantee_user = if let Some(user) = User::find_by_mail(&claims.email, &conn).await {
Invitation::take(&claims.email, &conn).await;
user
} else {
err!("Invited user not found")
};
// We need to search for the uuid in combination with the email, since we do not yet store the uuid of the grantee in the database.
@@ -766,7 +770,7 @@ pub async fn emergency_request_timeout_job(pool: DbPool) {
}
}
} else {
error!("Failed to get DB connection while searching emergency request timed out")
error!("Failed to get DB connection while searching emergency request timed out");
}
}
@@ -825,6 +829,6 @@ pub async fn emergency_notification_reminder_job(pool: DbPool) {
}
}
} else {
error!("Failed to get DB connection while searching emergency notification reminder")
error!("Failed to get DB connection while searching emergency notification reminder");
}
}
+54 -60
View File
@@ -1,18 +1,18 @@
use std::net::IpAddr;
use chrono::NaiveDateTime;
use rocket::{form::FromForm, serde::json::Json, Route};
use rocket::{Route, form::FromForm, serde::json::Json};
use serde_json::Value;
use crate::{
CONFIG,
api::{EmptyResult, JsonResult},
auth::{AdminHeaders, Headers},
db::{
models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId},
DbConn, DbPool,
models::{Cipher, CipherId, Event, Membership, MembershipId, OrganizationId, UserId},
},
util::parse_date,
CONFIG,
};
/// ###############################################################################################################
@@ -38,9 +38,7 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
// Return an empty vec when we org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let events_json: Vec<Value> = if CONFIG.org_events_enabled() {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
@@ -51,8 +49,10 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
Event::find_by_organization_uuid(&org_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
.map(Event::to_json)
.collect()
} else {
Vec::with_capacity(0)
};
Ok(Json(json!({
@@ -64,27 +64,21 @@ async fn get_org_events(org_id: OrganizationId, data: EventRange, headers: Admin
#[get("/ciphers/<cipher_id>/events?<data..>")]
async fn get_cipher_events(cipher_id: CipherId, data: EventRange, headers: Headers, conn: DbConn) -> JsonResult {
// Return an empty vec when we org events are disabled.
// Return an empty vec when org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let mut events_json = Vec::with_capacity(0);
if Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
} else {
parse_date(&data.end)
};
let events_json: Vec<Value> = if CONFIG.org_events_enabled()
&& Membership::user_has_ge_admin_access_to_cipher(&headers.user.uuid, &cipher_id, &conn).await
{
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
} else {
parse_date(&data.end)
};
events_json = Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
.collect()
}
events_json
Event::find_by_cipher_uuid(&cipher_id, &start_date, &end_date, &conn).await.iter().map(Event::to_json).collect()
} else {
Vec::with_capacity(0)
};
Ok(Json(json!({
@@ -107,9 +101,7 @@ async fn get_user_events(
}
// Return an empty vec when we org events are disabled.
// This prevents client errors
let events_json: Vec<Value> = if !CONFIG.org_events_enabled() {
Vec::with_capacity(0)
} else {
let events_json: Vec<Value> = if CONFIG.org_events_enabled() {
let start_date = parse_date(&data.start);
let end_date = if let Some(before_date) = &data.continuation_token {
parse_date(before_date)
@@ -120,8 +112,10 @@ async fn get_user_events(
Event::find_by_org_and_member(&org_id, &member_id, &start_date, &end_date, &conn)
.await
.iter()
.map(|e| e.to_json())
.map(Event::to_json)
.collect()
} else {
Vec::with_capacity(0)
};
Ok(Json(json!({
@@ -134,7 +128,8 @@ async fn get_user_events(
fn get_continuation_token(events_json: &[Value]) -> Option<&str> {
// When the length of the vec equals the max page_size there probably is more data
// When it is less, then all events are loaded.
if events_json.len() as i64 == Event::PAGE_SIZE {
#[expect(clippy::cast_possible_truncation, reason = "PAGE_SIZE fits within usize")]
if events_json.len() == Event::PAGE_SIZE as usize {
if let Some(last_event) = events_json.last() {
last_event["date"].as_str()
} else {
@@ -176,7 +171,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
let event_date = parse_date(&event.date);
match event.r#type {
1000..=1099 => {
_log_user_event(
log_user_event_impl(
event.r#type,
&headers.user.uuid,
headers.device.atype,
@@ -188,7 +183,7 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
}
1600..=1699 => {
if let Some(org_id) = &event.organization_id {
_log_event(
log_event_impl(
event.r#type,
org_id,
org_id,
@@ -202,22 +197,21 @@ async fn post_events_collect(data: Json<Vec<EventCollection>>, headers: Headers,
}
}
_ => {
if let Some(cipher_uuid) = &event.cipher_id {
if let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await {
if let Some(org_id) = cipher.organization_uuid {
_log_event(
event.r#type,
cipher_uuid,
&org_id,
&headers.user.uuid,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&conn,
)
.await;
}
}
if let Some(cipher_uuid) = &event.cipher_id
&& let Some(cipher) = Cipher::find_by_uuid(cipher_uuid, &conn).await
&& let Some(org_id) = cipher.organization_uuid
{
log_event_impl(
event.r#type,
cipher_uuid,
&org_id,
&headers.user.uuid,
headers.device.atype,
Some(event_date),
&headers.ip.ip,
&conn,
)
.await;
}
}
}
@@ -229,10 +223,10 @@ pub async fn log_user_event(event_type: i32, user_id: &UserId, device_type: i32,
if !CONFIG.org_events_enabled() {
return;
}
_log_user_event(event_type, user_id, device_type, None, ip, conn).await;
log_user_event_impl(event_type, user_id, device_type, None, ip, conn).await;
}
async fn _log_user_event(
async fn log_user_event_impl(
event_type: i32,
user_id: &UserId,
device_type: i32,
@@ -278,11 +272,11 @@ pub async fn log_event(
if !CONFIG.org_events_enabled() {
return;
}
_log_event(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await;
log_event_impl(event_type, source_uuid, org_id, act_user_id, device_type, None, ip, conn).await;
}
#[allow(clippy::too_many_arguments)]
async fn _log_event(
#[expect(clippy::too_many_arguments)]
async fn log_event_impl(
event_type: i32,
source_uuid: &str,
org_id: &OrganizationId,
@@ -298,24 +292,24 @@ async fn _log_event(
// 1000..=1099 Are user events, they need to be logged via log_user_event()
// Cipher Events
1100..=1199 => {
event.cipher_uuid = Some(source_uuid.to_string().into());
event.cipher_uuid = Some(source_uuid.to_owned().into());
}
// Collection Events
1300..=1399 => {
event.collection_uuid = Some(source_uuid.to_string().into());
event.collection_uuid = Some(source_uuid.to_owned().into());
}
// Group Events
1400..=1499 => {
event.group_uuid = Some(source_uuid.to_string().into());
event.group_uuid = Some(source_uuid.to_owned().into());
}
// Org User Events
1500..=1599 => {
event.org_user_uuid = Some(source_uuid.to_string().into());
event.org_user_uuid = Some(source_uuid.to_owned().into());
}
// 1600..=1699 Are organizational events, and they do not need the source_uuid
// Policy Events
1700..=1799 => {
event.policy_uuid = Some(source_uuid.to_string().into());
event.policy_uuid = Some(source_uuid.to_owned().into());
}
// Ignore others
_ => {}
@@ -338,6 +332,6 @@ pub async fn event_cleanup_job(pool: DbPool) {
if let Ok(conn) = pool.get().await {
Event::clean_events(&conn).await.ok();
} else {
error!("Failed to get DB connection while trying to cleanup the events table")
error!("Failed to get DB connection while trying to cleanup the events table");
}
}
+5 -4
View File
@@ -5,8 +5,8 @@ use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
models::{Folder, FolderId},
DbConn,
models::{Folder, FolderId},
},
util::deser_opt_nonempty_str,
};
@@ -29,9 +29,10 @@ async fn get_folders(headers: Headers, conn: DbConn) -> Json<Value> {
#[get("/folders/<folder_id>")]
async fn get_folder(folder_id: FolderId, headers: Headers, conn: DbConn) -> JsonResult {
match Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Some(folder) => Ok(Json(folder.to_json())),
_ => err!("Invalid folder", "Folder does not exist or belongs to another user"),
if let Some(folder) = Folder::find_by_uuid_and_user(&folder_id, &headers.user.uuid, &conn).await {
Ok(Json(folder.to_json()))
} else {
err!("Invalid folder", "Folder does not exist or belongs to another user")
}
}
+50 -40
View File
@@ -1,4 +1,6 @@
pub mod accounts;
pub mod two_factor;
mod ciphers;
mod emergency_access;
mod events;
@@ -6,17 +8,32 @@ mod folders;
mod organizations;
mod public;
mod sends;
pub mod two_factor;
pub use accounts::purge_auth_requests;
pub use ciphers::{purge_trashed_ciphers, CipherData, CipherSyncData, CipherSyncType};
pub use ciphers::{CipherData, CipherSyncData, CipherSyncType, purge_trashed_ciphers};
pub use emergency_access::{emergency_notification_reminder_job, emergency_request_timeout_job};
pub use events::{event_cleanup_job, log_event, log_user_event};
use reqwest::Method;
pub use sends::purge_sends;
use reqwest::Method;
use rocket::{Catcher, Route, serde::json::Json, serde::json::Value};
use crate::{
CONFIG,
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
DbConn,
models::{Membership, MembershipStatus, OrgPolicy, Organization, User},
},
error::Error,
http_client::make_http_request,
mail,
util::{FeatureFlagFilter, parse_experimental_client_feature_flags},
};
pub fn routes() -> Vec<Route> {
let mut eq_domains_routes = routes![get_eq_domains, post_eq_domains, put_eq_domains];
let mut eq_domains_routes = routes![get_settings_domains, post_settings_domains, put_settings_domains];
let mut hibp_routes = routes![hibp_breach];
let mut meta_routes = routes![alive, now, version, config, get_api_webauthn];
@@ -44,25 +61,6 @@ pub fn events_routes() -> Vec<Route> {
routes
}
//
// Move this somewhere else
//
use rocket::{serde::json::Json, serde::json::Value, Catcher, Route};
use crate::{
api::{EmptyResult, JsonResult, Notify, UpdateType},
auth::Headers,
db::{
models::{Membership, MembershipStatus, OrgPolicy, Organization, User},
DbConn,
},
error::Error,
http_client::make_http_request,
mail,
util::{parse_experimental_client_feature_flags, FeatureFlagFilter},
CONFIG,
};
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct GlobalDomain {
@@ -73,15 +71,17 @@ struct GlobalDomain {
const GLOBAL_DOMAINS: &str = include_str!("../../static/global_domains.json");
#[expect(clippy::needless_pass_by_value, reason = "Not beneficial for Headers")]
#[get("/settings/domains")]
fn get_eq_domains(headers: Headers) -> Json<Value> {
_get_eq_domains(&headers, false)
fn get_settings_domains(headers: Headers) -> Json<Value> {
get_eq_domains(&headers, false)
}
fn _get_eq_domains(headers: &Headers, no_excluded: bool) -> Json<Value> {
let user = &headers.user;
fn get_eq_domains(headers: &Headers, no_excluded: bool) -> Json<Value> {
use serde_json::from_str;
let user = &headers.user;
let equivalent_domains: Vec<Vec<String>> = from_str(&user.equivalent_domains).unwrap();
let excluded_globals: Vec<i32> = from_str(&user.excluded_globals).unwrap();
@@ -110,28 +110,39 @@ struct EquivDomainData {
}
#[post("/settings/domains", data = "<data>")]
async fn post_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
async fn post_settings_domains(
data: Json<EquivDomainData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
use serde_json::to_string;
let data: EquivDomainData = data.into_inner();
let excluded_globals = data.excluded_global_equivalent_domains.unwrap_or_default();
let equivalent_domains = data.equivalent_domains.unwrap_or_default();
let mut user = headers.user;
use serde_json::to_string;
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_string());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_string());
user.excluded_globals = to_string(&excluded_globals).unwrap_or_else(|_| "[]".to_owned());
user.equivalent_domains = to_string(&equivalent_domains).unwrap_or_else(|_| "[]".to_owned());
user.save(&conn).await?;
nt.send_user_update(UpdateType::SyncSettings, &user, &headers.device.push_uuid, &conn).await;
nt.send_user_update(UpdateType::SyncSettings, &user, headers.device.push_uuid.as_ref(), &conn).await;
Ok(Json(json!({})))
}
#[put("/settings/domains", data = "<data>")]
async fn put_eq_domains(data: Json<EquivDomainData>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
post_eq_domains(data, headers, conn, nt).await
async fn put_settings_domains(
data: Json<EquivDomainData>,
headers: Headers,
conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
post_settings_domains(data, headers, conn, nt).await
}
#[get("/hibp/breach?<username>")]
@@ -204,11 +215,11 @@ fn config() -> Json<Value> {
// Client (v2026.2.1): https://github.com/bitwarden/clients/blob/f96380c3138291a028bdd2c7a5fee540d5c98ba5/libs/common/src/enums/feature-flag.enum.ts#L12
// Android (v2026.2.1): https://github.com/bitwarden/android/blob/6902c19c0093fa476bbf74ccaa70c9f14afbb82f/core/src/main/kotlin/com/bitwarden/core/data/manager/model/FlagKey.kt#L31
// iOS (v2026.2.1): https://github.com/bitwarden/ios/blob/cdd9ba1770ca2ffc098d02d12cc3208e3a830454/BitwardenShared/Core/Platform/Models/Enum/FeatureFlag.swift#L7
let feature_states = parse_experimental_client_feature_flags(
let mut feature_states = parse_experimental_client_feature_flags(
&CONFIG.experimental_client_feature_flags(),
FeatureFlagFilter::ValidOnly,
&FeatureFlagFilter::ValidOnly,
);
// Add default feature_states here if needed, currently no features are needed by default.
feature_states.insert("pm-19148-innovation-archive".to_owned(), true);
Json(json!({
// Note: The clients use this version to handle backwards compatibility concerns
@@ -278,9 +289,8 @@ async fn accept_org_invite(
member.save(conn).await?;
if CONFIG.mail_enabled() {
let org = match Organization::find_by_uuid(&member.org_uuid, conn).await {
Some(org) => org,
None => err!("Organization not found."),
let Some(org) = Organization::find_by_uuid(&member.org_uuid, conn).await else {
err!("Organization not found.")
};
// User was invited to an organization, so they must be confirmed manually after acceptance
mail::send_invite_accepted(&user.email, &member.invited_by_email.unwrap_or(org.billing_email), &org.name)
+162 -165
View File
@@ -1,28 +1,28 @@
use num_traits::FromPrimitive;
use rocket::serde::json::Json;
use rocket::Route;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use crate::api::admin::FAKE_ADMIN_UUID;
use num_traits::FromPrimitive;
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use crate::{
CONFIG,
api::admin::FAKE_ADMIN_UUID,
api::{
core::{accept_org_invite, log_event, two_factor, CipherSyncData, CipherSyncType},
EmptyResult, JsonResult, Notify, PasswordOrOtpData, UpdateType,
core::{CipherSyncData, CipherSyncType, accept_org_invite, log_event, two_factor},
},
auth::{decode_invite, AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders},
auth::{AdminHeaders, Headers, ManagerHeaders, ManagerHeadersLoose, OrgMemberHeaders, OwnerHeaders, decode_invite},
db::{
DbConn,
models::{
Cipher, CipherId, Collection, CollectionCipher, CollectionGroup, CollectionId, CollectionUser, EventType,
Group, GroupId, GroupUser, Invitation, Membership, MembershipId, MembershipStatus, MembershipType,
OrgPolicy, OrgPolicyType, Organization, OrganizationApiKey, OrganizationId, User, UserId,
},
DbConn,
},
mail,
sso::FAKE_SSO_IDENTIFIER,
util::{convert_json_key_lcase_first, NumberOrString},
CONFIG,
util::{NumberOrString, convert_json_key_lcase_first},
};
pub fn routes() -> Vec<Route> {
@@ -78,6 +78,7 @@ pub fn routes() -> Vec<Route> {
revoke_member,
bulk_revoke_members,
restore_member,
restore_member_vnext,
bulk_restore_members,
get_groups,
get_groups_details,
@@ -96,7 +97,7 @@ pub fn routes() -> Vec<Route> {
get_reset_password_details,
put_reset_password,
get_org_export,
api_key,
post_api_key,
rotate_api_key,
get_billing_metadata,
get_billing_warnings,
@@ -285,9 +286,10 @@ async fn get_organization(org_id: OrganizationId, headers: OwnerHeaders, conn: D
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
}
match Organization::find_by_uuid(&org_id, &conn).await {
Some(organization) => Ok(Json(organization.to_json())),
None => err!("Can't find organization details"),
if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await {
Ok(Json(organization.to_json()))
} else {
err!("Can't find organization details")
}
}
@@ -366,7 +368,7 @@ async fn get_auto_enroll_status(identifier: &str, headers: Headers, conn: DbConn
};
let (id, identifier, rp_auto_enroll) = match org {
None => (identifier.to_string(), identifier.to_string(), false),
None => (identifier.to_owned(), identifier.to_owned(), false),
Some(org) => (
org.uuid.to_string(),
org.uuid.to_string(),
@@ -392,7 +394,7 @@ async fn get_org_collections(org_id: OrganizationId, headers: ManagerHeadersLoos
}
Ok(Json(json!({
"data": _get_org_collections(&org_id, &conn).await,
"data": get_org_collections_impl(&org_id, &conn).await,
"object": "list",
"continuationToken": null,
})))
@@ -464,7 +466,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
CollectionGroup::find_by_collection(&col.uuid, &conn)
.await
.iter()
.map(|collection_group| collection_group.to_json_details_for_group())
.map(CollectionGroup::to_json_details_for_group)
.collect()
} else {
Vec::with_capacity(0)
@@ -476,7 +478,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
json_object["groups"] = json!(groups);
json_object["object"] = json!("collectionAccessDetails");
json_object["unmanaged"] = json!(false);
data.push(json_object)
data.push(json_object);
}
Ok(Json(json!({
@@ -486,7 +488,7 @@ async fn get_org_collections_details(org_id: OrganizationId, headers: ManagerHea
})))
}
async fn _get_org_collections(org_id: &OrganizationId, conn: &DbConn) -> Value {
async fn get_org_collections_impl(org_id: &OrganizationId, conn: &DbConn) -> Value {
Collection::find_by_organization(org_id, conn).await.iter().map(Collection::to_json).collect::<Value>()
}
@@ -572,7 +574,7 @@ async fn post_bulk_access_collections(
if Organization::find_by_uuid(&org_id, &conn).await.is_none() {
err!("Can't find organization details")
};
}
for col_id in data.collection_ids {
let Some(collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else {
@@ -649,7 +651,7 @@ async fn post_organization_collection_update(
if Organization::find_by_uuid(&org_id, &conn).await.is_none() {
err!("Can't find organization details")
};
}
let Some(mut collection) = Collection::find_by_uuid_and_org(&col_id, &org_id, &conn).await else {
err!("Collection not found")
@@ -700,7 +702,7 @@ async fn post_organization_collection_update(
Ok(Json(collection.to_json_details(&headers.user.uuid, None, &conn).await))
}
async fn _delete_organization_collection(
async fn delete_organization_collection_impl(
org_id: &OrganizationId,
col_id: &CollectionId,
headers: &ManagerHeaders,
@@ -732,7 +734,7 @@ async fn delete_organization_collection(
headers: ManagerHeaders,
conn: DbConn,
) -> EmptyResult {
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await
delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await
}
#[post("/organizations/<org_id>/collections/<col_id>/delete")]
@@ -742,7 +744,7 @@ async fn post_organization_collection_delete(
headers: ManagerHeaders,
conn: DbConn,
) -> EmptyResult {
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await
delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await
}
#[derive(Deserialize, Debug)]
@@ -768,7 +770,7 @@ async fn bulk_delete_organization_collections(
let headers = ManagerHeaders::from_loose(headers, &collections, &conn).await?;
for col_id in collections {
_delete_organization_collection(&org_id, &col_id, &headers, &conn).await?
delete_organization_collection_impl(&org_id, &col_id, &headers, &conn).await?;
}
Ok(())
}
@@ -798,7 +800,7 @@ async fn get_org_collection_detail(
CollectionGroup::find_by_collection(&collection.uuid, &conn)
.await
.iter()
.map(|collection_group| collection_group.to_json_details_for_group())
.map(CollectionGroup::to_json_details_for_group)
.collect()
} else {
// The Bitwarden clients seem to call this API regardless of whether groups are enabled,
@@ -885,13 +887,13 @@ async fn get_org_details(data: OrgIdData, headers: ManagerHeadersLoose, conn: Db
}
Ok(Json(json!({
"data": _get_org_details(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?,
"data": get_org_details_impl(&data.organization_id, &headers.host, &headers.user.uuid, &conn).await?,
"object": "list",
"continuationToken": null,
})))
}
async fn _get_org_details(
async fn get_org_details_impl(
org_id: &OrganizationId,
host: &str,
user_id: &UserId,
@@ -907,36 +909,21 @@ async fn _get_org_details(
Ok(json!(ciphers_json))
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct OrgDomainDetails {
email: String,
}
// Returning a Domain/Organization here allow to prefill it and prevent prompting the user
// So we either return an Org name associated to the user or a dummy value.
// So we return a dummy value, since we only support a single SSO integration, and do not use the response anywhere
// In use since `v2025.6.0`, appears to use only the first `organizationIdentifier`
#[post("/organizations/domain/sso/verified", data = "<data>")]
async fn get_org_domain_sso_verified(data: Json<OrgDomainDetails>, conn: DbConn) -> JsonResult {
let data: OrgDomainDetails = data.into_inner();
let identifiers = match Organization::find_org_user_email(&data.email, &conn)
.await
.into_iter()
.map(|o| (o.name, o.uuid.to_string()))
.collect::<Vec<(String, String)>>()
{
v if !v.is_empty() => v,
_ => vec![(FAKE_SSO_IDENTIFIER.to_string(), FAKE_SSO_IDENTIFIER.to_string())],
};
#[post("/organizations/domain/sso/verified")]
fn get_org_domain_sso_verified() -> JsonResult {
// Always return a dummy value, no matter if SSO is enabled or not
Ok(Json(json!({
"object": "list",
"data": identifiers.into_iter().map(|(name, identifier)| json!({
"organizationName": name, // appear unused
"organizationIdentifier": identifier,
"domainName": CONFIG.domain(), // appear unused
})).collect::<Vec<Value>>()
"data": [{
"organizationIdentifier": FAKE_SSO_IDENTIFIER,
// These appear to be unused
"organizationName": FAKE_SSO_IDENTIFIER,
"domainName": CONFIG.domain()
}],
"continuationToken": null
})))
}
@@ -989,14 +976,13 @@ async fn post_org_keys(
}
let data: OrgKeyData = data.into_inner();
let mut org = match Organization::find_by_uuid(&org_id, &conn).await {
Some(organization) => {
if organization.private_key.is_some() && organization.public_key.is_some() {
err!("Organization Keys already exist")
}
organization
let mut org = if let Some(organization) = Organization::find_by_uuid(&org_id, &conn).await {
if organization.private_key.is_some() && organization.public_key.is_some() {
err!("Organization Keys already exist")
}
None => err!("Can't find organization details"),
organization
} else {
err!("Can't find organization details")
};
org.private_key = Some(data.encrypted_private_key);
@@ -1057,9 +1043,10 @@ async fn send_invite(
// The from_str() will convert the custom role type into a manager role type
let raw_type = &data.r#type.into_string();
// Membership::from_str will convert custom (4) to manager (3)
let new_type = match MembershipType::from_str(raw_type) {
Some(new_type) => new_type as i32,
None => err!("Invalid type"),
let new_type = if let Some(new_type) = MembershipType::from_str(raw_type) {
new_type as i32
} else {
err!("Invalid type")
};
if new_type != MembershipType::User && headers.membership_type != MembershipType::Owner {
@@ -1076,7 +1063,7 @@ async fn send_invite(
&& data.permissions.get("createNewCollections") == Some(&json!(true)));
let mut user_created: bool = false;
for email in data.emails.iter() {
for email in &data.emails {
let mut member_status = MembershipStatus::Invited as i32;
let user = match User::find_by_mail(email, &conn).await {
None => {
@@ -1100,13 +1087,13 @@ async fn send_invite(
Some(user) => {
if Membership::find_by_user_and_org(&user.uuid, &org_id, &conn).await.is_some() {
err!(format!("User already in organization: {email}"))
} else {
// automatically accept existing users if mail is disabled
if !CONFIG.mail_enabled() && !user.password_hash.is_empty() {
member_status = MembershipStatus::Accepted as i32;
}
user
}
// automatically accept existing users if mail is disabled
if !CONFIG.mail_enabled() && !user.password_hash.is_empty() {
member_status = MembershipStatus::Accepted as i32;
}
user
}
};
@@ -1117,9 +1104,10 @@ async fn send_invite(
new_member.save(&conn).await?;
if CONFIG.mail_enabled() {
let org_name = match Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => org.name,
None => err!("Error looking up organization"),
let org_name = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await {
org.name
} else {
err!("Error looking up organization")
};
if let Err(e) = mail::send_invite(
@@ -1173,7 +1161,7 @@ async fn send_invite(
}
}
for group_id in data.groups.iter() {
for group_id in &data.groups {
let mut group_entry = GroupUser::new(group_id.clone(), new_member.uuid.clone());
group_entry.save(&conn).await?;
}
@@ -1196,8 +1184,8 @@ async fn bulk_reinvite_members(
let mut bulk_response = Vec::new();
for member_id in data.ids {
let err_msg = match _reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await {
Ok(_) => String::new(),
let err_msg = match reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await {
Ok(()) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -1207,7 +1195,7 @@ async fn bulk_reinvite_members(
"id": member_id,
"error": err_msg
}
))
));
}
Ok(Json(json!({
@@ -1227,10 +1215,10 @@ async fn reinvite_member(
if org_id != headers.org_id {
err!("Organization not found", "Organization id's do not match");
}
_reinvite_member(&org_id, &member_id, &headers.user.email, &conn).await
reinvite_member_impl(&org_id, &member_id, &headers.user.email, &conn).await
}
async fn _reinvite_member(
async fn reinvite_member_impl(
org_id: &OrganizationId,
member_id: &MembershipId,
invited_by_email: &str,
@@ -1252,13 +1240,14 @@ async fn _reinvite_member(
err!("Invitations are not allowed.")
}
let org_name = match Organization::find_by_uuid(org_id, conn).await {
Some(org) => org.name,
None => err!("Error looking up organization."),
let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await {
org.name
} else {
err!("Error looking up organization.")
};
if CONFIG.mail_enabled() {
mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_string())).await?;
mail::send_invite(&user, org_id.clone(), member.uuid, &org_name, Some(invited_by_email.to_owned())).await?;
} else if user.password_hash.is_empty() {
let invitation = Invitation::new(&user.email);
invitation.save(conn).await?;
@@ -1366,8 +1355,8 @@ async fn bulk_confirm_invite(
for invite in keys {
let member_id = invite.id.unwrap();
let user_key = invite.key.unwrap_or_default();
let err_msg = match _confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await {
Ok(_) => String::new(),
let err_msg = match confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await {
Ok(()) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -1401,10 +1390,10 @@ async fn confirm_invite(
) -> EmptyResult {
let data = data.into_inner();
let user_key = data.key.unwrap_or_default();
_confirm_invite(&org_id, &member_id, &user_key, &headers, &conn, &nt).await
confirm_invite_impl(&org_id, &member_id, &user_key, &headers, &conn, &nt).await
}
async fn _confirm_invite(
async fn confirm_invite_impl(
org_id: &OrganizationId,
member_id: &MembershipId,
key: &str,
@@ -1432,7 +1421,7 @@ async fn _confirm_invite(
}
member_to_confirm.status = MembershipStatus::Confirmed as i32;
member_to_confirm.akey = key.to_string();
member_to_confirm.akey = key.to_owned();
// This check is also done at accept_invite, _confirm_invite, _activate_member, edit_member, admin::update_membership_type
OrgPolicy::check_user_allowed(&member_to_confirm, "confirm", conn).await?;
@@ -1449,13 +1438,15 @@ async fn _confirm_invite(
.await;
if CONFIG.mail_enabled() {
let org_name = match Organization::find_by_uuid(org_id, conn).await {
Some(org) => org.name,
None => err!("Error looking up organization."),
let org_name = if let Some(org) = Organization::find_by_uuid(org_id, conn).await {
org.name
} else {
err!("Error looking up organization.")
};
let address = match User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
Some(user) => user.email,
None => err!("Error looking up user."),
let address = if let Some(user) = User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
user.email
} else {
err!("Error looking up user.")
};
mail::send_invite_confirmed(&address, &org_name).await?;
}
@@ -1463,7 +1454,7 @@ async fn _confirm_invite(
let save_result = member_to_confirm.save(conn).await;
if let Some(user) = User::find_by_uuid(&member_to_confirm.user_uuid, conn).await {
nt.send_user_update(UpdateType::SyncOrgKeys, &user, &headers.device.push_uuid, conn).await;
nt.send_user_update(UpdateType::SyncOrgKeys, &user, headers.device.push_uuid.as_ref(), conn).await;
}
save_result
@@ -1651,8 +1642,8 @@ async fn bulk_delete_member(
let mut bulk_response = Vec::new();
for member_id in data.ids {
let err_msg = match _delete_member(&org_id, &member_id, &headers, &conn, &nt).await {
Ok(_) => String::new(),
let err_msg = match delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await {
Ok(()) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -1662,7 +1653,7 @@ async fn bulk_delete_member(
"id": member_id,
"error": err_msg
}
))
));
}
Ok(Json(json!({
@@ -1680,10 +1671,10 @@ async fn delete_member(
conn: DbConn,
nt: Notify<'_>,
) -> EmptyResult {
_delete_member(&org_id, &member_id, &headers, &conn, &nt).await
delete_member_impl(&org_id, &member_id, &headers, &conn, &nt).await
}
async fn _delete_member(
async fn delete_member_impl(
org_id: &OrganizationId,
member_id: &MembershipId,
headers: &AdminHeaders,
@@ -1721,7 +1712,7 @@ async fn _delete_member(
.await;
if let Some(user) = User::find_by_uuid(&member_to_delete.user_uuid, conn).await {
nt.send_user_update(UpdateType::SyncOrgKeys, &user, &headers.device.push_uuid, conn).await;
nt.send_user_update(UpdateType::SyncOrgKeys, &user, headers.device.push_uuid.as_ref(), conn).await;
}
member_to_delete.delete(conn).await
@@ -1767,8 +1758,8 @@ async fn bulk_public_keys(
})))
}
use super::ciphers::update_cipher_from_data;
use super::ciphers::CipherData;
use super::ciphers::update_cipher_from_data;
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -1916,24 +1907,24 @@ async fn post_bulk_collections(data: Json<BulkCollectionsData>, headers: Headers
}
}
for cipher_id in data.cipher_ids.iter() {
for cipher_id in &data.cipher_ids {
// Only act on existing cipher uuid's
// Do not abort the operation just ignore it, it could be a cipher was just deleted for example
if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await {
if cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await {
// When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection
// In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting.
if data.remove_collections {
for collection in &data.collection_ids {
CollectionCipher::delete(&cipher.uuid, collection, &conn).await?;
}
} else {
for collection in &data.collection_ids {
CollectionCipher::save(&cipher.uuid, collection, &conn).await?;
}
if let Some(cipher) = Cipher::find_by_uuid_and_org(cipher_id, &data.organization_id, &conn).await
&& cipher.is_write_accessible_to_user(&headers.user.uuid, &conn).await
{
// When selecting a specific collection from the left filter list, and use the bulk option, you can remove an item from that collection
// In these cases the client will call this endpoint twice, once for adding the new collections and a second for deleting.
if data.remove_collections {
for collection in &data.collection_ids {
CollectionCipher::delete(&cipher.uuid, collection, &conn).await?;
}
} else {
for collection in &data.collection_ids {
CollectionCipher::save(&cipher.uuid, collection, &conn).await?;
}
}
};
}
}
Ok(())
@@ -1979,11 +1970,11 @@ async fn list_policies_token(org_id: OrganizationId, token: &str, conn: DbConn)
}
// Called during the SSO enrollment return the default policy
#[get("/organizations/vaultwarden-dummy-oidc-identifier/policies/master-password", rank = 1)]
#[get("/organizations/00000000-01DC-01DC-01DC-000000000000/policies/master-password", rank = 1)]
fn get_dummy_master_password_policy() -> JsonResult {
let (enabled, data) = match CONFIG.sso_master_password_policy_value() {
Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()),
_ => (false, "null".to_string()),
_ => (false, "null".to_owned()),
};
let policy = OrgPolicy::new(FAKE_SSO_IDENTIFIER.into(), OrgPolicyType::MasterPassword, enabled, data);
Ok(Json(policy.to_json()))
@@ -1996,7 +1987,7 @@ async fn get_master_password_policy(org_id: OrganizationId, _headers: OrgMemberH
OrgPolicy::find_by_org_and_type(&org_id, OrgPolicyType::MasterPassword, &conn).await.unwrap_or_else(|| {
let (enabled, data) = match CONFIG.sso_master_password_policy_value() {
Some(policy) if CONFIG.sso_enabled() => (true, policy.to_string()),
_ => (false, "null".to_string()),
_ => (false, "null".to_owned()),
};
OrgPolicy::new(org_id, OrgPolicyType::MasterPassword, enabled, data)
@@ -2017,7 +2008,7 @@ async fn get_policy(org_id: OrganizationId, pol_type: i32, headers: AdminHeaders
let policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await {
Some(p) => p,
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_string()),
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "null".to_owned()),
};
Ok(Json(policy.to_json()))
@@ -2092,7 +2083,7 @@ async fn put_policy(
// When enabling the SingleOrg policy, remove this org's members that are members of other orgs
if pol_type_enum == OrgPolicyType::SingleOrg && data.enabled {
for mut member in Membership::find_by_org(&org_id, &conn).await.into_iter() {
for mut member in Membership::find_by_org(&org_id, &conn).await {
// Policy only applies to non-Owner/non-Admin members who have accepted joining the org
// Exclude invited and revoked users when checking for this policy.
// Those users will not be allowed to accept or be activated because of the policy checks done there.
@@ -2127,7 +2118,7 @@ async fn put_policy(
let mut policy = match OrgPolicy::find_by_org_and_type(&org_id, pol_type_enum, &conn).await {
Some(p) => p,
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_string()),
None => OrgPolicy::new(org_id.clone(), pol_type_enum, false, "{}".to_owned()),
};
policy.enabled = data.enabled;
@@ -2201,7 +2192,7 @@ fn get_plans() -> Json<Value> {
#[get("/organizations/<_org_id>/billing/metadata")]
fn get_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHeaders) -> Json<Value> {
// Prevent a 404 error, which also causes Javascript errors.
Json(_empty_data_json())
Json(empty_data_json())
}
#[get("/organizations/<_org_id>/billing/vnext/warnings")]
@@ -2223,7 +2214,7 @@ fn get_self_host_billing_metadata(_org_id: OrganizationId, _headers: OrgMemberHe
}))
}
fn _empty_data_json() -> Value {
fn empty_data_json() -> Value {
json!({
"object": "list",
"data": [],
@@ -2244,7 +2235,7 @@ async fn revoke_member(
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
_revoke_member(&org_id, &member_id, &headers, &conn).await
revoke_member_impl(&org_id, &member_id, &headers, &conn).await
}
#[put("/organizations/<org_id>/users/revoke", data = "<data>")]
@@ -2263,8 +2254,8 @@ async fn bulk_revoke_members(
match data.ids {
Some(members) => {
for member_id in members {
let err_msg = match _revoke_member(&org_id, &member_id, &headers, &conn).await {
Ok(_) => String::new(),
let err_msg = match revoke_member_impl(&org_id, &member_id, &headers, &conn).await {
Ok(()) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -2287,7 +2278,7 @@ async fn bulk_revoke_members(
})))
}
async fn _revoke_member(
async fn revoke_member_impl(
org_id: &OrganizationId,
member_id: &MembershipId,
headers: &AdminHeaders,
@@ -2330,6 +2321,18 @@ async fn _revoke_member(
Ok(())
}
#[put("/organizations/<org_id>/users/<member_id>/restore/vnext")]
async fn restore_member_vnext(
org_id: OrganizationId,
member_id: MembershipId,
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
// Vaultwarden does not (yet) support the per User Collection linked to the `Enforce organization data ownership` policy.
// Therefor we ignore the `defaultUserCollectionName` data sent and just call restore_member
restore_member_impl(&org_id, &member_id, &headers, &conn).await
}
#[put("/organizations/<org_id>/users/<member_id>/restore")]
async fn restore_member(
org_id: OrganizationId,
@@ -2337,7 +2340,7 @@ async fn restore_member(
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
_restore_member(&org_id, &member_id, &headers, &conn).await
restore_member_impl(&org_id, &member_id, &headers, &conn).await
}
#[put("/organizations/<org_id>/users/restore", data = "<data>")]
@@ -2354,8 +2357,8 @@ async fn bulk_restore_members(
let mut bulk_response = Vec::new();
for member_id in data.ids {
let err_msg = match _restore_member(&org_id, &member_id, &headers, &conn).await {
Ok(_) => String::new(),
let err_msg = match restore_member_impl(&org_id, &member_id, &headers, &conn).await {
Ok(()) => String::new(),
Err(e) => format!("{e:?}"),
};
@@ -2375,7 +2378,7 @@ async fn bulk_restore_members(
})))
}
async fn _restore_member(
async fn restore_member_impl(
org_id: &OrganizationId,
member_id: &MembershipId,
headers: &AdminHeaders,
@@ -2431,11 +2434,11 @@ async fn get_groups_data(
if details {
for g in groups {
groups_json.push(g.to_json_details(&conn).await)
groups_json.push(g.to_json_details(&conn).await);
}
} else {
for g in groups {
groups_json.push(g.to_json())
groups_json.push(g.to_json());
}
}
groups_json
@@ -2674,15 +2677,15 @@ async fn post_delete_group(
headers: AdminHeaders,
conn: DbConn,
) -> EmptyResult {
_delete_group(&org_id, &group_id, &headers, &conn).await
delete_group_impl(&org_id, &group_id, &headers, &conn).await
}
#[delete("/organizations/<org_id>/groups/<group_id>")]
async fn delete_group(org_id: OrganizationId, group_id: GroupId, headers: AdminHeaders, conn: DbConn) -> EmptyResult {
_delete_group(&org_id, &group_id, &headers, &conn).await
delete_group_impl(&org_id, &group_id, &headers, &conn).await
}
async fn _delete_group(
async fn delete_group_impl(
org_id: &OrganizationId,
group_id: &GroupId,
headers: &AdminHeaders,
@@ -2730,7 +2733,7 @@ async fn bulk_delete_groups(
let data: BulkGroupIds = data.into_inner();
for group_id in data.ids {
_delete_group(&org_id, &group_id, &headers, &conn).await?
delete_group_impl(&org_id, &group_id, &headers, &conn).await?;
}
Ok(())
}
@@ -2767,7 +2770,7 @@ async fn get_group_members(
if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() {
err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization")
};
}
let group_members: Vec<MembershipId> = GroupUser::find_by_group(&group_id, &org_id, &conn)
.await
@@ -2795,7 +2798,7 @@ async fn put_group_members(
if Group::find_by_uuid_and_org(&group_id, &org_id, &conn).await.is_none() {
err!("Group could not be found!", "Group uuid is invalid or does not belong to the organization")
};
}
let assigned_members = data.into_inner();
@@ -3049,10 +3052,7 @@ async fn put_reset_password_enrollment(
err!("User to enroll isn't member of required organization", "The user_id and acting user do not match");
}
let Some(mut membership) = Membership::find_confirmed_by_user_and_org(&headers.user.uuid, &org_id, &conn).await
else {
err!("User to enroll isn't member of required organization")
};
let mut membership = headers.membership;
check_reset_password_applicable(&org_id, &conn).await?;
@@ -3105,12 +3105,12 @@ async fn get_org_export(org_id: OrganizationId, headers: AdminHeaders, conn: DbC
}
Ok(Json(json!({
"collections": convert_json_key_lcase_first(_get_org_collections(&org_id, &conn).await),
"ciphers": convert_json_key_lcase_first(_get_org_details(&org_id, &headers.host, &headers.user.uuid, &conn).await?),
"collections": convert_json_key_lcase_first(get_org_collections_impl(&org_id, &conn).await),
"ciphers": convert_json_key_lcase_first(get_org_details_impl(&org_id, &headers.host, &headers.user.uuid, &conn).await?),
})))
}
async fn _api_key(
async fn api_key(
org_id: &OrganizationId,
data: Json<PasswordOrOtpData>,
rotate: bool,
@@ -3126,21 +3126,18 @@ async fn _api_key(
// Validate the admin users password/otp
data.validate(&user, true, &conn).await?;
let org_api_key = match OrganizationApiKey::find_by_org_uuid(org_id, &conn).await {
Some(mut org_api_key) => {
if rotate {
org_api_key.api_key = crate::crypto::generate_api_key();
org_api_key.revision_date = chrono::Utc::now().naive_utc();
org_api_key.save(&conn).await.expect("Error rotating organization API Key");
}
org_api_key
}
None => {
let api_key = crate::crypto::generate_api_key();
let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key);
new_org_api_key.save(&conn).await.expect("Error creating organization API Key");
new_org_api_key
let org_api_key = if let Some(mut org_api_key) = OrganizationApiKey::find_by_org_uuid(org_id, &conn).await {
if rotate {
org_api_key.api_key = crate::crypto::generate_api_key();
org_api_key.revision_date = chrono::Utc::now().naive_utc();
org_api_key.save(&conn).await.expect("Error rotating organization API Key");
}
org_api_key
} else {
let api_key = crate::crypto::generate_api_key();
let new_org_api_key = OrganizationApiKey::new(org_id.clone(), api_key);
new_org_api_key.save(&conn).await.expect("Error creating organization API Key");
new_org_api_key
};
Ok(Json(json!({
@@ -3151,13 +3148,13 @@ async fn _api_key(
}
#[post("/organizations/<org_id>/api-key", data = "<data>")]
async fn api_key(
async fn post_api_key(
org_id: OrganizationId,
data: Json<PasswordOrOtpData>,
headers: AdminHeaders,
conn: DbConn,
) -> JsonResult {
_api_key(&org_id, data, false, headers, conn).await
api_key(&org_id, data, false, headers, conn).await
}
#[post("/organizations/<org_id>/rotate-api-key", data = "<data>")]
@@ -3167,5 +3164,5 @@ async fn rotate_api_key(
headers: AdminHeaders,
conn: DbConn,
) -> JsonResult {
_api_key(&org_id, data, true, headers, conn).await
api_key(&org_id, data, true, headers, conn).await
}
+63 -66
View File
@@ -1,23 +1,24 @@
use chrono::Utc;
use rocket::{
request::{FromRequest, Outcome},
serde::json::Json,
Request, Route,
};
use std::collections::HashSet;
use chrono::Utc;
use rocket::{
Request, Route,
request::{FromRequest, Outcome},
serde::json::Json,
};
use crate::{
CONFIG,
api::EmptyResult,
auth,
db::{
DbConn,
models::{
Group, GroupUser, Invitation, Membership, MembershipStatus, MembershipType, Organization,
OrganizationApiKey, OrganizationId, User,
},
DbConn,
},
mail, CONFIG,
mail,
};
pub fn routes() -> Vec<Route> {
@@ -90,19 +91,18 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
}
} else {
// If user is not part of the organization
let user = match User::find_by_mail(&user_data.email, &conn).await {
Some(user) => user, // exists in vaultwarden
None => {
// User does not exist yet
let mut new_user = User::new(&user_data.email, None);
new_user.save(&conn).await?;
let user = if let Some(user) = User::find_by_mail(&user_data.email, &conn).await {
user
} else {
// User does not exist yet
let mut new_user = User::new(&user_data.email, None);
new_user.save(&conn).await?;
if !CONFIG.mail_enabled() {
Invitation::new(&new_user.email).save(&conn).await?;
}
user_created = true;
new_user
if !CONFIG.mail_enabled() {
Invitation::new(&new_user.email).save(&conn).await?;
}
user_created = true;
new_user
};
let member_status = if CONFIG.mail_enabled() || user.password_hash.is_empty() {
MembershipStatus::Invited as i32
@@ -110,9 +110,10 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
MembershipStatus::Accepted as i32 // Automatically mark user as accepted if no email invites
};
let (org_name, org_email) = match Organization::find_by_uuid(&org_id, &conn).await {
Some(org) => (org.name, org.billing_email),
None => err!("Error looking up organization"),
let (org_name, org_email) = if let Some(org) = Organization::find_by_uuid(&org_id, &conn).await {
(org.name, org.billing_email)
} else {
err!("Error looking up organization")
};
let mut new_member = Membership::new(user.uuid.clone(), org_id.clone(), Some(org_email.clone()));
@@ -123,37 +124,33 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
new_member.save(&conn).await?;
if CONFIG.mail_enabled() {
if let Err(e) =
if CONFIG.mail_enabled()
&& let Err(e) =
mail::send_invite(&user, org_id.clone(), new_member.uuid.clone(), &org_name, Some(org_email)).await
{
// Upon error delete the user, invite and org member records when needed
if user_created {
user.delete(&conn).await?;
} else {
new_member.delete(&conn).await?;
}
err!(format!("Error sending invite: {e:?} "));
{
// Upon error delete the user, invite and org member records when needed
if user_created {
user.delete(&conn).await?;
} else {
new_member.delete(&conn).await?;
}
err!(format!("Error sending invite: {e:?} "));
}
}
}
if CONFIG.org_groups_enabled() {
for group_data in &data.groups {
let group_uuid = match Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await {
Some(group) => group.uuid,
None => {
let mut group = Group::new(
org_id.clone(),
group_data.name.clone(),
false,
Some(group_data.external_id.clone()),
);
group.save(&conn).await?;
group.uuid
}
let group_uuid = if let Some(group) =
Group::find_by_external_id_and_org(&group_data.external_id, &org_id, &conn).await
{
group.uuid
} else {
let mut group =
Group::new(org_id.clone(), group_data.name.clone(), false, Some(group_data.external_id.clone()));
group.save(&conn).await?;
group.uuid
};
GroupUser::delete_all_by_group(&group_uuid, &org_id, &conn).await?;
@@ -174,18 +171,17 @@ async fn ldap_import(data: Json<OrgImportData>, token: PublicToken, conn: DbConn
// Generate a HashSet to quickly verify if a member is listed or not.
let sync_members: HashSet<String> = data.members.into_iter().map(|m| m.external_id).collect();
for member in Membership::find_by_org(&org_id, &conn).await {
if let Some(ref user_external_id) = member.external_id {
if !sync_members.contains(user_external_id) {
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1
{
warn!("Can't delete the last owner");
continue;
}
if let Some(ref user_external_id) = member.external_id
&& !sync_members.contains(user_external_id)
{
if member.atype == MembershipType::Owner && member.status == MembershipStatus::Confirmed as i32 {
// Removing owner, check that there is at least one other confirmed owner
if Membership::count_confirmed_by_org_and_type(&org_id, MembershipType::Owner, &conn).await <= 1 {
warn!("Can't delete the last owner");
continue;
}
member.delete(&conn).await?;
}
member.delete(&conn).await?;
}
}
}
@@ -202,12 +198,14 @@ impl<'r> FromRequest<'r> for PublicToken {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = request.headers();
// Get access_token
let access_token: &str = match headers.get_one("Authorization") {
Some(a) => match a.rsplit("Bearer ").next() {
Some(split) => split,
None => err_handler!("No access token provided"),
},
None => err_handler!("No access token provided"),
let access_token: &str = if let Some(a) = headers.get_one("Authorization") {
if let Some(split) = a.rsplit("Bearer ").next() {
split
} else {
err_handler!("No access token provided")
}
} else {
err_handler!("No access token provided")
};
// Check JWT token is valid and get device and user from it
let Ok(claims) = auth::decode_api_org(access_token) else {
@@ -229,14 +227,13 @@ impl<'r> FromRequest<'r> for PublicToken {
// Check if claims.sub is org_api_key.uuid
// Check if claims.client_sub is org_api_key.org_uuid
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
};
let Some(org_id) = claims.client_id.strip_prefix("organization.") else {
err_handler!("Malformed client_id")
};
let org_id: OrganizationId = org_id.to_string().into();
let org_id: OrganizationId = org_id.to_owned().into();
let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, &conn).await else {
err_handler!("Invalid client_id")
};
+36 -34
View File
@@ -10,15 +10,15 @@ use rocket::{
use serde_json::Value;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult, JsonResult, Notify, UpdateType},
auth::{ClientIp, Headers, Host},
config::PathType,
db::{
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
DbConn, DbPool,
models::{Device, OrgPolicy, OrgPolicyType, Send, SendFileId, SendId, SendType, UserId},
},
util::{save_temp_file, NumberOrString},
CONFIG,
util::{NumberOrString, save_temp_file},
};
const SEND_INACCESSIBLE_MSG: &str = "Send does not exist or is no longer available";
@@ -63,7 +63,7 @@ pub async fn purge_sends(pool: DbPool) {
if let Ok(conn) = pool.get().await {
Send::purge(&conn).await;
} else {
error!("Failed to get DB connection while purging sends")
error!("Failed to get DB connection while purging sends");
}
}
@@ -168,7 +168,7 @@ fn create_send(data: SendData, user_id: UserId) -> ApiResult<Send> {
#[get("/sends")]
async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
let sends = Send::find_by_user(&headers.user.uuid, &conn);
let sends_json: Vec<Value> = sends.await.iter().map(|s| s.to_json()).collect();
let sends_json: Vec<Value> = sends.await.iter().map(Send::to_json).collect();
Json(json!({
"data": sends_json,
@@ -179,9 +179,10 @@ async fn get_sends(headers: Headers, conn: DbConn) -> Json<Value> {
#[get("/sends/<send_id>")]
async fn get_send(send_id: SendId, headers: Headers, conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Some(send) => Ok(Json(send.to_json())),
None => err!("Send not found", "Invalid send uuid or does not belong to user"),
if let Some(send) = Send::find_by_uuid_and_user(&send_id, &headers.user.uuid, &conn).await {
Ok(Json(send.to_json()))
} else {
err!("Send not found", "Invalid send uuid or does not belong to user")
}
}
@@ -310,9 +311,10 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, conn: DbConn)
enforce_disable_hide_email_policy(&data, &headers, &conn).await?;
let file_length = match &data.file_length {
Some(m) => m.into_i64()?,
_ => err!("Invalid send length"),
let file_length = if let Some(m) = &data.file_length {
m.into_i64()?
} else {
err!("Invalid send length")
};
if file_length < 0 {
err!("Send size can't be negative")
@@ -457,16 +459,16 @@ async fn post_access(
err_code!(SEND_INACCESSIBLE_MSG, 404)
};
if let Some(max_access_count) = send.max_access_count {
if send.access_count >= max_access_count {
err_code!(SEND_INACCESSIBLE_MSG, 404);
}
if let Some(max_access_count) = send.max_access_count
&& send.access_count >= max_access_count
{
err_code!(SEND_INACCESSIBLE_MSG, 404);
}
if let Some(expiration) = send.expiration_date {
if Utc::now().naive_utc() >= expiration {
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
if let Some(expiration) = send.expiration_date
&& Utc::now().naive_utc() >= expiration
{
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
if Utc::now().naive_utc() >= send.deletion_date {
@@ -517,16 +519,16 @@ async fn post_access_file(
err_code!(SEND_INACCESSIBLE_MSG, 404)
};
if let Some(max_access_count) = send.max_access_count {
if send.access_count >= max_access_count {
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
if let Some(max_access_count) = send.max_access_count
&& send.access_count >= max_access_count
{
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
if let Some(expiration) = send.expiration_date {
if Utc::now().naive_utc() >= expiration {
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
if let Some(expiration) = send.expiration_date
&& Utc::now().naive_utc() >= expiration
{
err_code!(SEND_INACCESSIBLE_MSG, 404)
}
if Utc::now().naive_utc() >= send.deletion_date {
@@ -568,22 +570,22 @@ async fn post_access_file(
async fn download_url(host: &Host, send_id: &SendId, file_id: &SendFileId) -> Result<String, crate::Error> {
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Sends)?;
if operator.info().scheme() == <&'static str>::from(opendal::Scheme::Fs) {
if crate::storage::is_fs_operator(&operator) {
let token_claims = crate::auth::generate_send_claims(send_id, file_id);
let token = crate::auth::encode_jwt(&token_claims);
Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", &host.host))
Ok(format!("{}/api/sends/{send_id}/{file_id}?t={token}", host.host))
} else {
Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_secs(5 * 60)).await?.uri().to_string())
Ok(operator.presign_read(&format!("{send_id}/{file_id}"), Duration::from_mins(5)).await?.uri().to_string())
}
}
#[get("/sends/<send_id>/<file_id>?<t>")]
async fn download_send(send_id: SendId, file_id: SendFileId, t: &str) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(t) {
if claims.sub == format!("{send_id}/{file_id}") {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
}
if let Ok(claims) = crate::auth::decode_send(t)
&& claims.sub == format!("{send_id}/{file_id}")
{
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
}
None
}
+11 -11
View File
@@ -1,14 +1,13 @@
use data_encoding::BASE32;
use rocket::serde::json::Json;
use rocket::Route;
use rocket::{Route, serde::json::Json};
use crate::{
api::{core::log_user_event, core::two_factor::_generate_recover_code, EmptyResult, JsonResult, PasswordOrOtpData},
api::{EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event, core::two_factor::generate_recover_code},
auth::{ClientIp, Headers},
crypto,
db::{
models::{EventType, TwoFactor, TwoFactorType, UserId},
DbConn,
models::{EventType, TwoFactor, TwoFactorType, UserId},
},
util::NumberOrString,
};
@@ -70,9 +69,10 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
.await?;
// Validate key as base32 and 20 bytes length
let decoded_key: Vec<u8> = match BASE32.decode(key.as_bytes()) {
Ok(decoded) => decoded,
_ => err!("Invalid totp secret"),
let decoded_key: Vec<u8> = if let Ok(decoded) = BASE32.decode(key.as_bytes()) {
decoded
} else {
err!("Invalid totp secret")
};
if decoded_key.len() != 20 {
@@ -82,7 +82,7 @@ async fn activate_authenticator(data: Json<EnableAuthenticatorData>, headers: He
// Validate the token provided with the key, and save new twofactor
validate_totp_code(&user.uuid, &token, &key.to_uppercase(), &headers.ip, &conn).await?;
_generate_recover_code(&mut user, &conn).await;
generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -119,7 +119,7 @@ pub async fn validate_totp_code(
ip: &ClientIp,
conn: &DbConn,
) -> EmptyResult {
use totp_lite::{totp_custom, Sha1};
use totp_lite::{Sha1, totp_custom};
let Ok(decoded_secret) = BASE32.decode(secret.as_bytes()) else {
err!("Invalid TOTP secret")
@@ -128,7 +128,7 @@ pub async fn validate_totp_code(
let mut twofactor = match TwoFactor::find_by_user_and_type(user_id, TwoFactorType::Authenticator as i32, conn).await
{
Some(tf) => tf,
_ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_string()),
_ => TwoFactor::new(user_id.clone(), TwoFactorType::Authenticator, secret.to_owned()),
};
// The amount of steps back and forward in time
@@ -145,7 +145,7 @@ pub async fn validate_totp_code(
// We need to calculate the time offsite and cast it as an u64.
// Since we only have times into the future and the totp generator needs an u64 instead of the default i64.
let time = (current_timestamp + step * 30i64) as u64;
let time: u64 = (current_timestamp + step * 30i64).cast_unsigned();
let generated = totp_custom::<Sha1>(30, 6, &decoded_secret, time);
// Check the given code equals the generated and if the time_step is larger then the one last used.
+16 -17
View File
@@ -1,22 +1,21 @@
use chrono::Utc;
use data_encoding::BASE64;
use rocket::serde::json::Json;
use rocket::Route;
use rocket::{Route, serde::json::Json};
use crate::{
CONFIG,
api::{
core::log_user_event, core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult,
PasswordOrOtpData,
ApiResult, EmptyResult, JsonResult, PasswordOrOtpData, core::log_user_event,
core::two_factor::generate_recover_code,
},
auth::Headers,
crypto,
db::{
models::{EventType, TwoFactor, TwoFactorType, User, UserId},
DbConn,
models::{EventType, TwoFactor, TwoFactorType, User, UserId},
},
error::MapResult,
http_client::make_http_request,
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -82,8 +81,7 @@ enum DuoStatus {
impl DuoStatus {
fn data(self) -> Option<DuoData> {
match self {
DuoStatus::Global(data) => Some(data),
DuoStatus::User(data) => Some(data),
DuoStatus::Global(data) | DuoStatus::User(data) => Some(data),
DuoStatus::Disabled(_) => None,
}
}
@@ -182,7 +180,7 @@ async fn activate_duo(data: Json<EnableDuoData>, headers: Headers, conn: DbConn)
let twofactor = TwoFactor::new(user.uuid.clone(), type_, data_str);
twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &conn).await;
generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -201,14 +199,14 @@ async fn activate_duo_put(data: Json<EnableDuoData>, headers: Headers, conn: DbC
}
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
use reqwest::{header, Method};
use reqwest::{Method, header};
use std::str::FromStr;
// https://duo.com/docs/authapi#api-details
let url = format!("https://{}{path}", &data.host);
let date = Utc::now().to_rfc2822();
let url = format!("https://{}{path}", data.host);
let dt = Utc::now().to_rfc2822();
let username = &data.ik;
let fields = [&date, method, &data.host, path, params];
let fields = [&dt, method, &data.host, path, params];
let password = crypto::hmac_sign(&data.sk, &fields.join("\n"));
let m = Method::from_str(method).unwrap_or_default();
@@ -216,7 +214,7 @@ async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData)
make_http_request(m, &url)?
.basic_auth(username, Some(password))
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
.header(header::DATE, date)
.header(header::DATE, dt)
.send()
.await?
.error_for_status()?;
@@ -356,9 +354,10 @@ fn parse_duo_values(key: &str, val: &str, ikey: &str, prefix: &str, time: i64) -
err!("Invalid ikey")
}
let expire: i64 = match expire.parse() {
Ok(e) => e,
Err(_) => err!("Invalid expire time"),
let expire: i64 = if let Ok(e) = expire.parse() {
e
} else {
err!("Invalid expire time")
};
if time >= expire {
+21 -23
View File
@@ -1,23 +1,24 @@
use std::collections::HashMap;
use chrono::Utc;
use data_encoding::HEXLOWER;
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation};
use reqwest::{header, StatusCode};
use ring::digest::{digest, Digest, SHA512_256};
use reqwest::{StatusCode, header};
use ring::digest::{Digest, SHA512_256, digest};
use serde::Serialize;
use std::collections::HashMap;
use url::Url;
use crate::{
api::{core::two_factor::duo::get_duo_keys_email, EmptyResult},
CONFIG,
api::{EmptyResult, core::two_factor::duo::get_duo_keys_email},
crypto,
db::{
models::{DeviceId, EventType, TwoFactorDuoContext},
DbConn, DbPool,
models::{DeviceId, EventType, TwoFactorDuoContext},
},
error::Error,
http_client::make_http_request,
CONFIG,
};
use url::Url;
// The location on this service that Duo should redirect users to. For us, this is a bridge
// built in to the Bitwarden clients.
@@ -124,7 +125,7 @@ impl DuoClient {
ClientAssertion {
iss: self.client_id.clone(),
sub: self.client_id.clone(),
aud: url.to_string(),
aud: url.to_owned(),
exp: now + JWT_VALIDITY_SECS,
jti: jwt_id,
iat: now,
@@ -302,7 +303,7 @@ impl DuoClient {
if !(matching_nonces && matching_usernames) {
err!("Error validating Duo authorization, nonce or username mismatch.")
};
}
Ok(())
}
@@ -347,7 +348,7 @@ pub async fn purge_duo_contexts(pool: DbPool) {
if let Ok(conn) = pool.get().await {
TwoFactorDuoContext::purge_expired_duo_contexts(&conn).await;
} else {
error!("Failed to get DB connection while purging expired Duo authentications")
error!("Failed to get DB connection while purging expired Duo authentications");
}
}
@@ -394,7 +395,7 @@ pub async fn get_duo_auth_url(
match client.health_check().await {
Ok(()) => {}
Err(e) => return Err(e),
};
}
// Generate random OAuth2 state and OIDC Nonce
let state: String = crypto::get_random_string_alphanum(STATE_LENGTH);
@@ -438,16 +439,13 @@ pub async fn validate_duo_login(
// Get the context by the state reported by the client. If we don't have one,
// it means the context is either missing or expired.
let ctx = match extract_context(state, conn).await {
Some(c) => c,
None => {
err!(
"Error validating duo authentication",
ErrorEvent {
event: EventType::UserFailedLogIn2fa
}
)
}
let Some(ctx) = extract_context(state, conn).await else {
err!(
"Error validating duo authentication",
ErrorEvent {
event: EventType::UserFailedLogIn2fa
}
)
};
// Context validation steps
@@ -476,13 +474,13 @@ pub async fn validate_duo_login(
match client.health_check().await {
Ok(()) => {}
Err(e) => return Err(e),
};
}
let d: Digest = digest(&SHA512_256, format!("{}{device_identifier}", ctx.nonce).as_bytes());
let hash: String = HEXLOWER.encode(d.as_ref());
match client.exchange_authz_code_for_result(code, email, hash.as_str()).await {
Ok(_) => Ok(()),
Ok(()) => Ok(()),
Err(_) => {
err!(
"Error validating duo authentication",
+25 -22
View File
@@ -1,20 +1,20 @@
use chrono::{DateTime, TimeDelta, Utc};
use rocket::serde::json::Json;
use rocket::Route;
use rocket::{Route, serde::json::Json};
use crate::{
CONFIG,
api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
},
auth::{ClientHeaders, Headers},
crypto,
db::{
models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId},
DbConn,
models::{AuthRequest, AuthRequestId, DeviceId, EventType, TwoFactor, TwoFactorType, User, UserId},
},
error::{Error, MapResult},
mail, CONFIG,
mail,
};
pub fn routes() -> Vec<Route> {
@@ -25,7 +25,7 @@ pub fn routes() -> Vec<Route> {
#[serde(rename_all = "camelCase")]
struct SendEmailLoginData {
#[serde(alias = "DeviceIdentifier")]
device_identifier: DeviceId,
device_identifier: Option<DeviceId>,
#[serde(alias = "Email")]
email: Option<String>,
#[serde(alias = "MasterPasswordHash")]
@@ -91,8 +91,11 @@ async fn send_email_login(data: Json<SendEmailLoginData>, client_headers: Client
user
} else {
let Some(device_identifier) = &data.device_identifier else {
err!("No device identifier has been submitted.")
};
// SSO login only sends device id, so we get the user by the most recently used device
let Some(user) = User::find_by_device_for_email2fa(&data.device_identifier, &conn).await else {
let Some(user) = User::find_by_device_for_email2fa(device_identifier, &conn).await else {
err!("Username or password is incorrect. Try again.")
};
@@ -229,7 +232,7 @@ async fn email(data: Json<EmailData>, headers: Headers, conn: DbConn) -> JsonRes
twofactor.data = email_data.to_json();
twofactor.save(&conn).await?;
_generate_recover_code(&mut user, &conn).await;
generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -281,9 +284,9 @@ pub async fn validate_email_code_str(
twofactor.data = email_data.to_json();
twofactor.save(conn).await?;
let date = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc();
let max_time = CONFIG.email_expiration_time() as i64;
if date + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() {
let dt = DateTime::from_timestamp(email_data.token_sent, 0).expect("Email token timestamp invalid.").naive_utc();
let max_time = CONFIG.email_expiration_time().cast_signed();
if dt + TimeDelta::try_seconds(max_time).unwrap() < Utc::now().naive_utc() {
err!(
"Token has expired",
ErrorEvent {
@@ -339,9 +342,10 @@ impl EmailTokenData {
pub fn from_json(string: &str) -> Result<EmailTokenData, Error> {
let res: Result<EmailTokenData, serde_json::Error> = serde_json::from_str(string);
match res {
Ok(x) => Ok(x),
Err(_) => err!("Could not decode EmailTokenData from string"),
if let Ok(x) = res {
Ok(x)
} else {
err!("Could not decode EmailTokenData from string")
}
}
}
@@ -359,18 +363,17 @@ pub async fn activate_email_2fa(user: &User, conn: &DbConn) -> EmptyResult {
pub fn obscure_email(email: &str) -> String {
let split: Vec<&str> = email.rsplitn(2, '@').collect();
let mut name = split[1].to_string();
let mut name = split[1].to_owned();
let domain = &split[0];
let name_size = name.chars().count();
let new_name = match name_size {
1..=3 => "*".repeat(name_size),
_ => {
let stars = "*".repeat(name_size - 2);
name.truncate(2);
format!("{name}{stars}")
}
let new_name = if let 1..=3 = name_size {
"*".repeat(name_size)
} else {
let stars = "*".repeat(name_size - 2);
name.truncate(2);
format!("{name}{stars}")
};
format!("{new_name}@{domain}")
+14 -21
View File
@@ -1,28 +1,27 @@
use chrono::{TimeDelta, Utc};
use data_encoding::BASE32;
use num_traits::FromPrimitive;
use rocket::serde::json::Json;
use rocket::Route;
use rocket::{Route, serde::json::Json};
use serde::Deserialize;
use serde_json::Value;
use crate::{
CONFIG,
api::{
core::{log_event, log_user_event},
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_event, log_user_event},
},
auth::Headers,
crypto,
db::{
DbConn, DbPool,
models::{
DeviceType, EventType, Membership, MembershipType, OrgPolicyType, Organization, OrganizationId, TwoFactor,
TwoFactorIncomplete, TwoFactorType, User, UserId,
},
DbConn, DbPool,
},
mail,
util::NumberOrString,
CONFIG,
};
pub mod authenticator;
@@ -37,7 +36,7 @@ fn has_global_duo_credentials() -> bool {
CONFIG._enable_duo() && CONFIG.duo_host().is_some() && CONFIG.duo_ikey().is_some() && CONFIG.duo_skey().is_some()
}
pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data: Option<&str>) -> bool {
pub fn is_twofactor_provider_usable(provider_type: &TwoFactorType, provider_data: Option<&str>) -> bool {
#[derive(Deserialize)]
struct DuoProviderData {
host: String,
@@ -46,7 +45,7 @@ pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data:
}
match provider_type {
TwoFactorType::Authenticator => true,
TwoFactorType::Authenticator | TwoFactorType::RecoveryCode => true,
TwoFactorType::Email => CONFIG._enable_email_2fa(),
TwoFactorType::Duo | TwoFactorType::OrganizationDuo => {
provider_data
@@ -59,7 +58,6 @@ pub fn is_twofactor_provider_usable(provider_type: TwoFactorType, provider_data:
}
TwoFactorType::Webauthn => CONFIG.is_webauthn_2fa_supported(),
TwoFactorType::Remember => !CONFIG.disable_2fa_remember(),
TwoFactorType::RecoveryCode => true,
TwoFactorType::U2f
| TwoFactorType::U2fRegisterChallenge
| TwoFactorType::U2fLoginChallenge
@@ -96,7 +94,7 @@ async fn get_twofactor(headers: Headers, conn: DbConn) -> Json<Value> {
.iter()
.filter_map(|tf| {
let provider_type = TwoFactorType::from_i32(tf.atype)?;
is_twofactor_provider_usable(provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf))
is_twofactor_provider_usable(&provider_type, Some(&tf.data)).then(|| TwoFactor::to_json_provider(tf))
})
.collect();
@@ -120,7 +118,7 @@ async fn get_recover(data: Json<PasswordOrOtpData>, headers: Headers, conn: DbCo
})))
}
async fn _generate_recover_code(user: &mut User, conn: &DbConn) {
async fn generate_recover_code(user: &mut User, conn: &DbConn) {
if user.totp_recover.is_none() {
let totp_recover = crypto::encode_random_bytes::<20>(&BASE32);
user.totp_recover = Some(totp_recover);
@@ -180,9 +178,7 @@ pub async fn enforce_2fa_policy(
ip: &std::net::IpAddr,
conn: &DbConn,
) -> EmptyResult {
for member in
Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await.into_iter()
{
for member in Membership::find_by_user_and_policy(&user.uuid, OrgPolicyType::TwoFactorAuthentication, conn).await {
// Policy only applies to non-Owner/non-Admin members who have accepted joining the org
if member.atype < MembershipType::Admin {
if CONFIG.mail_enabled() {
@@ -217,7 +213,7 @@ pub async fn enforce_2fa_policy_for_org(
conn: &DbConn,
) -> EmptyResult {
let org = Organization::find_by_uuid(org_id, conn).await.unwrap();
for member in Membership::find_confirmed_by_org(org_id, conn).await.into_iter() {
for member in Membership::find_confirmed_by_org(org_id, conn).await {
// Don't enforce the policy for Admins and Owners.
if member.atype < MembershipType::Admin && TwoFactor::find_by_user(&member.user_uuid, conn).await.is_empty() {
if CONFIG.mail_enabled() {
@@ -251,12 +247,9 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
return;
}
let conn = match pool.get().await {
Ok(conn) => conn,
_ => {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
return;
}
let Ok(conn) = pool.get().await else {
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
return;
};
let now = Utc::now().naive_utc();
@@ -278,7 +271,7 @@ pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
)
.await
{
Ok(_) => {
Ok(()) => {
if let Err(e) = login.delete(&conn).await {
error!("Error deleting incomplete 2FA record: {e:#?}");
}
+16 -10
View File
@@ -1,16 +1,17 @@
use chrono::{naive::serde::ts_seconds, NaiveDateTime, TimeDelta, Utc};
use rocket::{serde::json::Json, Route};
use chrono::{NaiveDateTime, TimeDelta, Utc, naive::serde::ts_seconds};
use rocket::{Route, serde::json::Json};
use crate::{
CONFIG,
api::EmptyResult,
auth::Headers,
crypto,
db::{
models::{TwoFactor, TwoFactorType, UserId},
DbConn,
models::{TwoFactor, TwoFactorType, UserId},
},
error::{Error, MapResult},
mail, CONFIG,
mail,
};
pub fn routes() -> Vec<Route> {
@@ -44,9 +45,10 @@ impl ProtectedActionData {
pub fn from_json(string: &str) -> Result<Self, Error> {
let res: Result<Self, serde_json::Error> = serde_json::from_str(string);
match res {
Ok(x) => Ok(x),
Err(_) => err!("Could not decode ProtectedActionData from string"),
if let Ok(x) = res {
Ok(x)
} else {
err!("Could not decode ProtectedActionData from string")
}
}
@@ -62,7 +64,9 @@ impl ProtectedActionData {
#[post("/accounts/request-otp")]
async fn request_otp(headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
err!(
"Email is disabled for this server. Either enable email or login using your master password instead of login via device."
);
}
let user = headers.user;
@@ -102,7 +106,9 @@ struct ProtectedActionVerify {
#[post("/accounts/verify-otp", data = "<data>")]
async fn verify_otp(data: Json<ProtectedActionVerify>, headers: Headers, conn: DbConn) -> EmptyResult {
if !CONFIG.mail_enabled() {
err!("Email is disabled for this server. Either enable email or login using your master password instead of login via device.");
err!(
"Email is disabled for this server. Either enable email or login using your master password instead of login via device."
);
}
let user = headers.user;
@@ -133,7 +139,7 @@ pub async fn validate_protected_action_otp(
}
// Check if the token has expired (Using the email 2fa expiration time)
let max_time = CONFIG.email_expiration_time() as i64;
let max_time = CONFIG.email_expiration_time().cast_signed();
if pa_data.time_since_sent().num_seconds() > max_time {
pa.delete(conn).await?;
err!("Token has expired")
+45 -45
View File
@@ -1,34 +1,35 @@
use crate::{
api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData,
},
auth::Headers,
crypto::ct_eq,
db::{
models::{EventType, TwoFactor, TwoFactorType, UserId},
DbConn,
},
error::Error,
util::NumberOrString,
CONFIG,
};
use rocket::serde::json::Json;
use rocket::Route;
use std::{str::FromStr, sync::LazyLock, time::Duration};
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use std::str::FromStr;
use std::sync::LazyLock;
use std::time::Duration;
use url::Url;
use uuid::Uuid;
use webauthn_rs::prelude::{Base64UrlSafeData, Credential, Passkey, PasskeyAuthentication, PasskeyRegistration};
use webauthn_rs::{Webauthn, WebauthnBuilder};
use webauthn_rs::{
Webauthn, WebauthnBuilder,
prelude::{Base64UrlSafeData, Credential, Passkey, PasskeyAuthentication, PasskeyRegistration},
};
use webauthn_rs_proto::{
AuthenticationExtensionsClientOutputs, AuthenticatorAssertionResponseRaw, AuthenticatorAttestationResponseRaw,
PublicKeyCredential, RegisterPublicKeyCredential, RegistrationExtensionsClientOutputs,
RequestAuthenticationExtensions, UserVerificationPolicy,
};
use crate::{
CONFIG,
api::{
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
},
auth::Headers,
crypto::ct_eq,
db::{
DbConn,
models::{EventType, TwoFactor, TwoFactorType, UserId},
},
error::Error,
util::NumberOrString,
};
static WEBAUTHN: LazyLock<Webauthn> = LazyLock::new(|| {
let domain = CONFIG.domain();
let domain_origin = CONFIG.domain_origin();
@@ -38,7 +39,7 @@ static WEBAUTHN: LazyLock<Webauthn> = LazyLock::new(|| {
let webauthn = WebauthnBuilder::new(&rp_id, &rp_origin)
.expect("Creating WebauthnBuilder failed")
.rp_name(&domain)
.timeout(Duration::from_millis(60000));
.timeout(Duration::from_mins(1));
webauthn.build().expect("Building Webauthn failed")
});
@@ -149,7 +150,7 @@ async fn generate_webauthn_challenge(data: Json<PasswordOrOtpData>, headers: Hea
)?;
let mut state = serde_json::to_value(&state)?;
state["rs"]["policy"] = Value::String("discouraged".to_string());
state["rs"]["policy"] = Value::String("discouraged".to_owned());
state["rs"]["extensions"].as_object_mut().unwrap().clear();
let type_ = TwoFactorType::WebauthnRegisterChallenge;
@@ -265,13 +266,12 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, con
// Retrieve and delete the saved challenge state
let type_ = TwoFactorType::WebauthnRegisterChallenge as i32;
let state = match TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
Some(tf) => {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&conn).await?;
state
}
None => err!("Can't recover challenge"),
let state = if let Some(tf) = TwoFactor::find_by_user_and_type(&user.uuid, type_, &conn).await {
let state: PasskeyRegistration = serde_json::from_str(&tf.data)?;
tf.delete(&conn).await?;
state
} else {
err!("Can't recover challenge")
};
// Verify the credentials with the saved state
@@ -291,7 +291,7 @@ async fn activate_webauthn(data: Json<EnableWebauthnData>, headers: Headers, con
TwoFactor::new(user.uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&registrations)?)
.save(&conn)
.await?;
_generate_recover_code(&mut user, &conn).await;
generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
@@ -342,9 +342,10 @@ async fn delete_webauthn(data: Json<DeleteU2FData>, headers: Headers, conn: DbCo
// If entry is migrated from u2f, delete the u2f entry as well
if let Some(mut u2f) = TwoFactor::find_by_user_and_type(&headers.user.uuid, TwoFactorType::U2f as i32, &conn).await
{
let mut data: Vec<U2FRegistration> = match serde_json::from_str(&u2f.data) {
Ok(d) => d,
Err(_) => err!("Error parsing U2F data"),
let mut data: Vec<U2FRegistration> = if let Ok(d) = serde_json::from_str(&u2f.data) {
d
} else {
err!("Error parsing U2F data")
};
data.retain(|r| r.reg.key_handle != removed_item.credential.cred_id().as_slice());
@@ -388,10 +389,10 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes
// Modify to discourage user verification
let mut state = serde_json::to_value(&state)?;
state["ast"]["policy"] = Value::String("discouraged".to_string());
state["ast"]["policy"] = Value::String("discouraged".to_owned());
// Add appid, this is only needed for U2F compatibility, so maybe it can be removed as well
let app_id = format!("{}/app-id.json", &CONFIG.domain());
let app_id = format!("{}/app-id.json", CONFIG.domain());
state["ast"]["appid"] = Value::String(app_id.clone());
response.public_key.user_verification = UserVerificationPolicy::Discouraged_DO_NOT_USE;
@@ -416,18 +417,17 @@ pub async fn generate_webauthn_login(user_id: &UserId, conn: &DbConn) -> JsonRes
pub async fn validate_webauthn_login(user_id: &UserId, response: &str, conn: &DbConn) -> EmptyResult {
let type_ = TwoFactorType::WebauthnLoginChallenge as i32;
let mut state = match TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
Some(tf) => {
let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?;
tf.delete(conn).await?;
state
}
None => err!(
let mut state = if let Some(tf) = TwoFactor::find_by_user_and_type(user_id, type_, conn).await {
let state: PasskeyAuthentication = serde_json::from_str(&tf.data)?;
tf.delete(conn).await?;
state
} else {
err!(
"Can't recover login challenge",
ErrorEvent {
event: EventType::UserFailedLogIn2fa
}
),
)
};
let rsp: PublicKeyCredentialCopy = serde_json::from_str(response)?;
+10 -10
View File
@@ -1,20 +1,19 @@
use rocket::serde::json::Json;
use rocket::Route;
use rocket::{Route, serde::json::Json};
use serde_json::Value;
use yubico::{config::Config, verify_async};
use crate::{
CONFIG,
api::{
core::{log_user_event, two_factor::_generate_recover_code},
EmptyResult, JsonResult, PasswordOrOtpData,
core::{log_user_event, two_factor::generate_recover_code},
},
auth::Headers,
db::{
models::{EventType, TwoFactor, TwoFactorType},
DbConn,
models::{EventType, TwoFactor, TwoFactorType},
},
error::{Error, MapResult},
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -46,7 +45,7 @@ pub struct YubikeyMetadata {
fn parse_yubikeys(data: &EnableYubikeyData) -> Vec<String> {
let data_keys = [&data.key1, &data.key2, &data.key3, &data.key4, &data.key5];
data_keys.iter().filter_map(|e| e.as_ref().cloned()).collect()
data_keys.into_iter().flatten().cloned().collect()
}
fn jsonify_yubikeys(yubikeys: Vec<String>) -> Value {
@@ -64,9 +63,10 @@ fn get_yubico_credentials() -> Result<(String, String), Error> {
err!("Yubico support is disabled");
}
match (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) {
(Some(id), Some(secret)) => Ok((id, secret)),
_ => err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled"),
if let (Some(id), Some(secret)) = (CONFIG.yubico_client_id(), CONFIG.yubico_secret_key()) {
Ok((id, secret))
} else {
err!("`YUBICO_CLIENT_ID` or `YUBICO_SECRET_KEY` environment variable is not set. Yubikey OTP Disabled")
}
}
@@ -162,7 +162,7 @@ async fn activate_yubikey(data: Json<EnableYubikeyData>, headers: Headers, conn:
yubikey_data.data = serde_json::to_string(&yubikey_metadata).unwrap();
yubikey_data.save(&conn).await?;
_generate_recover_code(&mut user, &conn).await;
generate_recover_code(&mut user, &conn).await;
log_user_event(EventType::UserUpdated2fa as i32, &user.uuid, headers.device.atype, &headers.ip.ip, &conn).await;
+89 -110
View File
@@ -6,28 +6,29 @@ use std::{
};
use bytes::{Bytes, BytesMut};
use futures::{stream::StreamExt, TryFutureExt};
use futures::{TryFutureExt, stream::StreamExt};
use html5gum::{Emitter, HtmlString, Readable, StringReader, Tokenizer};
use regex::Regex;
use reqwest::{
header::{self, HeaderMap, HeaderValue},
Client, Response,
header::{self, HeaderMap, HeaderValue},
};
use rocket::{http::ContentType, response::Redirect, Route};
use svg_hush::{data_url_filter, Filter};
use rocket::{Route, http::ContentType, response::Redirect};
use svg_hush::{Filter, data_url_filter};
use crate::{
CONFIG,
config::PathType,
error::Error,
http_client::{get_reqwest_client_builder, should_block_address, CustomHttpClientError},
http_client::{CustomHttpClientError, get_reqwest_client_builder, get_valid_host, should_block_host},
util::Cached,
CONFIG,
};
pub fn routes() -> Vec<Route> {
match CONFIG.icon_service().as_str() {
"internal" => routes![icon_internal],
_ => routes![icon_external],
if CONFIG.icon_service().as_str() == "internal" {
routes![icon_internal]
} else {
routes![icon_external]
}
}
@@ -81,19 +82,19 @@ static ICON_SIZE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?x)(\d+
// The function name `icon_external` is checked in the `on_response` function in `AppHeaders`
// It is used to prevent sending a specific header which breaks icon downloads.
// If this function needs to be renamed, also adjust the code in `util.rs`
#[get("/<domain>/icon.png")]
fn icon_external(domain: &str) -> Cached<Option<Redirect>> {
if !is_valid_domain(domain) {
warn!("Invalid domain: {domain}");
#[get("/<host>/icon.png")]
fn icon_external(host: &str) -> Cached<Option<Redirect>> {
let Ok(host) = get_valid_host(host) else {
warn!("Invalid host: {host}");
return Cached::ttl(None, CONFIG.icon_cache_negttl(), true);
};
if should_block_host(&host).is_err() {
warn!("Blocked address: {host}");
return Cached::ttl(None, CONFIG.icon_cache_negttl(), true);
}
if should_block_address(domain) {
warn!("Blocked address: {domain}");
return Cached::ttl(None, CONFIG.icon_cache_negttl(), true);
}
let url = CONFIG._icon_service_url().replace("{}", domain);
let url = CONFIG._icon_service_url().replace("{}", &host.to_string());
let redir = match CONFIG.icon_redirect_code() {
301 => Some(Redirect::moved(url)), // legacy permanent redirect
302 => Some(Redirect::found(url)), // legacy temporary redirect
@@ -107,12 +108,21 @@ fn icon_external(domain: &str) -> Cached<Option<Redirect>> {
Cached::ttl(redir, CONFIG.icon_cache_ttl(), true)
}
#[get("/<domain>/icon.png")]
async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
#[get("/<host>/icon.png")]
async fn icon_internal(host: &str) -> Cached<(ContentType, Vec<u8>)> {
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
if !is_valid_domain(domain) {
warn!("Invalid domain: {domain}");
let Ok(host) = get_valid_host(host) else {
warn!("Invalid host: {host}");
return Cached::ttl(
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
);
};
if should_block_host(&host).is_err() {
warn!("Blocked address: {host}");
return Cached::ttl(
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
@@ -120,16 +130,7 @@ async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
);
}
if should_block_address(domain) {
warn!("Blocked address: {domain}");
return Cached::ttl(
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
CONFIG.icon_cache_negttl(),
true,
);
}
match get_icon(domain).await {
match get_icon(&host.to_string()).await {
Some((icon, icon_type)) => {
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
}
@@ -137,42 +138,6 @@ async fn icon_internal(domain: &str) -> Cached<(ContentType, Vec<u8>)> {
}
}
/// Returns if the domain provided is valid or not.
///
/// This does some manual checks and makes use of Url to do some basic checking.
/// domains can't be larger then 63 characters (not counting multiple subdomains) according to the RFC's, but we limit the total size to 255.
fn is_valid_domain(domain: &str) -> bool {
const ALLOWED_CHARS: &str = "-.";
// If parsing the domain fails using Url, it will not work with reqwest.
if let Err(parse_error) = url::Url::parse(format!("https://{domain}").as_str()) {
debug!("Domain parse error: '{domain}' - {parse_error:?}");
return false;
} else if domain.is_empty()
|| domain.contains("..")
|| domain.starts_with('.')
|| domain.starts_with('-')
|| domain.ends_with('-')
{
debug!(
"Domain validation error: '{domain}' is either empty, contains '..', starts with an '.', starts or ends with a '-'"
);
return false;
} else if domain.len() > 255 {
debug!("Domain validation error: '{domain}' exceeds 255 characters");
return false;
}
for c in domain.chars() {
if !c.is_alphanumeric() && !ALLOWED_CHARS.contains(c) {
debug!("Domain validation error: '{domain}' contains an invalid character '{c}'");
return false;
}
}
true
}
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
let path = format!("{domain}.png");
@@ -183,7 +148,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
if let Some(icon) = get_cached_icon(&path).await {
let icon_type = get_icon_type(&icon).unwrap_or("x-icon");
return Some((icon, icon_type.to_string()));
return Some((icon, icon_type.to_owned()));
}
if CONFIG.disable_icon_download() {
@@ -194,7 +159,7 @@ async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
match download_icon(domain).await {
Ok((icon, icon_type)) => {
save_icon(&path, icon.to_vec()).await;
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_owned()))
}
Err(e) => {
// If this error comes from the custom resolver, this means this is a blocked domain
@@ -219,10 +184,10 @@ async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
}
// Try to read the cached icon, and return it if it exists
if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache) {
if let Ok(buf) = operator.read(path).await {
return Some(buf.to_vec());
}
if let Ok(operator) = CONFIG.opendal_operator_for_path_type(&PathType::IconCache)
&& let Ok(buf) = operator.read(path).await
{
return Some(buf.to_vec());
}
None
@@ -316,17 +281,17 @@ fn get_favicons_node(dom: Tokenizer<StringReader<'_>, FaviconEmitter>, icons: &m
}
for icon_tag in icon_tags {
if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF) {
if let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default()) {
let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) {
std::str::from_utf8(v).unwrap_or_default()
} else {
""
};
let priority = get_icon_priority(full_href.as_str(), sizes);
icons.push(Icon::new(priority, full_href.to_string()));
}
};
if let Some(icon_href) = icon_tag.attributes.get(ATTR_HREF)
&& let Ok(full_href) = base_url.join(std::str::from_utf8(icon_href).unwrap_or_default())
{
let sizes = if let Some(v) = icon_tag.attributes.get(ATTR_SIZES) {
std::str::from_utf8(v).unwrap_or_default()
} else {
""
};
let priority = get_icon_priority(full_href.as_str(), sizes);
icons.push(Icon::new(priority, full_href.to_string()));
}
}
}
@@ -367,7 +332,7 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
tld = domain_parts.next_back().unwrap(),
base = domain_parts.next_back().unwrap()
);
if is_valid_domain(&base_domain) {
if get_valid_host(&base_domain).is_ok() {
let sslbase = format!("https://{base_domain}");
let httpbase = format!("http://{base_domain}");
debug!("[get_icon_url]: Trying without subdomains '{base_domain}'");
@@ -378,7 +343,7 @@ async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
} else if is_ip.is_err() && domain.matches('.').count() < 2 {
let www_domain = format!("www.{domain}");
if is_valid_domain(&www_domain) {
if get_valid_host(&www_domain).is_ok() {
let sslwww = format!("https://{www_domain}");
let httpwww = format!("http://{www_domain}");
debug!("[get_icon_url]: Trying with www. prefix '{www_domain}'");
@@ -442,7 +407,7 @@ async fn get_page(url: &str) -> Result<Response, Error> {
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
let mut client = CLIENT.get(url);
if !referer.is_empty() {
client = client.header("Referer", referer)
client = client.header("Referer", referer);
}
Ok(client.send().await?.error_for_status()?)
@@ -530,11 +495,10 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
let mut buffer = Bytes::new();
let mut icon_type: Option<&str> = None;
use data_url::DataUrl;
for icon in icon_result.iconlist.iter().take(5) {
let mut icons = icon_result.iconlist.iter().take(5).peekable();
while let Some(icon) = icons.next() {
if icon.href.starts_with("data:image") {
let Ok(datauri) = DataUrl::process(&icon.href) else {
let Ok(datauri) = data_url::DataUrl::process(&icon.href) else {
continue;
};
// Check if we are able to decode the data uri
@@ -558,13 +522,25 @@ async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
}
}
_ => debug!("Extracted icon from data:image uri is invalid"),
};
}
} else {
let res = get_page_with_referer(&icon.href, &icon_result.referer).await?;
debug!("Trying {}", icon.href);
// Make sure all icons are checked before returning error
let res = match get_page_with_referer(&icon.href, &icon_result.referer).await {
Ok(r) => r,
Err(e) if icons.peek().is_none() => return Err(e),
Err(e) if CustomHttpClientError::downcast_ref(&e).is_some() => return Err(e), // If blacklisted stop immediately instead of checking the rest of the icons. see explanation and actual handling inside get_icon()
Err(e) => {
warn!("Unable to download icon: {e:?}");
// Continue to next icon
continue;
}
};
buffer = stream_to_bytes_limit(res, 5120 * 1024).await?; // 5120KB/5MB for each icon max (Same as icons.bitwarden.net)
// Check if the icon type is allowed, else try an icon from the list.
// Check if the icon type is allowed, else try another icon from the list.
icon_type = get_icon_type(&buffer);
if icon_type.is_none() {
buffer.clear();
@@ -610,22 +586,25 @@ async fn save_icon(path: &str, icon: Vec<u8>) {
fn get_icon_type(bytes: &[u8]) -> Option<&'static str> {
fn check_svg_after_xml_declaration(bytes: &[u8]) -> Option<&'static str> {
// Look for SVG tag within the first 1KB
if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)]) {
if content.contains("<svg") || content.contains("<SVG") {
return Some("svg+xml");
}
if let Ok(content) = std::str::from_utf8(&bytes[..bytes.len().min(1024)])
&& (content.contains("<svg") || content.contains("<SVG"))
{
return Some("svg+xml");
}
None
}
// Some details can be found here:
// - https://www.garykessler.net/library/file_sigs_GCK_latest.html
// - https://en.wikipedia.org/wiki/List_of_file_signatures
match bytes {
[137, 80, 78, 71, ..] => Some("png"),
[0, 0, 1, 0, ..] => Some("x-icon"),
[82, 73, 70, 70, ..] => Some("webp"),
[255, 216, 255, ..] => Some("jpeg"),
[71, 73, 70, 56, ..] => Some("gif"),
[66, 77, ..] => Some("bmp"),
[60, 115, 118, 103, ..] => Some("svg+xml"), // Normal svg
[137, 80, 78, 71, 13, 10, 26, 10, ..] => Some("png"),
[0, 0, 1, 0, n1, n2, ..] if u16::from_le_bytes([*n1, *n2]) > 0 => Some("x-icon"), // https://en.wikipedia.org/wiki/ICO_(file_format)
[82, 73, 70, 70, _, _, _, _, 87, 69, 66, 80, ..] => Some("webp"), // Only match WebP Images
[255, 216, 255, b, ..] if *b >= 0xC0 => Some("jpeg"),
[71, 73, 70, 56, 55 | 57, 97, ..] => Some("gif"),
[66, 77, _, _, _, _, 0, 0, 0, 0, ..] => Some("bmp"), // https://en.wikipedia.org/wiki/BMP_file_format
[60, 115, 118, 103, ..] => Some("svg+xml"), // Normal svg
[60, 63, 120, 109, 108, ..] => check_svg_after_xml_declaration(bytes), // An svg starting with <?xml
_ => None,
}
@@ -753,7 +732,7 @@ impl FaviconEmitter {
let rel_value =
std::str::from_utf8(token.tag.attributes.get(ATTR_REL).unwrap()).unwrap_or_default();
if rel_value.contains("icon") && !rel_value.contains("mask-icon") {
self.emit_token = true
self.emit_token = true;
}
}
_ => (),
@@ -826,13 +805,13 @@ impl Emitter for FaviconEmitter {
fn push_attribute_name(&mut self, s: &[u8]) {
if let Some(attr) = &mut self.current_attribute {
attr.0.extend(s)
attr.0.extend(s);
}
}
fn push_attribute_value(&mut self, s: &[u8]) {
if let Some(attr) = &mut self.current_attribute {
attr.1.extend(s)
attr.1.extend(s);
}
}
+227 -161
View File
@@ -1,17 +1,20 @@
use chrono::Utc;
use num_traits::FromPrimitive;
use rocket::{
Route,
form::{Form, FromForm},
http::{Cookie, CookieJar, SameSite},
response::Redirect,
serde::json::Json,
Route,
};
use serde_json::Value;
use crate::{
CONFIG,
api::{
ApiResult, EmptyResult, JsonResult,
core::{
accounts::{_prelogin, _register, kdf_upgrade, PreloginData, RegisterData},
accounts::{PreloginData, RegisterData, kdf_upgrade, prelogin, register},
log_user_event,
two_factor::{
authenticator, duo, duo_oidc, email, enforce_2fa_policy, is_twofactor_provider_usable, webauthn,
@@ -20,27 +23,29 @@ use crate::{
},
master_password_policy,
push::register_push_device,
ApiResult, EmptyResult, JsonResult,
},
auth,
auth::{generate_organization_api_key_login_claims, AuthMethod, ClientHeaders, ClientIp, ClientVersion},
auth::{AuthMethod, ClientHeaders, ClientIp, ClientVersion, Secure, generate_organization_api_key_login_claims},
crypto,
db::{
models::{
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeWrapper, OrganizationApiKey,
OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User, UserId,
},
DbConn,
models::{
AuthRequest, AuthRequestId, Device, DeviceId, EventType, Invitation, OIDCCodeResponseError,
OrganizationApiKey, OrganizationId, SsoAuth, SsoUser, TwoFactor, TwoFactorIncomplete, TwoFactorType, User,
UserId,
},
},
error::MapResult,
mail, sso,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
util, CONFIG,
util,
};
pub fn routes() -> Vec<Route> {
routes![
login,
prelogin,
post_prelogin,
prelogin_password,
identity_register,
register_verification_email,
register_finish,
@@ -64,43 +69,43 @@ async fn login(
let login_result = match data.grant_type.as_ref() {
"refresh_token" => {
_check_is_some(&data.refresh_token, "refresh_token cannot be blank")?;
_refresh_login(data, &conn, &client_header.ip).await
check_is_some(data.refresh_token.as_ref(), "refresh_token cannot be blank")?;
refresh_login(data, &conn, &client_header.ip).await
}
"password" if CONFIG.sso_enabled() && CONFIG.sso_only() => err!("SSO sign-in is required"),
"password" => {
_check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.password, "password cannot be blank")?;
_check_is_some(&data.scope, "scope cannot be blank")?;
_check_is_some(&data.username, "username cannot be blank")?;
check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
check_is_some(data.password.as_ref(), "password cannot be blank")?;
check_is_some(data.scope.as_ref(), "scope cannot be blank")?;
check_is_some(data.username.as_ref(), "username cannot be blank")?;
_check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_password_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
password_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await
}
"client_credentials" => {
_check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.client_secret, "client_secret cannot be blank")?;
_check_is_some(&data.scope, "scope cannot be blank")?;
check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
check_is_some(data.client_secret.as_ref(), "client_secret cannot be blank")?;
check_is_some(data.scope.as_ref(), "scope cannot be blank")?;
_check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_api_key_login(data, &mut user_id, &conn, &client_header.ip).await
api_key_login(data, &mut user_id, &conn, &client_header.ip).await
}
"authorization_code" if CONFIG.sso_enabled() => {
_check_is_some(&data.client_id, "client_id cannot be blank")?;
_check_is_some(&data.code, "code cannot be blank")?;
_check_is_some(&data.code_verifier, "code verifier cannot be blank")?;
check_is_some(data.client_id.as_ref(), "client_id cannot be blank")?;
check_is_some(data.code.as_ref(), "code cannot be blank")?;
check_is_some(data.code_verifier.as_ref(), "code verifier cannot be blank")?;
_check_is_some(&data.device_identifier, "device_identifier cannot be blank")?;
_check_is_some(&data.device_name, "device_name cannot be blank")?;
_check_is_some(&data.device_type, "device_type cannot be blank")?;
check_is_some(data.device_identifier.as_ref(), "device_identifier cannot be blank")?;
check_is_some(data.device_name.as_ref(), "device_name cannot be blank")?;
check_is_some(data.device_type.as_ref(), "device_type cannot be blank")?;
_sso_login(data, &mut user_id, &conn, &client_header.ip, &client_version).await
sso_login(data, &mut user_id, &conn, &client_header.ip, client_version.as_ref()).await
}
"authorization_code" => err!("SSO sign-in is not available"),
t => err!("Invalid type", t),
@@ -121,7 +126,7 @@ async fn login(
Err(e) => {
if let Some(ev) = e.get_event() {
log_user_event(ev.event as i32, &user_id, client_header.device_type, &client_header.ip.ip, &conn)
.await
.await;
}
}
}
@@ -130,7 +135,7 @@ async fn login(
login_result
}
async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
async fn refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// When a refresh token is invalid or missing we need to respond with an HTTP BadRequest (400)
// It also needs to return a json which holds at least a key `error` with the value `invalid_grant`
// See the link below for details
@@ -171,19 +176,19 @@ async fn _refresh_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> Json
}
// After exchanging the code we need to check first if 2FA is needed before continuing
async fn _sso_login(
async fn sso_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &DbConn,
ip: &ClientIp,
client_version: &Option<ClientVersion>,
client_version: Option<&ClientVersion>,
) -> JsonResult {
AuthMethod::Sso.check_scope(data.scope.as_ref())?;
// Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?;
let (state, code_verifier) = match (data.code.as_ref(), data.code_verifier.as_ref()) {
let (code, code_verifier) = match (data.code.as_ref(), data.code_verifier.as_ref()) {
(None, _) => err!(
"Got no code in OIDC data",
ErrorEvent {
@@ -199,7 +204,7 @@ async fn _sso_login(
(Some(code), Some(code_verifier)) => (code, code_verifier.clone()),
};
let (sso_auth, user_infos) = sso::exchange_code(state, code_verifier, conn).await?;
let (sso_auth, user_infos) = sso::exchange_code(code, code_verifier, conn).await?;
let user_with_sso = match SsoUser::find_by_identifier(&user_infos.identifier, conn).await {
None => match SsoUser::find_by_mail(&user_infos.email, conn).await {
None => None,
@@ -227,7 +232,33 @@ async fn _sso_login(
}
)
}
Some((user, None)) => Some((user, None)),
Some((user, None)) => match user_infos.email_verified {
None if !CONFIG.sso_allow_unknown_email_verification() => {
error!(
"Login failure ({}), existing non SSO user ({}) with same email ({}) and email verification status is unknown",
user_infos.identifier, user.uuid, user.email
);
err_silent!(
"Email verification status is unknown",
ErrorEvent {
event: EventType::UserFailedLogIn
}
)
}
Some(false) => {
error!(
"Login failure ({}), existing non SSO user ({}) with same email ({}) and email is not verified",
user_infos.identifier, user.uuid, user.email
);
err_silent!(
"Email is not verified by the SSO provider",
ErrorEvent {
event: EventType::UserFailedLogIn
}
)
}
_ => Some((user, None)),
},
},
Some((user, sso_user)) => Some((user, Some(sso_user))),
};
@@ -314,12 +345,12 @@ async fn _sso_login(
authenticated_response(&user, &mut device, auth_tokens, twofactor_token, conn, ip).await
}
async fn _password_login(
async fn password_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &DbConn,
ip: &ClientIp,
client_version: &Option<ClientVersion>,
client_version: Option<&ClientVersion>,
) -> JsonResult {
// Validate scope
AuthMethod::Password.check_scope(data.scope.as_ref())?;
@@ -398,9 +429,9 @@ async fn _password_login(
if user.verified_at.is_none() && CONFIG.mail_enabled() && CONFIG.signups_verify() {
if user.last_verifying_at.is_none()
|| now.signed_duration_since(user.last_verifying_at.unwrap()).num_seconds()
> CONFIG.signups_verify_resend_time() as i64
> CONFIG.signups_verify_resend_time().cast_signed()
{
let resend_limit = CONFIG.signups_verify_resend_limit() as i32;
let resend_limit = CONFIG.signups_verify_resend_limit().cast_signed();
if resend_limit == 0 || user.login_verify_count < resend_limit {
// We want to send another email verification if we require signups to verify
// their email address, and we haven't sent them a reminder in a while...
@@ -536,19 +567,19 @@ async fn authenticated_response(
Ok(Json(result))
}
async fn _api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
async fn api_key_login(data: ConnectData, user_id: &mut Option<UserId>, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Ratelimit the login
crate::ratelimit::check_limit_login(&ip.ip)?;
// Validate scope
match data.scope.as_ref() {
Some(scope) if scope == &AuthMethod::UserApiKey.scope() => _user_api_key_login(data, user_id, conn, ip).await,
Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => _organization_api_key_login(data, conn, ip).await,
Some(scope) if scope == &AuthMethod::UserApiKey.scope() => user_api_key_login(data, user_id, conn, ip).await,
Some(scope) if scope == &AuthMethod::OrgApiKey.scope() => organization_api_key_login(data, conn, ip).await,
_ => err!("Scope not supported"),
}
}
async fn _user_api_key_login(
async fn user_api_key_login(
data: ConnectData,
user_id: &mut Option<UserId>,
conn: &DbConn,
@@ -680,13 +711,13 @@ async fn _user_api_key_login(
Ok(Json(result))
}
async fn _organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
async fn organization_api_key_login(data: ConnectData, conn: &DbConn, ip: &ClientIp) -> JsonResult {
// Get the org via the client_id
let client_id = data.client_id.as_ref().unwrap();
let Some(org_id) = client_id.strip_prefix("organization.") else {
err!("Malformed client_id", format!("IP: {}.", ip.ip))
};
let org_id: OrganizationId = org_id.to_string().into();
let org_id: OrganizationId = org_id.to_owned().into();
let Some(org_api_key) = OrganizationApiKey::find_by_org_uuid(&org_id, conn).await else {
err!("Invalid client_id", format!("IP: {}.", ip.ip))
};
@@ -717,14 +748,13 @@ async fn get_device(data: &ConnectData, conn: &DbConn, user: &User) -> ApiResult
let device_name = data.device_name.clone().expect("No device name provided");
// Find device or create new
match Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await {
Some(device) => Ok(device),
None => {
let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type);
// save device without updating `device.updated_at`
device.save(false, conn).await?;
Ok(device)
}
if let Some(device) = Device::find_by_uuid_and_user(&device_id, &user.uuid, conn).await {
Ok(device)
} else {
let mut device = Device::new(device_id, user.uuid.clone(), device_name, device_type);
// save device without updating `device.updated_at`
device.save(false, conn).await?;
Ok(device)
}
}
@@ -733,7 +763,7 @@ async fn twofactor_auth(
data: &ConnectData,
device: &mut Device,
ip: &ClientIp,
client_version: &Option<ClientVersion>,
client_version: Option<&ClientVersion>,
conn: &DbConn,
) -> ApiResult<Option<String>> {
let twofactors = TwoFactor::find_by_user(&user.uuid, conn).await;
@@ -750,7 +780,7 @@ async fn twofactor_auth(
.iter()
.filter_map(|tf| {
let provider_type = TwoFactorType::from_i32(tf.atype)?;
(tf.enabled && is_twofactor_provider_usable(provider_type, Some(&tf.data))).then_some(tf.atype)
(tf.enabled && is_twofactor_provider_usable(&provider_type, Some(&tf.data))).then_some(tf.atype)
})
.collect();
if twofactor_ids.is_empty() {
@@ -758,59 +788,51 @@ async fn twofactor_auth(
}
let selected_id = data.two_factor_provider.unwrap_or(twofactor_ids[0]); // If we aren't given a two factor provider, assume the first one
// Ignore Remember and RecoveryCode Types during this check, these are special
// Ignore Remember and RecoveryCode Types during this check, these are special
if ![TwoFactorType::Remember as i32, TwoFactorType::RecoveryCode as i32].contains(&selected_id)
&& !twofactor_ids.contains(&selected_id)
{
err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"Invalid two factor provider"
)
}
let twofactor_code = match data.two_factor_token {
Some(ref code) => code,
None => {
err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA token not provided"
)
}
let Some(ref twofactor_code) = data.two_factor_token else {
err_json!(
json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA token not provided"
)
};
let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled);
use crate::crypto::ct_eq;
let selected_data = _selected_data(selected_twofactor);
let selected_data = selected_data(selected_twofactor);
match TwoFactorType::from_i32(selected_id) {
Some(TwoFactorType::Authenticator) => {
authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await?
authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await?;
}
Some(TwoFactorType::Webauthn) => webauthn::validate_webauthn_login(&user.uuid, twofactor_code, conn).await?,
Some(TwoFactorType::YubiKey) => yubikey::validate_yubikey_login(twofactor_code, &selected_data?).await?,
Some(TwoFactorType::Duo) => {
match CONFIG.duo_use_iframe() {
true => {
// Legacy iframe prompt flow
duo::validate_duo_login(&user.email, twofactor_code, conn).await?
}
false => {
// OIDC based flow
duo_oidc::validate_duo_login(
&user.email,
twofactor_code,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?
}
if CONFIG.duo_use_iframe() {
// Legacy iframe prompt flow
duo::validate_duo_login(&user.email, twofactor_code, conn).await?;
} else {
// OIDC based flow
duo_oidc::validate_duo_login(
&user.email,
twofactor_code,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?;
}
}
Some(TwoFactorType::Email) => {
email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await?
email::validate_email_code_str(&user.uuid, twofactor_code, &selected_data?, &ip.ip, conn).await?;
}
Some(TwoFactorType::Remember) => {
match device.twofactor_remember {
@@ -818,7 +840,7 @@ async fn twofactor_auth(
// If it is invalid we need to trigger the 2FA Login prompt
Some(ref token)
if !CONFIG.disable_2fa_remember()
&& (ct_eq(token, twofactor_code)
&& (crypto::ct_eq(token, twofactor_code)
&& auth::decode_2fa_remember(twofactor_code)
.is_ok_and(|t| t.sub == device.uuid && t.user_uuid == user.uuid)) => {}
_ => {
@@ -829,7 +851,7 @@ async fn twofactor_auth(
device.save(true, conn).await?;
}
err_json!(
_json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
json_err_twofactor(&twofactor_ids, &user.uuid, data, client_version, conn).await?,
"2FA Remember token not provided or expired"
)
}
@@ -870,15 +892,15 @@ async fn twofactor_auth(
Ok(two_factor)
}
fn _selected_data(tf: Option<TwoFactor>) -> ApiResult<String> {
fn selected_data(tf: Option<TwoFactor>) -> ApiResult<String> {
tf.map(|t| t.data).map_res("Two factor doesn't exist")
}
async fn _json_err_twofactor(
async fn json_err_twofactor(
providers: &[i32],
user_id: &UserId,
data: &ConnectData,
client_version: &Option<ClientVersion>,
client_version: Option<&ClientVersion>,
conn: &DbConn,
) -> ApiResult<Value> {
let mut result = json!({
@@ -895,42 +917,38 @@ async fn _json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = Value::Null;
match TwoFactorType::from_i32(*provider) {
Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ }
Some(TwoFactorType::Webauthn) if CONFIG.is_webauthn_2fa_supported() => {
let request = webauthn::generate_webauthn_login(user_id, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = request.0;
}
Some(TwoFactorType::Duo) => {
let email = match User::find_by_uuid(user_id, conn).await {
Some(u) => u.email,
None => err!("User does not exist"),
let email = if let Some(u) = User::find_by_uuid(user_id, conn).await {
u.email
} else {
err!("User does not exist")
};
match CONFIG.duo_use_iframe() {
true => {
// Legacy iframe prompt flow
let (signature, host) = duo::generate_duo_signature(&email, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Host": host,
"Signature": signature,
})
}
false => {
// OIDC based flow
let auth_url = duo_oidc::get_duo_auth_url(
&email,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?;
if CONFIG.duo_use_iframe() {
// Legacy iframe prompt flow
let (signature, host) = duo::generate_duo_signature(&email, conn).await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Host": host,
"Signature": signature,
});
} else {
// OIDC based flow
let auth_url = duo_oidc::get_duo_auth_url(
&email,
data.client_id.as_ref().unwrap(),
data.device_identifier.as_ref().unwrap(),
conn,
)
.await?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"AuthUrl": auth_url,
})
}
result["TwoFactorProviders2"][provider.to_string()] = json!({
"AuthUrl": auth_url,
});
}
}
@@ -943,7 +961,7 @@ async fn _json_err_twofactor(
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Nfc": yubikey_metadata.nfc,
})
});
}
Some(tf_type @ TwoFactorType::Email) => {
@@ -961,16 +979,30 @@ async fn _json_err_twofactor(
// Send email immediately if email is the only 2FA option.
if providers.len() == 1 && !disabled_send {
email::send_token(user_id, conn).await?
email::send_token(user_id, conn).await?;
}
let email_data = email::EmailTokenData::from_json(&twofactor.data)?;
result["TwoFactorProviders2"][provider.to_string()] = json!({
"Email": email::obscure_email(&email_data.email),
})
});
}
_ => {}
None
| Some(
TwoFactorType::Authenticator
| TwoFactorType::EmailVerificationChallenge
| TwoFactorType::OrganizationDuo
| TwoFactorType::ProtectedActions
| TwoFactorType::RecoveryCode
| TwoFactorType::Remember
| TwoFactorType::U2f
| TwoFactorType::U2fLoginChallenge
| TwoFactorType::U2fRegisterChallenge
| TwoFactorType::Webauthn
| TwoFactorType::WebauthnLoginChallenge
| TwoFactorType::WebauthnRegisterChallenge,
) => { /* Nothing special to do for these providers */ }
}
}
@@ -978,13 +1010,18 @@ async fn _json_err_twofactor(
}
#[post("/accounts/prelogin", data = "<data>")]
async fn prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
_prelogin(data, conn).await
async fn post_prelogin(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
prelogin(data, conn).await
}
#[post("/accounts/prelogin/password", data = "<data>")]
async fn prelogin_password(data: Json<PreloginData>, conn: DbConn) -> Json<Value> {
prelogin(data, conn).await
}
#[post("/accounts/register", data = "<data>")]
async fn identity_register(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
_register(data, false, conn).await
register(data, false, conn).await
}
#[derive(Debug, Deserialize)]
@@ -1023,13 +1060,13 @@ async fn register_verification_email(
if should_send_mail {
let user = User::find_by_mail(&data.email, &conn).await;
if user.filter(|u| u.private_key.is_some()).is_some() {
if user.as_ref().is_some_and(|u| u.private_key.is_some()) {
// There is still a timing side channel here in that the code
// paths that send mail take noticeably longer than ones that don't.
// Add a randomized sleep to mitigate this somewhat.
use rand::{rngs::SmallRng, RngExt};
use rand::{RngExt, rngs::SmallRng};
let mut rng: SmallRng = rand::make_rng();
let sleep_ms = rng.random_range(900..=1100) as u64;
let sleep_ms: u64 = rng.random_range(900..=1100);
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
} else {
mail::send_register_verify_email(&data.email, &token).await?;
@@ -1045,7 +1082,7 @@ async fn register_verification_email(
#[post("/accounts/register/finish", data = "<data>")]
async fn register_finish(data: Json<RegisterData>, conn: DbConn) -> JsonResult {
_register(data, true, conn).await
register(data, true, conn).await
}
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
@@ -1104,11 +1141,11 @@ struct ConnectData {
// Needed for authorization code
#[field(name = uncased("code"))]
code: Option<OIDCState>,
code: Option<OIDCCode>,
#[field(name = uncased("code_verifier"))]
code_verifier: Option<OIDCCodeVerifier>,
}
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
fn check_is_some<T>(value: Option<&T>, msg: &str) -> EmptyResult {
if value.is_none() {
err!(msg)
}
@@ -1127,33 +1164,32 @@ fn prevalidate() -> JsonResult {
}
}
const SSO_BINDING_COOKIE: &str = "VW_SSO_BINDING";
#[get("/connect/oidc-signin?<code>&<state>", rank = 1)]
async fn oidcsignin(code: OIDCCode, state: String, mut conn: DbConn) -> ApiResult<Redirect> {
_oidcsignin_redirect(
state,
OIDCCodeWrapper::Ok {
code,
},
&mut conn,
)
.await
async fn oidcsignin(code: OIDCCode, state: String, cookies: &CookieJar<'_>, mut conn: DbConn) -> ApiResult<Redirect> {
oidcsignin_redirect(state, code, None, cookies, &mut conn).await
}
// Bitwarden client appear to only care for code and state so we pipe it through
// cf: https://github.com/bitwarden/clients/blob/80b74b3300e15b4ae414dc06044cc9b02b6c10a6/libs/auth/src/angular/sso/sso.component.ts#L141
// Bitwarden client appear to only care for code and state
// We save the error in the database and set the encoded state as the code to be able to retrieve them later on
// cf: https://github.com/bitwarden/clients/blob/afd36d290ce18fb0048e0575e7d5a8f78b5dbffc/libs/auth/src/angular/sso/sso.component.ts#L156
#[get("/connect/oidc-signin?<state>&<error>&<error_description>", rank = 2)]
async fn oidcsignin_error(
state: String,
error: String,
error_description: Option<String>,
cookies: &CookieJar<'_>,
mut conn: DbConn,
) -> ApiResult<Redirect> {
_oidcsignin_redirect(
state,
OIDCCodeWrapper::Error {
oidcsignin_redirect(
state.clone(),
state.into(),
Some(OIDCCodeResponseError {
error,
error_description,
},
}),
cookies,
&mut conn,
)
.await
@@ -1162,18 +1198,32 @@ async fn oidcsignin_error(
// The state was encoded using Base64 to ensure no issue with providers.
// iss and scope parameters are needed for redirection to work on IOS.
// We pass the state as the code to get it back later on.
async fn _oidcsignin_redirect(
async fn oidcsignin_redirect(
base64_state: String,
code_response: OIDCCodeWrapper,
code: OIDCCode,
error: Option<OIDCCodeResponseError>,
cookies: &CookieJar<'_>,
conn: &mut DbConn,
) -> ApiResult<Redirect> {
let state = sso::decode_state(&base64_state)?;
let mut sso_auth = match SsoAuth::find(&state, conn).await {
None => err!(format!("Cannot retrieve sso_auth for {state}")),
Some(sso_auth) => sso_auth,
let Some(mut sso_auth) = SsoAuth::find(&state, conn).await else {
err!(format!("Cannot retrieve sso_auth for {state}"))
};
sso_auth.code_response = Some(code_response);
// Browser-binding check
// The cookie was set on /connect/authorize and must come from the same browser that initiated the flow.
let cookie_value = cookies.get(SSO_BINDING_COOKIE).map(|c| c.value().to_owned());
let provided_hash = cookie_value.as_deref().map(|v| crypto::sha256_hex(v.as_bytes()));
match (sso_auth.binding_hash.as_deref(), provided_hash.as_deref()) {
(Some(expected), Some(actual)) if crypto::ct_eq(expected, actual) => {}
_ => err!(format!("SSO session binding mismatch for {state}")),
}
cookies
.remove(Cookie::build(SSO_BINDING_COOKIE).path(format!("{}/identity/connect/", CONFIG.domain_path())).build());
sso_auth.code_response = Some(code.clone());
sso_auth.code_response_error = error;
sso_auth.updated_at = Utc::now().naive_utc();
sso_auth.save(conn).await?;
@@ -1183,7 +1233,7 @@ async fn _oidcsignin_redirect(
};
url.query_pairs_mut()
.append_pair("code", &state)
.append_pair("code", &code)
.append_pair("state", &state)
.append_pair("scope", &AuthMethod::Sso.scope())
.append_pair("iss", &CONFIG.domain());
@@ -1219,7 +1269,7 @@ struct AuthorizeData {
// The `redirect_uri` will change depending of the client (web, android, ios ..)
#[get("/connect/authorize?<data..>")]
async fn authorize(data: AuthorizeData, conn: DbConn) -> ApiResult<Redirect> {
async fn authorize(data: AuthorizeData, cookies: &CookieJar<'_>, secure: Secure, conn: DbConn) -> ApiResult<Redirect> {
let AuthorizeData {
client_id,
redirect_uri,
@@ -1233,7 +1283,23 @@ async fn authorize(data: AuthorizeData, conn: DbConn) -> ApiResult<Redirect> {
err!("Unsupported code challenge method");
}
let auth_url = sso::authorize_url(state, code_challenge, &client_id, &redirect_uri, conn).await?;
// Generate browser-binding token. Stored hashed in DB; raw value handed to the browser as a cookie.
// Validated on /connect/oidc-signin
let binding_token = data_encoding::BASE64URL_NOPAD.encode(&crypto::get_random_bytes::<32>());
let binding_hash = crypto::sha256_hex(binding_token.as_bytes());
let auth_url =
sso::authorize_url(state, code_challenge, &client_id, &redirect_uri, Some(binding_hash), conn).await?;
cookies.add(
Cookie::build((SSO_BINDING_COOKIE, binding_token))
.path(format!("{}/identity/connect/", CONFIG.domain_path()))
.max_age(time::Duration::seconds(sso::SSO_AUTH_EXPIRATION.num_seconds()))
.same_site(SameSite::Lax) // Lax is needed because the IdP runs on a different FQDN
.http_only(true)
.secure(secure.https)
.build(),
);
Ok(Redirect::temporary(String::from(auth_url)))
}
+7 -4
View File
@@ -32,11 +32,13 @@ pub use crate::api::{
web::routes as web_routes,
web::static_files,
};
use crate::db::{
models::{OrgPolicy, OrgPolicyType, User},
DbConn,
use crate::{
CONFIG,
db::{
DbConn,
models::{OrgPolicy, OrgPolicyType, User},
},
};
use crate::CONFIG;
// Type aliases for API methods results
pub type ApiResult<T> = Result<T, crate::error::Error>;
@@ -74,6 +76,7 @@ impl PasswordOrOtpData {
}
}
#[expect(clippy::struct_excessive_bools, reason = "Bitwarden clients expect the data in this specific format")]
#[derive(Debug, Default, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct MasterPasswordPolicy {
+20 -19
View File
@@ -6,17 +6,22 @@ use std::{
use chrono::{NaiveDateTime, Utc};
use rmpv::Value;
use rocket::{futures::StreamExt, Route};
use rocket::{Route, futures::StreamExt};
use rocket_ws::{Message, WebSocket};
use tokio::sync::mpsc::Sender;
use crate::{
CONFIG, Error,
auth::{ClientIp, WsAccessTokenHeader},
db::{
models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId},
DbConn,
models::{AuthRequestId, Cipher, CollectionId, Device, DeviceId, Folder, PushId, Send as DbSend, User, UserId},
},
Error, CONFIG,
};
use super::{
push::push_auth_request, push::push_auth_response, push_cipher_update, push_folder_update, push_logout,
push_send_update, push_user_update,
};
pub static WS_USERS: LazyLock<Arc<WebSocketUsers>> = LazyLock::new(|| {
@@ -31,11 +36,6 @@ pub static WS_ANONYMOUS_SUBSCRIPTIONS: LazyLock<Arc<AnonymousWebSocketSubscripti
})
});
use super::{
push::push_auth_request, push::push_auth_response, push_cipher_update, push_folder_update, push_logout,
push_send_update, push_user_update,
};
static NOTIFICATIONS_DISABLED: LazyLock<bool> = LazyLock::new(|| !CONFIG.enable_websocket() && !CONFIG.push_enabled());
pub fn routes() -> Vec<Route> {
@@ -102,7 +102,7 @@ impl Drop for WSAnonymousEntryMapGuard {
}
}
#[allow(tail_expr_drop_order)]
#[expect(tail_expr_drop_order)]
#[get("/hub?<data..>")]
fn websockets_hub<'r>(
ws: WebSocket,
@@ -186,7 +186,7 @@ fn websockets_hub<'r>(
})
}
#[allow(tail_expr_drop_order)]
#[expect(tail_expr_drop_order)]
#[get("/anonymous-hub?<token..>")]
fn anonymous_websockets_hub<'r>(ws: WebSocket, token: String, ip: ClientIp) -> Result<rocket_ws::Stream!['r], Error> {
info!("Accepting Anonymous Rocket WS connection from {}", ip.ip);
@@ -268,14 +268,15 @@ fn serialize(val: &Value) -> Vec<u8> {
let mut len_buf: Vec<u8> = Vec::new();
loop {
let mut size_part = size & 0x7f;
#[expect(clippy::cast_possible_truncation, reason = "masked to 7 bits, fits u8")]
let mut size_part = (size & 0x7f) as u8;
size >>= 7;
if size > 0 {
size_part |= 0x80;
}
len_buf.push(size_part as u8);
len_buf.push(size_part);
if size == 0 {
break;
@@ -329,7 +330,7 @@ pub struct WebSocketUsers {
impl WebSocketUsers {
async fn send_update(&self, user_id: &UserId, data: &[u8]) {
if let Some(user) = self.map.get(user_id.as_ref()).map(|v| v.clone()) {
for (_, sender) in user.iter() {
for (_, sender) in &user {
if let Err(e) = sender.send(Message::binary(data)).await {
error!("Error sending WS update {e}");
}
@@ -338,7 +339,7 @@ impl WebSocketUsers {
}
// NOTE: The last modified date needs to be updated before calling these methods
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
pub async fn send_user_update(&self, ut: UpdateType, user: &User, push_uuid: Option<&PushId>, conn: &DbConn) {
// Skip any processing if both WebSockets and Push are not active
if *NOTIFICATIONS_DISABLED {
return;
@@ -538,10 +539,10 @@ pub struct AnonymousWebSocketSubscriptions {
impl AnonymousWebSocketSubscriptions {
async fn send_update(&self, token: &str, data: &[u8]) {
if let Some(sender) = self.map.get(token).map(|v| v.clone()) {
if let Err(e) = sender.send(Message::binary(data)).await {
error!("Error sending WS update {e}");
}
if let Some(sender) = self.map.get(token).map(|v| v.clone())
&& let Err(e) = sender.send(Message::binary(data)).await
{
error!("Error sending WS update {e}");
}
}
@@ -582,7 +583,7 @@ fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType, acting_device_id:
V::Nil,
"ReceiveMessage".into(),
V::Array(vec![V::Map(vec![
("ContextId".into(), acting_device_id.map(|v| v.to_string().into()).unwrap_or_else(|| V::Nil)),
("ContextId".into(), acting_device_id.map_or(V::Nil, |v| v.to_string().into())),
("Type".into(), (ut as i32).into()),
("Payload".into(), payload.into()),
])]),
+26 -26
View File
@@ -4,21 +4,21 @@ use std::{
};
use reqwest::{
header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
Method,
header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE},
};
use serde_json::Value;
use tokio::sync::RwLock;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult, UpdateType},
db::{
models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId},
DbConn,
models::{AuthRequestId, Cipher, Device, Folder, PushId, Send, User, UserId},
},
http_client::make_http_request,
util::{format_date, get_uuid},
CONFIG,
};
#[derive(Deserialize)]
@@ -74,9 +74,9 @@ async fn get_auth_api_token() -> ApiResult<String> {
};
let mut api_token = API_TOKEN.write().await;
api_token.valid_until = Instant::now()
.checked_add(Duration::new((json_pushtoken.expires_in / 2) as u64, 0)) // Token valid for half the specified time
.unwrap();
// Token valid for half the specified time
let half_expires_in = u64::from((json_pushtoken.expires_in / 2).max(0).cast_unsigned());
api_token.valid_until = Instant::now().checked_add(Duration::from_secs(half_expires_in)).unwrap();
api_token.access_token = json_pushtoken.access_token;
@@ -135,7 +135,7 @@ pub async fn register_push_device(device: &mut Device, conn: &DbConn) -> EmptyRe
Ok(())
}
pub async fn unregister_push_device(push_id: &Option<PushId>) -> EmptyResult {
pub async fn unregister_push_device(push_id: Option<&PushId>) -> EmptyResult {
if !CONFIG.push_enabled() || push_id.is_none() {
return Ok(());
}
@@ -161,7 +161,7 @@ pub async fn push_cipher_update(ut: UpdateType, cipher: &Cipher, device: &Device
// We shouldn't send a push notification on cipher update if the cipher belongs to an organization, this isn't implemented in the upstream server too.
if cipher.organization_uuid.is_some() {
return;
};
}
let Some(user_id) = &cipher.user_uuid else {
debug!("Cipher has no uuid");
return;
@@ -206,7 +206,7 @@ pub async fn push_logout(user: &User, acting_device: Option<&Device>, conn: &DbC
}
}
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: &Option<PushId>, conn: &DbConn) {
pub async fn push_user_update(ut: UpdateType, user: &User, push_uuid: Option<&PushId>, conn: &DbConn) {
if Device::check_user_has_push_device(&user.uuid, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
"userId": user.uuid,
@@ -244,23 +244,23 @@ pub async fn push_folder_update(ut: UpdateType, folder: &Folder, device: &Device
}
pub async fn push_send_update(ut: UpdateType, send: &Send, device: &Device, conn: &DbConn) {
if let Some(s) = &send.user_uuid {
if Device::check_user_has_push_device(s, conn).await {
tokio::task::spawn(send_to_push_relay(json!({
if let Some(s) = &send.user_uuid
&& Device::check_user_has_push_device(s, conn).await
{
tokio::task::spawn(send_to_push_relay(json!({
"userId": send.user_uuid,
"organizationId": null,
"deviceId": device.push_uuid, // Should be the records unique uuid of the acting device (unique uuid per user/device)
"identifier": device.uuid, // Should be the acting device id (aka uuid per device/app)
"type": ut as i32,
"payload": {
"id": send.uuid,
"userId": send.user_uuid,
"organizationId": null,
"deviceId": device.push_uuid, // Should be the records unique uuid of the acting device (unique uuid per user/device)
"identifier": device.uuid, // Should be the acting device id (aka uuid per device/app)
"type": ut as i32,
"payload": {
"id": send.uuid,
"userId": send.user_uuid,
"revisionDate": format_date(&send.revision_date)
},
"clientType": null,
"installationId": null
})));
}
"revisionDate": format_date(&send.revision_date)
},
"clientType": null,
"installationId": null
})));
}
}
@@ -296,7 +296,7 @@ async fn send_to_push_relay(notification_data: Value) {
.await
{
error!("An error occurred while sending a send update to the push relay: {e}");
};
}
}
pub async fn push_auth_request(user_id: &UserId, auth_request_id: &str, device: &Device, conn: &DbConn) {
+38 -10
View File
@@ -1,21 +1,24 @@
use std::path::{Path, PathBuf};
use rocket::{
Catcher, Route,
fs::NamedFile,
http::ContentType,
response::{content::RawCss as Css, content::RawHtml as Html, Redirect},
response::{Redirect, content::RawCss as Css, content::RawHtml as Html},
serde::json::Json,
Catcher, Route,
};
use serde_json::Value;
use crate::{
api::{core::now, ApiResult, EmptyResult},
CONFIG,
api::{ApiResult, EmptyResult, core::now},
auth::decode_file_download,
db::models::{AttachmentId, CipherId},
db::{
DbConn,
models::{AttachmentId, CipherId},
},
error::Error,
util::Cached,
CONFIG,
};
pub fn routes() -> Vec<Route> {
@@ -23,12 +26,20 @@ pub fn routes() -> Vec<Route> {
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
let mut routes = routes![attachments, alive, alive_head, static_files];
if CONFIG.web_vault_enabled() {
routes.append(&mut routes![web_index, web_index_direct, web_index_head, app_id, web_files, vaultwarden_css]);
routes.append(&mut routes![
web_index,
web_index_direct,
web_index_head,
app_id,
apple_app_site_association,
web_files,
vaultwarden_css
]);
}
#[cfg(debug_assertions)]
if CONFIG.reload_templates() {
routes.append(&mut routes![_static_files_dev]);
routes.append(&mut routes![static_files_dev]);
}
routes
@@ -160,6 +171,24 @@ fn app_id() -> Cached<(ContentType, Json<Value>)> {
)
}
#[get("/.well-known/apple-app-site-association")]
fn apple_app_site_association() -> Cached<(ContentType, Json<Value>)> {
Cached::long(
(
ContentType::JSON,
Json(json!({
"webcredentials": {
"apps": [
"LTZ2PFU5D6.com.8bit.bitwarden",
"LTZ2PFU5D6.com.8bit.bitwarden.beta"
]
}
})),
),
true,
)
}
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
@@ -178,7 +207,6 @@ async fn attachments(cipher_id: CipherId, file_id: AttachmentId, token: String)
}
// We use DbConn here to let the alive healthcheck also verify the database connection.
use crate::db::DbConn;
#[get("/alive")]
fn alive(_conn: DbConn) -> Json<String> {
now()
@@ -197,7 +225,7 @@ fn alive_head(_conn: DbConn) -> EmptyResult {
// NOTE: Do not forget to add any new files added to the `static_files` function below!
#[cfg(debug_assertions)]
#[get("/vw_static/<filename>", rank = 1)]
pub async fn _static_files_dev(filename: PathBuf) -> Option<NamedFile> {
pub async fn static_files_dev(filename: PathBuf) -> Option<NamedFile> {
warn!("LOADING STATIC FILES FROM DISK");
let file = filename.to_str().unwrap_or_default();
let ext = filename.extension().unwrap_or_default();
@@ -210,7 +238,7 @@ pub async fn _static_files_dev(filename: PathBuf) -> Option<NamedFile> {
if let Ok(path) = path {
return NamedFile::open(path).await.ok();
};
}
None
}
+73 -81
View File
@@ -5,21 +5,30 @@ use std::{
};
use chrono::{DateTime, TimeDelta, Utc};
use jsonwebtoken::{errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, errors::ErrorKind};
use num_traits::FromPrimitive;
use openssl::rsa::Rsa;
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
use serde::{de::DeserializeOwned, ser::Serialize};
use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use crate::{
CONFIG,
api::ApiResult,
config::PathType,
db::models::{
AttachmentId, CipherId, CollectionId, DeviceId, DeviceType, EmergencyAccessId, MembershipId, OrgApiKeyId,
OrganizationId, SendFileId, SendId, UserId,
db::{
DbConn,
models::{
AttachmentId, CipherId, Collection, CollectionId, Device, DeviceId, DeviceType, EmergencyAccessId,
Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, OrganizationId, SendFileId,
SendId, User, UserId, UserStampException,
},
},
error::Error,
sso, CONFIG,
sso,
};
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
@@ -52,16 +61,12 @@ static PRIVATE_RSA_KEY: OnceLock<EncodingKey> = OnceLock::new();
static PUBLIC_RSA_KEY: OnceLock<DecodingKey> = OnceLock::new();
pub async fn initialize_keys() -> Result<(), Error> {
use std::io::Error;
use std::io::Error as IoError;
let rsa_key_filename = std::path::PathBuf::from(CONFIG.private_rsa_key())
.file_name()
.ok_or_else(|| Error::other("Private RSA key path missing filename"))?
.to_str()
.ok_or_else(|| Error::other("Private RSA key path filename is not valid UTF-8"))?
.to_string();
let rsa_key_filename = crate::storage::file_name(&CONFIG.private_rsa_key())
.ok_or_else(|| IoError::other("Private RSA key path missing filename"))?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::RsaKey).map_err(Error::other)?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::RsaKey).map_err(IoError::other)?;
let priv_key_buffer = match operator.read(&rsa_key_filename).await {
Ok(buffer) => Some(buffer),
@@ -230,7 +235,7 @@ impl LoginJwtClaims {
// let orgmanager: Vec<_> = orgs.iter().filter(|o| o.atype == 3).map(|o| o.org_uuid.clone()).collect();
if exp <= (now + *BW_EXPIRATION).timestamp() {
warn!("Raise access_token lifetime to more than 5min.")
warn!("Raise access_token lifetime to more than 5min.");
}
// Create the JWT claims struct, to send to the client
@@ -257,7 +262,7 @@ impl LoginJwtClaims {
sstamp: user.security_stamp.clone(),
device: device.uuid.clone(),
devicetype: DeviceType::from_i32(device.atype).to_string(),
client_id: client_id.unwrap_or("undefined".to_string()),
client_id: client_id.unwrap_or("undefined".to_owned()),
scope,
amr: vec!["Application".into()],
}
@@ -510,7 +515,7 @@ pub fn generate_admin_claims() -> BasicJwtClaims {
nbf: time_now.timestamp(),
exp: (time_now + TimeDelta::try_minutes(CONFIG.admin_session_lifetime()).unwrap()).timestamp(),
iss: JWT_ADMIN_ISSUER.to_string(),
sub: "admin_panel".to_string(),
sub: "admin_panel".to_owned(),
}
}
@@ -527,16 +532,6 @@ pub fn generate_send_claims(send_id: &SendId, file_id: &SendFileId) -> BasicJwtC
//
// Bearer token authentication
//
use rocket::{
outcome::try_outcome,
request::{FromRequest, Outcome, Request},
};
use crate::db::{
models::{Collection, Device, Membership, MembershipStatus, MembershipType, User, UserStampException},
DbConn,
};
pub struct Host {
pub host: String,
}
@@ -552,7 +547,7 @@ impl<'r> FromRequest<'r> for Host {
let host = if CONFIG.domain_set() {
CONFIG.domain()
} else if let Some(referer) = headers.get_one("Referer") {
referer.to_string()
referer.to_owned()
} else {
// Try to guess from the headers
let protocol = if let Some(proto) = headers.get_one("X-Forwarded-Proto") {
@@ -588,13 +583,15 @@ impl<'r> FromRequest<'r> for ClientHeaders {
type Error = &'static str;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
let Outcome::Success(ip) = ClientIp::from_request(request).await else {
err_handler!("Error getting Client IP")
};
// When unknown or unable to parse, return 14, which is 'Unknown Browser'
let device_type: i32 =
request.headers().get_one("device-type").map(|d| d.parse().unwrap_or(14)).unwrap_or_else(|| 14);
// When unknown or unable to parse, return 'UnknownBrowser'
let device_type: i32 = request
.headers()
.get_one("device-type")
.and_then(|d| d.parse().ok())
.unwrap_or(DeviceType::UnknownBrowser as i32);
Outcome::Success(ClientHeaders {
device_type,
@@ -618,18 +615,19 @@ impl<'r> FromRequest<'r> for Headers {
let headers = request.headers();
let host = try_outcome!(Host::from_request(request).await).host;
let ip = match ClientIp::from_request(request).await {
Outcome::Success(ip) => ip,
_ => err_handler!("Error getting Client IP"),
let Outcome::Success(ip) = ClientIp::from_request(request).await else {
err_handler!("Error getting Client IP")
};
// Get access_token
let access_token: &str = match headers.get_one("Authorization") {
Some(a) => match a.rsplit("Bearer ").next() {
Some(split) => split,
None => err_handler!("No access token provided"),
},
None => err_handler!("No access token provided"),
let access_token: &str = if let Some(a) = headers.get_one("Authorization") {
if let Some(split) = a.rsplit("Bearer ").next() {
split
} else {
err_handler!("No access token provided")
}
} else {
err_handler!("No access token provided")
};
// Check JWT token is valid and get device and user from it
@@ -640,9 +638,8 @@ impl<'r> FromRequest<'r> for Headers {
let device_id = claims.device;
let user_id = claims.sub;
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
};
let Some(device) = Device::find_by_uuid_and_user(&device_id, &user_id, &conn).await else {
@@ -673,7 +670,7 @@ impl<'r> FromRequest<'r> for Headers {
error!("Error updating user: {e:#?}");
}
err_handler!("Stamp exception is expired")
} else if !stamp_exception.routes.contains(&current_route.to_string()) {
} else if !stamp_exception.routes.contains(&current_route.to_owned()) {
err_handler!("Invalid security stamp: Current route and exception route do not match")
} else if stamp_exception.security_stamp != claims.sstamp {
err_handler!("Invalid security stamp for matched stamp exception")
@@ -761,9 +758,8 @@ impl<'r> FromRequest<'r> for OrgHeaders {
match url_org_id {
Some(org_id) if uuid::Uuid::parse_str(&org_id).is_ok() => {
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
};
let user = headers.user;
@@ -835,16 +831,16 @@ impl<'r> FromRequest<'r> for AdminHeaders {
// but there could be cases where it is a query value.
// First check the path, if this is not a valid uuid, try the query values.
fn get_col_id(request: &Request<'_>) -> Option<CollectionId> {
if let Some(Ok(col_id)) = request.param::<String>(3) {
if uuid::Uuid::parse_str(&col_id).is_ok() {
return Some(col_id.into());
}
if let Some(Ok(col_id)) = request.param::<String>(3)
&& uuid::Uuid::parse_str(&col_id).is_ok()
{
return Some(col_id.into());
}
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
if uuid::Uuid::parse_str(&col_id).is_ok() {
return Some(col_id.into());
}
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId")
&& uuid::Uuid::parse_str(&col_id).is_ok()
{
return Some(col_id.into());
}
None
@@ -868,18 +864,16 @@ impl<'r> FromRequest<'r> for ManagerHeaders {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = try_outcome!(OrgHeaders::from_request(request).await);
if headers.is_confirmed_and_manager() {
match get_col_id(request) {
Some(col_id) => {
let conn = match DbConn::from_request(request).await {
Outcome::Success(conn) => conn,
_ => err_handler!("Error getting DB"),
};
if let Some(col_id) = get_col_id(request) {
let Outcome::Success(conn) = DbConn::from_request(request).await else {
err_handler!("Error getting DB")
};
if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await {
err_handler!("The current user isn't a manager for this collection")
}
if !Collection::is_coll_manageable_by_user(&col_id, &headers.membership.user_uuid, &conn).await {
err_handler!("The current user isn't a manager for this collection")
}
_ => err_handler!("Error getting the collection id"),
} else {
err_handler!("Error getting the collection id")
}
Outcome::Success(Self {
@@ -1040,7 +1034,7 @@ impl From<OrgMemberHeaders> for Headers {
//
// Client IP address detection
//
#[derive(Copy, Clone)]
pub struct ClientIp {
pub ip: IpAddr,
}
@@ -1072,6 +1066,7 @@ impl<'r> FromRequest<'r> for ClientIp {
}
}
#[derive(Copy, Clone)]
pub struct Secure {
pub https: bool,
}
@@ -1157,15 +1152,14 @@ pub enum AuthMethod {
impl AuthMethod {
pub fn scope(&self) -> String {
match self {
AuthMethod::OrgApiKey => "api.organization".to_string(),
AuthMethod::Password => "api offline_access".to_string(),
AuthMethod::Sso => "api offline_access".to_string(),
AuthMethod::UserApiKey => "api".to_string(),
AuthMethod::OrgApiKey => "api.organization".to_owned(),
AuthMethod::UserApiKey => "api".to_owned(),
AuthMethod::Password | AuthMethod::Sso => "api offline_access".to_owned(),
}
}
pub fn scope_vec(&self) -> Vec<String> {
self.scope().split_whitespace().map(str::to_string).collect()
self.scope().split_whitespace().map(str::to_owned).collect()
}
pub fn check_scope(&self, scope: Option<&String>) -> ApiResult<String> {
@@ -1278,17 +1272,15 @@ pub async fn refresh_tokens(
};
// Get device by refresh token
let mut device = match Device::find_by_refresh_token(&refresh_claims.device_token, conn).await {
None => err!("Invalid refresh token"),
Some(device) => device,
let Some(mut device) = Device::find_by_refresh_token(&refresh_claims.device_token, conn).await else {
err!("Invalid refresh token")
};
// Save to update `updated_at`.
device.save(true, conn).await?;
let user = match User::find_by_uuid(&device.user_uuid, conn).await {
None => err!("Impossible to find user"),
Some(user) => user,
let Some(user) = User::find_by_uuid(&device.user_uuid, conn).await else {
err!("Impossible to find user")
};
let auth_tokens = match refresh_claims.sub {
+157 -234
View File
@@ -3,8 +3,8 @@ use std::{
fmt,
process::exit,
sync::{
atomic::{AtomicBool, Ordering},
LazyLock, RwLock,
atomic::{AtomicBool, Ordering},
},
};
@@ -14,26 +14,23 @@ use serde::de::{self, Deserialize, Deserializer, MapAccess, Visitor};
use crate::{
error::Error,
storage,
util::{
get_active_web_release, get_env, get_env_bool, is_valid_email, parse_experimental_client_feature_flags,
FeatureFlagFilter,
FeatureFlagFilter, get_active_web_release, get_env, get_env_bool, is_valid_email,
parse_experimental_client_feature_flags,
},
};
static CONFIG_FILE: LazyLock<String> = LazyLock::new(|| {
let data_folder = get_env("DATA_FOLDER").unwrap_or_else(|| String::from("data"));
get_env("CONFIG_FILE").unwrap_or_else(|| format!("{data_folder}/config.json"))
get_env("CONFIG_FILE").unwrap_or_else(|| storage::join_path(&data_folder, "config.json"))
});
static CONFIG_FILE_PARENT_DIR: LazyLock<String> = LazyLock::new(|| {
let path = std::path::PathBuf::from(&*CONFIG_FILE);
path.parent().unwrap_or(std::path::Path::new("data")).to_str().unwrap_or("data").to_string()
});
static CONFIG_FILE_PARENT_DIR: LazyLock<String> =
LazyLock::new(|| storage::parent(&CONFIG_FILE).unwrap_or_else(|| "data".to_owned()));
static CONFIG_FILENAME: LazyLock<String> = LazyLock::new(|| {
let path = std::path::PathBuf::from(&*CONFIG_FILE);
path.file_name().unwrap_or(std::ffi::OsStr::new("config.json")).to_str().unwrap_or("config.json").to_string()
});
static CONFIG_FILENAME: LazyLock<String> =
LazyLock::new(|| storage::file_name(&CONFIG_FILE).unwrap_or_else(|| "config.json".to_owned()));
pub static SKIP_CONFIG_VALIDATION: AtomicBool = AtomicBool::new(false);
@@ -263,7 +260,7 @@ macro_rules! make_config {
}
async fn from_file() -> Result<Self, Error> {
let operator = opendal_operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let operator = storage::operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let config_bytes = operator.read(&CONFIG_FILENAME).await?;
println!("[INFO] Using saved config from `{}` for configuration.\n", *CONFIG_FILE);
serde_json::from_slice(&config_bytes.to_vec()).map_err(Into::into)
@@ -363,13 +360,7 @@ macro_rules! make_config {
)+)+
pub fn prepare_json(&self) -> serde_json::Value {
let (def, cfg, overridden) = {
// Lock the inner as short as possible and clone what is needed to prevent deadlocks
let inner = &self.inner.read().unwrap();
(inner._env.build(), inner.config.clone(), inner._overrides.clone())
};
fn _get_form_type(rust_type: &'static str) -> &'static str {
fn get_form_type(rust_type: &'static str) -> &'static str {
match rust_type {
"Pass" => "password",
"String" => "text",
@@ -378,7 +369,7 @@ macro_rules! make_config {
}
}
fn _get_doc(doc_str: &'static str) -> ElementDoc {
fn get_doc(doc_str: &'static str) -> ElementDoc {
let mut split = doc_str.split("|>").map(str::trim);
ElementDoc {
name: split.next().unwrap_or_default(),
@@ -386,6 +377,12 @@ macro_rules! make_config {
}
}
let (def, cfg, overridden) = {
// Lock the inner as short as possible and clone what is needed to prevent deadlocks
let inner = &self.inner.read().unwrap();
(inner._env.build(), inner.config.clone(), inner._overrides.clone())
};
let data: Vec<GroupData> = vec![
$( // This repetition is for each group
GroupData {
@@ -400,8 +397,8 @@ macro_rules! make_config {
name: stringify!($name),
value: serde_json::to_value(&cfg.$name).unwrap_or_default(),
default: serde_json::to_value(&def.$name).unwrap_or_default(),
r#type: _get_form_type(stringify!($ty)),
doc: _get_doc(concat!($($doc),+)),
r#type: get_form_type(stringify!($ty)),
doc: get_doc(concat!($($doc),+)),
overridden: overridden.contains(&pastey::paste!(stringify!([<$name:upper>]))),
},
)+], // End of elements repetition
@@ -411,9 +408,31 @@ macro_rules! make_config {
}
pub fn get_support_json(&self) -> serde_json::Value {
/// We map over the string and remove all alphanumeric, _ and - characters.
/// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds)
fn privacy_mask(value: &str) -> String {
let mut n: u16 = 0;
let mut colon_match = false;
value
.chars()
.map(|c| {
n += 1;
match c {
':' if n <= 11 => {
colon_match = true;
c
}
'/' if n <= 13 && colon_match => c,
',' => c,
_ => '*',
}
})
.collect::<String>()
}
// Define which config keys need to be masked.
// Pass types will always be masked and no need to put them in the list.
// Besides Pass, only String types will be masked via _privacy_mask.
// Besides Pass, only String types will be masked via privacy_mask.
const PRIVACY_CONFIG: &[&str] = &[
"allowed_connect_src",
"allowed_iframe_ancestors",
@@ -440,28 +459,6 @@ macro_rules! make_config {
inner.config.clone()
};
/// We map over the string and remove all alphanumeric, _ and - characters.
/// This is the fastest way (within micro-seconds) instead of using a regex (which takes mili-seconds)
fn _privacy_mask(value: &str) -> String {
let mut n: u16 = 0;
let mut colon_match = false;
value
.chars()
.map(|c| {
n += 1;
match c {
':' if n <= 11 => {
colon_match = true;
c
}
'/' if n <= 13 && colon_match => c,
',' => c,
_ => '*',
}
})
.collect::<String>()
}
serde_json::Value::Object({
let mut json = serde_json::Map::new();
$($(
@@ -471,7 +468,7 @@ macro_rules! make_config {
for mask_key in PRIVACY_CONFIG {
if let Some(value) = json.get_mut(*mask_key) {
if let Some(s) = value.as_str() {
*value = _privacy_mask(s).into();
*value = privacy_mask(s).into();
}
}
}
@@ -505,23 +502,23 @@ macro_rules! make_config {
make_config! {
folders {
/// Data folder |> Main data folder
data_folder: String, false, def, "data".to_string();
data_folder: String, false, def, "data".to_owned();
/// Database URL
database_url: String, false, auto, |c| format!("{}/db.sqlite3", c.data_folder);
database_url: String, false, auto, |c| format!("sqlite://{}", storage::join_path(&c.data_folder, "db.sqlite3"));
/// Icon cache folder
icon_cache_folder: String, false, auto, |c| format!("{}/icon_cache", c.data_folder);
icon_cache_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "icon_cache");
/// Attachments folder
attachments_folder: String, false, auto, |c| format!("{}/attachments", c.data_folder);
attachments_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "attachments");
/// Sends folder
sends_folder: String, false, auto, |c| format!("{}/sends", c.data_folder);
sends_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "sends");
/// Temp folder |> Used for storing temporary file uploads
tmp_folder: String, false, auto, |c| format!("{}/tmp", c.data_folder);
tmp_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "tmp");
/// Templates folder
templates_folder: String, false, auto, |c| format!("{}/templates", c.data_folder);
templates_folder: String, false, auto, |c| storage::join_path(&c.data_folder, "templates");
/// Session JWT key
rsa_key_filename: String, false, auto, |c| format!("{}/rsa_key", c.data_folder);
rsa_key_filename: String, false, auto, |c| storage::join_path(&c.data_folder, "rsa_key");
/// Web vault folder
web_vault_folder: String, false, def, "web-vault/".to_string();
web_vault_folder: String, false, def, "web-vault/".to_owned();
},
ws {
/// Enable websocket notifications
@@ -531,9 +528,9 @@ make_config! {
/// Enable push notifications
push_enabled: bool, false, def, false;
/// Push relay uri
push_relay_uri: String, false, def, "https://push.bitwarden.com".to_string();
push_relay_uri: String, false, def, "https://push.bitwarden.com".to_owned();
/// Push identity uri
push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_string();
push_identity_uri: String, false, def, "https://identity.bitwarden.com".to_owned();
/// Installation id |> The installation id from https://bitwarden.com/host
push_installation_id: Pass, false, def, String::new();
/// Installation key |> The installation key from https://bitwarden.com/host
@@ -545,38 +542,38 @@ make_config! {
job_poll_interval_ms: u64, false, def, 30_000;
/// Send purge schedule |> Cron schedule of the job that checks for Sends past their deletion date.
/// Defaults to hourly. Set blank to disable this job.
send_purge_schedule: String, false, def, "0 5 * * * *".to_string();
send_purge_schedule: String, false, def, "0 5 * * * *".to_owned();
/// Trash purge schedule |> Cron schedule of the job that checks for trashed items to delete permanently.
/// Defaults to daily. Set blank to disable this job.
trash_purge_schedule: String, false, def, "0 5 0 * * *".to_string();
trash_purge_schedule: String, false, def, "0 5 0 * * *".to_owned();
/// Incomplete 2FA login schedule |> Cron schedule of the job that checks for incomplete 2FA logins.
/// Defaults to once every minute. Set blank to disable this job.
incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_string();
incomplete_2fa_schedule: String, false, def, "30 * * * * *".to_owned();
/// Emergency notification reminder schedule |> Cron schedule of the job that sends expiration reminders to emergency access grantors.
/// Defaults to hourly. (3 minutes after the hour) Set blank to disable this job.
emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_string();
emergency_notification_reminder_schedule: String, false, def, "0 3 * * * *".to_owned();
/// Emergency request timeout schedule |> Cron schedule of the job that grants emergency access requests that have met the required wait time.
/// Defaults to hourly. (7 minutes after the hour) Set blank to disable this job.
emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_string();
emergency_request_timeout_schedule: String, false, def, "0 7 * * * *".to_owned();
/// Event cleanup schedule |> Cron schedule of the job that cleans old events from the event table.
/// Defaults to daily. Set blank to disable this job.
event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_string();
event_cleanup_schedule: String, false, def, "0 10 0 * * *".to_owned();
/// Auth Request cleanup schedule |> Cron schedule of the job that cleans old auth requests from the auth request.
/// Defaults to every minute. Set blank to disable this job.
auth_request_purge_schedule: String, false, def, "30 * * * * *".to_string();
auth_request_purge_schedule: String, false, def, "30 * * * * *".to_owned();
/// Duo Auth context cleanup schedule |> Cron schedule of the job that cleans expired Duo contexts from the database. Does nothing if Duo MFA is disabled or set to use the legacy iframe prompt.
/// Defaults to once every minute. Set blank to disable this job.
duo_context_purge_schedule: String, false, def, "30 * * * * *".to_string();
duo_context_purge_schedule: String, false, def, "30 * * * * *".to_owned();
/// Purge incomplete SSO auth. |> Cron schedule of the job that cleans leftover auth in db due to incomplete SSO login.
/// Defaults to daily. Set blank to disable this job.
purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_string();
purge_incomplete_sso_auth: String, false, def, "0 20 0 * * *".to_owned();
},
/// General settings
settings {
/// Domain URL |> This needs to be set to the URL used to access the server, including 'http[s]://'
/// and port, if it's different than the default. Some server functions don't work correctly without this value
domain: String, true, def, "http://localhost".to_string();
domain: String, true, def, "http://localhost".to_owned();
/// Domain Set |> Indicates if the domain is set by the admin. Otherwise the default will be used.
domain_set: bool, false, def, false;
/// Domain origin |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin)
@@ -656,7 +653,7 @@ make_config! {
admin_token: Pass, true, option;
/// Invitation organization name |> Name shown in the invitation emails that don't come from a specific organization
invitation_org_name: String, true, def, "Vaultwarden".to_string();
invitation_org_name: String, true, def, "Vaultwarden".to_owned();
/// Events days retain |> Number of days to retain events stored in the database. If unset, events are kept indefinitely.
events_days_retain: i64, false, option;
@@ -666,7 +663,7 @@ make_config! {
advanced {
/// Client IP header |> If not present, the remote IP is used.
/// Set to the string "none" (without quotes), to disable any headers and just use the remote IP
ip_header: String, true, def, "X-Real-IP".to_string();
ip_header: String, true, def, "X-Real-IP".to_owned();
/// Internal IP header property, used to avoid recomputing each time
_ip_header_enabled: bool, false, generated, |c| &c.ip_header.trim().to_lowercase() != "none";
/// Icon service |> The predefined icon services are: internal, bitwarden, duckduckgo, google.
@@ -675,7 +672,7 @@ make_config! {
/// `internal` refers to Vaultwarden's built-in icon fetching implementation. If an external
/// service is set, an icon request to Vaultwarden will return an HTTP redirect to the
/// corresponding icon at the external service.
icon_service: String, false, def, "internal".to_string();
icon_service: String, false, def, "internal".to_owned();
/// _icon_service_url
_icon_service_url: String, false, generated, |c| generate_icon_service_url(&c.icon_service);
/// _icon_service_csp
@@ -726,14 +723,14 @@ make_config! {
/// Enable extended logging
extended_logging: bool, false, def, true;
/// Log timestamp format
log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_string();
log_timestamp_format: String, true, def, "%Y-%m-%d %H:%M:%S.%3f".to_owned();
/// Enable the log to output to Syslog
use_syslog: bool, false, def, false;
/// Log file path
log_file: String, false, option;
/// Log level |> Valid values are "trace", "debug", "info", "warn", "error" and "off"
/// For a specific module append it as a comma separated value "info,path::to::module=debug"
log_level: String, false, def, "info".to_string();
log_level: String, false, def, "info".to_owned();
/// Enable DB WAL |> Turning this off might lead to worse performance, but might help if using vaultwarden on some exotic filesystems,
/// that do not support WAL. Please make sure you read project wiki on the topic before changing this setting.
@@ -815,7 +812,7 @@ make_config! {
/// Authority Server |> Base url of the OIDC provider discovery endpoint (without `/.well-known/openid-configuration`)
sso_authority: String, true, def, String::new();
/// Authorization request scopes |> List the of the needed scope (`openid` is implicit)
sso_scopes: String, true, def, "email profile".to_string();
sso_scopes: String, true, def, "email profile".to_owned();
/// Authorization request extra parameters
sso_authorize_extra_params: String, true, def, String::new();
/// Use PKCE during Authorization flow
@@ -883,7 +880,7 @@ make_config! {
/// From Address
smtp_from: String, true, def, String::new();
/// From Name
smtp_from_name: String, true, def, "Vaultwarden".to_string();
smtp_from_name: String, true, def, "Vaultwarden".to_owned();
/// Username
smtp_username: String, true, option;
/// Password
@@ -929,10 +926,13 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
{
use crate::db::DbConnType;
let url = &cfg.database_url;
if DbConnType::from_url(url)? == DbConnType::Sqlite && url.contains('/') {
let path = std::path::Path::new(&url);
if let Some(parent) = path.parent() {
if !parent.is_dir() {
if DbConnType::from_url(url)? == DbConnType::Sqlite {
let file_path = url.strip_prefix("sqlite://").unwrap_or(url);
if file_path.contains('/') {
let path = std::path::Path::new(file_path);
if let Some(parent) = path.parent()
&& !parent.is_dir()
{
err!(format!(
"SQLite database directory `{}` does not exist or is not a directory",
parent.display()
@@ -959,10 +959,10 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
err!(format!("`DATABASE_MIN_CONNS` must be smaller than or equal to `DATABASE_MAX_CONNS`.",));
}
if let Some(log_file) = &cfg.log_file {
if std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err() {
err!("Unable to write to log file", log_file);
}
if let Some(log_file) = &cfg.log_file
&& std::fs::OpenOptions::new().append(true).create(true).open(log_file).is_err()
{
err!("Unable to write to log file", log_file);
}
let dom = cfg.domain.to_lowercase();
@@ -975,7 +975,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
let connect_src = cfg.allowed_connect_src.to_lowercase();
for url in connect_src.split_whitespace() {
if !url.starts_with("https://") || Url::parse(url).is_err() {
err!("ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed");
err!(
"ALLOWED_CONNECT_SRC variable contains one or more invalid URLs. Only FQDN's starting with https are allowed"
);
}
}
@@ -991,11 +993,12 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
err!("`ORG_CREATION_USERS` contains invalid email addresses");
}
if let Some(ref token) = cfg.admin_token {
if token.trim().is_empty() && !cfg.disable_admin_token {
println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled.");
println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`.");
}
if let Some(ref token) = cfg.admin_token
&& token.trim().is_empty()
&& !cfg.disable_admin_token
{
println!("[WARNING] `ADMIN_TOKEN` is enabled but has an empty value, so the admin page will be disabled.");
println!("[WARNING] To enable the admin page without a token, use `DISABLE_ADMIN_TOKEN`.");
}
if cfg.push_enabled && (cfg.push_installation_id == String::new() || cfg.push_installation_key == String::new()) {
@@ -1029,37 +1032,41 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
}
}
let invalid_flags =
parse_experimental_client_feature_flags(&cfg.experimental_client_feature_flags, FeatureFlagFilter::InvalidOnly);
let invalid_flags = parse_experimental_client_feature_flags(
&cfg.experimental_client_feature_flags,
&FeatureFlagFilter::InvalidOnly,
);
if !invalid_flags.is_empty() {
let feature_flags_error = format!("Unrecognized experimental client feature flags: {:?}.\n\
let feature_flags_error = format!(
"Unrecognized experimental client feature flags: {invalid_flags:?}.\n\
Please ensure all feature flags are spelled correctly and that they are supported in this version.\n\
Supported flags: {:?}\n", invalid_flags, SUPPORTED_FEATURE_FLAGS);
Supported flags: {SUPPORTED_FEATURE_FLAGS:?}\n"
);
if on_update {
err!(feature_flags_error);
} else {
println!("[WARNING] {feature_flags_error}");
}
println!("[WARNING] {feature_flags_error}");
}
#[expect(clippy::items_after_statements, reason = "Keep this close to where it is used")]
const MAX_FILESIZE_KB: i64 = i64::MAX >> 10;
if let Some(limit) = cfg.user_attachment_limit {
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) {
err!("`USER_ATTACHMENT_LIMIT` is out of bounds");
}
if let Some(limit) = cfg.user_attachment_limit
&& !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`USER_ATTACHMENT_LIMIT` is out of bounds");
}
if let Some(limit) = cfg.org_attachment_limit {
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) {
err!("`ORG_ATTACHMENT_LIMIT` is out of bounds");
}
if let Some(limit) = cfg.org_attachment_limit
&& !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`ORG_ATTACHMENT_LIMIT` is out of bounds");
}
if let Some(limit) = cfg.user_send_limit {
if !(0i64..=MAX_FILESIZE_KB).contains(&limit) {
err!("`USER_SEND_LIMIT` is out of bounds");
}
if let Some(limit) = cfg.user_send_limit
&& !(0i64..=MAX_FILESIZE_KB).contains(&limit)
{
err!("`USER_SEND_LIMIT` is out of bounds");
}
if cfg._enable_duo
@@ -1076,7 +1083,7 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
validate_internal_sso_issuer_url(&cfg.sso_authority)?;
validate_internal_sso_redirect_url(&cfg.sso_callback_path)?;
validate_sso_master_password_policy(&cfg.sso_master_password_policy)?;
validate_sso_master_password_policy(cfg.sso_master_password_policy.as_ref())?;
}
if cfg._enable_yubico {
@@ -1087,7 +1094,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
if let Some(yubico_server) = &cfg.yubico_server {
let yubico_server = yubico_server.to_lowercase();
if !yubico_server.starts_with("https://") {
err!("`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL.")
err!(
"`YUBICO_SERVER` must be a valid URL and start with 'https://'. Either unset this variable or provide a valid URL."
)
}
}
}
@@ -1139,7 +1148,9 @@ fn validate_config(cfg: &ConfigItems, on_update: bool) -> Result<(), Error> {
}
if cfg.smtp_username.is_some() != cfg.smtp_password.is_some() {
err!("Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`")
err!(
"Both `SMTP_USERNAME` and `SMTP_PASSWORD` need to be set to enable email authentication without `USE_SENDMAIL`"
)
}
}
@@ -1271,7 +1282,7 @@ fn validate_internal_sso_redirect_url(sso_callback_path: &String) -> Result<open
}
fn validate_sso_master_password_policy(
sso_master_password_policy: &Option<String>,
sso_master_password_policy: Option<&String>,
) -> Result<Option<serde_json::Value>, Error> {
let policy = sso_master_password_policy.as_ref().map(|mpp| serde_json::from_str::<serde_json::Value>(mpp));
@@ -1300,7 +1311,7 @@ fn extract_url_origin(url: &str) -> String {
/// All trailing '/' chars are trimmed, even if the path is a lone '/'.
fn extract_url_path(url: &str) -> String {
match Url::parse(url) {
Ok(u) => u.path().trim_end_matches('/').to_string(),
Ok(u) => u.path().trim_end_matches('/').to_owned(),
Err(_) => {
// We already print it in the method above, no need to do it again
String::new()
@@ -1310,7 +1321,7 @@ fn extract_url_path(url: &str) -> String {
fn generate_smtp_img_src(embed_images: bool, domain: &str) -> String {
if embed_images {
"cid:".to_string()
"cid:".to_owned()
} else {
// normalize base_url
let base_url = domain.trim_end_matches('/');
@@ -1329,10 +1340,10 @@ fn generate_sso_callback_path(domain: &str) -> String {
fn generate_icon_service_url(icon_service: &str) -> String {
match icon_service {
"internal" => String::new(),
"bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_string(),
"duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_string(),
"google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_string(),
_ => icon_service.to_string(),
"bitwarden" => "https://icons.bitwarden.net/{}/icon.png".to_owned(),
"duckduckgo" => "https://icons.duckduckgo.com/ip3/{}.ico".to_owned(),
"google" => "https://www.google.com/s2/favicons?domain={}&sz=32".to_owned(),
_ => icon_service.to_owned(),
}
}
@@ -1341,7 +1352,7 @@ fn generate_icon_service_csp(icon_service: &str, icon_service_url: &str) -> Stri
// We split on the first '{', since that is the variable delimiter for an icon service URL.
// Everything up until the first '{' should be fixed and can be used as an CSP string.
let csp_string = match icon_service_url.split_once('{') {
Some((c, _)) => c.to_string(),
Some((c, _)) => c.to_owned(),
None => String::new(),
};
@@ -1358,96 +1369,12 @@ fn smtp_convert_deprecated_ssl_options(smtp_ssl: Option<bool>, smtp_explicit_tls
println!("[DEPRECATED]: `SMTP_SSL` or `SMTP_EXPLICIT_TLS` is set. Please use `SMTP_SECURITY` instead.");
}
if smtp_explicit_tls.is_some() && smtp_explicit_tls.unwrap() {
return "force_tls".to_string();
return "force_tls".to_owned();
} else if smtp_ssl.is_some() && !smtp_ssl.unwrap() {
return "off".to_string();
return "off".to_owned();
}
// Return the default `starttls` in all other cases
"starttls".to_string()
}
fn opendal_operator_for_path(path: &str) -> Result<opendal::Operator, Error> {
// Cache of previously built operators by path
static OPERATORS_BY_PATH: LazyLock<dashmap::DashMap<String, opendal::Operator>> =
LazyLock::new(dashmap::DashMap::new);
if let Some(operator) = OPERATORS_BY_PATH.get(path) {
return Ok(operator.clone());
}
let operator = if path.starts_with("s3://") {
#[cfg(not(s3))]
return Err(opendal::Error::new(opendal::ErrorKind::ConfigInvalid, "S3 support is not enabled").into());
#[cfg(s3)]
opendal_s3_operator_for_path(path)?
} else {
let builder = opendal::services::Fs::default().root(path);
opendal::Operator::new(builder)?.finish()
};
OPERATORS_BY_PATH.insert(path.to_string(), operator.clone());
Ok(operator)
}
#[cfg(s3)]
fn opendal_s3_operator_for_path(path: &str) -> Result<opendal::Operator, Error> {
use crate::http_client::aws::AwsReqwestConnector;
use aws_config::{default_provider::credentials::DefaultCredentialsChain, provider_config::ProviderConfig};
// This is a custom AWS credential loader that uses the official AWS Rust
// SDK config crate to load credentials. This ensures maximum compatibility
// with AWS credential configurations. For example, OpenDAL doesn't support
// AWS SSO temporary credentials yet.
struct OpenDALS3CredentialLoader {}
#[async_trait]
impl reqsign::AwsCredentialLoad for OpenDALS3CredentialLoader {
async fn load_credential(&self, _client: reqwest::Client) -> anyhow::Result<Option<reqsign::AwsCredential>> {
use aws_credential_types::provider::ProvideCredentials as _;
use tokio::sync::OnceCell;
static DEFAULT_CREDENTIAL_CHAIN: OnceCell<DefaultCredentialsChain> = OnceCell::const_new();
let chain = DEFAULT_CREDENTIAL_CHAIN
.get_or_init(|| {
let reqwest_client = reqwest::Client::builder().build().unwrap();
let connector = AwsReqwestConnector {
client: reqwest_client,
};
let conf = ProviderConfig::default().with_http_client(connector);
DefaultCredentialsChain::builder().configure(conf).build()
})
.await;
let creds = chain.provide_credentials().await?;
Ok(Some(reqsign::AwsCredential {
access_key_id: creds.access_key_id().to_string(),
secret_access_key: creds.secret_access_key().to_string(),
session_token: creds.session_token().map(|s| s.to_string()),
expires_in: creds.expiry().map(|expiration| expiration.into()),
}))
}
}
const OPEN_DAL_S3_CREDENTIAL_LOADER: OpenDALS3CredentialLoader = OpenDALS3CredentialLoader {};
let url = Url::parse(path).map_err(|e| format!("Invalid path S3 URL path {path:?}: {e}"))?;
let bucket = url.host_str().ok_or_else(|| format!("Missing Bucket name in data folder S3 URL {path:?}"))?;
let builder = opendal::services::S3::default()
.customized_credential_load(Box::new(OPEN_DAL_S3_CREDENTIAL_LOADER))
.enable_virtual_host_style()
.bucket(bucket)
.root(url.path())
.default_storage_class("INTELLIGENT_TIERING");
Ok(opendal::Operator::new(builder)?.finish())
"starttls".to_owned()
}
pub enum PathType {
@@ -1490,12 +1417,12 @@ pub const SUPPORTED_FEATURE_FLAGS: &[&str] = &[
impl Config {
pub async fn load() -> Result<Self, Error> {
// Loading from env and file
let _env = ConfigBuilder::from_env();
let _usr = ConfigBuilder::from_file().await.unwrap_or_default();
let env = ConfigBuilder::from_env();
let usr = ConfigBuilder::from_file().await.unwrap_or_default();
// Create merged config, config file overwrites env
let mut _overrides = Vec::new();
let builder = _env.merge(&_usr, true, &mut _overrides);
let mut overrides = Vec::new();
let builder = env.merge(&usr, true, &mut overrides);
// Fill any missing with defaults
let config = builder.build();
@@ -1508,9 +1435,9 @@ impl Config {
rocket_shutdown_handle: None,
templates: load_templates(&config.templates_folder),
config,
_env,
_usr,
_overrides,
_env: env,
_usr: usr,
_overrides: overrides,
}),
})
}
@@ -1547,7 +1474,7 @@ impl Config {
}
//Save to file
let operator = opendal_operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let operator = storage::operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
operator.write(&CONFIG_FILENAME, config_str).await?;
Ok(())
@@ -1556,8 +1483,8 @@ impl Config {
async fn update_config_partial(&self, other: ConfigBuilder) -> Result<(), Error> {
let builder = {
let usr = &self.inner.read().unwrap()._usr;
let mut _overrides = Vec::new();
usr.merge(&other, false, &mut _overrides)
let mut overrides = Vec::new();
usr.merge(&other, false, &mut overrides)
};
self.update_config(builder, false).await
}
@@ -1580,11 +1507,11 @@ impl Config {
/// Tests whether signup is allowed for an email address, taking into
/// account the signups_allowed and signups_domains_whitelist settings.
pub fn is_signup_allowed(&self, email: &str) -> bool {
if !self.signups_domains_whitelist().is_empty() {
if self.signups_domains_whitelist().is_empty() {
self.signups_allowed()
} else {
// The whitelist setting overrides the signups_allowed setting.
self.is_email_domain_allowed(email)
} else {
self.signups_allowed()
}
}
@@ -1612,7 +1539,7 @@ impl Config {
}
pub async fn delete_user_config(&self) -> Result<(), Error> {
let operator = opendal_operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
let operator = storage::operator_for_path(&CONFIG_FILE_PARENT_DIR)?;
operator.delete(&CONFIG_FILENAME).await?;
// Empty user config
@@ -1636,7 +1563,7 @@ impl Config {
}
pub fn private_rsa_key(&self) -> String {
format!("{}.pem", self.rsa_key_filename())
storage::with_extension(&self.rsa_key_filename(), "pem")
}
pub fn mail_enabled(&self) -> bool {
let inner = &self.inner.read().unwrap().config;
@@ -1677,15 +1604,11 @@ impl Config {
PathType::IconCache => self.icon_cache_folder(),
PathType::Attachments => self.attachments_folder(),
PathType::Sends => self.sends_folder(),
PathType::RsaKey => std::path::Path::new(&self.rsa_key_filename())
.parent()
.ok_or_else(|| std::io::Error::other("Failed to get directory of RSA key file"))?
.to_str()
.ok_or_else(|| std::io::Error::other("Failed to convert RSA key file directory to UTF-8 string"))?
.to_string(),
PathType::RsaKey => storage::parent(&self.private_rsa_key())
.ok_or_else(|| std::io::Error::other("Failed to get directory of RSA key file"))?,
};
opendal_operator_for_path(&path)
storage::operator_for_path(&path)
}
pub fn render_template<T: serde::ser::Serialize>(&self, name: &str, data: &T) -> Result<String, Error> {
@@ -1709,10 +1632,10 @@ impl Config {
}
pub fn shutdown(&self) {
if let Ok(mut c) = self.inner.write() {
if let Some(handle) = c.rocket_shutdown_handle.take() {
handle.notify();
}
if let Ok(mut c) = self.inner.write()
&& let Some(handle) = c.rocket_shutdown_handle.take()
{
handle.notify();
}
}
@@ -1725,11 +1648,11 @@ impl Config {
}
pub fn sso_master_password_policy_value(&self) -> Option<serde_json::Value> {
validate_sso_master_password_policy(&self.sso_master_password_policy()).ok().flatten()
validate_sso_master_password_policy(self.sso_master_password_policy().as_ref()).ok().flatten()
}
pub fn sso_scopes_vec(&self) -> Vec<String> {
self.sso_scopes().split_whitespace().map(str::to_string).collect()
self.sso_scopes().split_whitespace().map(str::to_owned).collect()
}
pub fn sso_authorize_extra_params_vec(&self) -> Vec<(String, String)> {
@@ -1839,7 +1762,7 @@ fn case_helper<'reg, 'rc>(
let value = param.value().clone();
if h.params().iter().skip(1).any(|x| x.value() == &value) {
h.template().map(|t| t.render(r, ctx, rc, out)).unwrap_or_else(|| Ok(()))
h.template().map_or(Ok(()), |t| t.render(r, ctx, rc, out))
} else {
Ok(())
}
+7
View File
@@ -113,3 +113,10 @@ pub fn ct_eq<T: AsRef<[u8]>, U: AsRef<[u8]>>(a: T, b: U) -> bool {
use subtle::ConstantTimeEq;
a.as_ref().ct_eq(b.as_ref()).into()
}
//
// SHA256
//
pub fn sha256_hex(data: &[u8]) -> String {
HEXLOWER.encode(digest::digest(&digest::SHA256, data).as_ref())
}
+38 -21
View File
@@ -6,25 +6,23 @@ use std::{
};
use diesel::{
Connection, RunQueryDsl,
connection::SimpleConnection,
r2d2::{CustomizeConnection, Pool, PooledConnection},
Connection, RunQueryDsl,
};
use rocket::{
Request,
http::Status,
request::{FromRequest, Outcome},
Request,
};
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::timeout,
};
use crate::{
error::{Error, MapResult},
CONFIG,
error::{Error, MapResult},
};
// These changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
@@ -62,7 +60,7 @@ pub struct DbConnManager {
impl DbConnManager {
pub fn new(database_url: &str) -> Self {
Self {
database_url: database_url.to_string(),
database_url: database_url.to_owned(),
}
}
@@ -224,7 +222,7 @@ impl DbPool {
// Set a global to determine the database more easily throughout the rest of the code
if ACTIVE_DB_TYPE.set(conn_type).is_err() {
error!("Tried to set the active database connection type more than once.")
error!("Tried to set the active database connection type more than once.");
}
Ok(DbPool {
@@ -272,22 +270,40 @@ impl DbConnType {
#[cfg(not(postgresql))]
err!("`DATABASE_URL` is a PostgreSQL URL, but the 'postgresql' feature is not enabled")
//Sqlite
} else {
// Sqlite (explicit)
} else if url.len() > 7 && &url[..7] == "sqlite:" {
#[cfg(sqlite)]
return Ok(DbConnType::Sqlite);
#[cfg(not(sqlite))]
err!("`DATABASE_URL` looks like a SQLite URL, but 'sqlite' feature is not enabled")
err!("`DATABASE_URL` is a SQLite URL, but the 'sqlite' feature is not enabled")
}
// No recognized scheme — assume legacy bare-path SQLite, but the database file must already exist.
// This prevents misconfigured URLs (typos, quoted strings) from silently creating a new empty SQLite database.
#[cfg(sqlite)]
{
if std::path::Path::new(url).exists() {
return Ok(DbConnType::Sqlite);
}
err!(format!(
"`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://) \
and no existing SQLite database was found at '{url}'. \
If you intend to use SQLite, use an explicit `sqlite://` scheme in your `DATABASE_URL`. \
Otherwise, check your DATABASE_URL for typos or quoting issues."
))
}
#[cfg(not(sqlite))]
err!("`DATABASE_URL` does not match any known database scheme (mysql://, postgresql://, sqlite://)")
}
pub fn get_init_stmts(&self) -> String {
let init_stmts = CONFIG.database_conn_init();
if !init_stmts.is_empty() {
init_stmts
} else {
if init_stmts.is_empty() {
self.default_init_stmts()
} else {
init_stmts
}
}
@@ -298,7 +314,7 @@ impl DbConnType {
#[cfg(postgresql)]
Self::Postgresql => String::new(),
#[cfg(sqlite)]
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_string(),
Self::Sqlite => "PRAGMA busy_timeout = 5000; PRAGMA synchronous = NORMAL;".to_owned(),
}
}
}
@@ -389,12 +405,13 @@ pub fn backup_sqlite() -> Result<String, Error> {
use diesel::Connection;
let db_url = CONFIG.database_url();
if DbConnType::from_url(&CONFIG.database_url()).map(|t| t == DbConnType::Sqlite).unwrap_or(false) {
// Since we do not allow any schema for sqlite database_url's like `file:` or `sqlite:` to be set, we can assume here it isn't
// This way we can set a readonly flag on the opening mode without issues.
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{db_url}?mode=ro"))?;
if DbConnType::from_url(&CONFIG.database_url()).is_ok_and(|t| t == DbConnType::Sqlite) {
// Strip the sqlite:// prefix if present to get the raw file path
let file_path = db_url.strip_prefix("sqlite://").unwrap_or(&db_url);
// Open a read-only connection for the backup
let mut conn = diesel::sqlite::SqliteConnection::establish(&format!("sqlite://{file_path}?mode=ro"))?;
let db_path = std::path::Path::new(&db_url).parent().unwrap();
let db_path = std::path::Path::new(file_path).parent().unwrap();
let backup_file = db_path
.join(format!("db_{}.sqlite3", chrono::Utc::now().format("%Y%m%d_%H%M%S")))
.to_string_lossy()
@@ -423,12 +440,12 @@ pub async fn get_sql_server_version(conn: &DbConn) -> String {
postgresql,mysql {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string())
.unwrap_or_else(|_| "Unknown".to_owned())
}
sqlite {
diesel::select(diesel::dsl::sql::<diesel::sql_types::Text>("sqlite_version();"))
.get_result::<String>(conn)
.unwrap_or_else(|_| "Unknown".to_string())
.unwrap_or_else(|_| "Unknown".to_owned())
}
}
}
+95
View File
@@ -0,0 +1,95 @@
use chrono::NaiveDateTime;
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::archives},
error::MapResult,
};
use super::{CipherId, User, UserId};
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = archives)]
#[diesel(primary_key(user_uuid, cipher_uuid))]
pub struct Archive {
pub user_uuid: UserId,
pub cipher_uuid: CipherId,
pub archived_at: NaiveDateTime,
}
impl Archive {
// Returns the date the specified cipher was archived
pub async fn get_archived_at(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> Option<NaiveDateTime> {
conn.run(move |conn| {
archives::table
.filter(archives::cipher_uuid.eq(cipher_uuid))
.filter(archives::user_uuid.eq(user_uuid))
.select(archives::archived_at)
.first::<NaiveDateTime>(conn)
.ok()
})
.await
}
// Saves (inserts or updates) an archive record with the provided timestamp
pub async fn save(
user_uuid: &UserId,
cipher_uuid: &CipherId,
archived_at: NaiveDateTime,
conn: &DbConn,
) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn:
sqlite, mysql {
diesel::replace_into(archives::table)
.values((
archives::user_uuid.eq(user_uuid),
archives::cipher_uuid.eq(cipher_uuid),
archives::archived_at.eq(archived_at),
))
.execute(conn)
.map_res("Error saving archive")
}
postgresql {
diesel::insert_into(archives::table)
.values((
archives::user_uuid.eq(user_uuid),
archives::cipher_uuid.eq(cipher_uuid),
archives::archived_at.eq(archived_at),
))
.on_conflict((archives::user_uuid, archives::cipher_uuid))
.do_update()
.set(archives::archived_at.eq(archived_at))
.execute(conn)
.map_res("Error saving archive")
}
}
}
// Deletes an archive record for a specific cipher
pub async fn delete_by_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(user_uuid, conn).await;
conn.run(move |conn| {
diesel::delete(
archives::table.filter(archives::user_uuid.eq(user_uuid)).filter(archives::cipher_uuid.eq(cipher_uuid)),
)
.execute(conn)
.map_res("Error deleting archive")
})
.await
}
/// Return a vec with (cipher_uuid, archived_at)
/// This is used during a full sync so we only need one query for all archive matches
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, NaiveDateTime)> {
conn.run(move |conn| {
archives::table
.filter(archives::user_uuid.eq(user_uuid))
.select((archives::cipher_uuid, archives::archived_at))
.load::<(CipherId, NaiveDateTime)>(conn)
.unwrap_or_default()
})
.await
}
}
+44 -37
View File
@@ -1,13 +1,24 @@
use std::time::Duration;
use bigdecimal::{BigDecimal, ToPrimitive};
use derive_more::{AsRef, Deref, Display};
use diesel::prelude::*;
use serde_json::Value;
use std::time::Duration;
use crate::{
CONFIG,
api::EmptyResult,
auth::{encode_jwt, generate_file_download_claims},
config::PathType,
db::{
DbConn,
schema::{attachments, ciphers},
},
error::MapResult,
};
use macros::IdFromParam;
use super::{CipherId, OrganizationId, UserId};
use crate::db::schema::{attachments, ciphers};
use crate::{config::PathType, CONFIG};
use macros::IdFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = attachments)]
@@ -46,11 +57,11 @@ impl Attachment {
pub async fn get_url(&self, host: &str) -> Result<String, crate::Error> {
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Attachments)?;
if operator.info().scheme() == <&'static str>::from(opendal::Scheme::Fs) {
if crate::storage::is_fs_operator(&operator) {
let token = encode_jwt(&generate_file_download_claims(self.cipher_uuid.clone(), self.id.clone()));
Ok(format!("{host}/attachments/{}/{}?token={token}", self.cipher_uuid, self.id))
} else {
Ok(operator.presign_read(&self.get_file_path(), Duration::from_secs(5 * 60)).await?.uri().to_string())
Ok(operator.presign_read(&self.get_file_path(), Duration::from_mins(5)).await?.uri().to_string())
}
}
@@ -67,12 +78,6 @@ impl Attachment {
}
}
use crate::auth::{encode_jwt, generate_file_download_claims};
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Attachment {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@@ -107,15 +112,15 @@ impl Attachment {
}
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
crate::util::retry(||
diesel::delete(attachments::table.filter(attachments::id.eq(&self.id)))
.execute(conn),
conn.run(move |conn| {
crate::util::retry(
|| diesel::delete(attachments::table.filter(attachments::id.eq(&self.id))).execute(conn),
10,
)
.map(|_| ())
.map_res("Error deleting attachment")
}}?;
})
.await?;
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Attachments)?;
let file_path = self.get_file_path();
@@ -139,25 +144,22 @@ impl Attachment {
}
pub async fn find_by_id(id: &AttachmentId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
attachments::table
.filter(attachments::id.eq(id.to_lowercase()))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| attachments::table.filter(attachments::id.eq(id.to_lowercase())).first::<Self>(conn).ok())
.await
}
pub async fn find_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
attachments::table
.filter(attachments::cipher_uuid.eq(cipher_uuid))
.load::<Self>(conn)
.expect("Error loading attachments")
}}
})
.await
}
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid))
@@ -168,24 +170,26 @@ impl Attachment {
match result.map(|r| r.to_i64()) {
Some(Some(r)) => r,
Some(None) => i64::MAX,
None => 0
None => 0,
}
}}
})
.await
}
pub async fn count_by_user(user_uuid: &UserId, conn: &DbConn) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid))
.count()
.first(conn)
.unwrap_or(0)
}}
})
.await
}
pub async fn size_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
let result: Option<BigDecimal> = attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::organization_uuid.eq(org_uuid))
@@ -196,20 +200,22 @@ impl Attachment {
match result.map(|r| r.to_i64()) {
Some(Some(r)) => r,
Some(None) => i64::MAX,
None => 0
None => 0,
}
}}
})
.await
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::organization_uuid.eq(org_uuid))
.count()
.first(conn)
.unwrap_or(0)
}}
})
.await
}
// This will return all attachments linked to the user or org
@@ -220,7 +226,7 @@ impl Attachment {
org_uuids: &Vec<OrganizationId>,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
attachments::table
.left_join(ciphers::table.on(ciphers::uuid.eq(attachments::cipher_uuid)))
.filter(ciphers::user_uuid.eq(user_uuid))
@@ -228,7 +234,8 @@ impl Attachment {
.select(attachments::all_columns)
.load::<Self>(conn)
.expect("Error loading attachments")
}}
})
.await
}
}
+27 -25
View File
@@ -1,12 +1,19 @@
use super::{DeviceId, OrganizationId, UserId};
use crate::db::schema::auth_requests;
use crate::{crypto::ct_eq, util::format_date};
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use macros::UuidFromParam;
use serde_json::Value;
use crate::{
api::EmptyResult,
crypto::ct_eq,
db::{DbConn, schema::auth_requests},
error::MapResult,
util::format_date,
};
use macros::UuidFromParam;
use super::{DeviceId, OrganizationId, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Deserialize, Serialize)]
#[diesel(table_name = auth_requests)]
#[diesel(treat_none_as_null = true)]
@@ -74,11 +81,6 @@ impl AuthRequest {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
impl AuthRequest {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
@@ -112,31 +114,28 @@ impl AuthRequest {
}
pub async fn find_by_uuid(uuid: &AuthRequestId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| auth_requests::table.filter(auth_requests::uuid.eq(uuid)).first::<Self>(conn).ok()).await
}
pub async fn find_by_uuid_and_user(uuid: &AuthRequestId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
auth_requests::table
.filter(auth_requests::uuid.eq(uuid))
.filter(auth_requests::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading auth_requests")
}}
})
.await
}
pub async fn find_by_user_and_requested_device(
@@ -144,7 +143,7 @@ impl AuthRequest {
device_uuid: &DeviceId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
auth_requests::table
.filter(auth_requests::user_uuid.eq(user_uuid))
.filter(auth_requests::request_device_identifier.eq(device_uuid))
@@ -152,24 +151,27 @@ impl AuthRequest {
.order_by(auth_requests::creation_date.desc())
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_created_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
auth_requests::table
.filter(auth_requests::creation_date.lt(dt))
.load::<Self>(conn)
.expect("Error loading auth_requests")
}}
})
.await
}
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(auth_requests::table.filter(auth_requests::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting auth request")
}}
})
.await
}
pub fn check_access_code(&self, access_code: &str) -> bool {
+364 -341
View File
File diff suppressed because it is too large Load Diff
+402 -309
View File
@@ -1,16 +1,25 @@
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections,
users_organizations,
},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{
CipherId, CollectionGroup, GroupUser, Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId,
User, UserId,
};
use crate::db::schema::{
ciphers_collections, collections, collections_groups, groups, groups_users, users_collections, users_organizations,
};
use crate::CONFIG;
use diesel::prelude::*;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = collections)]
@@ -74,7 +83,7 @@ impl Collection {
if external_id.is_empty() {
self.external_id = None;
} else {
self.external_id = Some(external_id)
self.external_id = Some(external_id);
}
}
None => self.external_id = None,
@@ -147,11 +156,6 @@ impl Collection {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Collection {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@@ -193,11 +197,12 @@ impl Collection {
CollectionUser::delete_all_by_collection(&self.uuid, conn).await?;
CollectionGroup::delete_all_by_collection(&self.uuid, &self.org_uuid, conn).await?;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(collections::table.filter(collections::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting collection")
}}
})
.await
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -208,90 +213,90 @@ impl Collection {
}
pub async fn update_users_revision(&self, conn: &DbConn) {
for member in Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await.iter() {
for member in &Membership::find_by_collection_and_org(&self.uuid, &self.org_uuid, conn).await {
User::update_uuid_revision(&member.user_uuid, conn).await;
}
}
pub async fn find_by_uuid(uuid: &CollectionId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
collections::table
.filter(collections::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| collections::table.filter(collections::uuid.eq(uuid)).first::<Self>(conn).ok()).await
}
pub async fn find_by_user_uuid(user_uuid: UserId, conn: &DbConn) -> Vec<Self> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid.clone())
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
.left_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)
))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(
users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true) // access_all in Organization
).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null()
)
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
)
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
}}
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(
users_collections::user_uuid
.eq(user_uuid)
.or(
// Directly accessed collection
users_organizations::access_all.eq(true), // access_all in Organization
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid
.eq(users_organizations::uuid)
.and(collections_groups::collections_uuid.is_not_null()),
),
)
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
})
.await
} else {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid.clone())
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)
))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(
users_collections::user_uuid.eq(user_uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true) // access_all in Organization
)
)
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
}}
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(users_collections::user_uuid.eq(user_uuid).or(
// Directly accessed collection
users_organizations::access_all.eq(true), // access_all in Organization
))
.select(collections::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading collections")
})
.await
}
}
@@ -308,256 +313,311 @@ impl Collection {
}
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.filter(collections::org_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading collections")
}}
})
.await
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
collections::table
.filter(collections::org_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
conn.run(move |conn| {
collections::table.filter(collections::org_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
})
.await
}
pub async fn find_by_uuid_and_org(uuid: &CollectionId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.filter(collections::uuid.eq(uuid))
.filter(collections::org_uuid.eq(org_uuid))
.select(collections::all_columns)
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_uuid_and_user(uuid: &CollectionId, user_uuid: UserId, conn: &DbConn) -> Option<Self> {
if CONFIG.org_groups_enabled() {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
.left_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)
))
.filter(collections::uuid.eq(uuid))
.filter(
users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null()
)
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
).select(collections::all_columns)
.first::<Self>(conn)
.ok()
}}
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(collections::uuid.eq(uuid))
.filter(
users_collections::collection_uuid
.eq(uuid)
.or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid
.eq(users_organizations::uuid)
.and(collections_groups::collections_uuid.is_not_null()),
),
)
.select(collections::all_columns)
.first::<Self>(conn)
.ok()
})
.await
} else {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
))
.filter(collections::uuid.eq(uuid))
.filter(
users_collections::collection_uuid.eq(uuid).or( // Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
.filter(collections::uuid.eq(uuid))
.filter(users_collections::collection_uuid.eq(uuid).or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
))
).select(collections::all_columns)
.first::<Self>(conn)
.ok()
}}
.select(collections::all_columns)
.first::<Self>(conn)
.ok()
})
.await
}
}
pub async fn is_writable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
if CONFIG.org_groups_enabled() {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.filter(collections::uuid.eq(&self.uuid))
.inner_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))
))
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))
))
.filter(users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // access_all via group
.or(collections_groups::collections_uuid.is_not_null() // write access given via group
.and(collections_groups::read_only.eq(false)))
.inner_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))),
)
.left_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
.filter(
users_organizations::atype
.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid
.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false)))
.or(groups::access_all.eq(true)) // access_all via group
.or(collections_groups::collections_uuid
.is_not_null() // write access given via group
.and(collections_groups::read_only.eq(false))),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
.unwrap_or(0)
!= 0
})
.await
} else {
db_run! { conn: {
conn.run(move |conn| {
collections::table
.filter(collections::uuid.eq(&self.uuid))
.inner_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))
))
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))
))
.filter(users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false)))
.inner_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid.clone()))),
)
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid))),
)
.filter(
users_organizations::atype
.le(MembershipType::Admin as i32) // Org admin or owner
.or(users_organizations::access_all.eq(true)) // access_all via membership
.or(users_collections::collection_uuid
.eq(&self.uuid) // write access given to collection
.and(users_collections::read_only.eq(false))),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
.unwrap_or(0)
!= 0
})
.await
}
}
pub async fn hide_passwords_for_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
let user_uuid = user_uuid.to_string();
db_run! { conn: {
conn.run(move |conn| {
collections::table
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
.left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
))
.filter(collections::uuid.eq(&self.uuid))
.filter(
users_collections::collection_uuid.eq(&self.uuid).and(users_collections::hide_passwords.eq(true)).or(// Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null().and(
collections_groups::hide_passwords.eq(true))
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
.filter(collections::uuid.eq(&self.uuid))
.filter(
users_collections::collection_uuid
.eq(&self.uuid)
.and(users_collections::hide_passwords.eq(true))
.or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid
.is_not_null()
.and(collections_groups::hide_passwords.eq(true)),
),
),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
}
pub async fn is_coll_manageable_by_user(uuid: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
let uuid = uuid.to_string();
let user_uuid = user_uuid.to_string();
db_run! { conn: {
conn.run(move |conn| {
collections::table
.left_join(users_collections::table.on(
users_collections::collection_uuid.eq(collections::uuid).and(
users_collections::user_uuid.eq(user_uuid.clone())
.left_join(
users_collections::table.on(users_collections::collection_uuid
.eq(collections::uuid)
.and(users_collections::user_uuid.eq(user_uuid.clone()))),
)
))
.left_join(users_organizations::table.on(
collections::org_uuid.eq(users_organizations::org_uuid).and(
users_organizations::user_uuid.eq(user_uuid)
.left_join(
users_organizations::table.on(collections::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
))
.left_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid).and(
collections_groups::collections_uuid.eq(collections::uuid)
.left_join(groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
))
.filter(collections::uuid.eq(&uuid))
.filter(
users_collections::collection_uuid.eq(&uuid).and(users_collections::manage.eq(true)).or(// Directly accessed collection
users_organizations::access_all.eq(true).or( // access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32) // Org admin or owner
)).or(
groups::access_all.eq(true) // access_all in groups
).or( // access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid.is_not_null().and(
collections_groups::manage.eq(true))
)
.left_join(
collections_groups::table.on(collections_groups::groups_uuid
.eq(groups_users::groups_uuid)
.and(collections_groups::collections_uuid.eq(collections::uuid))),
)
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
.filter(collections::uuid.eq(&uuid))
.filter(
users_collections::collection_uuid
.eq(&uuid)
.and(users_collections::manage.eq(true))
.or(
// Directly accessed collection
users_organizations::access_all.eq(true).or(
// access_all in Organization
users_organizations::atype.le(MembershipType::Admin as i32), // Org admin or owner
),
)
.or(
groups::access_all.eq(true), // access_all in groups
)
.or(
// access via groups
groups_users::users_organizations_uuid.eq(users_organizations::uuid).and(
collections_groups::collections_uuid
.is_not_null()
.and(collections_groups::manage.eq(true)),
),
),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
}
pub async fn is_manageable_by_user(&self, user_uuid: &UserId, conn: &DbConn) -> bool {
@@ -572,7 +632,7 @@ impl CollectionUser {
user_uuid: &UserId,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid))
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
@@ -580,24 +640,35 @@ impl CollectionUser {
.select(users_collections::all_columns)
.load::<Self>(conn)
.expect("Error loading users_collections")
}}
})
.await
}
pub async fn find_by_organization_swap_user_uuid_with_member_uuid(
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: {
users_collections::table
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.filter(users_organizations::org_uuid.eq(org_uuid))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<Self>(conn)
.expect("Error loading users_collections")
}};
col_users.into_iter().map(|c| c.into()).collect()
let col_users = conn
.run(move |conn| {
users_collections::table
.inner_join(collections::table.on(collections::uuid.eq(users_collections::collection_uuid)))
.filter(collections::org_uuid.eq(org_uuid))
.inner_join(
users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)),
)
.filter(users_organizations::org_uuid.eq(org_uuid))
.select((
users_organizations::uuid,
users_collections::collection_uuid,
users_collections::read_only,
users_collections::hide_passwords,
users_collections::manage,
))
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await;
col_users.into_iter().map(Into::into).collect()
}
pub async fn save(
@@ -666,7 +737,7 @@ impl CollectionUser {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.user_uuid, conn).await;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(
users_collections::table
.filter(users_collections::user_uuid.eq(&self.user_uuid))
@@ -674,17 +745,19 @@ impl CollectionUser {
)
.execute(conn)
.map_res("Error removing user from collection")
}}
})
.await
}
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.select(users_collections::all_columns)
.load::<Self>(conn)
.expect("Error loading users_collections")
}}
})
.await
}
pub async fn find_by_org_and_coll_swap_user_uuid_with_member_uuid(
@@ -692,16 +765,26 @@ impl CollectionUser {
collection_uuid: &CollectionId,
conn: &DbConn,
) -> Vec<CollectionMembership> {
let col_users = db_run! { conn: {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)))
.select((users_organizations::uuid, users_collections::collection_uuid, users_collections::read_only, users_collections::hide_passwords, users_collections::manage))
.load::<Self>(conn)
.expect("Error loading users_collections")
}};
col_users.into_iter().map(|c| c.into()).collect()
let col_users = conn
.run(move |conn| {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(
users_organizations::table.on(users_organizations::user_uuid.eq(users_collections::user_uuid)),
)
.select((
users_organizations::uuid,
users_collections::collection_uuid,
users_collections::read_only,
users_collections::hide_passwords,
users_collections::manage,
))
.load::<Self>(conn)
.expect("Error loading users_collections")
})
.await;
col_users.into_iter().map(Into::into).collect()
}
pub async fn find_by_collection_and_user(
@@ -709,36 +792,39 @@ impl CollectionUser {
user_uuid: &UserId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_collections::table
.filter(users_collections::collection_uuid.eq(collection_uuid))
.filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_collections::table
.filter(users_collections::user_uuid.eq(user_uuid))
.select(users_collections::all_columns)
.load::<Self>(conn)
.expect("Error loading users_collections")
}}
})
.await
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
for collection in CollectionUser::find_by_collection(collection_uuid, conn).await.iter() {
for collection in &CollectionUser::find_by_collection(collection_uuid, conn).await {
User::update_uuid_revision(&collection.user_uuid, conn).await;
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(users_collections::table.filter(users_collections::collection_uuid.eq(collection_uuid)))
.execute(conn)
.map_res("Error deleting users from collection")
}}
})
.await
}
pub async fn delete_all_by_user_and_org(
@@ -748,17 +834,21 @@ impl CollectionUser {
) -> EmptyResult {
let collectionusers = Self::find_by_organization_and_user_uuid(org_uuid, user_uuid, conn).await;
db_run! { conn: {
conn.run(move |conn| {
for user in collectionusers {
let _: () = diesel::delete(users_collections::table.filter(
users_collections::user_uuid.eq(user_uuid)
.and(users_collections::collection_uuid.eq(user.collection_uuid))
))
.execute(conn)
.map_res("Error removing user from collections")?;
let _: () = diesel::delete(
users_collections::table.filter(
users_collections::user_uuid
.eq(user_uuid)
.and(users_collections::collection_uuid.eq(user.collection_uuid)),
),
)
.execute(conn)
.map_res("Error removing user from collections")?;
}
Ok(())
}}
})
.await
}
pub async fn has_access_to_collection_by_user(col_id: &CollectionId, user_uuid: &UserId, conn: &DbConn) -> bool {
@@ -801,7 +891,7 @@ impl CollectionCipher {
pub async fn delete(cipher_uuid: &CipherId, collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
Self::update_users_revision(collection_uuid, conn).await;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(
ciphers_collections::table
.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid))
@@ -809,23 +899,26 @@ impl CollectionCipher {
)
.execute(conn)
.map_res("Error deleting cipher from collection")
}}
})
.await
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error removing cipher from collections")
}}
})
.await
}
pub async fn delete_all_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(ciphers_collections::table.filter(ciphers_collections::collection_uuid.eq(collection_uuid)))
.execute(conn)
.map_res("Error removing ciphers from collection")
}}
})
.await
}
pub async fn update_users_revision(collection_uuid: &CollectionId, conn: &DbConn) {
+43 -45
View File
@@ -1,18 +1,20 @@
use chrono::{NaiveDateTime, Utc};
use data_encoding::BASE64URL;
use derive_more::{Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{AuthRequest, UserId};
use crate::db::schema::devices;
use crate::{
api::EmptyResult,
crypto,
db::{DbConn, schema::devices},
error::MapResult,
util::{format_date, get_uuid},
};
use diesel::prelude::*;
use macros::{IdFromParam, UuidFromParam};
use super::{AuthRequest, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = devices)]
#[diesel(treat_none_as_null = true)]
@@ -25,7 +27,7 @@ pub struct Device {
pub user_uuid: UserId,
pub name: String,
pub atype: i32, // https://github.com/bitwarden/server/blob/9ebe16587175b1c0e9208f84397bb75d0d595510/src/Core/Enums/DeviceType.cs
pub atype: i32, // https://github.com/bitwarden/server/blob/8d547dcc280babab70dd4a3c94ced6a34b12dfbf/src/Core/Enums/DeviceType.cs
pub push_uuid: Option<PushId>,
pub push_token: Option<String>,
@@ -135,10 +137,6 @@ impl DeviceWithAuthRequest {
}
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Device {
@@ -171,21 +169,23 @@ impl Device {
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(devices::table.filter(devices::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error removing devices for user")
}}
})
.await
}
pub async fn find_by_uuid_and_user(uuid: &DeviceId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
devices::table
.filter(devices::uuid.eq(uuid))
.filter(devices::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_with_auth_request_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<DeviceWithAuthRequest> {
@@ -199,71 +199,65 @@ impl Device {
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading devices")
}}
conn.run(move |conn| {
devices::table.filter(devices::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading devices")
})
.await
}
pub async fn find_by_uuid(uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| devices::table.filter(devices::uuid.eq(uuid)).first::<Self>(conn).ok()).await
}
pub async fn clear_push_token_by_uuid(uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::update(devices::table)
.filter(devices::uuid.eq(uuid))
.set(devices::push_token.eq::<Option<String>>(None))
.execute(conn)
.map_res("Error removing push token")
}}
})
.await
}
pub async fn find_by_refresh_token(refresh_token: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
devices::table
.filter(devices::refresh_token.eq(refresh_token))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| devices::table.filter(devices::refresh_token.eq(refresh_token)).first::<Self>(conn).ok())
.await
}
pub async fn find_latest_active_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.order(devices::updated_at.desc())
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_push_devices_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.load::<Self>(conn)
.expect("Error loading push devices")
}}
})
.await
}
pub async fn check_user_has_push_device(user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: {
conn.run(move |conn| {
devices::table
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
.filter(devices::user_uuid.eq(user_uuid))
.filter(devices::push_token.is_not_null())
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
}
pub async fn rotate_refresh_tokens_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@@ -332,9 +326,12 @@ pub enum DeviceType {
MacOsCLI = 24,
#[display("Linux CLI")]
LinuxCLI = 25,
#[display("DuckDuckGo")]
DuckDuckGoBrowser = 26,
}
impl DeviceType {
#[expect(clippy::match_same_arms, reason = "Specifically define 14 and have a fallback for new types")]
pub fn from_i32(value: i32) -> DeviceType {
match value {
0 => DeviceType::Android,
@@ -363,6 +360,7 @@ impl DeviceType {
23 => DeviceType::WindowsCLI,
24 => DeviceType::MacOsCLI,
25 => DeviceType::LinuxCLI,
26 => DeviceType::DuckDuckGoBrowser,
_ => DeviceType::UnknownBrowser,
}
}
+71 -50
View File
@@ -1,13 +1,17 @@
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{User, UserId};
use crate::db::schema::emergency_access;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::emergency_access},
error::MapResult,
};
use macros::UuidFromParam;
use super::{User, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = emergency_access)]
#[diesel(treat_none_as_null = true)]
@@ -85,17 +89,15 @@ impl EmergencyAccess {
pub async fn to_json_grantee_details(&self, conn: &DbConn) -> Option<Value> {
let grantee_user = if let Some(grantee_uuid) = &self.grantee_uuid {
User::find_by_uuid(grantee_uuid, conn).await.expect("Grantee user not found.")
} else if let Some(email) = self.email.as_deref() {
match User::find_by_mail(email, conn).await {
Some(user) => user,
None => {
// remove outstanding invitations which should not exist
Self::delete_all_by_grantee_email(email, conn).await.ok();
return None;
}
}
} else {
return None;
let email = self.email.as_deref()?;
if let Some(user) = User::find_by_mail(email, conn).await {
user
} else {
// remove outstanding invitations which should not exist
Self::delete_all_by_grantee_email(email, conn).await.ok();
return None;
}
};
Some(json!({
@@ -184,28 +186,36 @@ impl EmergencyAccess {
self.status = status;
date.clone_into(&mut self.updated_at);
db_run! { conn: {
crate::util::retry(|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::status.eq(status), emergency_access::updated_at.eq(date)))
.execute(conn)
}, 10)
conn.run(move |conn| {
crate::util::retry(
|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::status.eq(status), emergency_access::updated_at.eq(date)))
.execute(conn)
},
10,
)
.map_res("Error updating emergency access status")
}}
})
.await
}
pub async fn update_last_notification_date_and_save(&mut self, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
self.last_notification_at = Some(date.to_owned());
date.clone_into(&mut self.updated_at);
db_run! { conn: {
crate::util::retry(|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::last_notification_at.eq(date), emergency_access::updated_at.eq(date)))
.execute(conn)
}, 10)
conn.run(move |conn| {
crate::util::retry(
|| {
diesel::update(emergency_access::table.filter(emergency_access::uuid.eq(&self.uuid)))
.set((emergency_access::last_notification_at.eq(date), emergency_access::updated_at.eq(date)))
.execute(conn)
},
10,
)
.map_res("Error updating emergency access status")
}}
})
.await
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@@ -228,11 +238,12 @@ impl EmergencyAccess {
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
User::update_uuid_revision(&self.grantor_uuid, conn).await;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(emergency_access::table.filter(emergency_access::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error removing user from emergency access")
}}
})
.await
}
pub async fn find_by_grantor_uuid_and_grantee_uuid_or_email(
@@ -241,23 +252,25 @@ impl EmergencyAccess {
email: &str,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid).or(emergency_access::email.eq(email)))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_all_recoveries_initiated(conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::status.eq(EmergencyAccessStatus::RecoveryInitiated as i32))
.filter(emergency_access::recovery_initiated_at.is_not_null())
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
})
.await
}
pub async fn find_by_uuid_and_grantor_uuid(
@@ -265,13 +278,14 @@ impl EmergencyAccess {
grantor_uuid: &UserId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_uuid_and_grantee_uuid(
@@ -279,13 +293,14 @@ impl EmergencyAccess {
grantee_uuid: &UserId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_uuid_and_grantee_email(
@@ -293,61 +308,67 @@ impl EmergencyAccess {
grantee_email: &str,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::uuid.eq(uuid))
.filter(emergency_access::email.eq(grantee_email))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_all_by_grantee_uuid(grantee_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::grantee_uuid.eq(grantee_uuid))
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
})
.await
}
pub async fn find_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_all_invited_by_grantee_email(grantee_email: &str, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::email.eq(grantee_email))
.filter(emergency_access::status.eq(EmergencyAccessStatus::Invited as i32))
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
})
.await
}
pub async fn find_all_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
})
.await
}
pub async fn find_all_confirmed_by_grantor_uuid(grantor_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
emergency_access::table
.filter(emergency_access::grantor_uuid.eq(grantor_uuid))
.filter(emergency_access::status.ge(EmergencyAccessStatus::Confirmed as i32))
.load::<Self>(conn)
.expect("Error loading emergency_access")
}}
})
.await
}
pub async fn accept_invite(&mut self, grantee_uuid: &UserId, grantee_email: &str, conn: &DbConn) -> EmptyResult {
+38 -28
View File
@@ -1,11 +1,18 @@
use chrono::{NaiveDateTime, TimeDelta, Utc};
//use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{event, users_organizations},
},
error::MapResult,
};
use super::{CipherId, CollectionId, GroupId, MembershipId, OrgPolicyId, OrganizationId, UserId};
use crate::db::schema::{event, users_organizations};
use crate::{api::EmptyResult, db::DbConn, error::MapResult, CONFIG};
use diesel::prelude::*;
// https://bitwarden.com/help/event-logs/
@@ -249,11 +256,10 @@ impl Event {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting event")
}}
conn.run(move |conn| {
diesel::delete(event::table.filter(event::uuid.eq(self.uuid))).execute(conn).map_res("Error deleting event")
})
.await
}
/// ##############
@@ -264,7 +270,7 @@ impl Event {
end: &NaiveDateTime,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
event::table
.filter(event::org_uuid.eq(org_uuid))
.filter(event::event_date.between(start, end))
@@ -272,18 +278,15 @@ impl Event {
.limit(Self::PAGE_SIZE)
.load::<Self>(conn)
.expect("Error filtering events")
}}
})
.await
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
event::table
.filter(event::org_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
conn.run(move |conn| {
event::table.filter(event::org_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
})
.await
}
pub async fn find_by_org_and_member(
@@ -293,18 +296,23 @@ impl Event {
end: &NaiveDateTime,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
event::table
.inner_join(users_organizations::table.on(users_organizations::uuid.eq(member_uuid)))
.filter(event::org_uuid.eq(org_uuid))
.filter(event::event_date.between(start, end))
.filter(event::user_uuid.eq(users_organizations::user_uuid.nullable()).or(event::act_user_uuid.eq(users_organizations::user_uuid.nullable())))
.filter(
event::user_uuid
.eq(users_organizations::user_uuid.nullable())
.or(event::act_user_uuid.eq(users_organizations::user_uuid.nullable())),
)
.select(event::all_columns)
.order_by(event::event_date.desc())
.limit(Self::PAGE_SIZE)
.load::<Self>(conn)
.expect("Error filtering events")
}}
})
.await
}
pub async fn find_by_cipher_uuid(
@@ -313,7 +321,7 @@ impl Event {
end: &NaiveDateTime,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
event::table
.filter(event::cipher_uuid.eq(cipher_uuid))
.filter(event::event_date.between(start, end))
@@ -321,17 +329,19 @@ impl Event {
.limit(Self::PAGE_SIZE)
.load::<Self>(conn)
.expect("Error filtering events")
}}
})
.await
}
pub async fn clean_events(conn: &DbConn) -> EmptyResult {
if let Some(days_to_retain) = CONFIG.events_days_retain() {
let dt = Utc::now().naive_utc() - TimeDelta::try_days(days_to_retain).unwrap();
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(event::table.filter(event::event_date.lt(dt)))
.execute(conn)
.map_res("Error cleaning old events")
}}
.execute(conn)
.map_res("Error cleaning old events")
})
.await
} else {
Ok(())
}
+32 -30
View File
@@ -1,7 +1,13 @@
use super::{CipherId, User, UserId};
use crate::db::schema::favorites;
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::favorites},
error::MapResult,
};
use super::{CipherId, User, UserId};
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = favorites)]
#[diesel(primary_key(user_uuid, cipher_uuid))]
@@ -10,24 +16,18 @@ pub struct Favorite {
pub cipher_uuid: CipherId,
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
impl Favorite {
// Returns whether the specified cipher is a favorite of the specified user.
pub async fn is_favorite(cipher_uuid: &CipherId, user_uuid: &UserId, conn: &DbConn) -> bool {
db_run! { conn: {
conn.run(move |conn| {
let query = favorites::table
.filter(favorites::cipher_uuid.eq(cipher_uuid))
.filter(favorites::user_uuid.eq(user_uuid))
.count();
query.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
query.first::<i64>(conn).ok().unwrap_or(0) != 0
})
.await
}
// Sets whether the specified cipher is a favorite of the specified user.
@@ -41,27 +41,26 @@ impl Favorite {
match (old, new) {
(false, true) => {
User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: {
diesel::insert_into(favorites::table)
.values((
favorites::user_uuid.eq(user_uuid),
favorites::cipher_uuid.eq(cipher_uuid),
))
.execute(conn)
.map_res("Error adding favorite")
}}
conn.run(move |conn| {
diesel::insert_into(favorites::table)
.values((favorites::user_uuid.eq(user_uuid), favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error adding favorite")
})
.await
}
(true, false) => {
User::update_uuid_revision(user_uuid, conn).await;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(
favorites::table
.filter(favorites::user_uuid.eq(user_uuid))
.filter(favorites::cipher_uuid.eq(cipher_uuid))
.filter(favorites::cipher_uuid.eq(cipher_uuid)),
)
.execute(conn)
.map_res("Error removing favorite")
}}
})
.await
}
// Otherwise, the favorite status is already what it should be.
_ => Ok(()),
@@ -70,31 +69,34 @@ impl Favorite {
// Delete all favorite entries associated with the specified cipher.
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(favorites::table.filter(favorites::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error removing favorites by cipher")
}}
})
.await
}
// Delete all favorite entries associated with the specified user.
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(favorites::table.filter(favorites::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error removing favorites by user")
}}
})
.await
}
/// Return a vec with (cipher_uuid) this will only contain favorite flagged ciphers
/// This is used during a full sync so we only need one query for all favorite cipher matches.
pub async fn get_all_cipher_uuid_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<CipherId> {
db_run! { conn: {
conn.run(move |conn| {
favorites::table
.filter(favorites::user_uuid.eq(user_uuid))
.select(favorites::cipher_uuid)
.load::<CipherId>(conn)
.unwrap_or_default()
}}
})
.await
}
}
+40 -31
View File
@@ -1,12 +1,20 @@
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{CipherId, User, UserId};
use crate::db::schema::{folders, folders_ciphers};
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{
DbConn,
schema::{folders, folders_ciphers},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{CipherId, User, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = folders)]
#[diesel(primary_key(uuid))]
@@ -56,17 +64,12 @@ impl Folder {
impl FolderCipher {
pub fn new(folder_uuid: FolderId, cipher_uuid: CipherId) -> Self {
Self {
folder_uuid,
cipher_uuid,
folder_uuid,
}
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Folder {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
@@ -107,11 +110,12 @@ impl Folder {
User::update_uuid_revision(&self.user_uuid, conn).await;
FolderCipher::delete_all_by_folder(&self.uuid, conn).await?;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(folders::table.filter(folders::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting folder")
}}
})
.await
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
@@ -122,22 +126,21 @@ impl Folder {
}
pub async fn find_by_uuid_and_user(uuid: &FolderId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
folders::table
.filter(folders::uuid.eq(uuid))
.filter(folders::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
folders::table
.filter(folders::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading folders")
}}
conn.run(move |conn| {
folders::table.filter(folders::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading folders")
})
.await
}
}
@@ -165,7 +168,7 @@ impl FolderCipher {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(
folders_ciphers::table
.filter(folders_ciphers::cipher_uuid.eq(self.cipher_uuid))
@@ -173,23 +176,26 @@ impl FolderCipher {
)
.execute(conn)
.map_res("Error removing cipher from folder")
}}
})
.await
}
pub async fn delete_all_by_cipher(cipher_uuid: &CipherId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid)))
.execute(conn)
.map_res("Error removing cipher from folders")
}}
})
.await
}
pub async fn delete_all_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(folders_ciphers::table.filter(folders_ciphers::folder_uuid.eq(folder_uuid)))
.execute(conn)
.map_res("Error removing ciphers from folder")
}}
})
.await
}
pub async fn find_by_folder_and_cipher(
@@ -197,35 +203,38 @@ impl FolderCipher {
cipher_uuid: &CipherId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.filter(folders_ciphers::cipher_uuid.eq(cipher_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_folder(folder_uuid: &FolderId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
folders_ciphers::table
.filter(folders_ciphers::folder_uuid.eq(folder_uuid))
.load::<Self>(conn)
.expect("Error loading folders")
}}
})
.await
}
/// Return a vec with (cipher_uuid, folder_uuid)
/// This is used during a full sync so we only need one query for all folder matches.
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<(CipherId, FolderId)> {
db_run! { conn: {
conn.run(move |conn| {
folders_ciphers::table
.inner_join(folders::table)
.filter(folders::user_uuid.eq(user_uuid))
.select(folders_ciphers::all_columns)
.load::<(CipherId, FolderId)>(conn)
.unwrap_or_default()
}}
})
.await
}
}
+143 -122
View File
@@ -1,14 +1,20 @@
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
use crate::api::EmptyResult;
use crate::db::schema::{collections, collections_groups, groups, groups_users, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use macros::UuidFromParam;
use serde_json::Value;
use crate::{
api::EmptyResult,
db::{
DbConn,
schema::{collections, collections_groups, groups, groups_users, users_organizations},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{CollectionId, Membership, MembershipId, OrganizationId, User, UserId};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = groups)]
#[diesel(treat_none_as_null = true)]
@@ -197,33 +203,31 @@ impl Group {
}
pub async fn find_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
groups::table
.filter(groups::organizations_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading groups")
}}
})
.await
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
groups::table
.filter(groups::organizations_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
conn.run(move |conn| {
groups::table.filter(groups::organizations_uuid.eq(org_uuid)).count().first::<i64>(conn).ok().unwrap_or(0)
})
.await
}
pub async fn find_by_uuid_and_org(uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
groups::table
.filter(groups::uuid.eq(uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_external_id_and_org(
@@ -231,77 +235,85 @@ impl Group {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
groups::table
.filter(groups::external_id.eq(external_id))
.filter(groups::organizations_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
//Returns all organizations the user has full access to
pub async fn get_orgs_by_user_with_full_access(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: {
conn.run(move |conn| {
groups_users::table
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.inner_join(
users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)
.inner_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(groups::access_all.eq(true))
.select(groups::organizations_uuid)
.distinct()
.load::<OrganizationId>(conn)
.expect("Error loading organization group full access information for user")
}}
})
.await
}
pub async fn is_in_full_access_group(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> bool {
db_run! { conn: {
conn.run(move |conn| {
groups::table
.inner_join(groups_users::table.on(
groups_users::groups_uuid.eq(groups::uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.inner_join(groups_users::table.on(groups_users::groups_uuid.eq(groups::uuid)))
.inner_join(
users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.filter(groups::access_all.eq(true))
.select(groups::access_all)
.first::<bool>(conn)
.unwrap_or_default()
}}
})
.await
}
pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
CollectionGroup::delete_all_by_group(&self.uuid, org_uuid, conn).await?;
GroupUser::delete_all_by_group(&self.uuid, org_uuid, conn).await?;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(groups::table.filter(groups::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting group")
}}
})
.await
}
pub async fn update_revision(uuid: &GroupId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}");
}
}
async fn _update_revision(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
crate::util::retry(|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
.set(groups::revision_date.eq(date))
.execute(conn)
}, 10)
async fn update_revision_impl(uuid: &GroupId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
crate::util::retry(
|| {
diesel::update(groups::table.filter(groups::uuid.eq(uuid)))
.set(groups::revision_date.eq(date))
.execute(conn)
},
10,
)
.map_res("Error updating group revision")
}}
})
.await
}
}
@@ -366,60 +378,63 @@ impl CollectionGroup {
}
pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
collections_groups::table
.inner_join(groups::table.on(
groups::uuid.eq(collections_groups::groups_uuid)
))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)))
.inner_join(
collections::table.on(collections::uuid
.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.filter(collections_groups::groups_uuid.eq(group_uuid))
.filter(collections::org_uuid.eq(org_uuid))
.select(collections_groups::all_columns)
.load::<Self>(conn)
.expect("Error loading collection groups")
}}
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
collections_groups::table
.inner_join(groups_users::table.on(
groups_users::groups_uuid.eq(collections_groups::groups_uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.inner_join(groups_users::table.on(groups_users::groups_uuid.eq(collections_groups::groups_uuid)))
.inner_join(
users_organizations::table.on(users_organizations::uuid.eq(groups_users::users_organizations_uuid)),
)
.inner_join(
groups::table.on(groups::uuid
.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.inner_join(
collections::table.on(collections::uuid
.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.filter(users_organizations::user_uuid.eq(user_uuid))
.select(collections_groups::all_columns)
.load::<Self>(conn)
.expect("Error loading user collection groups")
}}
})
.await
}
pub async fn find_by_collection(collection_uuid: &CollectionId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
collections_groups::table
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(collections::org_uuid))
))
.inner_join(collections::table.on(collections::uuid.eq(collections_groups::collections_uuid)))
.inner_join(
groups::table.on(groups::uuid
.eq(collections_groups::groups_uuid)
.and(groups::organizations_uuid.eq(collections::org_uuid))),
)
.select(collections_groups::all_columns)
.load::<Self>(conn)
.expect("Error loading collection groups")
}}
})
.await
}
pub async fn delete(&self, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -428,13 +443,14 @@ impl CollectionGroup {
group_user.update_user_revision(conn).await;
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(collections_groups::table)
.filter(collections_groups::collections_uuid.eq(&self.collections_uuid))
.filter(collections_groups::groups_uuid.eq(&self.groups_uuid))
.execute(conn)
.map_res("Error deleting collection group")
}}
})
.await
}
pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -443,12 +459,13 @@ impl CollectionGroup {
group_user.update_user_revision(conn).await;
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(collections_groups::table)
.filter(collections_groups::groups_uuid.eq(group_uuid))
.execute(conn)
.map_res("Error deleting collection group")
}}
})
.await
}
pub async fn delete_all_by_collection(
@@ -464,12 +481,13 @@ impl CollectionGroup {
}
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(collections_groups::table)
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.execute(conn)
.map_res("Error deleting collection group")
}}
})
.await
}
}
@@ -521,30 +539,31 @@ impl GroupUser {
}
pub async fn find_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
groups_users::table
.inner_join(groups::table.on(
groups::uuid.eq(groups_users::groups_uuid)
))
.inner_join(users_organizations::table.on(
users_organizations::uuid.eq(groups_users::users_organizations_uuid)
.and(users_organizations::org_uuid.eq(groups::organizations_uuid))
))
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
.inner_join(
users_organizations::table.on(users_organizations::uuid
.eq(groups_users::users_organizations_uuid)
.and(users_organizations::org_uuid.eq(groups::organizations_uuid))),
)
.filter(groups_users::groups_uuid.eq(group_uuid))
.filter(groups::organizations_uuid.eq(org_uuid))
.select(groups_users::all_columns)
.load::<Self>(conn)
.expect("Error loading group users")
}}
})
.await
}
pub async fn find_by_member(member_uuid: &MembershipId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
groups_users::table
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.load::<Self>(conn)
.expect("Error loading groups for user")
}}
})
.await
}
pub async fn has_access_to_collection_by_member(
@@ -552,24 +571,23 @@ impl GroupUser {
member_uuid: &MembershipId,
conn: &DbConn,
) -> bool {
db_run! { conn: {
conn.run(move |conn| {
groups_users::table
.inner_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
))
.inner_join(groups::table.on(
groups::uuid.eq(groups_users::groups_uuid)
))
.inner_join(collections::table.on(
collections::uuid.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))
))
.inner_join(collections_groups::table.on(collections_groups::groups_uuid.eq(groups_users::groups_uuid)))
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
.inner_join(
collections::table.on(collections::uuid
.eq(collections_groups::collections_uuid)
.and(collections::org_uuid.eq(groups::organizations_uuid))),
)
.filter(collections_groups::collections_uuid.eq(collection_uuid))
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.count()
.first::<i64>(conn)
.unwrap_or(0) != 0
}}
.unwrap_or(0)
!= 0
})
.await
}
pub async fn has_full_access_by_member(
@@ -577,18 +595,18 @@ impl GroupUser {
member_uuid: &MembershipId,
conn: &DbConn,
) -> bool {
db_run! { conn: {
conn.run(move |conn| {
groups_users::table
.inner_join(groups::table.on(
groups::uuid.eq(groups_users::groups_uuid)
))
.inner_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)))
.filter(groups::organizations_uuid.eq(org_uuid))
.filter(groups::access_all.eq(true))
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.count()
.first::<i64>(conn)
.unwrap_or(0) != 0
}}
.unwrap_or(0)
!= 0
})
.await
}
pub async fn update_user_revision(&self, conn: &DbConn) {
@@ -606,15 +624,16 @@ impl GroupUser {
match Membership::find_by_uuid(member_uuid, conn).await {
Some(member) => User::update_uuid_revision(&member.user_uuid, conn).await,
None => warn!("Member could not be found!"),
};
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(groups_users::table)
.filter(groups_users::groups_uuid.eq(group_uuid))
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.execute(conn)
.map_res("Error deleting group users")
}}
})
.await
}
pub async fn delete_all_by_group(group_uuid: &GroupId, org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -623,12 +642,13 @@ impl GroupUser {
group_user.update_user_revision(conn).await;
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(groups_users::table)
.filter(groups_users::groups_uuid.eq(group_uuid))
.execute(conn)
.map_res("Error deleting group users")
}}
})
.await
}
pub async fn delete_all_by_member(member_uuid: &MembershipId, conn: &DbConn) -> EmptyResult {
@@ -637,12 +657,13 @@ impl GroupUser {
None => warn!("Member could not be found!"),
}
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(groups_users::table)
.filter(groups_users::users_organizations_uuid.eq(member_uuid))
.execute(conn)
.map_res("Error deleting user groups")
}}
})
.await
}
}
+5 -3
View File
@@ -1,3 +1,4 @@
mod archive;
mod attachment;
mod auth_request;
mod cipher;
@@ -17,11 +18,12 @@ mod two_factor_duo_context;
mod two_factor_incomplete;
mod user;
pub use self::archive::Archive;
pub use self::attachment::{Attachment, AttachmentId};
pub use self::auth_request::{AuthRequest, AuthRequestId};
pub use self::cipher::{Cipher, CipherId, RepromptType};
pub use self::collection::{Collection, CollectionCipher, CollectionId, CollectionUser};
pub use self::device::{Device, DeviceId, DeviceType, PushId};
pub use self::device::{Device, DeviceId, DeviceType, DeviceWithAuthRequest, PushId};
pub use self::emergency_access::{EmergencyAccess, EmergencyAccessId, EmergencyAccessStatus, EmergencyAccessType};
pub use self::event::{Event, EventType};
pub use self::favorite::Favorite;
@@ -33,10 +35,10 @@ pub use self::organization::{
OrganizationId,
};
pub use self::send::{
id::{SendFileId, SendId},
Send, SendType,
id::{SendFileId, SendId},
};
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeWrapper, SsoAuth};
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeResponseError, SsoAuth};
pub use self::two_factor::{TwoFactor, TwoFactorType};
pub use self::two_factor_duo_context::TwoFactorDuoContext;
pub use self::two_factor_incomplete::TwoFactorIncomplete;
+74 -70
View File
@@ -1,14 +1,17 @@
use derive_more::{AsRef, From};
use diesel::prelude::*;
use serde::Deserialize;
use serde_json::Value;
use crate::api::core::two_factor;
use crate::api::EmptyResult;
use crate::db::schema::{org_policies, users_organizations};
use crate::db::DbConn;
use crate::error::MapResult;
use crate::CONFIG;
use diesel::prelude::*;
use crate::{
CONFIG,
api::{EmptyResult, core::two_factor},
db::{
DbConn,
schema::{org_policies, users_organizations},
},
error::MapResult,
};
use super::{Membership, MembershipId, MembershipStatus, MembershipType, OrganizationId, TwoFactor, UserId};
@@ -148,37 +151,38 @@ impl OrgPolicy {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(org_policies::table.filter(org_policies::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting org_policy")
}}
})
.await
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading org_policy")
}}
})
.await
}
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
org_policies::table
.inner_join(
users_organizations::table.on(
users_organizations::org_uuid.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
users_organizations::table.on(users_organizations::org_uuid
.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.select(org_policies::all_columns)
.load::<Self>(conn)
.expect("Error loading org_policy")
}}
})
.await
}
pub async fn find_by_org_and_type(
@@ -186,21 +190,23 @@ impl OrgPolicy {
policy_type: OrgPolicyType,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
org_policies::table
.filter(org_policies::org_uuid.eq(org_uuid))
.filter(org_policies::atype.eq(policy_type as i32))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(org_policies::table.filter(org_policies::org_uuid.eq(org_uuid)))
.execute(conn)
.map_res("Error deleting org_policy")
}}
})
.await
}
pub async fn find_accepted_and_confirmed_by_user_and_active_policy(
@@ -208,25 +214,22 @@ impl OrgPolicy {
policy_type: OrgPolicyType,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
org_policies::table
.inner_join(
users_organizations::table.on(
users_organizations::org_uuid.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Accepted as i32)
)
.or_filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
users_organizations::table.on(users_organizations::org_uuid
.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
.filter(users_organizations::status.eq(MembershipStatus::Accepted as i32))
.or_filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns)
.load::<Self>(conn)
.expect("Error loading org_policy")
}}
})
.await
}
pub async fn find_confirmed_by_user_and_active_policy(
@@ -234,22 +237,21 @@ impl OrgPolicy {
policy_type: OrgPolicyType,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
org_policies::table
.inner_join(
users_organizations::table.on(
users_organizations::org_uuid.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
users_organizations::table.on(users_organizations::org_uuid
.eq(org_policies::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))),
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(org_policies::atype.eq(policy_type as i32))
.filter(org_policies::enabled.eq(true))
.select(org_policies::all_columns)
.load::<Self>(conn)
.expect("Error loading org_policy")
}}
})
.await
}
/// Returns true if the user belongs to an org that has enabled the specified policy type,
@@ -269,10 +271,10 @@ impl OrgPolicy {
continue;
}
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await {
if user.atype < MembershipType::Admin {
return true;
}
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await
&& user.atype < MembershipType::Admin
{
return true;
}
}
false
@@ -282,13 +284,13 @@ impl OrgPolicy {
if m.atype < MembershipType::Admin && m.status > (MembershipStatus::Invited as i32) {
// Enforce TwoFactor/TwoStep login
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::TwoFactorAuthentication, conn).await
&& p.enabled
&& TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty()
{
if p.enabled && TwoFactor::find_by_user(&m.user_uuid, conn).await.is_empty() {
if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?;
} else {
err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid));
}
if CONFIG.email_2fa_auto_fallback() {
two_factor::email::find_and_activate_email_2fa(&m.user_uuid, conn).await?;
} else {
err!(format!("Cannot {} because 2FA is required (membership {})", action, m.uuid));
}
}
@@ -300,12 +302,14 @@ impl OrgPolicy {
));
}
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await {
if p.enabled
&& Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0
{
err!(format!("Cannot {} because the organization policy forbids being part of other organization (membership {})", action, m.uuid));
}
if let Some(p) = Self::find_by_org_and_type(&m.org_uuid, OrgPolicyType::SingleOrg, conn).await
&& p.enabled
&& Membership::count_accepted_and_confirmed_by_user(&m.user_uuid, &m.org_uuid, conn).await > 0
{
err!(format!(
"Cannot {} because the organization policy forbids being part of other organization (membership {})",
action, m.uuid
));
}
}
@@ -332,16 +336,16 @@ impl OrgPolicy {
for policy in
OrgPolicy::find_confirmed_by_user_and_active_policy(user_uuid, OrgPolicyType::SendOptions, conn).await
{
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await {
if user.atype < MembershipType::Admin {
match serde_json::from_str::<SendOptionsPolicyData>(&policy.data) {
Ok(opts) => {
if opts.disable_hide_email {
return true;
}
if let Some(user) = Membership::find_confirmed_by_user_and_org(user_uuid, &policy.org_uuid, conn).await
&& user.atype < MembershipType::Admin
{
match serde_json::from_str::<SendOptionsPolicyData>(&policy.data) {
Ok(opts) => {
if opts.disable_hide_email {
return true;
}
_ => error!("Failed to deserialize SendOptionsPolicyData: {}", policy.data),
}
_ => error!("Failed to deserialize SendOptionsPolicyData: {}", policy.data),
}
}
}
@@ -349,10 +353,10 @@ impl OrgPolicy {
}
pub async fn is_enabled_for_member(member_uuid: &MembershipId, policy_type: OrgPolicyType, conn: &DbConn) -> bool {
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await {
if let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await {
return policy.enabled;
}
if let Some(member) = Membership::find_by_uuid(member_uuid, conn).await
&& let Some(policy) = OrgPolicy::find_by_org_and_type(&member.org_uuid, policy_type, conn).await
{
return policy.enabled;
}
false
}
+202 -185
View File
@@ -1,23 +1,32 @@
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use num_traits::FromPrimitive;
use serde_json::Value;
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};
use chrono::{NaiveDateTime, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use num_traits::FromPrimitive;
use serde_json::Value;
use crate::{
CONFIG,
api::EmptyResult,
db::{
DbConn,
schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations,
},
},
error::MapResult,
};
use macros::UuidFromParam;
use super::{
CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
Cipher, CipherId, Collection, CollectionGroup, CollectionId, CollectionUser, Group, GroupId, GroupUser, OrgPolicy,
OrgPolicyType, TwoFactor, User, UserId,
};
use crate::db::schema::{
ciphers, ciphers_collections, collections_groups, groups, groups_users, org_policies, organization_api_key,
organizations, users, users_collections, users_organizations,
};
use crate::CONFIG;
use macros::UuidFromParam;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = organizations)]
@@ -93,6 +102,10 @@ pub enum MembershipType {
impl MembershipType {
pub fn from_str(s: &str) -> Option<Self> {
#[expect(
clippy::match_same_arms,
reason = "Specifically define `4|Custom` since this is a hack, not a default"
)]
match s {
"0" | "Owner" => Some(MembershipType::Owner),
"1" | "Admin" => Some(MembershipType::Admin),
@@ -321,11 +334,6 @@ impl OrganizationApiKey {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
/// Database methods
impl Organization {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
@@ -333,7 +341,7 @@ impl Organization {
err!(format!("BillingEmail {} is not a valid email address", self.billing_email))
}
for member in Membership::find_by_org(&self.uuid, conn).await.iter() {
for member in &Membership::find_by_org(&self.uuid, conn).await {
User::update_uuid_revision(&member.user_uuid, conn).await;
}
@@ -369,8 +377,6 @@ impl Organization {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
use super::{Cipher, Collection};
Cipher::delete_all_by_organization(&self.uuid, conn).await?;
Collection::delete_all_by_organization(&self.uuid, conn).await?;
Membership::delete_all_by_organization(&self.uuid, conn).await?;
@@ -378,43 +384,30 @@ impl Organization {
Group::delete_all_by_organization(&self.uuid, conn).await?;
OrganizationApiKey::delete_all_by_organization(&self.uuid, conn).await?;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(organizations::table.filter(organizations::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error saving organization")
}}
})
.await
}
pub async fn find_by_uuid(uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
organizations::table
.filter(organizations::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| organizations::table.filter(organizations::uuid.eq(uuid)).first::<Self>(conn).ok()).await
}
pub async fn find_by_name(name: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
organizations::table
.filter(organizations::name.eq(name))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| organizations::table.filter(organizations::name.eq(name)).first::<Self>(conn).ok()).await
}
pub async fn get_all(conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
organizations::table
.load::<Self>(conn)
.expect("Error loading organizations")
}}
conn.run(move |conn| organizations::table.load::<Self>(conn).expect("Error loading organizations")).await
}
pub async fn find_main_org_user_email(user_email: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = user_email.to_lowercase();
db_run! { conn: {
conn.run(move |conn| {
organizations::table
.inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid)))
.inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid)))
@@ -424,13 +417,14 @@ impl Organization {
.select(organizations::all_columns)
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_org_user_email(user_email: &str, conn: &DbConn) -> Vec<Self> {
let lower_mail = user_email.to_lowercase();
db_run! { conn: {
conn.run(move |conn| {
organizations::table
.inner_join(users_organizations::table.on(users_organizations::org_uuid.eq(organizations::uuid)))
.inner_join(users::table.on(users::uuid.eq(users_organizations::user_uuid)))
@@ -440,7 +434,8 @@ impl Organization {
.select(organizations::all_columns)
.load::<Self>(conn)
.expect("Error loading user orgs")
}}
})
.await
}
}
@@ -780,11 +775,12 @@ impl Membership {
CollectionUser::delete_all_by_user_and_org(&self.user_uuid, &self.org_uuid, conn).await?;
GroupUser::delete_all_by_member(&self.uuid, conn).await?;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(users_organizations::table.filter(users_organizations::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error removing user from organization")
}}
})
.await
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
@@ -802,10 +798,10 @@ impl Membership {
}
pub async fn find_by_email_and_org(email: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Membership> {
if let Some(user) = User::find_by_mail(email, conn).await {
if let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await {
return Some(member);
}
if let Some(user) = User::find_by_mail(email, conn).await
&& let Some(member) = Membership::find_by_user_and_org(&user.uuid, org_uuid, conn).await
{
return Some(member);
}
None
@@ -824,64 +820,67 @@ impl Membership {
}
pub async fn find_by_uuid(uuid: &MembershipId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users_organizations::table
.filter(users_organizations::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| {
users_organizations::table.filter(users_organizations::uuid.eq(uuid)).first::<Self>(conn).ok()
})
.await
}
pub async fn find_by_uuid_and_org(uuid: &MembershipId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::uuid.eq(uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_confirmed_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<Self>(conn)
.unwrap_or_default()
}}
})
.await
}
pub async fn find_invited_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.load::<Self>(conn)
.unwrap_or_default()
}}
})
.await
}
// Should be used only when email are disabled.
// In Organizations::send_invite status is set to Accepted only if the user has a password.
pub async fn accept_user_invitations(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::update(users_organizations::table)
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Invited as i32))
.set(users_organizations::status.eq(MembershipStatus::Accepted as i32))
.execute(conn)
.map_res("Error confirming invitations")
}}
})
.await
}
pub async fn find_any_state_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.unwrap_or_default()
}}
})
.await
}
pub async fn count_accepted_and_confirmed_by_user(
@@ -889,70 +888,83 @@ impl Membership {
excluded_org: &OrganizationId,
conn: &DbConn,
) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.ne(excluded_org))
.filter(users_organizations::status.eq(MembershipStatus::Accepted as i32).or(users_organizations::status.eq(MembershipStatus::Confirmed as i32)))
.filter(
users_organizations::status
.eq(MembershipStatus::Accepted as i32)
.or(users_organizations::status.eq(MembershipStatus::Confirmed as i32)),
)
.count()
.first::<i64>(conn)
.unwrap_or(0)
}}
})
.await
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
})
.await
}
pub async fn find_confirmed_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.load::<Self>(conn)
.unwrap_or_default()
}}
})
.await
}
// Get all users which are either owner or admin, or a manager which can manage/access all
pub async fn find_confirmed_and_manage_all_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.filter(
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype.eq(MembershipType::Manager as i32).and(users_organizations::access_all.eq(true)))
users_organizations::atype
.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32])
.or(users_organizations::atype
.eq(MembershipType::Manager as i32)
.and(users_organizations::access_all.eq(true))),
)
.load::<Self>(conn)
.unwrap_or_default()
}}
})
.await
}
pub async fn count_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
}}
})
.await
}
pub async fn find_by_org_and_type(org_uuid: &OrganizationId, atype: MembershipType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32))
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
})
.await
}
pub async fn count_confirmed_by_org_and_type(
@@ -960,7 +972,7 @@ impl Membership {
atype: MembershipType,
conn: &DbConn,
) -> i64 {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(users_organizations::atype.eq(atype as i32))
@@ -968,17 +980,19 @@ impl Membership {
.count()
.first::<i64>(conn)
.unwrap_or(0)
}}
})
.await
}
pub async fn find_by_user_and_org(user_uuid: &UserId, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_confirmed_by_user_and_org(
@@ -986,78 +1000,76 @@ impl Membership {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::org_uuid.eq(org_uuid))
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
})
.await
}
pub async fn get_orgs_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<OrganizationId> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.select(users_organizations::org_uuid)
.load::<OrganizationId>(conn)
.unwrap_or_default()
}}
})
.await
}
pub async fn find_by_user_and_policy(user_uuid: &UserId, policy_type: OrgPolicyType, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.inner_join(
org_policies::table.on(
org_policies::org_uuid.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))
.and(org_policies::atype.eq(policy_type as i32))
.and(org_policies::enabled.eq(true)))
)
.filter(
users_organizations::status.eq(MembershipStatus::Confirmed as i32)
org_policies::table.on(org_policies::org_uuid
.eq(users_organizations::org_uuid)
.and(users_organizations::user_uuid.eq(user_uuid))
.and(org_policies::atype.eq(policy_type as i32))
.and(org_policies::enabled.eq(true))),
)
.filter(users_organizations::status.eq(MembershipStatus::Confirmed as i32))
.select(users_organizations::all_columns)
.load::<Self>(conn)
.unwrap_or_default()
}}
})
.await
}
pub async fn find_by_cipher_and_org(cipher_uuid: &CipherId, org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(
users_collections::user_uuid.eq(users_organizations::user_uuid)
))
.left_join(ciphers_collections::table.on(
ciphers_collections::collection_uuid.eq(users_collections::collection_uuid).and(
ciphers_collections::cipher_uuid.eq(&cipher_uuid)
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(users_collections::user_uuid.eq(users_organizations::user_uuid)))
.left_join(
ciphers_collections::table.on(ciphers_collections::collection_uuid
.eq(users_collections::collection_uuid)
.and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))),
)
))
.filter(
users_organizations::access_all.eq(true).or( // AccessAll..
ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection with cipher
)
)
.select(users_organizations::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
.filter(users_organizations::access_all.eq(true).or(
// AccessAll..
ciphers_collections::cipher_uuid.eq(&cipher_uuid), // ..or access to collection with cipher
))
.select(users_organizations::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
}
pub async fn find_by_cipher_and_org_with_group(
@@ -1065,45 +1077,54 @@ impl Membership {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(groups_users::table.on(
groups_users::users_organizations_uuid.eq(users_organizations::uuid)
))
.left_join(collections_groups::table.on(
collections_groups::groups_uuid.eq(groups_users::groups_uuid)
))
.left_join(groups::table.on(groups::uuid.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))
))
.left_join(ciphers_collections::table.on(
ciphers_collections::collection_uuid.eq(collections_groups::collections_uuid).and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))
))
.filter(
groups::access_all.eq(true).or( // AccessAll via groups
ciphers_collections::cipher_uuid.eq(&cipher_uuid) // ..or access to collection via group
)
.filter(users_organizations::org_uuid.eq(org_uuid))
.inner_join(
groups_users::table.on(groups_users::users_organizations_uuid.eq(users_organizations::uuid)),
)
.left_join(collections_groups::table.on(collections_groups::groups_uuid.eq(groups_users::groups_uuid)))
.left_join(
groups::table.on(groups::uuid
.eq(groups_users::groups_uuid)
.and(groups::organizations_uuid.eq(users_organizations::org_uuid))),
)
.left_join(
ciphers_collections::table.on(ciphers_collections::collection_uuid
.eq(collections_groups::collections_uuid)
.and(ciphers_collections::cipher_uuid.eq(&cipher_uuid))),
)
.filter(groups::access_all.eq(true).or(
// AccessAll via groups
ciphers_collections::cipher_uuid.eq(&cipher_uuid), // ..or access to collection via group
))
.select(users_organizations::all_columns)
.distinct()
.load::<Self>(conn)
.expect("Error loading user organizations with groups")
}}
.load::<Self>(conn)
.expect("Error loading user organizations with groups")
})
.await
}
pub async fn user_has_ge_admin_access_to_cipher(user_uuid: &UserId, cipher_uuid: &CipherId, conn: &DbConn) -> bool {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.inner_join(ciphers::table.on(ciphers::uuid.eq(cipher_uuid).and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))))
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]))
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0) != 0
}}
.inner_join(
ciphers::table.on(ciphers::uuid
.eq(cipher_uuid)
.and(ciphers::organization_uuid.eq(users_organizations::org_uuid.nullable()))),
)
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(
users_organizations::atype.eq_any(vec![MembershipType::Owner as i32, MembershipType::Admin as i32]),
)
.count()
.first::<i64>(conn)
.ok()
.unwrap_or(0)
!= 0
})
.await
}
pub async fn find_by_collection_and_org(
@@ -1111,44 +1132,41 @@ impl Membership {
org_uuid: &OrganizationId,
conn: &DbConn,
) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(
users_collections::user_uuid.eq(users_organizations::user_uuid)
))
.filter(
users_organizations::access_all.eq(true).or( // AccessAll..
users_collections::collection_uuid.eq(&collection_uuid) // ..or access to collection with cipher
)
)
.select(users_organizations::all_columns)
.load::<Self>(conn)
.expect("Error loading user organizations")
}}
.filter(users_organizations::org_uuid.eq(org_uuid))
.left_join(users_collections::table.on(users_collections::user_uuid.eq(users_organizations::user_uuid)))
.filter(users_organizations::access_all.eq(true).or(
// AccessAll..
users_collections::collection_uuid.eq(&collection_uuid), // ..or access to collection with cipher
))
.select(users_organizations::all_columns)
.load::<Self>(conn)
.expect("Error loading user organizations")
})
.await
}
pub async fn find_by_external_id_and_org(ext_id: &str, org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(
users_organizations::external_id.eq(ext_id)
.and(users_organizations::org_uuid.eq(org_uuid))
)
.first::<Self>(conn)
.ok()
}}
.filter(users_organizations::external_id.eq(ext_id).and(users_organizations::org_uuid.eq(org_uuid)))
.first::<Self>(conn)
.ok()
})
.await
}
pub async fn find_main_user_org(user_uuid: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
users_organizations::table
.filter(users_organizations::user_uuid.eq(user_uuid))
.filter(users_organizations::status.ne(MembershipStatus::Revoked as i32))
.order(users_organizations::atype.asc())
.first::<Self>(conn)
.ok()
}}
})
.await
}
}
@@ -1186,20 +1204,19 @@ impl OrganizationApiKey {
}
pub async fn find_by_org_uuid(org_uuid: &OrganizationId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
organization_api_key::table
.filter(organization_api_key::org_uuid.eq(org_uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| {
organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)).first::<Self>(conn).ok()
})
.await
}
pub async fn delete_all_by_organization(org_uuid: &OrganizationId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(organization_api_key::table.filter(organization_api_key::org_uuid.eq(org_uuid)))
.execute(conn)
.map_res("Error removing organization api key from organization")
}}
})
.await
}
}
+55 -82
View File
@@ -1,11 +1,19 @@
use chrono::{NaiveDateTime, Utc};
use data_encoding::BASE64URL_NOPAD;
use diesel::prelude::*;
use serde_json::Value;
use uuid::Uuid;
use crate::{config::PathType, util::LowerCase, CONFIG};
use crate::{
CONFIG,
api::EmptyResult,
config::PathType,
db::{DbConn, schema::sends},
error::MapResult,
util::{LowerCase, NumberOrString, format_date},
};
use super::{OrganizationId, User, UserId};
use crate::db::schema::sends;
use diesel::prelude::*;
use id::SendId;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
@@ -107,37 +115,33 @@ impl Send {
pub fn check_password(&self, password: &str) -> bool {
match (&self.password_hash, &self.password_salt, self.password_iter) {
(Some(hash), Some(salt), Some(iter)) => {
crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter as u32)
crate::crypto::verify_password_hash(password.as_bytes(), salt, hash, iter.cast_unsigned())
}
_ => false,
}
}
pub async fn creator_identifier(&self, conn: &DbConn) -> Option<String> {
if let Some(hide_email) = self.hide_email {
if hide_email {
return None;
}
if let Some(hide_email) = self.hide_email
&& hide_email
{
return None;
}
if let Some(user_uuid) = &self.user_uuid {
if let Some(user) = User::find_by_uuid(user_uuid, conn).await {
return Some(user.email);
}
if let Some(user_uuid) = &self.user_uuid
&& let Some(user) = User::find_by_uuid(user_uuid, conn).await
{
return Some(user.email);
}
None
}
pub fn to_json(&self) -> Value {
use crate::util::format_date;
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
// Mobile clients expect size to be a string instead of a number
if let Some(size) = data.get("size").and_then(|v| v.as_i64()) {
if let Some(size) = data.get("size").and_then(Value::as_i64) {
data["size"] = Value::String(size.to_string());
}
@@ -167,12 +171,10 @@ impl Send {
}
pub async fn to_json_access(&self, conn: &DbConn) -> Value {
use crate::util::format_date;
let mut data = serde_json::from_str::<LowerCase<Value>>(&self.data).map(|d| d.data).unwrap_or_default();
// Mobile clients expect size to be a string instead of a number
if let Some(size) = data.get("size").and_then(|v| v.as_i64()) {
if let Some(size) = data.get("size").and_then(Value::as_i64) {
data["size"] = Value::String(size.to_string());
}
@@ -191,12 +193,6 @@ impl Send {
}
}
use crate::db::DbConn;
use crate::api::EmptyResult;
use crate::error::MapResult;
use crate::util::NumberOrString;
impl Send {
pub async fn save(&mut self, conn: &DbConn) -> EmptyResult {
self.update_users_revision(conn).await;
@@ -237,14 +233,13 @@ impl Send {
if self.atype == SendType::File as i32 {
let operator = CONFIG.opendal_operator_for_path_type(&PathType::Sends)?;
operator.remove_all(&self.uuid).await.ok();
operator.delete_with(&self.uuid).recursive(true).await.ok();
}
db_run! { conn: {
diesel::delete(sends::table.filter(sends::uuid.eq(&self.uuid)))
.execute(conn)
.map_res("Error deleting send")
}}
conn.run(move |conn| {
diesel::delete(sends::table.filter(sends::uuid.eq(&self.uuid))).execute(conn).map_res("Error deleting send")
})
.await
}
/// Purge all sends that are past their deletion date.
@@ -256,15 +251,12 @@ impl Send {
pub async fn update_users_revision(&self, conn: &DbConn) -> Vec<UserId> {
let mut user_uuids = Vec::new();
match &self.user_uuid {
Some(user_uuid) => {
User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone())
}
None => {
// Belongs to Organization, not implemented
}
};
if let Some(user_uuid) = &self.user_uuid {
User::update_uuid_revision(user_uuid, conn).await;
user_uuids.push(user_uuid.clone());
} else {
// Belongs to Organization, not implemented
}
user_uuids
}
@@ -276,9 +268,6 @@ impl Send {
}
pub async fn find_by_access_id(access_id: &str, conn: &DbConn) -> Option<Self> {
use data_encoding::BASE64URL_NOPAD;
use uuid::Uuid;
let Ok(uuid_vec) = BASE64URL_NOPAD.decode(access_id.as_bytes()) else {
return None;
};
@@ -292,50 +281,38 @@ impl Send {
}
pub async fn find_by_uuid(uuid: &SendId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
sends::table
.filter(sends::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| sends::table.filter(sends::uuid.eq(uuid)).first::<Self>(conn).ok()).await
}
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
sends::table
.filter(sends::uuid.eq(uuid))
.filter(sends::user_uuid.eq(user_uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| {
sends::table.filter(sends::uuid.eq(uuid)).filter(sends::user_uuid.eq(user_uuid)).first::<Self>(conn).ok()
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
sends::table
.filter(sends::user_uuid.eq(user_uuid))
.load::<Self>(conn)
.expect("Error loading sends")
}}
conn.run(move |conn| {
sends::table.filter(sends::user_uuid.eq(user_uuid)).load::<Self>(conn).expect("Error loading sends")
})
.await
}
pub async fn size_by_user(user_uuid: &UserId, conn: &DbConn) -> Option<i64> {
let sends = Self::find_by_user(user_uuid, conn).await;
#[derive(serde::Deserialize)]
struct FileData {
#[serde(rename = "size", alias = "Size")]
size: NumberOrString,
}
let sends = Self::find_by_user(user_uuid, conn).await;
let mut total: i64 = 0;
for send in sends {
if send.atype == SendType::File as i32 {
if let Ok(size) =
if send.atype == SendType::File as i32
&& let Ok(size) =
serde_json::from_str::<FileData>(&send.data).map_err(Into::into).and_then(|d| d.size.into_i64())
{
total = total.checked_add(size)?;
};
{
total = total.checked_add(size)?;
}
}
@@ -343,22 +320,18 @@ impl Send {
}
pub async fn find_by_org(org_uuid: &OrganizationId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
sends::table
.filter(sends::organization_uuid.eq(org_uuid))
.load::<Self>(conn)
.expect("Error loading sends")
}}
conn.run(move |conn| {
sends::table.filter(sends::organization_uuid.eq(org_uuid)).load::<Self>(conn).expect("Error loading sends")
})
.await
}
pub async fn find_by_past_deletion_date(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().naive_utc();
db_run! { conn: {
sends::table
.filter(sends::deletion_date.lt(now))
.load::<Self>(conn)
.expect("Error loading sends")
}}
conn.run(move |conn| {
sends::table.filter(sends::deletion_date.lt(now)).load::<Self>(conn).expect("Error loading sends")
})
.await
}
}
+49 -27
View File
@@ -1,31 +1,29 @@
use chrono::{NaiveDateTime, Utc};
use std::time::Duration;
use crate::api::EmptyResult;
use crate::db::schema::sso_auth;
use crate::db::{DbConn, DbPool};
use crate::error::MapResult;
use crate::sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION};
use chrono::{NaiveDateTime, Utc};
use diesel::{
deserialize::FromSql,
expression::AsExpression,
prelude::*,
serialize::{Output, ToSql},
sql_types::Text,
};
use diesel::deserialize::FromSql;
use diesel::expression::AsExpression;
use diesel::prelude::*;
use diesel::serialize::{Output, ToSql};
use diesel::sql_types::Text;
use crate::{
api::EmptyResult,
db::{DbConn, DbPool, schema::sso_auth},
error::MapResult,
sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION},
};
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)]
pub enum OIDCCodeWrapper {
Ok {
code: OIDCCode,
},
Error {
error: String,
error_description: Option<String>,
},
pub struct OIDCCodeResponseError {
pub error: String,
pub error_description: Option<String>,
}
impl_FromToSqlText!(OIDCCodeWrapper);
impl_FromToSqlText!(OIDCCodeResponseError);
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)]
@@ -50,15 +48,23 @@ pub struct SsoAuth {
pub client_challenge: OIDCCodeChallenge,
pub nonce: String,
pub redirect_uri: String,
pub code_response: Option<OIDCCodeWrapper>,
pub code_response: Option<OIDCCode>,
pub code_response_error: Option<OIDCCodeResponseError>,
pub auth_response: Option<OIDCAuthenticatedUser>,
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
pub binding_hash: Option<String>,
}
/// Local methods
impl SsoAuth {
pub fn new(state: OIDCState, client_challenge: OIDCCodeChallenge, nonce: String, redirect_uri: String) -> Self {
pub fn new(
state: OIDCState,
client_challenge: OIDCCodeChallenge,
nonce: String,
redirect_uri: String,
binding_hash: Option<String>,
) -> Self {
let now = Utc::now().naive_utc();
SsoAuth {
@@ -69,7 +75,9 @@ impl SsoAuth {
created_at: now,
updated_at: now,
code_response: None,
code_response_error: None,
auth_response: None,
binding_hash,
}
}
}
@@ -101,32 +109,46 @@ impl SsoAuth {
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: {
conn.run(move |conn| {
sso_auth::table
.filter(sso_auth::state.eq(state))
.filter(sso_auth::created_at.ge(oldest))
.first::<Self>(conn)
.ok()
})
.await
}
pub async fn find_by_code(code: &OIDCCode, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: {
sso_auth::table
.filter(sso_auth::code_response.eq(code))
.filter(sso_auth::created_at.ge(oldest))
.first::<Self>(conn)
.ok()
}}
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! {conn: {
conn.run(move |conn| {
diesel::delete(sso_auth::table.filter(sso_auth::state.eq(self.state)))
.execute(conn)
.map_res("Error deleting sso_auth")
}}
})
.await
}
pub async fn delete_expired(pool: DbPool) -> EmptyResult {
debug!("Purging expired sso_auth");
if let Ok(conn) = pool.get().await {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(sso_auth::table.filter(sso_auth::created_at.lt(oldest)))
.execute(conn)
.map_res("Error deleting expired SSO nonce")
}}
})
.await
} else {
err!("Failed to get DB connection while purging expired sso_auth")
}
+39 -28
View File
@@ -1,13 +1,17 @@
use super::UserId;
use crate::api::core::two_factor::webauthn::WebauthnRegistration;
use crate::db::schema::twofactor;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use serde_json::Value;
use webauthn_rs::prelude::{Credential, ParsedAttestation};
use webauthn_rs_core::proto::CredentialV3;
use webauthn_rs_proto::{AttestationFormat, RegisteredExtensions};
use crate::{
api::{EmptyResult, core::two_factor::webauthn::WebauthnRegistration},
db::{DbConn, schema::twofactor},
error::MapResult,
};
use super::UserId;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor)]
#[diesel(primary_key(uuid))]
@@ -114,54 +118,59 @@ impl TwoFactor {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(twofactor::table.filter(twofactor::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting twofactor")
}}
})
.await
}
pub async fn find_by_user(user_uuid: &UserId, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.lt(1000)) // Filter implementation types
.load::<Self>(conn)
.expect("Error loading twofactor")
}}
})
.await
}
pub async fn find_by_user_and_type(user_uuid: &UserId, atype: i32, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
twofactor::table
.filter(twofactor::user_uuid.eq(user_uuid))
.filter(twofactor::atype.eq(atype))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(twofactor::table.filter(twofactor::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error deleting twofactors")
}}
})
.await
}
pub async fn migrate_u2f_to_webauthn(conn: &DbConn) -> EmptyResult {
let u2f_factors = db_run! { conn: {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
}};
use crate::api::core::two_factor::webauthn::U2FRegistration;
use crate::api::core::two_factor::webauthn::{get_webauthn_registrations, WebauthnRegistration};
use crate::api::core::two_factor::webauthn::{U2FRegistration, get_webauthn_registrations};
use webauthn_rs::prelude::{COSEEC2Key, COSEKey, COSEKeyType, ECDSACurve};
use webauthn_rs_proto::{COSEAlgorithm, UserVerificationPolicy};
let u2f_factors = conn
.run(move |conn| {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::U2f as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
})
.await;
for mut u2f in u2f_factors {
let mut regs: Vec<U2FRegistration> = serde_json::from_str(&u2f.data)?;
// If there are no registrations or they are migrated (we do the migration in batch so we can consider them all migrated when the first one is)
@@ -227,12 +236,14 @@ impl TwoFactor {
}
pub async fn migrate_credential_to_passkey(conn: &DbConn) -> EmptyResult {
let webauthn_factors = db_run! { conn: {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
}};
let webauthn_factors = conn
.run(move |conn| {
twofactor::table
.filter(twofactor::atype.eq(TwoFactorType::Webauthn as i32))
.load::<Self>(conn)
.expect("Error loading twofactor")
})
.await;
for webauthn_factor in webauthn_factors {
// assume that a failure to parse into the old struct, means that it was already converted
@@ -241,7 +252,7 @@ impl TwoFactor {
continue;
};
let regs = regs.into_iter().map(|r| r.into()).collect::<Vec<WebauthnRegistration>>();
let regs = regs.into_iter().map(Into::into).collect::<Vec<WebauthnRegistration>>();
TwoFactor::new(webauthn_factor.user_uuid.clone(), TwoFactorType::Webauthn, serde_json::to_string(&regs)?)
.save(conn)
+25 -23
View File
@@ -1,9 +1,12 @@
use chrono::Utc;
use crate::db::schema::twofactor_duo_ctx;
use crate::{api::EmptyResult, db::DbConn, error::MapResult};
use diesel::prelude::*;
use crate::{
api::EmptyResult,
db::{DbConn, schema::twofactor_duo_ctx},
error::MapResult,
};
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_duo_ctx)]
#[diesel(primary_key(state))]
@@ -16,12 +19,10 @@ pub struct TwoFactorDuoContext {
impl TwoFactorDuoContext {
pub async fn find_by_state(state: &str, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(state))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| {
twofactor_duo_ctx::table.filter(twofactor_duo_ctx::state.eq(state)).first::<Self>(conn).ok()
})
.await
}
pub async fn save(state: &str, user_email: &str, nonce: &str, ttl: i64, conn: &DbConn) -> EmptyResult {
@@ -29,41 +30,42 @@ impl TwoFactorDuoContext {
let exists = Self::find_by_state(state, conn).await;
if exists.is_some() {
return Ok(());
};
}
let exp = Utc::now().timestamp() + ttl;
db_run! { conn: {
conn.run(move |conn| {
diesel::insert_into(twofactor_duo_ctx::table)
.values((
twofactor_duo_ctx::state.eq(state),
twofactor_duo_ctx::user_email.eq(user_email),
twofactor_duo_ctx::nonce.eq(nonce),
twofactor_duo_ctx::exp.eq(exp)
))
.execute(conn)
.map_res("Error saving context to twofactor_duo_ctx")
}}
twofactor_duo_ctx::exp.eq(exp),
))
.execute(conn)
.map_res("Error saving context to twofactor_duo_ctx")
})
.await
}
pub async fn find_expired(conn: &DbConn) -> Vec<Self> {
let now = Utc::now().timestamp();
db_run! { conn: {
conn.run(move |conn| {
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::exp.lt(now))
.load::<Self>(conn)
.expect("Error finding expired contexts in twofactor_duo_ctx")
}}
})
.await
}
pub async fn delete(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(
twofactor_duo_ctx::table
.filter(twofactor_duo_ctx::state.eq(&self.state)))
conn.run(move |conn| {
diesel::delete(twofactor_duo_ctx::table.filter(twofactor_duo_ctx::state.eq(&self.state)))
.execute(conn)
.map_res("Error deleting from twofactor_duo_ctx")
}}
})
.await
}
pub async fn purge_expired_duo_contexts(conn: &DbConn) {
+26 -19
View File
@@ -1,17 +1,17 @@
use chrono::{NaiveDateTime, Utc};
use diesel::prelude::*;
use crate::db::schema::twofactor_incomplete;
use crate::{
CONFIG,
api::EmptyResult,
auth::ClientIp,
db::{
models::{DeviceId, UserId},
DbConn,
models::{DeviceId, UserId},
schema::twofactor_incomplete,
},
error::MapResult,
CONFIG,
};
use diesel::prelude::*;
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = twofactor_incomplete)]
@@ -49,7 +49,7 @@ impl TwoFactorIncomplete {
return Ok(());
}
db_run! { conn: {
conn.run(move |conn| {
diesel::insert_into(twofactor_incomplete::table)
.values((
twofactor_incomplete::user_uuid.eq(user_uuid),
@@ -61,7 +61,8 @@ impl TwoFactorIncomplete {
))
.execute(conn)
.map_res("Error adding twofactor_incomplete record")
}}
})
.await
}
pub async fn mark_complete(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
@@ -73,22 +74,24 @@ impl TwoFactorIncomplete {
}
pub async fn find_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
conn.run(move |conn| {
twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.first::<Self>(conn)
.ok()
}}
})
.await
}
pub async fn find_logins_before(dt: &NaiveDateTime, conn: &DbConn) -> Vec<Self> {
db_run! { conn: {
conn.run(move |conn| {
twofactor_incomplete::table
.filter(twofactor_incomplete::login_time.lt(dt))
.load::<Self>(conn)
.expect("Error loading twofactor_incomplete")
}}
})
.await
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
@@ -96,20 +99,24 @@ impl TwoFactorIncomplete {
}
pub async fn delete_by_user_and_device(user_uuid: &UserId, device_uuid: &DeviceId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)))
.execute(conn)
.map_res("Error in twofactor_incomplete::delete_by_user_and_device()")
}}
conn.run(move |conn| {
diesel::delete(
twofactor_incomplete::table
.filter(twofactor_incomplete::user_uuid.eq(user_uuid))
.filter(twofactor_incomplete::device_uuid.eq(device_uuid)),
)
.execute(conn)
.map_res("Error in twofactor_incomplete::delete_by_user_and_device()")
})
.await
}
pub async fn delete_all_by_user(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(twofactor_incomplete::table.filter(twofactor_incomplete::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error in twofactor_incomplete::delete_all_by_user()")
}}
})
.await
}
}
+71 -72
View File
@@ -1,23 +1,27 @@
use crate::db::schema::{invitations, sso_users, twofactor_incomplete, users};
use chrono::{NaiveDateTime, TimeDelta, Utc};
use derive_more::{AsRef, Deref, Display, From};
use diesel::prelude::*;
use serde_json::Value;
use super::{
Cipher, Device, EmergencyAccess, Favorite, Folder, Membership, MembershipType, TwoFactor, TwoFactorIncomplete,
};
use crate::{
CONFIG,
api::EmptyResult,
crypto,
db::{models::DeviceId, DbConn},
db::{
DbConn,
models::DeviceId,
schema::{invitations, sso_users, twofactor_incomplete, users},
},
error::MapResult,
sso::OIDCIdentifier,
util::{format_date, get_uuid, retry},
CONFIG,
};
use macros::UuidFromParam;
use super::{
Cipher, Device, EmergencyAccess, Favorite, Folder, Membership, MembershipType, TwoFactor, TwoFactorIncomplete,
};
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
#[diesel(table_name = users)]
#[diesel(treat_none_as_null = true)]
@@ -137,8 +141,8 @@ impl User {
_totp_secret: None,
totp_recover: None,
equivalent_domains: "[]".to_string(),
excluded_globals: "[]".to_string(),
equivalent_domains: "[]".to_owned(),
excluded_globals: "[]".to_owned(),
client_kdf_type: Self::CLIENT_KDF_TYPE_DEFAULT,
client_kdf_iter: Self::CLIENT_KDF_ITER_DEFAULT,
@@ -158,7 +162,7 @@ impl User {
password.as_bytes(),
&self.salt,
&self.password_hash,
self.password_iterations as u32,
self.password_iterations.cast_unsigned(),
)
}
@@ -193,7 +197,8 @@ impl User {
allow_next_route: Option<Vec<String>>,
conn: &DbConn,
) -> EmptyResult {
self.password_hash = crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations as u32);
self.password_hash =
crypto::hash_password(password.as_bytes(), &self.salt, self.password_iterations.cast_unsigned());
if let Some(route) = allow_next_route {
self.set_stamp_exception(route);
@@ -238,10 +243,10 @@ impl User {
pub fn display_name(&self) -> &str {
// default to email if name is empty
if !&self.name.is_empty() {
&self.name
} else {
if self.name.is_empty() {
&self.email
} else {
&self.name
}
}
}
@@ -337,15 +342,14 @@ impl User {
TwoFactorIncomplete::delete_all_by_user(&self.uuid, conn).await?;
Invitation::take(&self.email, conn).await; // Delete invitation if any
db_run! { conn: {
diesel::delete(users::table.filter(users::uuid.eq(self.uuid)))
.execute(conn)
.map_res("Error deleting user")
}}
conn.run(move |conn| {
diesel::delete(users::table.filter(users::uuid.eq(self.uuid))).execute(conn).map_res("Error deleting user")
})
.await
}
pub async fn update_uuid_revision(uuid: &UserId, conn: &DbConn) {
if let Err(e) = Self::_update_revision(uuid, &Utc::now().naive_utc(), conn).await {
if let Err(e) = Self::update_revision_impl(uuid, &Utc::now().naive_utc(), conn).await {
warn!("Failed to update revision for {uuid}: {e:#?}");
}
}
@@ -353,68 +357,62 @@ impl User {
pub async fn update_all_revisions(conn: &DbConn) -> EmptyResult {
let updated_at = Utc::now().naive_utc();
db_run! { conn: {
retry(|| {
diesel::update(users::table)
.set(users::updated_at.eq(updated_at))
.execute(conn)
}, 10)
.map_res("Error updating revision date for all users")
}}
conn.run(move |conn| {
retry(|| diesel::update(users::table).set(users::updated_at.eq(updated_at)).execute(conn), 10)
.map_res("Error updating revision date for all users")
})
.await
}
pub async fn update_revision(&mut self, conn: &DbConn) -> EmptyResult {
self.updated_at = Utc::now().naive_utc();
Self::_update_revision(&self.uuid, &self.updated_at, conn).await
Self::update_revision_impl(&self.uuid, &self.updated_at, conn).await
}
async fn _update_revision(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
retry(|| {
diesel::update(users::table.filter(users::uuid.eq(uuid)))
.set(users::updated_at.eq(date))
.execute(conn)
}, 10)
async fn update_revision_impl(uuid: &UserId, date: &NaiveDateTime, conn: &DbConn) -> EmptyResult {
conn.run(move |conn| {
retry(
|| {
diesel::update(users::table.filter(users::uuid.eq(uuid)))
.set(users::updated_at.eq(date))
.execute(conn)
},
10,
)
.map_res("Error updating user revision")
}}
})
.await
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase();
db_run! { conn: {
users::table
.filter(users::email.eq(lower_mail))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| users::table.filter(users::email.eq(lower_mail)).first::<Self>(conn).ok()).await
}
pub async fn find_by_uuid(uuid: &UserId, conn: &DbConn) -> Option<Self> {
db_run! { conn: {
users::table
.filter(users::uuid.eq(uuid))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| users::table.filter(users::uuid.eq(uuid)).first::<Self>(conn).ok()).await
}
pub async fn find_by_device_for_email2fa(device_uuid: &DeviceId, conn: &DbConn) -> Option<Self> {
if let Some(user_uuid) = db_run! ( conn: {
twofactor_incomplete::table
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.order_by(twofactor_incomplete::login_time.desc())
.select(twofactor_incomplete::user_uuid)
.first::<UserId>(conn)
.ok()
}) {
if let Some(user_uuid) = conn
.run(move |conn| {
twofactor_incomplete::table
.filter(twofactor_incomplete::device_uuid.eq(device_uuid))
.order_by(twofactor_incomplete::login_time.desc())
.select(twofactor_incomplete::user_uuid)
.first::<UserId>(conn)
.ok()
})
.await
{
return Self::find_by_uuid(&user_uuid, conn).await;
}
None
}
pub async fn get_all(conn: &DbConn) -> Vec<(Self, Option<SsoUser>)> {
db_run! { conn: {
conn.run(move |conn| {
users::table
.left_join(sso_users::table)
.select(<(Self, Option<SsoUser>)>::as_select())
@@ -422,7 +420,8 @@ impl User {
.expect("Error loading groups for user")
.into_iter()
.collect()
}}
})
.await
}
pub async fn last_active(&self, conn: &DbConn) -> Option<NaiveDateTime> {
@@ -467,21 +466,18 @@ impl Invitation {
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(invitations::table.filter(invitations::email.eq(self.email)))
.execute(conn)
.map_res("Error deleting invitation")
}}
})
.await
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<Self> {
let lower_mail = mail.to_lowercase();
db_run! { conn: {
invitations::table
.filter(invitations::email.eq(lower_mail))
.first::<Self>(conn)
.ok()
}}
conn.run(move |conn| invitations::table.filter(invitations::email.eq(lower_mail)).first::<Self>(conn).ok())
.await
}
pub async fn take(mail: &str, conn: &DbConn) -> bool {
@@ -531,34 +527,37 @@ impl SsoUser {
}
pub async fn find_by_identifier(identifier: &str, conn: &DbConn) -> Option<(User, Self)> {
db_run! { conn: {
conn.run(move |conn| {
users::table
.inner_join(sso_users::table)
.select(<(User, Self)>::as_select())
.filter(sso_users::identifier.eq(identifier))
.first::<(User, Self)>(conn)
.ok()
}}
})
.await
}
pub async fn find_by_mail(mail: &str, conn: &DbConn) -> Option<(User, Option<Self>)> {
let lower_mail = mail.to_lowercase();
db_run! { conn: {
conn.run(move |conn| {
users::table
.left_join(sso_users::table)
.select(<(User, Option<Self>)>::as_select())
.filter(users::email.eq(lower_mail))
.first::<(User, Option<Self>)>(conn)
.ok()
}}
})
.await
}
pub async fn delete(user_uuid: &UserId, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
conn.run(move |conn| {
diesel::delete(sso_users::table.filter(sso_users::user_uuid.eq(user_uuid)))
.execute(conn)
.map_res("Error deleting sso user")
}}
})
.await
}
}
+6 -5
View File
@@ -1,6 +1,7 @@
use diesel::connection::{Instrumentation, InstrumentationEvent};
use std::{cell::RefCell, collections::HashMap, time::Instant};
use diesel::connection::{Instrumentation, InstrumentationEvent};
thread_local! {
static QUERY_PERF_TRACKER: RefCell<HashMap<String, Instant>> = RefCell::new(HashMap::new());
}
@@ -11,7 +12,7 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
url,
..
} => {
debug!("Establishing connection: {url}")
debug!("Establishing connection: {url}");
}
InstrumentationEvent::FinishEstablishConnection {
url,
@@ -19,9 +20,9 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
..
} => {
if let Some(e) = error {
error!("Error during establishing a connection with {url}: {e:?}")
error!("Error during establishing a connection with {url}: {e:?}");
} else {
debug!("Connection established: {url}")
debug!("Connection established: {url}");
}
}
InstrumentationEvent::StartQuery {
@@ -47,7 +48,7 @@ pub fn simple_logger() -> Option<Box<dyn Instrumentation>> {
} else if duration.as_secs() >= 1 {
info!("SLOW QUERY [{:.2}s]: {}", duration.as_secs_f32(), query_string);
} else {
debug!("QUERY [{:?}]: {}", duration, query_string);
debug!("QUERY [{duration:?}]: {query_string}");
}
}
});
+13
View File
@@ -262,9 +262,11 @@ table! {
nonce -> Text,
redirect_uri -> Text,
code_response -> Nullable<Text>,
code_response_error -> Nullable<Text>,
auth_response -> Nullable<Text>,
created_at -> Timestamp,
updated_at -> Timestamp,
binding_hash -> Nullable<Text>,
}
}
@@ -341,6 +343,16 @@ table! {
}
}
table! {
archives (user_uuid, cipher_uuid) {
user_uuid -> Text,
cipher_uuid -> Text,
archived_at -> Timestamp,
}
}
joinable!(archives -> users (user_uuid));
joinable!(archives -> ciphers (cipher_uuid));
joinable!(attachments -> ciphers (cipher_uuid));
joinable!(ciphers -> organizations (organization_uuid));
joinable!(ciphers -> users (user_uuid));
@@ -372,6 +384,7 @@ joinable!(auth_requests -> users (user_uuid));
joinable!(sso_users -> users (user_uuid));
allow_tables_to_appear_in_same_query!(
archives,
attachments,
ciphers,
ciphers_collections,
+48 -44
View File
@@ -1,10 +1,11 @@
//
// Error generator macro
//
use std::error::Error as StdError;
use crate::db::models::EventType;
use crate::http_client::CustomHttpClientError;
use serde::ser::{Serialize, SerializeStruct, Serializer};
use std::error::Error as StdError;
macro_rules! make_error {
( $( $name:ident ( $ty:ty ): $src_fn:expr, $usr_msg_fun:expr ),+ $(,)? ) => {
@@ -14,24 +15,24 @@ macro_rules! make_error {
#[derive(Debug)]
pub struct ErrorEvent { pub event: EventType }
pub struct Error { message: String, error: ErrorKind, error_code: u16, event: Option<ErrorEvent> }
pub struct Error { message: String, kind: ErrorKind, code: u16, event: Option<ErrorEvent> }
$(impl From<$ty> for Error {
fn from(err: $ty) -> Self { Error::from((stringify!($name), err)) }
})+
$(impl<S: Into<String>> From<(S, $ty)> for Error {
fn from(val: (S, $ty)) -> Self {
Error { message: val.0.into(), error: ErrorKind::$name(val.1), error_code: BAD_REQUEST, event: None }
Error { message: val.0.into(), kind: ErrorKind::$name(val.1), code: BAD_REQUEST, event: None }
}
})+
impl StdError for Error {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
match &self.error {$( ErrorKind::$name(e) => $src_fn(e), )+}
match &self.kind {$( ErrorKind::$name(e) => $src_fn(e), )+}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.error {$(
match &self.kind {$(
ErrorKind::$name(e) => f.write_str(&$usr_msg_fun(e, &self.message)),
)+}
}
@@ -39,10 +40,10 @@ macro_rules! make_error {
};
}
use diesel::ConnectionError as DieselConErr;
use diesel::r2d2::Error as R2d2Err;
use diesel::r2d2::PoolError as R2d2PoolErr;
use diesel::result::Error as DieselErr;
use diesel::ConnectionError as DieselConErr;
use handlebars::RenderError as HbErr;
use jsonwebtoken::errors::Error as JwtErr;
use lettre::address::AddressError as AddrErr;
@@ -71,46 +72,46 @@ pub struct Compact {}
// The second one contains the function used to obtain the response sent to the client
make_error! {
// Just an empty error
Empty(Empty): _no_source, _serialize,
Empty(Empty): no_source, serialize,
// Used to represent err! calls
Simple(String): _no_source, _api_error,
Compact(Compact): _no_source, _compact_api_error,
Simple(String): no_source, api_error,
Compact(Compact): no_source, compact_api_error,
// Used in our custom http client to handle non-global IPs and blocked domains
CustomHttpClient(CustomHttpClientError): _has_source, _api_error,
CustomHttpClient(CustomHttpClientError): has_source, api_error,
// Used for special return values, like 2FA errors
Json(Value): _no_source, _serialize,
Db(DieselErr): _has_source, _api_error,
R2d2(R2d2Err): _has_source, _api_error,
R2d2Pool(R2d2PoolErr): _has_source, _api_error,
Serde(SerdeErr): _has_source, _api_error,
JWt(JwtErr): _has_source, _api_error,
Handlebars(HbErr): _has_source, _api_error,
Json(Value): no_source, serialize,
Db(DieselErr): has_source, api_error,
R2d2(R2d2Err): has_source, api_error,
R2d2Pool(R2d2PoolErr): has_source, api_error,
Serde(SerdeErr): has_source, api_error,
JWt(JwtErr): has_source, api_error,
Handlebars(HbErr): has_source, api_error,
Io(IoErr): _has_source, _api_error,
Time(TimeErr): _has_source, _api_error,
Req(ReqErr): _has_source, _api_error,
Regex(RegexErr): _has_source, _api_error,
Yubico(YubiErr): _has_source, _api_error,
Io(IoErr): has_source, api_error,
Time(TimeErr): has_source, api_error,
Req(ReqErr): has_source, api_error,
Regex(RegexErr): has_source, api_error,
Yubico(YubiErr): has_source, api_error,
Lettre(LettreErr): _has_source, _api_error,
Address(AddrErr): _has_source, _api_error,
Smtp(SmtpErr): _has_source, _api_error,
OpenSSL(SSLErr): _has_source, _api_error,
Rocket(RocketErr): _has_source, _api_error,
Lettre(LettreErr): has_source, api_error,
Address(AddrErr): has_source, api_error,
Smtp(SmtpErr): has_source, api_error,
OpenSSL(SSLErr): has_source, api_error,
Rocket(RocketErr): has_source, api_error,
DieselCon(DieselConErr): _has_source, _api_error,
Webauthn(WebauthnErr): _has_source, _api_error,
DieselCon(DieselConErr): has_source, api_error,
Webauthn(WebauthnErr): has_source, api_error,
OpenDAL(OpenDALErr): _has_source, _api_error,
OpenDAL(OpenDALErr): has_source, api_error,
}
impl std::fmt::Debug for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.source() {
Some(e) => write!(f, "{}.\n[CAUSE] {:#?}", self.message, e),
None => match self.error {
None => match self.kind {
ErrorKind::Empty(_) => Ok(()),
ErrorKind::Simple(ref s) => {
if &self.message == s {
@@ -135,6 +136,7 @@ impl Error {
(usr_msg.clone(), usr_msg.into()).into()
}
#[must_use]
pub fn empty() -> Self {
Empty {}.into()
}
@@ -147,13 +149,13 @@ impl Error {
#[must_use]
pub fn with_kind(mut self, kind: ErrorKind) -> Self {
self.error = kind;
self.kind = kind;
self
}
#[must_use]
pub const fn with_code(mut self, code: u16) -> Self {
self.error_code = code;
self.code = code;
self
}
@@ -194,14 +196,14 @@ impl<S> MapResult<S> for Option<S> {
}
}
const fn _has_source<T>(e: T) -> Option<T> {
const fn has_source<T>(e: T) -> Option<T> {
Some(e)
}
fn _no_source<T, S>(_: T) -> Option<S> {
fn no_source<T, S>(_: T) -> Option<S> {
None
}
fn _serialize(e: &impl Serialize, _msg: &str) -> String {
fn serialize(e: &impl Serialize, _msg: &str) -> String {
serde_json::to_string(e).unwrap()
}
@@ -280,14 +282,14 @@ struct ApiErrorResponse<'a>(ApiErrorMsg<'a>);
/// The custom serialization adds all other needed fields
struct CompactApiErrorResponse<'a>(ApiErrorMsg<'a>);
fn _api_error(_: &impl std::any::Any, msg: &str) -> String {
fn api_error(_: &impl std::any::Any, msg: &str) -> String {
let response = ApiErrorMsg {
message: msg,
};
serde_json::to_string(&ApiErrorResponse(response)).unwrap()
}
fn _compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
fn compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
let response = ApiErrorMsg {
message: msg,
};
@@ -299,18 +301,20 @@ fn _compact_api_error(_: &impl std::any::Any, msg: &str) -> String {
//
use std::io::Cursor;
use rocket::http::{ContentType, Status};
use rocket::request::Request;
use rocket::response::{self, Responder, Response};
use rocket::{
http::{ContentType, Status},
request::Request,
response::{self, Responder, Response},
};
impl Responder<'_, 'static> for Error {
fn respond_to(self, _: &Request<'_>) -> response::Result<'static> {
match self.error {
match self.kind {
ErrorKind::Empty(_) | ErrorKind::Simple(_) | ErrorKind::Compact(_) => {} // Don't print the error in this situation
_ => error!(target: "error", "{self:#?}"),
};
}
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
let code = Status::from_code(self.code).unwrap_or(Status::BadRequest);
let body = self.to_string();
Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
}
+290 -39
View File
@@ -1,22 +1,25 @@
use std::{
fmt,
net::{IpAddr, SocketAddr},
str::FromStr,
sync::{Arc, LazyLock, Mutex},
time::Duration,
};
use hickory_resolver::{net::runtime::TokioRuntimeProvider, TokioResolver};
use hickory_resolver::{TokioResolver, net::runtime::TokioRuntimeProvider};
use regex::Regex;
use reqwest::{
Client, ClientBuilder,
dns::{Name, Resolve, Resolving},
header, Client, ClientBuilder,
header,
};
use url::Host;
use crate::{util::is_global, CONFIG};
use crate::{CONFIG, util::is_global};
pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::RequestBuilder, crate::Error> {
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
let Ok(url) = url::Url::parse(url) else {
err!("Invalid URL");
};
@@ -26,9 +29,6 @@ pub fn make_http_request(method: reqwest::Method, url: &str) -> Result<reqwest::
should_block_host(&host)?;
static INSTANCE: LazyLock<Client> =
LazyLock::new(|| get_reqwest_client_builder().build().expect("Failed to build client"));
Ok(INSTANCE.request(method, url))
}
@@ -59,16 +59,6 @@ pub fn get_reqwest_client_builder() -> ClientBuilder {
.timeout(Duration::from_secs(10))
}
pub fn should_block_address(domain_or_ip: &str) -> bool {
if let Ok(ip) = IpAddr::from_str(domain_or_ip) {
if should_block_ip(ip) {
return true;
}
}
should_block_address_regex(domain_or_ip)
}
fn should_block_ip(ip: IpAddr) -> bool {
if !CONFIG.http_request_block_non_global_ips() {
return false;
@@ -78,18 +68,19 @@ fn should_block_ip(ip: IpAddr) -> bool {
}
fn should_block_address_regex(domain_or_ip: &str) -> bool {
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
let Some(block_regex) = CONFIG.http_request_block_regex() else {
return false;
};
static COMPILED_REGEX: Mutex<Option<(String, Regex)>> = Mutex::new(None);
let mut guard = COMPILED_REGEX.lock().unwrap();
// If the stored regex is up to date, use it
if let Some((value, regex)) = &*guard {
if value == &block_regex {
return regex.is_match(domain_or_ip);
}
if let Some((value, regex)) = &*guard
&& value == &block_regex
{
return regex.is_match(domain_or_ip);
}
// If we don't have a regex stored, or it's not up to date, recreate it
@@ -100,20 +91,63 @@ fn should_block_address_regex(domain_or_ip: &str) -> bool {
is_match
}
fn should_block_host(host: &Host<&str>) -> Result<(), CustomHttpClientError> {
pub fn get_valid_host(host: &str) -> Result<Host, CustomHttpClientError> {
let Ok(host) = Host::parse(host) else {
return Err(CustomHttpClientError::Invalid {
domain: host.to_owned(),
});
};
// Some extra checks to validate hosts
match host {
Host::Domain(ref domain) => {
// Host::parse() does not verify length or all possible invalid characters
// We do some extra checks here to prevent issues
if domain.len() > 253 {
debug!("Domain validation error: '{domain}' exceeds 253 characters");
return Err(CustomHttpClientError::Invalid {
domain: host.to_string(),
});
}
if !domain.split('.').all(|label| {
!label.is_empty()
// Labels can't be longer than 63 chars
&& label.len() <= 63
// Labels are not allowed to start or end with a hyphen `-`
&& !label.starts_with('-')
&& !label.ends_with('-')
// Only ASCII Alphanumeric characters are allowed
// We already received a punycoded domain back, so no unicode should exists here
&& label.chars().all(|c| c.is_ascii_alphanumeric() || c == '-')
}) {
debug!(
"Domain validation error: '{domain}' labels contain invalid characters or exceed the maximum length"
);
return Err(CustomHttpClientError::Invalid {
domain: host.to_string(),
});
}
}
Host::Ipv4(_) | Host::Ipv6(_) => {}
}
Ok(host)
}
pub fn should_block_host<S: AsRef<str>>(host: &Host<S>) -> Result<(), CustomHttpClientError> {
let (ip, host_str): (Option<IpAddr>, String) = match host {
Host::Ipv4(ip) => (Some(IpAddr::V4(*ip)), ip.to_string()),
Host::Ipv6(ip) => (Some(IpAddr::V6(*ip)), ip.to_string()),
Host::Domain(d) => (None, (*d).to_string()),
Host::Domain(d) => (None, d.as_ref().to_owned()),
};
if let Some(ip) = ip {
if should_block_ip(ip) {
return Err(CustomHttpClientError::NonGlobalIp {
domain: None,
ip,
});
}
if let Some(ip) = ip
&& should_block_ip(ip)
{
return Err(CustomHttpClientError::NonGlobalIp {
domain: None,
ip,
});
}
if should_block_address_regex(&host_str) {
@@ -134,6 +168,9 @@ pub enum CustomHttpClientError {
domain: Option<String>,
ip: IpAddr,
},
Invalid {
domain: String,
},
}
impl CustomHttpClientError {
@@ -155,7 +192,7 @@ impl fmt::Display for CustomHttpClientError {
match self {
Self::Blocked {
domain,
} => write!(f, "Blocked domain: {domain} matched HTTP_REQUEST_BLOCK_REGEX"),
} => write!(f, "Blocked domain: '{domain}' matched HTTP_REQUEST_BLOCK_REGEX"),
Self::NonGlobalIp {
domain: Some(domain),
ip,
@@ -163,7 +200,10 @@ impl fmt::Display for CustomHttpClientError {
Self::NonGlobalIp {
domain: None,
ip,
} => write!(f, "IP {ip} is not a global IP!"),
} => write!(f, "IP '{ip}' is not a global IP!"),
Self::Invalid {
domain,
} => write!(f, "Invalid host: '{domain}' contains invalid characters or exceeds the maximum length"),
}
}
}
@@ -195,8 +235,7 @@ impl CustomDnsResolver {
builder.build()
})
.inspect_err(|e| warn!("Error creating Hickory resolver, falling back to default: {e:?}"))
.map(|resolver| Arc::new(Self::Hickory(Arc::new(resolver))))
.unwrap_or_else(|_| Arc::new(Self::Default()))
.map_or_else(|_| Arc::new(Self::Default()), |resolver| Arc::new(Self::Hickory(Arc::new(resolver))))
}
// Note that we get an iterator of addresses, but we only grab the first one for convenience
@@ -217,9 +256,15 @@ impl CustomDnsResolver {
}
fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
if should_block_address(name) {
let Ok(host) = get_valid_host(name) else {
return Err(CustomHttpClientError::Invalid {
domain: name.to_owned(),
});
};
if should_block_host(&host).is_err() {
return Err(CustomHttpClientError::Blocked {
domain: name.to_string(),
domain: name.to_owned(),
});
}
@@ -229,7 +274,7 @@ fn pre_resolve(name: &str) -> Result<(), CustomHttpClientError> {
fn post_resolve(name: &str, ip: IpAddr) -> Result<(), CustomHttpClientError> {
if should_block_ip(ip) {
Err(CustomHttpClientError::NonGlobalIp {
domain: Some(name.to_string()),
domain: Some(name.to_owned()),
ip,
})
} else {
@@ -274,7 +319,7 @@ pub(crate) mod aws {
let future = async move {
let method = reqwest::Method::from_bytes(request.method().as_bytes())
.map_err(|e| ConnectorError::user(Box::new(e)))?;
let mut req_builder = client.request(method, request.uri().to_string());
let mut req_builder = client.request(method, request.uri().to_owned());
for (name, value) in request.headers() {
req_builder = req_builder.header(name, value);
@@ -308,3 +353,209 @@ pub(crate) mod aws {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::is_global_hardcoded;
use std::net::Ipv4Addr;
use url::Host;
// ===
// IPv4 numeric-format normalization
fn parse_to_ip(s: &str) -> Option<IpAddr> {
match Host::parse(s).ok()? {
Host::Ipv4(v4) => Some(IpAddr::V4(v4)),
Host::Ipv6(v6) => Some(IpAddr::V6(v6)),
Host::Domain(_) => None,
}
}
#[test]
fn dotted_decimal_loopback_normalizes() {
let ip = parse_to_ip("127.0.0.1").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn single_decimal_loopback_normalizes() {
// 127.0.0.1 == 2130706433
let ip = parse_to_ip("2130706433").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn hex_loopback_normalizes() {
let ip = parse_to_ip("0x7f000001").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn dotted_hex_loopback_normalizes() {
let ip = parse_to_ip("0x7f.0.0.1").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn octal_loopback_normalizes() {
// 017700000001 == 127.0.0.1
let ip = parse_to_ip("017700000001").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn dotted_octal_loopback_normalizes() {
let ip = parse_to_ip("0177.0.0.01").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn aws_metadata_decimal_blocked() {
// 169.254.169.254 == 2852039166 (link-local, AWS IMDS)
let ip = parse_to_ip("2852039166").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)));
assert!(!is_global_hardcoded(ip));
}
#[test]
fn rfc1918_hex_blocked() {
// 10.0.0.1
let ip = parse_to_ip("0x0a000001").unwrap();
assert!(!is_global_hardcoded(ip));
}
#[test]
fn public_ip_decimal_allowed() {
// 8.8.8.8 == 134744072
let ip = parse_to_ip("134744072").unwrap();
assert_eq!(ip, IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)));
assert!(is_global_hardcoded(ip));
}
// ===
// get_valid_host integration: numeric forms become Host::Ipv4
#[test]
fn get_valid_host_normalizes_decimal_int() {
let h = get_valid_host("2130706433").expect("valid");
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
}
#[test]
fn get_valid_host_normalizes_hex() {
let h = get_valid_host("0x7f000001").expect("valid");
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
}
#[test]
fn get_valid_host_normalizes_octal() {
let h = get_valid_host("017700000001").expect("valid");
assert!(matches!(h, Host::Ipv4(ip) if ip == Ipv4Addr::new(127, 0, 0, 1)));
}
// ===
// IPv6 formats
#[test]
fn ipv6_loopback_blocked() {
let h = get_valid_host("[::1]").expect("valid");
let Host::Ipv6(ip) = h else {
panic!("expected v6")
};
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
}
#[test]
fn ipv4_mapped_in_ipv6_loopback_blocked() {
// ::ffff:127.0.0.1 — v4-mapped form; is_global_hardcoded blocks via ::ffff:0:0/96
let h = get_valid_host("[::ffff:127.0.0.1]").expect("valid");
let Host::Ipv6(ip) = h else {
panic!("expected v6")
};
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
}
#[test]
fn ipv6_unique_local_blocked() {
let h = get_valid_host("[fc00::1]").expect("valid");
let Host::Ipv6(ip) = h else {
panic!("expected v6")
};
assert!(!is_global_hardcoded(IpAddr::V6(ip)));
}
// ===
// Punycode / IDN
#[test]
fn punycode_passthrough() {
let h = get_valid_host("xn--deadbeafcaf-lbb.test").expect("valid");
match h {
Host::Domain(d) => assert_eq!(d, "xn--deadbeafcaf-lbb.test"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_unicode_gets_punycoded() {
let h = get_valid_host("deadbeafcafé.test").expect("valid");
match h {
Host::Domain(d) => assert_eq!(d, "xn--deadbeafcaf-lbb.test"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_unicode_gets_punycoded_tld() {
let h = get_valid_host("deadbeaf.café").expect("valid");
match h {
Host::Domain(d) => assert_eq!(d, "deadbeaf.xn--caf-dma"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_emoji_gets_punycoded() {
let h = get_valid_host("xn--t88h.test").expect("valid"); // 🛡️.test
match h {
Host::Domain(d) => assert_eq!(d, "xn--t88h.test"),
_ => panic!("expected domain"),
}
}
#[test]
fn idn_unicode_to_punycode_roundtrip() {
let from_unicode = get_valid_host("🛡️.test").expect("valid");
let from_puny = get_valid_host("xn--t88h.test").expect("valid");
match (from_unicode, from_puny) {
(Host::Domain(a), Host::Domain(b)) => assert_eq!(a, b),
_ => panic!("expected domains"),
}
}
#[test]
fn invalid_punycode_rejected() {
// bare invalid punycode
assert!(get_valid_host("xn--").is_err());
}
#[test]
fn underscore_in_label_rejected() {
assert!(get_valid_host("dead_beaf.cafe").is_err());
}
#[test]
fn label_too_long_rejected() {
let label = "a".repeat(64);
assert!(get_valid_host(&format!("{label}.test")).is_err());
}
#[test]
fn domain_too_long_rejected() {
let big = "a.".repeat(130) + "test"; // > 253
assert!(get_valid_host(&big).is_err());
}
}
+26 -30
View File
@@ -1,16 +1,17 @@
use chrono::NaiveDateTime;
use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
use std::{env::consts::EXE_SUFFIX, str::FromStr};
use chrono::NaiveDateTime;
use lettre::{
Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
message::{Attachment, Body, Mailbox, Message, MultiPart, SinglePart},
transport::smtp::authentication::{Credentials, Mechanism as SmtpAuthMechanism},
transport::smtp::client::{Tls, TlsParameters},
transport::smtp::extension::ClientId,
Address, AsyncSendmailTransport, AsyncSmtpTransport, AsyncTransport, Tokio1Executor,
};
use percent_encoding::{NON_ALPHANUMERIC, percent_encode};
use crate::{
CONFIG,
api::EmptyResult,
auth::{
encode_jwt, generate_delete_claims, generate_emergency_access_invite_claims, generate_invite_claims,
@@ -18,7 +19,7 @@ use crate::{
},
db::models::{Device, DeviceType, EmergencyAccessId, MembershipId, OrganizationId, User, UserId},
error::Error,
CONFIG,
util::upcase_first,
};
fn sendmail_transport() -> AsyncSendmailTransport<Tokio1Executor> {
@@ -38,7 +39,9 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
.timeout(Some(Duration::from_secs(CONFIG.smtp_timeout())));
// Determine security
let smtp_client = if CONFIG.smtp_security() != *"off" {
let smtp_client = if CONFIG.smtp_security() == *"off" {
smtp_client
} else {
let mut tls_parameters = TlsParameters::builder(host);
if CONFIG.smtp_accept_invalid_hostnames() {
tls_parameters = tls_parameters.dangerous_accept_invalid_hostnames(true);
@@ -53,8 +56,6 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
} else {
smtp_client.tls(Tls::Required(tls_parameters))
}
} else {
smtp_client
};
let smtp_client = match (CONFIG.smtp_username(), CONFIG.smtp_password()) {
@@ -81,12 +82,12 @@ fn smtp_transport() -> AsyncSmtpTransport<Tokio1Executor> {
}
}
if !selected_mechanisms.is_empty() {
smtp_client.authentication(selected_mechanisms)
} else {
if selected_mechanisms.is_empty() {
// Only show a warning, and return without setting an actual authentication mechanism
warn!("No valid SMTP Auth mechanism found for '{mechanism}', using default values");
smtp_client
} else {
smtp_client.authentication(selected_mechanisms)
}
}
_ => smtp_client,
@@ -129,14 +130,16 @@ fn get_template(template_name: &str, data: &serde_json::Value) -> Result<(String
let text = CONFIG.render_template(template_name, data)?;
let mut text_split = text.split("<!---------------->");
let subject = match text_split.next() {
Some(s) => s.trim().to_string(),
None => err!("Template doesn't contain subject"),
let subject = if let Some(s) = text_split.next() {
s.trim().to_owned()
} else {
err!("Template doesn't contain subject")
};
let body = match text_split.next() {
Some(s) => s.trim().to_string(),
None => err!("Template doesn't contain body"),
let body = if let Some(s) = text_split.next() {
s.trim().to_owned()
} else {
err!("Template doesn't contain body")
};
if text_split.next().is_some() {
@@ -204,9 +207,8 @@ pub async fn send_verify_email(address: &str, user_id: &UserId) -> EmptyResult {
pub async fn send_register_verify_email(email: &str, token: &str) -> EmptyResult {
let mut query = url::Url::parse("https://query.builder").unwrap();
query.query_pairs_mut().append_pair("email", email).append_pair("token", token);
let query_string = match query.query() {
None => err!("Failed to build verify URL query parameters"),
Some(query) => query,
let Some(query_string) = query.query() else {
err!("Failed to build verify URL query parameters")
};
let (subject, body_html, body_text) = get_text(
@@ -504,8 +506,6 @@ pub async fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult
}
pub async fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTime, device: &Device) -> EmptyResult {
use crate::util::upcase_first;
let fmt = "%A, %B %_d, %Y at %r %Z";
let (subject, body_html, body_text) = get_text(
"email/new_device_logged_in",
@@ -529,8 +529,6 @@ pub async fn send_incomplete_2fa_login(
device_name: &str,
device_type: &str,
) -> EmptyResult {
use crate::util::upcase_first;
let fmt = "%A, %B %_d, %Y at %r %Z";
let (subject, body_html, body_text) = get_text(
"email/incomplete_2fa_login",
@@ -655,7 +653,7 @@ pub async fn send_protected_action_token(address: &str, token: &str) -> EmptyRes
async fn send_with_selected_transport(email: Message) -> EmptyResult {
if CONFIG.use_sendmail() {
match sendmail_transport().send(email).await {
Ok(_) => Ok(()),
Ok(()) => Ok(()),
// Match some common errors and make them more user friendly
Err(e) => {
if e.is_client() {
@@ -664,10 +662,9 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult {
} else if e.is_response() {
debug!("Sendmail response error: {e:?}");
err!(format!("Sendmail response error: {e}"));
} else {
debug!("Sendmail error: {e:?}");
err!(format!("Sendmail error: {e}"));
}
debug!("Sendmail error: {e:?}");
err!(format!("Sendmail error: {e}"));
}
}
} else {
@@ -695,10 +692,9 @@ async fn send_with_selected_transport(email: Message) -> EmptyResult {
} else if e.is_tls() {
debug!("SMTP encryption error: {e:#?}");
err!(format!("SMTP encryption error: {e}"));
} else {
debug!("SMTP error: {e:#?}");
err!(format!("SMTP error: {e}"));
}
debug!("SMTP error: {e:#?}");
err!(format!("SMTP error: {e}"));
}
}
}
+39 -31
View File
@@ -33,6 +33,7 @@ use std::{
path::Path,
process::exit,
str::FromStr,
sync::{Arc, atomic::Ordering},
thread,
};
@@ -44,6 +45,8 @@ use tokio::{
#[cfg(unix)]
use tokio::signal::unix::SignalKind;
use rocket::data::{Limits, ToByteUnit};
#[macro_use]
mod error;
mod api;
@@ -57,19 +60,19 @@ mod mail;
mod ratelimit;
mod sso;
mod sso_client;
mod storage;
mod util;
use crate::api::core::two_factor::duo_oidc::purge_duo_contexts;
use crate::api::purge_auth_requests;
use crate::api::{WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS};
pub use config::{PathType, CONFIG};
use crate::api::{
WS_ANONYMOUS_SUBSCRIPTIONS, WS_USERS, core::two_factor::duo_oidc::purge_duo_contexts, purge_auth_requests,
};
pub use config::{CONFIG, PathType};
pub use error::{Error, MapResult};
use rocket::data::{Limits, ToByteUnit};
use std::sync::{atomic::Ordering, Arc};
pub use util::is_running_in_container;
#[rocket::main]
async fn main() -> Result<(), Error> {
install_rustls_crypto_provider();
parse_args();
launch_info();
@@ -135,26 +138,23 @@ fn parse_args() {
if let Some(command) = pargs.subcommand().unwrap_or_default() {
if command == "hash" {
use argon2::{
password_hash::SaltString, Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13,
Algorithm::Argon2id, Argon2, ParamsBuilder, PasswordHasher, Version::V0x13, password_hash::SaltString,
};
let mut argon2_params = ParamsBuilder::new();
let preset: Option<String> = pargs.opt_value_from_str(["-p", "--preset"]).unwrap_or_default();
let selected_preset;
match preset.as_deref() {
Some("owasp") => {
selected_preset = "owasp";
argon2_params.m_cost(19456);
argon2_params.t_cost(2);
argon2_params.p_cost(1);
}
_ => {
// Bitwarden preset is the default
selected_preset = "bitwarden";
argon2_params.m_cost(65540);
argon2_params.t_cost(3);
argon2_params.p_cost(4);
}
if preset.as_deref() == Some("owasp") {
selected_preset = "owasp";
argon2_params.m_cost(19456);
argon2_params.t_cost(2);
argon2_params.p_cost(1);
} else {
// Bitwarden preset is the default
selected_preset = "bitwarden";
argon2_params.m_cost(65540);
argon2_params.t_cost(3);
argon2_params.p_cost(4);
}
println!("Generate an Argon2id PHC string using the '{selected_preset}' preset:\n");
@@ -202,6 +202,14 @@ fn parse_args() {
}
}
fn install_rustls_crypto_provider() {
if rustls::crypto::CryptoProvider::get_default().is_none() {
rustls::crypto::ring::default_provider()
.install_default()
.expect("failed to install rustls ring crypto provider");
}
}
fn launch_info() {
println!(
"\
@@ -237,7 +245,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
let level = caps
.get(1)
.and_then(|m| log::LevelFilter::from_str(m.as_str()).ok())
.ok_or(Error::new("Failed to parse global log level".to_string(), ""))?;
.ok_or(Error::new("Failed to parse global log level".to_owned(), ""))?;
let levels_override: Vec<(&str, log::LevelFilter)> = caps
.get(2)
@@ -246,13 +254,13 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
.split(',')
.collect::<Vec<&str>>()
.into_iter()
.flat_map(|s| match s.split_once('=') {
.filter_map(|s| match s.split_once('=') {
Some((log, lvl_str)) => log::LevelFilter::from_str(lvl_str).ok().map(|lvl| (log, lvl)),
_ => None,
})
.collect()
})
.ok_or(Error::new("Failed to parse overrides".to_string(), ""))?;
.ok_or(Error::new("Failed to parse overrides".to_owned(), ""))?;
(level, levels_override)
} else {
@@ -328,7 +336,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
("vaultwarden::db::query_logger", log::LevelFilter::Off),
]);
for (path, level) in levels_override.into_iter() {
for (path, level) in levels_override {
let _ = default_levels.insert(path, level);
}
@@ -342,7 +350,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
let mut logger = fern::Dispatch::new().level(level).chain(std::io::stdout());
for (path, level) in default_levels {
logger = logger.level_for(path.to_string(), level);
logger = logger.level_for(path.to_owned(), level);
}
if CONFIG.extended_logging() {
@@ -353,7 +361,7 @@ fn init_logging() -> Result<log::LevelFilter, Error> {
record.target(),
record.level(),
message
))
));
});
} else {
logger = logger.format(|out, message, _| out.finish(format_args!("{message}")));
@@ -599,9 +607,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
#[cfg(all(unix, sqlite))]
{
if db::ACTIVE_DB_TYPE.get() != Some(&db::DbConnType::Sqlite) {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
} else {
if db::ACTIVE_DB_TYPE.get() == Some(&db::DbConnType::Sqlite) {
tokio::spawn(async move {
let mut signal_user1 = tokio::signal::unix::signal(SignalKind::user_defined1()).unwrap();
loop {
@@ -614,6 +620,8 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error>
}
}
});
} else {
debug!("PostgreSQL and MySQL/MariaDB do not support this backup feature, skip adding USR1 signal.");
}
}
@@ -661,7 +669,7 @@ fn schedule_jobs(pool: db::DbPool) {
let runtime = tokio::runtime::Runtime::new().unwrap();
thread::Builder::new()
.name("job-scheduler".to_string())
.name("job-scheduler".to_owned())
.spawn(move || {
use job_scheduler_ng::{Job, JobScheduler};
let _runtime_guard = runtime.enter();
+4 -4
View File
@@ -1,8 +1,8 @@
use std::{net::IpAddr, num::NonZeroU32, sync::LazyLock, time::Duration};
use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
use governor::{Quota, RateLimiter, clock::DefaultClock, state::keyed::DashMapStateStore};
use crate::{Error, CONFIG};
use crate::{CONFIG, Error};
type Limiter<T = IpAddr> = RateLimiter<T, DashMapStateStore<T>, DefaultClock>;
@@ -20,7 +20,7 @@ static LIMITER_ADMIN: LazyLock<Limiter> = LazyLock::new(|| {
pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> {
match LIMITER_LOGIN.check_key(ip) {
Ok(_) => Ok(()),
Ok(()) => Ok(()),
Err(_e) => {
err_code!("Too many login requests", 429);
}
@@ -29,7 +29,7 @@ pub fn check_limit_login(ip: &IpAddr) -> Result<(), Error> {
pub fn check_limit_admin(ip: &IpAddr) -> Result<(), Error> {
match LIMITER_ADMIN.check_key(ip) {
Ok(_) => Ok(()),
Ok(()) => Ok(()),
Err(_e) => {
err_code!("Too many admin requests", 429);
}
+48 -43
View File
@@ -6,18 +6,18 @@ use regex::Regex;
use url::Url;
use crate::{
CONFIG,
api::ApiResult,
auth,
auth::{AuthMethod, AuthTokens, TokenWrapper, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY},
auth::{AuthMethod, AuthTokens, BW_EXPIRATION, DEFAULT_REFRESH_VALIDITY, TokenWrapper},
db::{
models::{Device, OIDCAuthenticatedUser, OIDCCodeWrapper, SsoAuth, SsoUser, User},
DbConn,
models::{Device, OIDCAuthenticatedUser, SsoAuth, SsoUser, User},
},
sso_client::Client,
CONFIG,
};
pub static FAKE_SSO_IDENTIFIER: &str = "vaultwarden-dummy-oidc-identifier";
pub static FAKE_SSO_IDENTIFIER: &str = "00000000-01DC-01DC-01DC-000000000000";
static SSO_JWT_ISSUER: LazyLock<String> = LazyLock::new(|| format!("{}|sso", CONFIG.domain_origin()));
@@ -123,7 +123,7 @@ pub fn encode_ssotoken_claims() -> String {
nbf: time_now.timestamp(),
exp: (time_now + chrono::TimeDelta::try_minutes(2).unwrap()).timestamp(),
iss: SSO_JWT_ISSUER.to_string(),
sub: "vaultwarden".to_string(),
sub: "vaultwarden".to_owned(),
};
auth::encode_jwt(&claims)
@@ -171,12 +171,14 @@ fn decode_token_claims(token_name: &str, token: &str) -> ApiResult<BasicTokenCla
}
pub fn decode_state(base64_state: &str) -> ApiResult<OIDCState> {
let state = match data_encoding::BASE64.decode(base64_state.as_bytes()) {
Ok(vec) => match String::from_utf8(vec) {
Ok(valid) => OIDCState(valid),
Err(_) => err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding")),
},
Err(_) => err!(format!("Failed to decode {base64_state} using base64")),
let state = if let Ok(vec) = data_encoding::BASE64.decode(base64_state.as_bytes()) {
if let Ok(valid) = String::from_utf8(vec) {
OIDCState(valid)
} else {
err!(format!("Invalid utf8 chars in {base64_state} after base64 decoding"))
}
} else {
err!(format!("Failed to decode {base64_state} using base64"))
};
Ok(state)
@@ -188,22 +190,26 @@ pub async fn authorize_url(
client_challenge: OIDCCodeChallenge,
client_id: &str,
raw_redirect_uri: &str,
binding_hash: Option<String>,
conn: DbConn,
) -> ApiResult<Url> {
let redirect_uri = match client_id {
"web" | "browser" => format!("{}/sso-connector.html", CONFIG.domain()),
"desktop" | "mobile" => "bitwarden://sso-callback".to_string(),
"desktop" | "mobile" => "bitwarden://sso-callback".to_owned(),
"cli" => {
let port_regex = Regex::new(r"^http://localhost:([0-9]{4})$").unwrap();
match port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str())) {
Some(port) => format!("http://localhost:{port}"),
None => err!("Failed to extract port number"),
if let Some(port) =
port_regex.captures(raw_redirect_uri).and_then(|captures| captures.get(1).map(|c| c.as_str()))
{
format!("http://localhost:{port}")
} else {
err!("Failed to extract port number")
}
}
_ => err!(format!("Unsupported client {client_id}")),
};
let (auth_url, sso_auth) = Client::authorize_url(state, client_challenge, redirect_uri).await?;
let (auth_url, sso_auth) = Client::authorize_url(state, client_challenge, redirect_uri, binding_hash).await?;
sso_auth.save(&conn).await?;
Ok(auth_url)
}
@@ -239,33 +245,32 @@ impl OIDCIdentifier {
// - second time we will rely on `SsoAuth.auth_response` since the `code` has already been exchanged.
// The `SsoAuth` will ensure that the user is authorized only once.
pub async fn exchange_code(
state: &OIDCState,
code: &OIDCCode,
client_verifier: OIDCCodeVerifier,
conn: &DbConn,
) -> ApiResult<(SsoAuth, OIDCAuthenticatedUser)> {
use openidconnect::OAuth2TokenResponse;
let mut sso_auth = match SsoAuth::find(state, conn).await {
None => err!(format!("Invalid state cannot retrieve sso auth")),
Some(sso_auth) => sso_auth,
let Some(mut sso_auth) = SsoAuth::find_by_code(code, conn).await else {
err!("Invalid code cannot retrieve sso auth")
};
if let Some(authenticated_user) = sso_auth.auth_response.clone() {
return Ok((sso_auth, authenticated_user));
}
let code = match sso_auth.code_response.clone() {
Some(OIDCCodeWrapper::Ok {
code,
}) => code.clone(),
Some(OIDCCodeWrapper::Error {
error,
error_description,
}) => {
let code = match (sso_auth.code_response.clone(), sso_auth.code_response_error.as_ref()) {
(Some(code), None) => code,
(_, Some(re)) => {
let error_msg = format!(
"SSO authorization failed: {}, {}",
re.error,
re.error_description.as_ref().unwrap_or(&String::new())
);
sso_auth.delete(conn).await?;
err!(format!("SSO authorization failed: {error}, {}", error_description.as_ref().unwrap_or(&String::new())))
err!(error_msg);
}
None => {
(None, _) => {
sso_auth.delete(conn).await?;
err!("Missing authorization provider return");
}
@@ -283,10 +288,10 @@ pub async fn exchange_code(
let email_verified = id_claims.email_verified().or(user_info.email_verified());
let user_name = id_claims.preferred_username().map(|un| un.to_string());
let user_name = id_claims.preferred_username().or(user_info.preferred_username()).map(|un| un.to_string());
let refresh_token = token_response.refresh_token().map(|t| t.secret());
if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_string()) {
let refresh_token = token_response.refresh_token().map(openidconnect::RefreshToken::secret);
if refresh_token.is_none() && CONFIG.sso_scopes_vec().contains(&"offline_access".to_owned()) {
error!("Scope offline_access is present but response contain no refresh_token");
}
@@ -330,7 +335,9 @@ pub async fn redeem(
user_sso.save(conn).await?;
}
if !CONFIG.sso_auth_only_not_session() {
if CONFIG.sso_auth_only_not_session() {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} else {
let now = Utc::now();
let (ap_nbf, ap_exp) =
@@ -343,9 +350,7 @@ pub async fn redeem(
let access_claims =
auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now);
_create_auth_tokens(device, auth_user.refresh_token, access_claims, auth_user.access_token)
} else {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
create_auth_tokens_impl(device, auth_user.refresh_token, access_claims, auth_user.access_token)
}
}
@@ -359,7 +364,9 @@ pub fn create_auth_tokens(
access_token: String,
expires_in: Option<Duration>,
) -> ApiResult<AuthTokens> {
if !CONFIG.sso_auth_only_not_session() {
if CONFIG.sso_auth_only_not_session() {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
} else {
let now = Utc::now();
let (ap_nbf, ap_exp) = match (decode_token_claims("access_token", &access_token), expires_in) {
@@ -371,13 +378,11 @@ pub fn create_auth_tokens(
let access_claims =
auth::LoginJwtClaims::new(device, user, ap_nbf, ap_exp, AuthMethod::Sso.scope_vec(), client_id, now);
_create_auth_tokens(device, refresh_token, access_claims, access_token)
} else {
Ok(AuthTokens::new(device, user, AuthMethod::Sso, client_id))
create_auth_tokens_impl(device, refresh_token, access_claims, access_token)
}
}
fn _create_auth_tokens(
fn create_auth_tokens_impl(
device: &Device,
refresh_token: Option<String>,
access_claims: auth::LoginJwtClaims,
@@ -461,7 +466,7 @@ pub async fn exchange_refresh_token(
now,
);
_create_auth_tokens(device, None, access_claims, access_token)
create_auth_tokens_impl(device, None, access_claims, access_token)
}
None => err!("No token present while in SSO"),
}
+70 -22
View File
@@ -1,17 +1,31 @@
use std::{borrow::Cow, sync::LazyLock, time::Duration};
use std::{borrow::Cow, future::Future, pin::Pin, sync::LazyLock, time::Duration};
use openidconnect::{core::*, reqwest, *};
use openidconnect::{
AccessToken, AsyncHttpClient, AuthDisplay, AuthPrompt, AuthenticationFlow, AuthorizationCode, AuthorizationRequest,
ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EmptyExtraTokenFields, EndpointNotSet, EndpointSet,
HttpClientError, HttpRequest, HttpResponse, IdTokenClaims, IdTokenFields, Nonce, OAuth2TokenResponse,
PkceCodeChallenge, PkceCodeVerifier, RefreshToken, ResponseType, Scope, StandardErrorResponse,
StandardTokenResponse,
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreIdTokenVerifier,
CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreJwsSigningAlgorithm, CoreProviderMetadata,
CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenResponse, CoreTokenType, CoreUserInfoClaims,
},
http, url,
};
use regex::Regex;
use url::Url;
use crate::{
CONFIG,
api::{ApiResult, EmptyResult},
db::models::SsoAuth,
http_client::get_reqwest_client_builder,
sso::{OIDCCode, OIDCCodeChallenge, OIDCCodeVerifier, OIDCState},
CONFIG,
};
static CLIENT_CACHE_KEY: LazyLock<String> = LazyLock::new(|| "sso-client".to_string());
static CLIENT_CACHE_KEY: LazyLock<String> = LazyLock::new(|| "sso-client".to_owned());
static CLIENT_CACHE: LazyLock<moka::sync::Cache<String, Client>> = LazyLock::new(|| {
moka::sync::Cache::builder()
.max_capacity(1)
@@ -46,19 +60,51 @@ pub type RefreshTokenResponse = (Option<String>, String, Option<Duration>);
#[derive(Clone)]
pub struct Client {
pub http_client: reqwest::Client,
pub http_client: OidcHttpClient,
pub core_client: CustomClient,
}
#[derive(Clone)]
pub struct OidcHttpClient {
client: reqwest::Client,
}
impl OidcHttpClient {
fn new() -> Result<Self, reqwest::Error> {
get_reqwest_client_builder().redirect(reqwest::redirect::Policy::none()).build().map(|client| Self {
client,
})
}
}
impl<'c> AsyncHttpClient<'c> for OidcHttpClient {
type Error = HttpClientError<reqwest::Error>;
type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
fn call(&'c self, request: HttpRequest) -> Self::Future {
Box::pin(async move {
let response = self.client.execute(request.try_into().map_err(Box::new)?).await.map_err(Box::new)?;
let mut builder = http::Response::builder().status(response.status()).version(response.version());
for (name, value) in response.headers() {
builder = builder.header(name, value);
}
builder.body(response.bytes().await.map_err(Box::new)?.to_vec()).map_err(HttpClientError::Http)
})
}
}
impl Client {
// Call the OpenId discovery endpoint to retrieve configuration
async fn _get_client() -> ApiResult<Self> {
async fn get_client() -> ApiResult<Self> {
let client_id = ClientId::new(CONFIG.sso_client_id());
let client_secret = ClientSecret::new(CONFIG.sso_client_secret());
let issuer_url = CONFIG.sso_issuer_url()?;
let http_client = match reqwest::ClientBuilder::new().redirect(reqwest::redirect::Policy::none()).build() {
let http_client = match OidcHttpClient::new() {
Err(err) => err!(format!("Failed to build http client: {err}")),
Ok(client) => client,
};
@@ -70,14 +116,16 @@ impl Client {
let base_client = CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret));
let token_uri = match base_client.token_uri() {
Some(uri) => uri.clone(),
None => err!("Failed to discover token_url, cannot proceed"),
let token_uri = if let Some(uri) = base_client.token_uri() {
uri.clone()
} else {
err!("Failed to discover token_url, cannot proceed")
};
let user_info_url = match base_client.user_info_url() {
Some(url) => url.clone(),
None => err!("Failed to discover user_info url, cannot proceed"),
let user_info_url = if let Some(url) = base_client.user_info_url() {
url.clone()
} else {
err!("Failed to discover user_info url, cannot proceed")
};
let core_client = base_client
@@ -96,13 +144,13 @@ impl Client {
if CONFIG.sso_client_cache_expiration() > 0 {
match CLIENT_CACHE.get(&*CLIENT_CACHE_KEY) {
Some(client) => Ok(client),
None => Self::_get_client().await.inspect(|client| {
None => Self::get_client().await.inspect(|client| {
debug!("Inserting new client in cache");
CLIENT_CACHE.insert(CLIENT_CACHE_KEY.clone(), client.clone());
}),
}
} else {
Self::_get_client().await
Self::get_client().await
}
}
@@ -117,6 +165,7 @@ impl Client {
state: OIDCState,
client_challenge: OIDCCodeChallenge,
redirect_uri: String,
binding_hash: Option<String>,
) -> ApiResult<(Url, SsoAuth)> {
let scopes = CONFIG.sso_scopes_vec().into_iter().map(Scope::new);
let base64_state = data_encoding::BASE64.encode(state.to_string().as_bytes());
@@ -139,7 +188,7 @@ impl Client {
}
let (auth_url, _, nonce) = auth_req.url();
Ok((auth_url, SsoAuth::new(state, client_challenge, nonce.secret().clone(), redirect_uri)))
Ok((auth_url, SsoAuth::new(state, client_challenge, nonce.secret().clone(), redirect_uri, binding_hash)))
}
pub async fn exchange_code(
@@ -180,15 +229,14 @@ impl Client {
Ok(token_response) => {
let oidc_nonce = Nonce::new(sso_auth.nonce.clone());
let id_token = match token_response.extra_fields().id_token() {
None => err!("Token response did not contain an id_token"),
Some(token) => token,
let Some(id_token) = token_response.extra_fields().id_token() else {
err!("Token response did not contain an id_token")
};
if CONFIG.sso_debug_tokens() {
debug!("Id token: {}", id_token.to_string());
debug!("Access token: {}", token_response.access_token().secret());
debug!("Refresh token: {:?}", token_response.refresh_token().map(|t| t.secret()));
debug!("Refresh token: {:?}", token_response.refresh_token().map(RefreshToken::secret));
debug!("Expiration time: {:?}", token_response.expires_in());
}
@@ -241,12 +289,12 @@ impl Client {
let client = Client::cached().await?;
REFRESH_CACHE
.get_with(refresh_token.clone(), async move { client._exchange_refresh_token(refresh_token).await })
.get_with(refresh_token.clone(), async move { client.exchange_refresh_token_impl(refresh_token).await })
.await
.map_err(Into::into)
}
async fn _exchange_refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse, String> {
async fn exchange_refresh_token_impl(&self, refresh_token: String) -> Result<RefreshTokenResponse, String> {
let rt = RefreshToken::new(refresh_token);
match self.core_client.exchange_refresh_token(&rt).request_async(&self.http_client).await {
+13 -2
View File
@@ -1,6 +1,17 @@
body {
padding-top: 75px;
}
/* Some extra width's for the main layout */
@media (min-width: 1600px) {
.container-xxl {
max-width: 1520px;
}
}
@media (min-width: 1800px) {
.container-xxl {
max-width: 1720px;
}
}
img {
width: 48px;
height: 48px;
@@ -38,8 +49,8 @@ img {
max-width: 130px;
}
#users-table .vw-actions, #orgs-table .vw-actions {
min-width: 155px;
max-width: 160px;
min-width: 170px;
max-width: 180px;
}
#users-table .vw-org-cell {
max-height: 120px;
+2 -2
View File
@@ -4,10 +4,10 @@
*
* To rebuild or modify this file with the latest versions of the included
* software please visit:
* https://datatables.net/download/#bs5/dt-2.3.7
* https://datatables.net/download/#bs5/dt-2.3.8
*
* Included libraries:
* DataTables 2.3.7
* DataTables 2.3.8
*/
:root {
+41 -11
View File
@@ -4,13 +4,13 @@
*
* To rebuild or modify this file with the latest versions of the included
* software please visit:
* https://datatables.net/download/#bs5/dt-2.3.7
* https://datatables.net/download/#bs5/dt-2.3.8
*
* Included libraries:
* DataTables 2.3.7
* DataTables 2.3.8
*/
/*! DataTables 2.3.7
/*! DataTables 2.3.8
* © SpryMedia Ltd - datatables.net/license
*/
@@ -525,7 +525,7 @@
*
* @type string
*/
builder: "bs5/dt-2.3.7",
builder: "bs5/dt-2.3.8",
/**
* Buttons. For use with the Buttons extension for DataTables. This is
@@ -3607,6 +3607,11 @@
if ( holdPosition !== true ) {
settings._iDisplayStart = 0;
}
else {
// Keep position, but make sure that there is actually data to display,
// otherwise we need to rewind a bit (e.g. if rows were deleted)
_fnLengthOverflow(settings);
}
// Let any modules know about the draw hold position state (used by
// scrolling internally)
@@ -4920,6 +4925,12 @@
var args = [settings, settings.json];
// If the footer element is empty after initialisation, then remove it
let tfoot = $(settings.tfoot);
if (tfoot.children().length === 0) {
tfoot.remove();
}
settings._bInitComplete = true;
// Table is fully set up and we have data, so calculate the
@@ -5376,12 +5387,12 @@
// the content of the cell so that the width applied to the header and body
// both match, but we want to hide it completely.
$('th, td', headerCopy).each(function () {
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing">');
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing" />');
});
if ( footer ) {
$('th, td', footerCopy).each(function () {
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing">');
$(this.childNodes).wrapAll('<div class="dt-scroll-sizing" />');
});
}
@@ -5409,6 +5420,10 @@
// Correct DOM ordering for colgroup - comes before the thead
table.children('colgroup').prependTo(table);
// Remove tabindex from the hidden row elements
table.find('thead, tfoot').find('[tabindex]').removeAttr('tabindex');
table.find('thead, tfoot').find('role').removeAttr('role');
// Adjust the position of the header in case we loose the y-scrollbar
divBody.trigger('scroll');
@@ -5732,8 +5747,12 @@
.replace(/id=".*?"/g, '')
.replace(/name=".*?"/g, '');
// Don't want Javascript at all in these calculation cells.
cellString = cellString.replace(/<script.*?<\/script>/gi, ' ');
// Don't want script, dialog or template tags in the width
// calculations as they are hidden content
cellString = cellString
.replace(/<script[\s\S]*?<\/script>/gi, ' ')
.replace(/<dialog[\s\S]*?<\/dialog>/gi, ' ')
.replace(/<template[\s\S]*?<\/template>/gi, ' ');
var noHtml = _stripHtml(cellString, ' ')
.replace( /&nbsp;/g, ' ' );
@@ -10304,7 +10323,7 @@
* @type string
* @default Version number
*/
DataTable.version = "2.3.7";
DataTable.version = "2.3.8";
/**
* Private data store, containing all of the settings objects that are
@@ -12586,6 +12605,7 @@
var __mlWarning = false;
var __luxon; // Can be assigned in DateTable.use()
var __moment; // Can be assigned in DateTable.use()
var __reIsoTimezone = /[T\s]\d{2}.*?(Z|[+-]\d{2}(?::?\d{2})?)$/;
/**
*
@@ -12606,7 +12626,7 @@
resolveWindowLibs();
if (__moment) {
dt = __moment.utc( d, format, locale, true );
dt = __moment( d, format, locale, true );
if (! dt.isValid()) {
return null;
@@ -12716,6 +12736,16 @@
return d;
}
// Determine if there is a timezone. If there is, we want to reuse
// it for the output, so the timezone doesn't change between the
// input and output.
let options = {};
let tzMatch = typeof d === 'string' ? d.match(__reIsoTimezone) : null;
if (tzMatch) {
options.timeZone = tzMatch[1] === 'Z' ? 'UTC' : tzMatch[1];
}
var dt = __mldObj(d, from, locale);
if (dt === null) {
@@ -12729,7 +12759,7 @@
var formatted = to === null
? __mld(dt, 'toDate', 'toJSDate', '')[localeString](
navigator.language,
{ timeZone: "UTC" }
options
)
: __mld(dt, 'format', 'toFormat', 'toISOString', to);
+1 -1
View File
@@ -27,7 +27,7 @@
</symbol>
</svg>
<nav class="navbar navbar-expand-md navbar-dark bg-dark mb-4 shadow fixed-top">
<div class="container-xl">
<div class="container-xxl">
<a class="navbar-brand" href="{{urlpath}}/admin"><img class="vaultwarden-icon" src="{{urlpath}}/vw_static/vaultwarden-icon.png" alt="V">aultwarden Admin</a>
<button class="navbar-toggler" type="button" data-bs-toggle="collapse" data-bs-target="#navbarCollapse"
aria-controls="navbarCollapse" aria-expanded="false" aria-label="Toggle navigation">
+1 -1
View File
@@ -1,4 +1,4 @@
<main class="container-xl">
<main class="container-xxl">
<div id="diagnostics-block" class="my-3 p-3 rounded shadow">
<h6 class="border-bottom pb-2 mb-2">Diagnostics</h6>
+1 -1
View File
@@ -1,4 +1,4 @@
<main class="container-xl">
<main class="container-xxl">
{{#if error}}
<div class="align-items-center p-3 mb-3 text-opacity-50 text-dark bg-warning rounded shadow">
<div>
+1 -1
View File
@@ -1,4 +1,4 @@
<main class="container-xl">
<main class="container-xxl">
<div id="organizations-block" class="my-3 p-3 rounded shadow">
<h6 class="border-bottom pb-2 mb-3">Organizations</h6>
<div class="table-responsive-xl small">
+1 -1
View File
@@ -1,4 +1,4 @@
<main class="container-xl">
<main class="container-xxl">
<div id="admin_token_warning" class="alert alert-warning alert-dismissible fade show d-none">
<button type="button" class="btn-close" data-bs-target="admin_token_warning" data-bs-dismiss="alert" aria-label="Close"></button>
You are using a plain text `ADMIN_TOKEN` which is insecure.<br>
+2 -2
View File
@@ -1,4 +1,4 @@
<main class="container-xl">
<main class="container-xxl">
<div id="users-block" class="my-3 p-3 rounded shadow">
<h6 class="border-bottom pb-2 mb-3">Registered Users</h6>
<div class="table-responsive-xl small">
@@ -43,7 +43,7 @@
</td>
{{#if ../sso_enabled}}
<td>
<span class="d-block">{{sso_identifier}}</span>
<span class="d-block text-break text-wrap">{{sso_identifier}}</span>
</td>
{{/if}}
<td>

Some files were not shown because too many files have changed in this diff Show More