diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f01e865d..ef1c2b2e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,8 +38,11 @@ env: # Custom nix binary cache if fork is being used ATTIC_ENDPOINT: ${{ vars.ATTIC_ENDPOINT }} ATTIC_PUBLIC_KEY: ${{ vars.ATTIC_PUBLIC_KEY }} - # Get error output from nix that we can actually use - NIX_CONFIG: show-trace = true + # Get error output from nix that we can actually use, and use our binary caches for the earlier CI steps + NIX_CONFIG: | + show-trace = true + extra-substituters = https://attic.kennel.juneis.dog/conduit https://attic.kennel.juneis.dog/conduwuit https://cache.lix.systems https://conduwuit.cachix.org + extra-trusted-public-keys = conduit:eEKoUwlQGDdYmAI/Q/0slVlegqh/QmAvQd7HBSm21Wk= conduwuit:BbycGUgTISsltcmH0qNjFR9dbrQNYgdIAcmViSGoVTE= cache.lix.systems:aBnZUw8zA7H35Cz2RyKFVs3H4PlGTLawyY5KRbvJR8o= conduwuit.cachix.org-1:MFRm6jcnfTf0jSAbmvLfhO3KBMt4px+1xaereWXp8Xg= permissions: packages: write @@ -57,7 +60,7 @@ jobs: uses: actions/checkout@v4 - name: Tag comparison check - if: startsWith(github.ref, 'refs/tags/v') + if: ${{ startsWith(github.ref, 'refs/tags/v') && !endsWith(github.ref, '-rc') }} run: | # Tag mismatch with latest repo tag check to prevent potential downgrades LATEST_TAG=$(git describe --tags `git rev-list --tags --max-count=1`) @@ -115,7 +118,7 @@ jobs: - name: Prepare build environment run: | echo 'source $HOME/.nix-profile/share/nix-direnv/direnvrc' > "$HOME/.direnvrc" - nix profile install --impure --inputs-from . nixpkgs#direnv nixpkgs#nix-direnv + nix profile install --inputs-from . nixpkgs#direnv nixpkgs#nix-direnv direnv allow nix develop .#all-features --command true @@ -129,15 +132,10 @@ jobs: run: | direnv exec . engage > >(tee -a test_output.log) - - name: Sync Complement repository - uses: actions/checkout@v4 - with: - repository: 'matrix-org/complement' - path: complement_src - - name: Run Complement tests run: | - direnv exec . bin/complement 'complement_src' 'complement_test_logs.jsonl' 'complement_test_results.jsonl' + # the nix devshell sets $COMPLEMENT_SRC, so "/dev/null" is no-op + direnv exec . bin/complement "/dev/null" complement_test_logs.jsonl complement_test_results.jsonl > >(tee -a test_output.log) cp -v -f result complement_oci_image.tar.gz - name: Upload Complement OCI image @@ -163,11 +161,7 @@ jobs: - name: Diff Complement results with checked-in repo results run: | - diff -u --color=always tests/test_results/complement/test_results.jsonl complement_test_results.jsonl > >(tee -a complement_test_output.log) - echo '# Complement diff results' >> $GITHUB_STEP_SUMMARY - echo '```diff' >> $GITHUB_STEP_SUMMARY - tail -n 100 complement_test_output.log | sed 's/\x1b\[[0-9;]*m//g' >> $GITHUB_STEP_SUMMARY - echo '```' >> $GITHUB_STEP_SUMMARY + diff -u --color=always tests/test_results/complement/test_results.jsonl complement_test_results.jsonl > >(tee -a complement_diff_output.log) - name: Update Job Summary if: success() || failure() @@ -175,9 +169,15 @@ jobs: if [ ${{ job.status }} == 'success' ]; then echo '# ✅ completed suwuccessfully' >> $GITHUB_STEP_SUMMARY else + echo '# CI failure' >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY tail -n 40 test_output.log | sed 's/\x1b\[[0-9;]*m//g' >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY + + echo '# Complement diff results' >> $GITHUB_STEP_SUMMARY + echo '```diff' >> $GITHUB_STEP_SUMMARY + tail -n 100 complement_diff_output.log | sed 's/\x1b\[[0-9;]*m//g' >> $GITHUB_STEP_SUMMARY + echo '```' >> $GITHUB_STEP_SUMMARY fi build: @@ -240,7 +240,7 @@ jobs: - name: Prepare build environment run: | echo 'source $HOME/.nix-profile/share/nix-direnv/direnvrc' > "$HOME/.direnvrc" - nix profile install --impure --inputs-from . nixpkgs#direnv nixpkgs#nix-direnv + nix profile install --inputs-from . nixpkgs#direnv nixpkgs#nix-direnv direnv allow nix develop .#all-features --command true @@ -249,7 +249,7 @@ jobs: CARGO_DEB_TARGET_TUPLE=$(echo ${{ matrix.target }} | grep -o -E '^([^-]*-){3}[^-]*') SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct) - bin/nix-build-and-cache just .#static-${{ matrix.target }} + bin/nix-build-and-cache just .#static-${{ matrix.target }}-all-features mkdir -v -p target/release/ mkdir -v -p target/$CARGO_DEB_TARGET_TUPLE/release/ cp -v -f result/bin/conduit target/release/conduwuit @@ -276,7 +276,7 @@ jobs: - name: Build OCI image ${{ matrix.target }} run: | - bin/nix-build-and-cache just .#oci-image-${{ matrix.target }} + bin/nix-build-and-cache just .#oci-image-${{ matrix.target }}-all-features cp -v -f result oci-image-${{ matrix.target }}.tar.gz - name: Upload OCI image ${{ matrix.target }} @@ -296,15 +296,15 @@ jobs: DOCKER_ARM64: docker.io/${{ github.repository }}:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }}-arm64v8 DOCKER_AMD64: docker.io/${{ github.repository }}:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }}-amd64 DOCKER_TAG: docker.io/${{ github.repository }}:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }} - DOCKER_BRANCH: docker.io/${{ github.repository }}:${{ (startsWith(github.ref, 'refs/tags/v') && 'latest') || (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }} + DOCKER_BRANCH: docker.io/${{ github.repository }}:${{ (startsWith(github.ref, 'refs/tags/v') && !endsWith(github.ref, '-rc') && 'latest') || (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }} GHCR_ARM64: ghcr.io/${{ github.repository }}:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }}-arm64v8 GHCR_AMD64: ghcr.io/${{ github.repository }}:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }}-amd64 GHCR_TAG: ghcr.io/${{ github.repository }}:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }} - GHCR_BRANCH: ghcr.io/${{ github.repository }}:${{ (startsWith(github.ref, 'refs/tags/v') && 'latest') || (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }} + GHCR_BRANCH: ghcr.io/${{ github.repository }}:${{ (startsWith(github.ref, 'refs/tags/v') && !endsWith(github.ref, '-rc') && 'latest') || (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }} GLCR_ARM64: registry.gitlab.com/conduwuit/conduwuit:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }}-arm64v8 GLCR_AMD64: registry.gitlab.com/conduwuit/conduwuit:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }}-amd64 GLCR_TAG: registry.gitlab.com/conduwuit/conduwuit:${{ (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }}-${{ github.sha }} - GLCR_BRANCH: registry.gitlab.com/conduwuit/conduwuit:${{ (startsWith(github.ref, 'refs/tags/v') && 'latest') || (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }} + GLCR_BRANCH: registry.gitlab.com/conduwuit/conduwuit:${{ (startsWith(github.ref, 'refs/tags/v') && !endsWith(github.ref, '-rc') && 'latest') || (github.head_ref != '' && format('merge-{0}-{1}', github.event.number, github.event.pull_request.user.login)) || github.ref_name }} DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} GITLAB_TOKEN: ${{ secrets.GITLAB_TOKEN }} diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index ed0e898d..c6f9133a 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -21,8 +21,11 @@ env: # Custom nix binary cache if fork is being used ATTIC_ENDPOINT: ${{ vars.ATTIC_ENDPOINT }} ATTIC_PUBLIC_KEY: ${{ vars.ATTIC_PUBLIC_KEY }} - # Get error output from nix that we can actually use - NIX_CONFIG: show-trace = true + # Get error output from nix that we can actually use, and use our binary caches for the earlier CI steps + NIX_CONFIG: | + show-trace = true + extra-substituters = https://attic.kennel.juneis.dog/conduit https://attic.kennel.juneis.dog/conduwuit https://cache.lix.systems https://conduwuit.cachix.org + extra-trusted-public-keys = conduit:eEKoUwlQGDdYmAI/Q/0slVlegqh/QmAvQd7HBSm21Wk= conduwuit:BbycGUgTISsltcmH0qNjFR9dbrQNYgdIAcmViSGoVTE= cache.lix.systems:aBnZUw8zA7H35Cz2RyKFVs3H4PlGTLawyY5KRbvJR8o= conduwuit.cachix.org-1:MFRm6jcnfTf0jSAbmvLfhO3KBMt4px+1xaereWXp8Xg= # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. @@ -98,7 +101,7 @@ jobs: - name: Prepare build environment run: | echo 'source $HOME/.nix-profile/share/nix-direnv/direnvrc' > "$HOME/.direnvrc" - nix profile install --impure --inputs-from . nixpkgs#direnv nixpkgs#nix-direnv + nix profile install --inputs-from . nixpkgs#direnv nixpkgs#nix-direnv direnv allow nix develop --command true diff --git a/.github/workflows/trivy.yml b/.github/workflows/trivy.yml index e9cc9a30..e0871324 100644 --- a/.github/workflows/trivy.yml +++ b/.github/workflows/trivy.yml @@ -26,7 +26,7 @@ jobs: uses: actions/checkout@v4 - name: Run Trivy code and vulnerability scanner on repo - uses: aquasecurity/trivy-action@0.23.0 + uses: aquasecurity/trivy-action@0.24.0 with: scan-type: repo format: sarif @@ -34,7 +34,7 @@ jobs: severity: CRITICAL,HIGH,MEDIUM,LOW - name: Run Trivy code and vulnerability scanner on filesystem - uses: aquasecurity/trivy-action@0.23.0 + uses: aquasecurity/trivy-action@0.24.0 with: scan-type: fs format: sarif diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 33d0c6ed..78449e36 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -58,7 +58,7 @@ before_script: ci: stage: ci - image: nixos/nix:2.23.1 + image: nixos/nix:2.23.3 script: # Cache CI dependencies - ./bin/nix-build-and-cache ci @@ -83,7 +83,7 @@ ci: artifacts: stage: artifacts - image: nixos/nix:2.23.1 + image: nixos/nix:2.23.3 script: - ./bin/nix-build-and-cache just .#static-x86_64-unknown-linux-musl - cp result/bin/conduit x86_64-unknown-linux-musl diff --git a/Cargo.lock b/Cargo.lock index f89c9f35..61cf48bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -118,18 +118,18 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -166,7 +166,7 @@ dependencies = [ "futures-util", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "itoa", "matchit", "memchr", @@ -192,9 +192,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util", "itoa", "matchit", @@ -253,7 +253,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -276,7 +276,7 @@ dependencies = [ "futures-util", "headers", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "mime", "pin-project-lite", @@ -296,9 +296,9 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util", "pin-project-lite", "rustls 0.21.12", @@ -377,7 +377,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -457,9 +457,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" [[package]] name = "bzip2-sys" @@ -474,9 +474,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.104" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74b6a57f98764a267ff415d50a25e6e166f3831a5071af4995296ea97d210490" +checksum = "18e2d530f35b40a84124146478cd16f34225306a8441998836466a2e2961c950" dependencies = [ "jobserver", "libc", @@ -504,6 +504,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "checked_ops" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b491d76efc1d99d74de3c8529bee64c62312c275c7eb124f9185291de45801d5" +dependencies = [ + "num-traits", +] + [[package]] name = "chrono" version = "0.4.38" @@ -526,9 +535,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.8" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84b3edb18336f4df585bc9aa31dd99c036dfa5dc5e9a2939a722a188f3a8970d" +checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" dependencies = [ "clap_builder", "clap_derive", @@ -536,9 +545,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.8" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1c09dd5ada6c6c78075d6fd0da3f90d8080651e2d6cc8eb2f1aaa4034ced708" +checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" dependencies = [ "anstyle", "clap_lex", @@ -553,7 +562,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -570,7 +579,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" [[package]] name = "conduit" -version = "0.4.4" +version = "0.4.5" dependencies = [ "clap", "conduit_admin", @@ -580,6 +589,7 @@ dependencies = [ "conduit_router", "conduit_service", "console-subscriber", + "const-str", "hardened_malloc-rs", "log", "opentelemetry", @@ -598,19 +608,16 @@ dependencies = [ [[package]] name = "conduit_admin" -version = "0.4.4" +version = "0.4.5" dependencies = [ "clap", "conduit_api", "conduit_core", - "conduit_database", "conduit_service", + "const-str", "futures-util", "log", - "loole", - "regex", "ruma", - "serde", "serde_json", "serde_yaml", "tokio", @@ -620,7 +627,7 @@ dependencies = [ [[package]] name = "conduit_api" -version = "0.4.4" +version = "0.4.5" dependencies = [ "axum 0.7.5", "axum-client-ip", @@ -630,10 +637,12 @@ dependencies = [ "conduit_core", "conduit_database", "conduit_service", + "const-str", "futures-util", "hmac", "http 1.1.0", - "hyper 1.4.0", + "http-body-util", + "hyper 1.4.1", "image", "ipaddress", "jsonwebtoken", @@ -645,7 +654,6 @@ dependencies = [ "serde_html_form", "serde_json", "sha-1", - "thiserror", "tokio", "tracing", "webpage", @@ -653,12 +661,14 @@ dependencies = [ [[package]] name = "conduit_core" -version = "0.4.4" +version = "0.4.5" dependencies = [ "argon2", "axum 0.7.5", "bytes", + "checked_ops", "chrono", + "const-str", "either", "figment", "hardened_malloc-rs", @@ -684,6 +694,7 @@ dependencies = [ "tikv-jemalloc-sys", "tikv-jemallocator", "tokio", + "tokio-metrics", "tracing", "tracing-core", "tracing-subscriber", @@ -692,11 +703,11 @@ dependencies = [ [[package]] name = "conduit_database" -version = "0.4.4" +version = "0.4.5" dependencies = [ "conduit_core", + "const-str", "log", - "ruma", "rust-rocksdb-uwu", "tokio", "tracing", @@ -704,25 +715,23 @@ dependencies = [ [[package]] name = "conduit_router" -version = "0.4.4" +version = "0.4.5" dependencies = [ "axum 0.7.5", "axum-client-ip", "axum-server", "axum-server-dual-protocol", "bytes", - "clap", "conduit_admin", "conduit_api", "conduit_core", - "conduit_database", "conduit_service", + "const-str", "http 1.1.0", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util", "log", - "regex", "ruma", "sd-notify", "sentry", @@ -737,13 +746,14 @@ dependencies = [ [[package]] name = "conduit_service" -version = "0.4.4" +version = "0.4.5" dependencies = [ "async-trait", "base64 0.22.1", "bytes", "conduit_core", "conduit_database", + "const-str", "cyborgtime", "futures-util", "hickory-resolver", @@ -759,7 +769,7 @@ dependencies = [ "regex", "reqwest", "ruma", - "ruma-identifiers-validation 0.9.5 (git+https://github.com/girlbossceo/ruwuma?rev=fd686e77950680462377c9105dfb4136dd49c7a0)", + "ruma-identifiers-validation", "rustyline-async", "serde", "serde_json", @@ -815,6 +825,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "const-str" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3618cccc083bb987a415d85c02ca6c9994ea5b44731ec28b9ecf09658655fba9" + [[package]] name = "const_panic" version = "0.2.8" @@ -1006,7 +1022,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -1110,7 +1126,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -1246,7 +1262,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -1527,7 +1543,7 @@ dependencies = [ "markup5ever", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -1574,9 +1590,9 @@ dependencies = [ [[package]] name = "http-body" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", "http 1.1.0", @@ -1591,7 +1607,7 @@ dependencies = [ "bytes", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "pin-project-lite", ] @@ -1615,9 +1631,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.29" +version = "0.14.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f361cde2f109281a220d4307746cdfd5ee3f410da58a70377762396775634b33" +checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" dependencies = [ "bytes", "futures-channel", @@ -1639,16 +1655,16 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4fe55fb7a772d59a5ff1dfbff4fe0258d19b89fec4b233e75d35d5d2316badc" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", "futures-util", "h2 0.4.5", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "httparse", "httpdate", "itoa", @@ -1666,9 +1682,9 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-util", - "rustls 0.23.10", + "rustls 0.23.11", "rustls-native-certs", "rustls-pki-types", "tokio", @@ -1683,7 +1699,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" dependencies = [ - "hyper 0.14.29", + "hyper 0.14.30", "pin-project-lite", "tokio", "tokio-io-timeout", @@ -1699,8 +1715,8 @@ dependencies = [ "futures-channel", "futures-util", "http 1.1.0", - "http-body 1.0.0", - "hyper 1.4.0", + "http-body 1.0.1", + "hyper 1.4.1", "pin-project-lite", "socket2", "tokio", @@ -1936,7 +1952,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -1964,7 +1980,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" dependencies = [ "cfg-if", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2405,7 +2421,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2445,7 +2461,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -2538,7 +2554,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -2626,7 +2642,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", "version_check", "yansi", ] @@ -2651,7 +2667,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -2698,7 +2714,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.10", + "rustls 0.23.11", "thiserror", "tokio", "tracing", @@ -2714,7 +2730,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.10", + "rustls 0.23.11", "slab", "thiserror", "tinyvec", @@ -2841,9 +2857,9 @@ dependencies = [ "h2 0.4.5", "hickory-resolver", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", - "hyper 1.4.0", + "hyper 1.4.1", "hyper-rustls", "hyper-util", "ipnet", @@ -2854,7 +2870,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.10", + "rustls 0.23.11", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types", @@ -2972,7 +2988,7 @@ dependencies = [ "percent-encoding", "rand", "regex", - "ruma-identifiers-validation 0.9.5 (git+https://github.com/girlbossceo/ruwuma?rev=c51ccb2c68d2e3557eb12b1a49036531711ec0e5)", + "ruma-identifiers-validation", "ruma-macros", "serde", "serde_html_form", @@ -2999,7 +3015,7 @@ dependencies = [ "pulldown-cmark", "regex", "ruma-common", - "ruma-identifiers-validation 0.9.5 (git+https://github.com/girlbossceo/ruwuma?rev=c51ccb2c68d2e3557eb12b1a49036531711ec0e5)", + "ruma-identifiers-validation", "ruma-macros", "serde", "serde_json", @@ -3030,15 +3046,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "ruma-identifiers-validation" -version = "0.9.5" -source = "git+https://github.com/girlbossceo/ruwuma?rev=fd686e77950680462377c9105dfb4136dd49c7a0#fd686e77950680462377c9105dfb4136dd49c7a0" -dependencies = [ - "js_int", - "thiserror", -] - [[package]] name = "ruma-identity-service-api" version = "0.9.0" @@ -3058,9 +3065,9 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "ruma-identifiers-validation 0.9.5 (git+https://github.com/girlbossceo/ruwuma?rev=c51ccb2c68d2e3557eb12b1a49036531711ec0e5)", + "ruma-identifiers-validation", "serde", - "syn 2.0.68", + "syn 2.0.71", "toml", ] @@ -3122,8 +3129,8 @@ dependencies = [ [[package]] name = "rust-librocksdb-sys" -version = "0.23.0+9.3.1" -source = "git+https://github.com/zaidoon1/rust-rocksdb?rev=b4887edfb84771336930855727390edec07d63fa#b4887edfb84771336930855727390edec07d63fa" +version = "0.24.0+9.4.0" +source = "git+https://github.com/zaidoon1/rust-rocksdb?rev=4056a3b0f823013fec49f6d0b3e5698856e6476a#4056a3b0f823013fec49f6d0b3e5698856e6476a" dependencies = [ "bindgen", "bzip2-sys", @@ -3139,8 +3146,8 @@ dependencies = [ [[package]] name = "rust-rocksdb" -version = "0.27.0" -source = "git+https://github.com/zaidoon1/rust-rocksdb?rev=b4887edfb84771336930855727390edec07d63fa#b4887edfb84771336930855727390edec07d63fa" +version = "0.28.0" +source = "git+https://github.com/zaidoon1/rust-rocksdb?rev=4056a3b0f823013fec49f6d0b3e5698856e6476a#4056a3b0f823013fec49f6d0b3e5698856e6476a" dependencies = [ "libc", "rust-librocksdb-sys", @@ -3196,30 +3203,31 @@ dependencies = [ "log", "ring", "rustls-pki-types", - "rustls-webpki 0.102.4", + "rustls-webpki 0.102.5", "subtle", "zeroize", ] [[package]] name = "rustls" -version = "0.23.10" +version = "0.23.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" +checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.102.4", + "rustls-webpki 0.102.5", "subtle", "zeroize", ] [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile", @@ -3256,9 +3264,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" dependencies = [ "ring", "rustls-pki-types", @@ -3274,8 +3282,7 @@ checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" [[package]] name = "rustyline-async" version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6eb06391513b2184f0a5405c11a4a0a5302e8be442f4c5c35267187c2b37d5" +source = "git+https://github.com/girlbossceo/rustyline-async?rev=de26100b0db03e419a3d8e1dd26895d170d1fe50#de26100b0db03e419a3d8e1dd26895d170d1fe50" dependencies = [ "crossterm", "futures-channel", @@ -3500,22 +3507,22 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -3808,9 +3815,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.68" +version = "2.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "b146dcf730474b4bcd16c311627b31ede9ab149045db4d6088b3becaea046462" dependencies = [ "proc-macro2", "quote", @@ -3868,22 +3875,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -3982,9 +3989,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -4032,7 +4039,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] @@ -4043,6 +4050,7 @@ checksum = "eace09241d62c98b7eeb1107d4c5c64ca3bd7da92e8c218c153ab3a78f9be112" dependencies = [ "futures-util", "pin-project-lite", + "tokio", "tokio-stream", ] @@ -4062,7 +4070,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.10", + "rustls 0.23.11", "rustls-pki-types", "tokio", ] @@ -4112,7 +4120,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.14", + "toml_edit 0.22.15", ] [[package]] @@ -4137,9 +4145,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.14" +version = "0.22.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" +checksum = "d59a3a72298453f564e2b111fa896f8d07fabb36f51f06d7e875fc5e0b5a3ef1" dependencies = [ "indexmap 2.2.6", "serde", @@ -4162,7 +4170,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.29", + "hyper 0.14.30", "hyper-timeout", "percent-encoding", "pin-project", @@ -4207,7 +4215,7 @@ dependencies = [ "futures-core", "futures-util", "http 1.1.0", - "http-body 1.0.0", + "http-body 1.0.1", "http-body-util", "pin-project-lite", "tokio", @@ -4233,7 +4241,7 @@ checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" [[package]] name = "tracing" version = "0.1.40" -source = "git+https://github.com/girlbossceo/tracing?rev=b348dca742af641c47bc390261f60711c2af573c#b348dca742af641c47bc390261f60711c2af573c" +source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b475aefa08bb14e4de91#4d78a14a5e03f539b8c6b475aefa08bb14e4de91" dependencies = [ "log", "pin-project-lite", @@ -4244,17 +4252,17 @@ dependencies = [ [[package]] name = "tracing-attributes" version = "0.1.27" -source = "git+https://github.com/girlbossceo/tracing?rev=b348dca742af641c47bc390261f60711c2af573c#b348dca742af641c47bc390261f60711c2af573c" +source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b475aefa08bb14e4de91#4d78a14a5e03f539b8c6b475aefa08bb14e4de91" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", ] [[package]] name = "tracing-core" version = "0.1.32" -source = "git+https://github.com/girlbossceo/tracing?rev=b348dca742af641c47bc390261f60711c2af573c#b348dca742af641c47bc390261f60711c2af573c" +source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b475aefa08bb14e4de91#4d78a14a5e03f539b8c6b475aefa08bb14e4de91" dependencies = [ "once_cell", "valuable", @@ -4274,7 +4282,7 @@ dependencies = [ [[package]] name = "tracing-log" version = "0.2.0" -source = "git+https://github.com/girlbossceo/tracing?rev=b348dca742af641c47bc390261f60711c2af573c#b348dca742af641c47bc390261f60711c2af573c" +source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b475aefa08bb14e4de91#4d78a14a5e03f539b8c6b475aefa08bb14e4de91" dependencies = [ "log", "once_cell", @@ -4302,7 +4310,7 @@ dependencies = [ [[package]] name = "tracing-subscriber" version = "0.3.18" -source = "git+https://github.com/girlbossceo/tracing?rev=b348dca742af641c47bc390261f60711c2af573c#b348dca742af641c47bc390261f60711c2af573c" +source = "git+https://github.com/girlbossceo/tracing?rev=4d78a14a5e03f539b8c6b475aefa08bb14e4de91#4d78a14a5e03f539b8c6b475aefa08bb14e4de91" dependencies = [ "matchers", "nu-ansi-term", @@ -4417,16 +4425,15 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.9.7" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.22.4", + "rustls 0.23.11", "rustls-pki-types", - "rustls-webpki 0.102.4", "url", "webpki-roots", ] @@ -4457,9 +4464,9 @@ checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" [[package]] name = "uuid" -version = "1.9.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5de17fd2f7da591098415cff336e12965a28061ddace43b59cb3c430179c9439" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" dependencies = [ "getrandom", "serde", @@ -4519,7 +4526,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", "wasm-bindgen-shared", ] @@ -4553,7 +4560,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.71", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4662,7 +4669,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ "windows-core", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4671,7 +4678,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4689,7 +4696,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4709,18 +4716,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4731,9 +4738,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4743,9 +4750,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4755,15 +4762,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4773,9 +4780,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4785,9 +4792,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4797,9 +4804,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4809,9 +4816,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" @@ -4876,27 +4883,27 @@ checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" [[package]] name = "zstd" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ "zstd-safe", ] [[package]] name = "zstd-safe" -version = "7.1.0" +version = "7.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" +checksum = "fa556e971e7b568dc775c136fc9de8c779b1c2fc3a63defaafadffdbd3181afa" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.11+zstd.1.5.6" +version = "2.0.12+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75652c55c0b6f3e6f12eb786fe1bc960396bf05a1eb3bf1f3691c3610ac2e6d4" +checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index a52b4138..7c17d0f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,11 +20,14 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/girlbossceo/conduwuit" rust-version = "1.77.0" -version = "0.4.4" +version = "0.4.5" [workspace.metadata.crane] name = "conduit" +[workspace.dependencies.const-str] +version = "0.5.7" + [workspace.dependencies.sanitize-filename] version = "0.5.0" @@ -50,7 +53,7 @@ version = "0.8.5" # Used for the http request / response body type for Ruma endpoints used with reqwest [workspace.dependencies.bytes] -version = "1.6.0" +version = "1.6.1" [workspace.dependencies.http-body-util] version = "0.1.1" @@ -113,7 +116,7 @@ features = [ ] [workspace.dependencies.serde] -version = "1.0.203" +version = "1.0.204" features = ["rc"] [workspace.dependencies.serde_json] @@ -169,7 +172,7 @@ default-features = false # used for conduit's CLI and admin room command parsing [workspace.dependencies.clap] -version = "4.5.4" +version = "4.5.9" default-features = false features = [ "std", @@ -197,6 +200,9 @@ features = [ "io-util", ] +[workspace.dependencies.tokio-metrics] +version = "0.3.1" + [workspace.dependencies.libloading] version = "0.8.3" @@ -208,7 +214,7 @@ features = ["serde"] # standard date and time tools [workspace.dependencies.chrono] version = "0.4.38" -features = ["alloc"] +features = ["alloc", "std"] default-features = false [workspace.dependencies.hyper] @@ -245,7 +251,7 @@ default-features = false # Used for conduit::Error type [workspace.dependencies.thiserror] -version = "1.0.61" +version = "1.0.62" # Used when hashing the state [workspace.dependencies.ring] @@ -265,7 +271,7 @@ version = "2.1.1" version = "0.3.1" [workspace.dependencies.async-trait] -version = "0.1.80" +version = "0.1.81" [workspace.dependencies.lru-cache] version = "0.1.2" @@ -303,7 +309,7 @@ features = [ [workspace.dependencies.ruma-identifiers-validation] git = "https://github.com/girlbossceo/ruwuma" -rev = "fd686e77950680462377c9105dfb4136dd49c7a0" +rev = "c51ccb2c68d2e3557eb12b1a49036531711ec0e5" [workspace.dependencies.rust-rocksdb] path = "deps/rust-rocksdb" @@ -311,7 +317,6 @@ package = "rust-rocksdb-uwu" features = [ "multi-threaded-cf", "mt_static", - "snappy", "lz4", "zstd", "zlib", @@ -380,10 +385,6 @@ version = "0.5.4" default-features = false features = ["use_std"] -[workspace.dependencies.tokio-metrics] -version = "0.3.1" -default-features = false - [workspace.dependencies.console-subscriber] version = "0.3" @@ -405,11 +406,15 @@ features = [ [workspace.dependencies.rustyline-async] version = "0.4.2" +default-features = false [workspace.dependencies.termimad] version = "0.29.4" default-features = false +[workspace.dependencies.checked_ops] +version = "0.1" + # # Patches @@ -420,16 +425,16 @@ default-features = false # https://github.com/girlbossceo/tracing/commit/b348dca742af641c47bc390261f60711c2af573c [patch.crates-io.tracing-subscriber] git = "https://github.com/girlbossceo/tracing" -rev = "b348dca742af641c47bc390261f60711c2af573c" +rev = "4d78a14a5e03f539b8c6b475aefa08bb14e4de91" [patch.crates-io.tracing] git = "https://github.com/girlbossceo/tracing" -rev = "b348dca742af641c47bc390261f60711c2af573c" +rev = "4d78a14a5e03f539b8c6b475aefa08bb14e4de91" [patch.crates-io.tracing-core] git = "https://github.com/girlbossceo/tracing" -rev = "b348dca742af641c47bc390261f60711c2af573c" +rev = "4d78a14a5e03f539b8c6b475aefa08bb14e4de91" [patch.crates-io.tracing-log] git = "https://github.com/girlbossceo/tracing" -rev = "b348dca742af641c47bc390261f60711c2af573c" +rev = "4d78a14a5e03f539b8c6b475aefa08bb14e4de91" # fixes hyper graceful shutdowns [https://github.com/programatik29/axum-server/issues/114] # https://github.com/girlbossceo/axum-server/commit/8e3368d899079818934e61cc9c839abcbbcada8a @@ -437,6 +442,12 @@ rev = "b348dca742af641c47bc390261f60711c2af573c" git = "https://github.com/girlbossceo/axum-server" rev = "8e3368d899079818934e61cc9c839abcbbcada8a" +# adds a tab completion callback: https://github.com/girlbossceo/rustyline-async/commit/de26100b0db03e419a3d8e1dd26895d170d1fe50 +# adds event for CTRL+\: https://github.com/girlbossceo/rustyline-async/commit/67d8c49aeac03a5ef4e818f663eaa94dd7bf339b +[patch.crates-io.rustyline-async] +git = "https://github.com/girlbossceo/rustyline-async" +rev = "de26100b0db03e419a3d8e1dd26895d170d1fe50" + # # Our crates # @@ -726,7 +737,6 @@ nursery = "warn" ## some sadness missing_const_for_fn = { level = "allow", priority = 1 } # TODO -needless_collect = { level = "allow", priority = 1 } # TODO option_if_let_else = { level = "allow", priority = 1 } # TODO redundant_pub_crate = { level = "allow", priority = 1 } # TODO significant_drop_in_scrutinee = { level = "allow", priority = 1 } # TODO @@ -736,21 +746,14 @@ significant_drop_tightening = { level = "allow", priority = 1 } # TODO pedantic = "warn" ## some sadness -cast_possible_truncation = { level = "allow", priority = 1 } -cast_precision_loss = { level = "allow", priority = 1 } -cast_sign_loss = { level = "allow", priority = 1 } doc_markdown = { level = "allow", priority = 1 } -error_impl_error = { level = "allow", priority = 1 } -expect_used = { level = "allow", priority = 1 } +enum_glob_use = { level = "allow", priority = 1 } if_not_else = { level = "allow", priority = 1 } if_then_some_else_none = { level = "allow", priority = 1 } -implicit_return = { level = "allow", priority = 1 } inline_always = { level = "allow", priority = 1 } -map_err_ignore = { level = "allow", priority = 1 } missing_docs_in_private_items = { level = "allow", priority = 1 } missing_errors_doc = { level = "allow", priority = 1 } missing_panics_doc = { level = "allow", priority = 1 } -mod_module_files = { level = "allow", priority = 1 } module_name_repetitions = { level = "allow", priority = 1 } no_effect_underscore_binding = { level = "allow", priority = 1 } similar_names = { level = "allow", priority = 1 } @@ -764,8 +767,10 @@ perf = "warn" ################### #restriction = "warn" -#arithmetic_side_effects = "warn" # TODO -#as_conversions = "warn" # TODO +allow_attributes = "warn" +arithmetic_side_effects = "warn" +as_conversions = "warn" +as_underscore = "warn" assertions_on_result_states = "warn" dbg_macro = "warn" default_union_representation = "warn" @@ -779,7 +784,6 @@ fn_to_numeric_cast_any = "warn" format_push_string = "warn" get_unwrap = "warn" impl_trait_in_params = "warn" -let_underscore_must_use = "warn" let_underscore_untyped = "warn" lossy_float_literal = "warn" mem_forget = "warn" @@ -793,6 +797,7 @@ rest_pat_in_fully_bound_structs = "warn" semicolon_outside_block = "warn" str_to_string = "warn" string_lit_chars_any = "warn" +string_slice = "warn" string_to_string = "warn" suspicious_xor_used_as_pow = "warn" tests_outside_test_module = "warn" @@ -803,6 +808,7 @@ unnecessary_safety_doc = "warn" unnecessary_self_imports = "warn" unneeded_field_pattern = "warn" unseparated_literal_suffix = "warn" +#unwrap_used = "warn" # TODO verbose_file_reads = "warn" ################### diff --git a/bin/complement b/bin/complement index c679035a..a715b14b 100755 --- a/bin/complement +++ b/bin/complement @@ -7,7 +7,7 @@ set -euo pipefail # The `COMPLEMENT_SRC` environment variable is set in the Nix dev shell, which # points to a store path containing the Complement source code. It's likely you # want to just pass that as the first argument to use it here. -COMPLEMENT_SRC="$1" +COMPLEMENT_SRC="${COMPLEMENT_SRC:-$1}" # A `.jsonl` file to write test logs to LOG_FILE="$2" @@ -17,12 +17,19 @@ RESULTS_FILE="$3" OCI_IMAGE="complement-conduit:main" -# Complement tests that are skipped due to flakiness/reliability issues (likely -# Complement itself induced based on various open issues) -# -# According to Go docs, these are separated by forward slashes and not pipes (why) +# Complement tests that are skipped due to flakiness/reliability issues SKIPPED_COMPLEMENT_TESTS='-skip=TestClientSpacesSummary.*|TestJoinFederatedRoomFromApplicationServiceBridgeUser.*|TestJumpToDateEndpoint.*' +# $COMPLEMENT_SRC needs to be a directory to Complement source code +if [ -f "$COMPLEMENT_SRC" ]; then + echo "\$COMPLEMENT_SRC must be a directory/path to Complement source code" + exit 1 +fi + +# quick test to make sure we can actually write to $LOG_FILE and $RESULTS_FILE +touch $LOG_FILE && rm -v $LOG_FILE +touch $RESULTS_FILE && rm -v $RESULTS_FILE + toplevel="$(git rev-parse --show-toplevel)" pushd "$toplevel" > /dev/null diff --git a/conduwuit-example.toml b/conduwuit-example.toml index 7e5c9710..57093128 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -57,6 +57,16 @@ # Defaults to 0.15 #sentry_traces_sample_rate = 0.15 +# Whether to attach a stacktrace to Sentry reports. +#sentry_attach_stacktrace = false + +# Send panics to sentry. This is true by default, but sentry has to be enabled. +#sentry_send_panic = true + +# Send errors to sentry. This is true by default, but sentry has to be enabled. This option is +# only effective in release-mode; forced to false in debug-mode. +#sentry_send_error = true + ### Database configuration @@ -411,8 +421,11 @@ allow_profile_lookup_federation_requests = true # Set this to any float value to multiply conduwuit's in-memory LRU caches with. # May be useful if you have significant memory to spare to increase performance. +# +# This was previously called `conduit_cache_capacity_modifier` +# # Defaults to 1.0. -#conduit_cache_capacity_modifier = 1.0 +#cache_capacity_modifier = 1.0 # Set this to any float value in megabytes for conduwuit to tell the database engine that this much memory is available for database-related caches. # May be useful if you have significant memory to spare to increase performance. diff --git a/deps/rust-rocksdb/Cargo.toml b/deps/rust-rocksdb/Cargo.toml index 03efc61c..cf49e1ad 100644 --- a/deps/rust-rocksdb/Cargo.toml +++ b/deps/rust-rocksdb/Cargo.toml @@ -10,7 +10,7 @@ repository.workspace = true version = "0.0.1" [features] -default = ["snappy", "lz4", "zstd", "zlib", "bzip2"] +default = ["lz4", "zstd", "zlib", "bzip2"] jemalloc = ["rust-rocksdb/jemalloc"] io-uring = ["rust-rocksdb/io-uring"] valgrind = ["rust-rocksdb/valgrind"] @@ -27,7 +27,7 @@ malloc-usable-size = ["rust-rocksdb/malloc-usable-size"] [dependencies.rust-rocksdb] git = "https://github.com/zaidoon1/rust-rocksdb" -rev = "b4887edfb84771336930855727390edec07d63fa" +rev = "4056a3b0f823013fec49f6d0b3e5698856e6476a" #branch = "master" default-features = false diff --git a/docs/deploying/docker-compose.for-traefik.yml b/docs/deploying/docker-compose.for-traefik.yml index ec3e720c..d10e5815 100644 --- a/docs/deploying/docker-compose.for-traefik.yml +++ b/docs/deploying/docker-compose.for-traefik.yml @@ -1,5 +1,4 @@ # conduwuit - Behind Traefik Reverse Proxy -version: '2.4' # uses '2.4' for cpuset services: homeserver: @@ -24,7 +23,7 @@ services: CONDUWUIT_TRUSTED_SERVERS: '["matrix.org"]' #CONDUWUIT_LOG: warn,state_res=warn CONDUWUIT_ADDRESS: 0.0.0.0 - #CONDUWUIT_CONFIG: './conduwuit.toml' # Uncomment if you mapped config toml above + #CONDUWUIT_CONFIG: '/etc/conduwuit.toml' # Uncomment if you mapped config toml above #cpuset: "0-4" # Uncomment to limit to specific CPU cores # We need some way to server the client and server .well-known json. The simplest way is to use a nginx container diff --git a/docs/deploying/docker-compose.override.yml b/docs/deploying/docker-compose.override.yml index 2e937e75..23d6a90b 100644 --- a/docs/deploying/docker-compose.override.yml +++ b/docs/deploying/docker-compose.override.yml @@ -1,5 +1,4 @@ # conduwuit - Traefik Reverse Proxy Labels -version: '2.4' # uses '2.4' for cpuset services: homeserver: diff --git a/docs/deploying/docker-compose.with-caddy.yml b/docs/deploying/docker-compose.with-caddy.yml new file mode 100644 index 00000000..899f4d67 --- /dev/null +++ b/docs/deploying/docker-compose.with-caddy.yml @@ -0,0 +1,54 @@ +services: + caddy: + # This compose file uses caddy-docker-proxy as the reverse proxy for conduwuit! + # For more info, visit https://github.com/lucaslorentz/caddy-docker-proxy + image: lucaslorentz/caddy-docker-proxy:ci-alpine + ports: + - 80:80 + - 443:443 + environment: + - CADDY_INGRESS_NETWORKS=caddy + networks: + - caddy + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - ./data:/data + restart: unless-stopped + labels: + caddy: example.com + caddy.0_respond: /.well-known/matrix/server {"m.server":"matrix.example.com:443"} + caddy.1_respond: /.well-known/matrix/client {"m.server":{"base_url":"https://matrix.example.com"},"m.homeserver":{"base_url":"https://matrix.example.com"},"org.matrix.msc3575.proxy":{"url":"https://matrix.example.com"}} + + homeserver: + ### If you already built the conduwuit image with 'docker build' or want to use a registry image, + ### then you are ready to go. + image: girlbossceo/conduwuit:latest + restart: unless-stopped + volumes: + - db:/var/lib/conduwuit + #- ./conduwuit.toml:/etc/conduwuit.toml + environment: + CONDUWUIT_SERVER_NAME: example.com # EDIT THIS + CONDUWUIT_DATABASE_PATH: /var/lib/conduwuit + CONDUWUIT_DATABASE_BACKEND: rocksdb + CONDUWUIT_PORT: 6167 + CONDUWUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + CONDUWUIT_ALLOW_REGISTRATION: 'true' + CONDUWUIT_ALLOW_FEDERATION: 'true' + CONDUWUIT_ALLOW_CHECK_FOR_UPDATES: 'true' + CONDUWUIT_TRUSTED_SERVERS: '["matrix.org"]' + #CONDUWUIT_LOG: warn,state_res=warn + CONDUWUIT_ADDRESS: 0.0.0.0 + #CONDUWUIT_CONFIG: '/etc/conduwuit.toml' # Uncomment if you mapped config toml above + networks: + - caddy + labels: + caddy: matrix.example.com + caddy.reverse_proxy: "{{upstreams 6167}}" + +volumes: + db: + +networks: + caddy: + external: true diff --git a/docs/deploying/docker-compose.with-traefik.yml b/docs/deploying/docker-compose.with-traefik.yml index c93f5414..79d20051 100644 --- a/docs/deploying/docker-compose.with-traefik.yml +++ b/docs/deploying/docker-compose.with-traefik.yml @@ -1,5 +1,4 @@ # conduwuit - Behind Traefik Reverse Proxy -version: '2.4' # uses '2.4' for cpuset services: homeserver: @@ -16,7 +15,7 @@ services: CONDUWUIT_SERVER_NAME: your.server.name # EDIT THIS CONDUWUIT_TRUSTED_SERVERS: '["matrix.org"]' CONDUWUIT_ALLOW_REGISTRATION : 'true' - #CONDUWUIT_CONFIG: './conduwuit.toml' # Uncomment if you mapped config toml above + #CONDUWUIT_CONFIG: '/etc/conduwuit.toml' # Uncomment if you mapped config toml above ### Uncomment and change values as desired # CONDUWUIT_ADDRESS: 0.0.0.0 # CONDUWUIT_PORT: 6167 @@ -28,7 +27,6 @@ services: # CONDUWUIT_DATABASE_PATH: /srv/conduwuit/.local/share/conduwuit # CONDUWUIT_WORKERS: 10 # CONDUWUIT_MAX_REQUEST_SIZE: 20000000 # in bytes, ~20 MB - #cpuset: "0-4" # Uncomment to limit to specific CPU cores # We need some way to server the client and server .well-known json. The simplest way is to use a nginx container # to serve those two as static files. If you want to use a different way, delete or comment the below service, here diff --git a/docs/deploying/docker-compose.yml b/docs/deploying/docker-compose.yml index 066c8fe1..bc9f2477 100644 --- a/docs/deploying/docker-compose.yml +++ b/docs/deploying/docker-compose.yml @@ -1,5 +1,4 @@ # conduwuit -version: '2.4' # uses '2.4' for cpuset services: homeserver: @@ -24,8 +23,7 @@ services: CONDUWUIT_TRUSTED_SERVERS: '["matrix.org"]' #CONDUWUIT_LOG: warn,state_res=warn CONDUWUIT_ADDRESS: 0.0.0.0 - #CONDUWUIT_CONFIG: './conduwuit.toml' # Uncomment if you mapped config toml above - #cpuset: "0-4" # Uncomment to limit to specific CPU cores + #CONDUWUIT_CONFIG: '/etc/conduwuit.toml' # Uncomment if you mapped config toml above # ### Uncomment if you want to use your own Element-Web App. ### Note: You need to provide a config.json for Element and you also need a second diff --git a/docs/deploying/docker.md b/docs/deploying/docker.md index 5cb89308..ffbad5b4 100644 --- a/docs/deploying/docker.md +++ b/docs/deploying/docker.md @@ -59,13 +59,22 @@ If the `docker run` command is not for you or your setup, you can also use one o Depending on your proxy setup, you can use one of the following files; - If you already have a `traefik` instance set up, use [`docker-compose.for-traefik.yml`](docker-compose.for-traefik.yml) -- If you don't have a `traefik` instance set up (or any other reverse proxy), use [`docker-compose.with-traefik.yml`](docker-compose.with-traefik.yml) +- If you don't have a `traefik` instance set up and would like to use it, use [`docker-compose.with-traefik.yml`](docker-compose.with-traefik.yml) +- If you want a setup that works out of the box with `caddy-docker-proxy`, use [`docker-compose.with-caddy.yml`](docker-compose.with-caddy.yml) and replace all `example.com` placeholders with your own domain - For any other reverse proxy, use [`docker-compose.yml`](docker-compose.yml) When picking the traefik-related compose file, rename it so it matches `docker-compose.yml`, and rename the override file to `docker-compose.override.yml`. Edit the latter with the values you want for your server. +When picking the `caddy-docker-proxy` compose file, it's important to first create the `caddy` network before spinning up the containers: + +```bash +docker network create caddy +``` + +After that, you can rename it so it matches `docker-compose.yml` and spin up the containers! + Additional info about deploying conduwuit can be found [here](generic.md). ### Build diff --git a/docs/deploying/generic.md b/docs/deploying/generic.md index 75b285d1..0cbd8ef8 100644 --- a/docs/deploying/generic.md +++ b/docs/deploying/generic.md @@ -23,7 +23,7 @@ Otherwise, follow standard Rust project build guides (installing git and cloning While conduwuit can run as any user it is better to use dedicated users for different services. This also allows you to make sure that the file permissions are correctly set up. -In Debian or RHEL, you can use this command to create a conduwuit user: +In Debian or Fedora/RHEL, you can use this command to create a conduwuit user: ```bash sudo adduser --system conduwuit --group --disabled-login --no-create-home @@ -53,13 +53,11 @@ RocksDB is the only supported database backend. ## Setting the correct file permissions -If you are using a dedicated user for conduwuit, you will need to allow it to read the config. To do that you can run this command on - -Debian or RHEL: +If you are using a dedicated user for conduwuit, you will need to allow it to read the config. To do that you can run this: ```bash sudo chown -R root:root /etc/conduwuit -sudo chmod 755 /etc/conduwuit +sudo chmod -R 755 /etc/conduwuit ``` If you use the default database path you also need to run this: diff --git a/engage.toml b/engage.toml index 365c0179..8bba0eaf 100644 --- a/engage.toml +++ b/engage.toml @@ -184,5 +184,10 @@ cargo test \ name = "nix-default" group = "tests" script = """ -nix run .#default -- --help +env DIRENV_DEVSHELL=dynamic \ + direnv exec . \ + bin/nix-build-and-cache just .#default +env DIRENV_DEVSHELL=dynamic \ + direnv exec . \ + nix run -L .#default -- --help """ diff --git a/flake.lock b/flake.lock index 1a064510..9d7a162e 100644 --- a/flake.lock +++ b/flake.lock @@ -9,11 +9,11 @@ "nixpkgs-stable": "nixpkgs-stable" }, "locked": { - "lastModified": 1717279440, - "narHash": "sha256-kH04ReTjxOpQumgWnqy40vvQLSnLGxWP6RF3nq5Esrk=", + "lastModified": 1720542474, + "narHash": "sha256-aKjJ/4l2I9+wNGTaOGRsuS3M1+IoTibqgEMPDikXm04=", "owner": "zhaofengli", "repo": "attic", - "rev": "717cc95983cdc357bc347d70be20ced21f935843", + "rev": "6139576a3ce6bb992e0f6c3022528ec233e45f00", "type": "github" }, "original": { @@ -81,11 +81,11 @@ "complement": { "flake": false, "locked": { - "lastModified": 1719903368, - "narHash": "sha256-PPzgxM4Bir+Zh9FUV/v+RBxEYeJxYVmi/BYo3uqt268=", + "lastModified": 1720637557, + "narHash": "sha256-oZz6nCmFmdJZpC+K1iOG2KkzTI6rlAmndxANPDVU7X0=", "owner": "matrix-org", "repo": "complement", - "rev": "bc97f1ddc1cd7485faf80c8935ee2641f3e1b57c", + "rev": "0d14432e010482ea9e13a6f7c47c1533c0c9d62f", "type": "github" }, "original": { @@ -123,11 +123,11 @@ ] }, "locked": { - "lastModified": 1716569590, - "narHash": "sha256-5eDbq8TuXFGGO3mqJFzhUbt5zHVTf5zilQoyW5jnJwo=", + "lastModified": 1720546058, + "narHash": "sha256-iU2yVaPIZm5vMGdlT0+57vdB/aPq/V5oZFBRwYw+HBM=", "owner": "ipetkov", "repo": "crane", - "rev": "109987da061a1bf452f435f1653c47511587d919", + "rev": "2d83156f23c43598cf44e152c33a59d3892f8b29", "type": "github" }, "original": { @@ -209,11 +209,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1716359173, - "narHash": "sha256-pYcjP6Gy7i6jPWrjiWAVV0BCQp+DdmGaI/k65lBb/kM=", + "lastModified": 1720852044, + "narHash": "sha256-3NBYz8VuXuKU+8ONd9NFafCNjPEGHIZQ2Mdoam1a4mY=", "owner": "nix-community", "repo": "fenix", - "rev": "b6fc5035b28e36a98370d0eac44f4ef3fd323df6", + "rev": "5087b12a595ee73131a944d922f24d81dae05725", "type": "github" }, "original": { @@ -381,11 +381,11 @@ "liburing": { "flake": false, "locked": { - "lastModified": 1719025212, - "narHash": "sha256-kD0yhjNStqC6uFqC1AxBwUpc/HlSFtiKrV+gwDyroDc=", + "lastModified": 1720798442, + "narHash": "sha256-gtPppAoksMLW4GuruQ36nf4EAqIA1Bs6V9Xcx8dBxrQ=", "owner": "axboe", "repo": "liburing", - "rev": "7b3245583069bd481190c9da18f22e9fc8c3a805", + "rev": "1d674f83b7d0f07553ac44d99a401b05853d9dbe", "type": "github" }, "original": { @@ -606,11 +606,11 @@ }, "nixpkgs_4": { "locked": { - "lastModified": 1716330097, - "narHash": "sha256-8BO3B7e3BiyIDsaKA0tY8O88rClYRTjvAp66y+VBUeU=", + "lastModified": 1720768451, + "narHash": "sha256-EYekUHJE2gxeo2pM/zM9Wlqw1Uw2XTJXOSAO79ksc4Y=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "5710852ba686cc1fd0d3b8e22b3117d43ba374c2", + "rev": "7e7c39ea35c5cdd002cd4588b03a3fb9ece6fad9", "type": "github" }, "original": { @@ -673,16 +673,16 @@ "rocksdb": { "flake": false, "locked": { - "lastModified": 1719949653, - "narHash": "sha256-DYx7XHH2GEh17GukKhXs6laM6l+eugCmRkF0adpi9wk=", + "lastModified": 1720900786, + "narHash": "sha256-Vta9Um/RRuWwZ46BjXftV06iWLm/j/9MX39emXUvSAY=", "owner": "girlbossceo", "repo": "rocksdb", - "rev": "a935c0273e1ba44eacf88ce3685a9b9831486155", + "rev": "911f4243e69c2e320a7a209bf1f5f3ff5f825495", "type": "github" }, "original": { "owner": "girlbossceo", - "ref": "v9.3.1", + "ref": "v9.4.0", "repo": "rocksdb", "type": "github" } @@ -705,11 +705,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1716107283, - "narHash": "sha256-NJgrwLiLGHDrCia5AeIvZUHUY7xYGVryee0/9D3Ir1I=", + "lastModified": 1720717809, + "narHash": "sha256-6I+fm+nTLF/iaj7ffiFGlSY7POmubwUaPA/Wq0Bm53M=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "21ec8f523812b88418b2bfc64240c62b3dd967bd", + "rev": "ffbc5ad993d5cd2f3b8bcf9a511165470944ab91", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 499d828d..a56d1bc0 100644 --- a/flake.nix +++ b/flake.nix @@ -9,8 +9,7 @@ flake-utils.url = "github:numtide/flake-utils?ref=main"; nix-filter.url = "github:numtide/nix-filter?ref=main"; nixpkgs.url = "github:NixOS/nixpkgs?ref=nixos-unstable"; - # https://github.com/girlbossceo/rocksdb/commit/db6df0b185774778457dabfcbd822cb81760cade - rocksdb = { url = "github:girlbossceo/rocksdb?ref=v9.3.1"; flake = false; }; + rocksdb = { url = "github:girlbossceo/rocksdb?ref=v9.4.0"; flake = false; }; liburing = { url = "github:axboe/liburing?ref=master"; flake = false; }; }; @@ -42,6 +41,37 @@ "v" (builtins.fromJSON (builtins.readFile ./flake.lock)) .nodes.rocksdb.original.ref; + # we have this already at https://github.com/girlbossceo/rocksdb/commit/a935c0273e1ba44eacf88ce3685a9b9831486155 + # unsetting this so i don't have to revert it and make this nix exclusive + patches = []; + cmakeFlags = pkgs.lib.subtractLists + [ + # no real reason to have snappy, no one uses this + "-DWITH_SNAPPY=1" + # we dont need to use ldb or sst_dump (core_tools) + "-DWITH_CORE_TOOLS=1" + # we dont need to build rocksdb tests + "-DWITH_TESTS=1" + # we use rust-rocksdb via C interface and dont need C++ RTTI + "-DUSE_RTTI=1" + ] + old.cmakeFlags + ++ [ + # we dont need to use ldb or sst_dump (core_tools) + "-DWITH_CORE_TOOLS=0" + # we dont need trace tools + "-DWITH_TRACE_TOOLS=0" + # we dont need to build rocksdb tests + "-DWITH_TESTS=0" + # we use rust-rocksdb via C interface and dont need C++ RTTI + "-DUSE_RTTI=0" + ]; + + # outputs has "tools" which we dont need or use + outputs = [ "out" ]; + + # preInstall hooks has stuff for messing with ldb/sst_dump which we dont need or use + preInstall = ""; }); # TODO: remove once https://github.com/NixOS/nixpkgs/pull/314945 is available liburing = pkgs.liburing.overrideAttrs (old: { @@ -50,16 +80,6 @@ configureFlags = pkgs.lib.subtractLists [ "--enable-static" "--disable-shared" ] old.configureFlags; - - postInstall = old.postInstall + '' - # we remove the extra outputs - # - # we need to do this to prevent rocksdb from trying to link the - # static library in a dynamic stdenv - rm $out/lib/liburing*${ - if pkgs.stdenv.hostPlatform.isStatic then ".so*" else ".a" - } - ''; }); }); @@ -124,9 +144,29 @@ { packages = { default = scopeHost.main; + all-features = scopeHost.main.override { + all_features = true; + disable_features = [ + # this is non-functional on nix for some reason + "hardened_malloc" + # dont include experimental features + "experimental" + ]; + }; hmalloc = scopeHost.main.override { features = ["hardened_malloc"]; }; oci-image = scopeHost.oci-image; + oci-image-all-features = scopeHost.oci-image.override { + main = scopeHost.main.override { + all_features = true; + disable_features = [ + # this is non-functional on nix for some reason + "hardened_malloc" + # dont include experimental features + "experimental" + ]; + }; + }; oci-image-hmalloc = scopeHost.oci-image.override { main = scopeHost.main.override { features = ["hardened_malloc"]; @@ -161,6 +201,20 @@ value = scopeCrossStatic.main; } + # An output for a statically-linked binary with `--all-features` + { + name = "${binaryName}-all-features"; + value = scopeCrossStatic.main.override { + all_features = true; + disable_features = [ + # this is non-functional on nix for some reason + "hardened_malloc" + # dont include experimental features + "experimental" + ]; + }; + } + # An output for a statically-linked binary with hardened_malloc { name = "${binaryName}-hmalloc"; @@ -175,6 +229,22 @@ value = scopeCrossStatic.oci-image; } + # An output for an OCI image based on that binary with `--all-features` + { + name = "oci-image-${crossSystem}-all-features"; + value = scopeCrossStatic.oci-image.override { + main = scopeCrossStatic.main.override { + all_features = true; + disable_features = [ + # this is non-functional on nix for some reason + "hardened_malloc" + # dont include experimental features + "experimental" + ]; + }; + }; + } + # An output for an OCI image based on that binary with hardened_malloc { name = "oci-image-${crossSystem}-hmalloc"; @@ -196,7 +266,15 @@ devShells.default = mkDevShell scopeHostStatic; devShells.all-features = mkDevShell (scopeHostStatic.overrideScope (final: prev: { - main = prev.main.override { all_features = true; }; + main = prev.main.override { + all_features = true; + disable_features = [ + # this is non-functional on nix for some reason + "hardened_malloc" + # dont include experimental features + "experimental" + ]; + }; })); devShells.no-features = mkDevShell (scopeHostStatic.overrideScope (final: prev: { diff --git a/nix/pkgs/complement/config.toml b/nix/pkgs/complement/config.toml index db1f2d81..15f939bb 100644 --- a/nix/pkgs/complement/config.toml +++ b/nix/pkgs/complement/config.toml @@ -5,13 +5,17 @@ allow_guest_registration = true allow_public_room_directory_over_federation = true allow_public_room_directory_without_auth = true allow_registration = true -allow_unstable_room_versions = true -database_backend = "rocksdb" database_path = "/database" -log = "trace" +log = "trace,h2=warn,hyper=warn" port = [8008, 8448] trusted_servers = [] +query_trusted_key_servers_first = false yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse = true +ip_range_denylist = [] +url_preview_domain_contains_allowlist = ["*"] +media_compat_file_link = false +media_statup_check = false +rocksdb_direct_io = false [global.tls] certs = "/certificate.crt" diff --git a/nix/pkgs/main/cross-compilation-env.nix b/nix/pkgs/main/cross-compilation-env.nix index 83fe6ed6..75b5862f 100644 --- a/nix/pkgs/main/cross-compilation-env.nix +++ b/nix/pkgs/main/cross-compilation-env.nix @@ -13,6 +13,12 @@ lib.optionalAttrs stdenv.hostPlatform.isStatic { lib.concatStringsSep " " ([] + ++ lib.optionals + stdenv.targetPlatform.isx86_64 + [ "-C" "target-cpu=x86-64-v2" ] + ++ lib.optionals + stdenv.targetPlatform.isAarch64 + [ "-C" "target-cpu=cortex-a55" ] # cortex-a55 == ARMv8.2-a # This disables PIE for static builds, which isn't great in terms # of security. Unfortunately, my hand is forced because nixpkgs' # `libstdc++.a` is built without `-fPIE`, which precludes us from diff --git a/nix/pkgs/main/default.nix b/nix/pkgs/main/default.nix index f91a9cdd..629957c8 100644 --- a/nix/pkgs/main/default.nix +++ b/nix/pkgs/main/default.nix @@ -25,11 +25,7 @@ let # on the nix side depend on feature values. crateFeatures = path: let manifest = lib.importTOML "${path}/Cargo.toml"; in - lib.remove "default" (lib.attrNames manifest.features) ++ - lib.attrNames - (lib.filterAttrs - (_: dependency: dependency.optional or false) - manifest.dependencies); + lib.remove "default" (lib.attrNames manifest.features); crateDefaultFeatures = path: (lib.importTOML "${path}/Cargo.toml").features.default; allDefaultFeatures = crateDefaultFeatures "${inputs.self}/src/main"; @@ -43,7 +39,7 @@ features'' = lib.subtractLists disable_features' features'; featureEnabled = feature : builtins.elem feature features''; -enableLiburing = featureEnabled "io_uring" && stdenv.isLinux; +enableLiburing = featureEnabled "io_uring" && !stdenv.isDarwin; # This derivation will set the JEMALLOC_OVERRIDE variable, causing the # tikv-jemalloc-sys crate to use the nixpkgs jemalloc instead of building it's @@ -70,12 +66,34 @@ buildDepsOnlyEnv = # # [1]: https://github.com/tikv/jemallocator/blob/ab0676d77e81268cd09b059260c75b38dbef2d51/jemalloc-sys/src/env.rs#L17 enableJemalloc = featureEnabled "jemalloc" && !stdenv.isDarwin; + + # for some reason enableLiburing in nixpkgs rocksdb is default true + # which breaks Darwin entirely + enableLiburing = enableLiburing; }).overrideAttrs (old: { - # TODO: static rocksdb fails to build on darwin + # TODO: static rocksdb fails to build on darwin, also see # build log at meta.broken = stdenv.hostPlatform.isStatic && stdenv.isDarwin; - # TODO: switch to enableUring option once https://github.com/NixOS/nixpkgs/pull/314945 is available - buildInputs = old.buildInputs ++ lib.optional enableLiburing liburing; + + enableLiburing = enableLiburing; + + sse42Support = stdenv.targetPlatform.isx86_64; + + cmakeFlags = if stdenv.targetPlatform.isx86_64 + then lib.subtractLists [ "-DPORTABLE=1" ] old.cmakeFlags + ++ lib.optionals stdenv.targetPlatform.isx86_64 [ + "-DPORTABLE=x86-64-v2" + "-DUSE_SSE=1" + "-DHAVE_SSE=1" + "-DHAVE_SSE42=1" + ] + else if stdenv.targetPlatform.isAarch64 + then lib.subtractLists [ "-DPORTABLE=1" ] old.cmakeFlags + ++ lib.optionals stdenv.targetPlatform.isAarch64 [ + # cortex-a55 == ARMv8.2-a + "-DPORTABLE=armv8.2-a" + ] + else old.cmakeFlags; }); in { @@ -102,7 +120,11 @@ buildPackageEnv = { # Only needed in static stdenv because these are transitive dependencies of rocksdb CARGO_BUILD_RUSTFLAGS = buildDepsOnlyEnv.CARGO_BUILD_RUSTFLAGS + lib.optionalString (enableLiburing && stdenv.hostPlatform.isStatic) - " -L${lib.getLib liburing}/lib -luring"; + " -L${lib.getLib liburing}/lib -luring" + + lib.optionalString stdenv.targetPlatform.isx86_64 + " -Ctarget-cpu=x86-64-v2" + + lib.optionalString stdenv.targetPlatform.isAarch64 + " -Ctarget-cpu=cortex-a55"; # cortex-a55 == ARMv8.2-a }; @@ -127,6 +149,8 @@ commonAttrs = { ]; }; + dontStrip = profile == "dev"; + buildInputs = lib.optional (featureEnabled "jemalloc") rust-jemalloc-sys'; nativeBuildInputs = [ diff --git a/src/admin/Cargo.toml b/src/admin/Cargo.toml index 97de55ff..1e13fb7a 100644 --- a/src/admin/Cargo.toml +++ b/src/admin/Cargo.toml @@ -29,15 +29,12 @@ release_max_log_level = [ clap.workspace = true conduit-api.workspace = true conduit-core.workspace = true -conduit-database.workspace = true conduit-service.workspace = true +const-str.workspace = true futures-util.workspace = true log.workspace = true -loole.workspace = true -regex.workspace = true ruma.workspace = true serde_json.workspace = true -serde.workspace = true serde_yaml.workspace = true tokio.workspace = true tracing-subscriber.workspace = true diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index c5b944bf..46f71622 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -15,7 +15,7 @@ use ruma::{ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use service::{rooms::event_handler::parse_incoming_pdu, sending::resolve::resolve_actual_dest, services, PduEvent}; +use service::{rooms::event_handler::parse_incoming_pdu, sending::resolve_actual_dest, services, PduEvent}; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; @@ -58,7 +58,7 @@ pub(super) async fn parse_pdu(body: Vec<&str>) -> Result match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { Ok(hash) => { @@ -314,6 +314,8 @@ pub(super) async fn force_device_list_updates(_body: Vec<&str>) -> Result, filter: Option, reset: bool, ) -> Result { + let handles = &["console"]; + if reset { let old_filter_layer = match EnvFilter::try_new(&services().globals.config.log) { Ok(s) => s, @@ -324,7 +326,12 @@ pub(super) async fn change_log_level( }, }; - match services().server.log.reload.reload(&old_filter_layer) { + match services() + .server + .log + .reload + .reload(&old_filter_layer, Some(handles)) + { Ok(()) => { return Ok(RoomMessageEventContent::text_plain(format!( "Successfully changed log level back to config value {}", @@ -349,7 +356,12 @@ pub(super) async fn change_log_level( }, }; - match services().server.log.reload.reload(&new_filter_layer) { + match services() + .server + .log + .reload + .reload(&new_filter_layer, Some(handles)) + { Ok(()) => { return Ok(RoomMessageEventContent::text_plain("Successfully changed log level")); }, @@ -570,7 +582,7 @@ pub(super) async fn force_set_room_state_from_server( .state_compressor .save_state(room_id.clone().as_ref(), new_room_state)?; - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; services() .rooms .state @@ -614,15 +626,16 @@ pub(super) async fn resolve_true_destination( let state = &services().server.log.capture; let logs = Arc::new(Mutex::new(String::new())); let capture = Capture::new(state, Some(filter), capture::fmt_markdown(logs.clone())); - let (actual_dest, hostname_uri); - { - let _capture_scope = capture.start(); - (actual_dest, hostname_uri) = resolve_actual_dest(&server_name, !no_cache).await?; - }; + + let capture_scope = capture.start(); + let actual = resolve_actual_dest(&server_name, !no_cache).await?; + drop(capture_scope); let msg = format!( - "{}\nDestination: {actual_dest}\nHostname URI: {hostname_uri}", - logs.lock().expect("locked") + "{}\nDestination: {}\nHostname URI: {}", + logs.lock().expect("locked"), + actual.dest, + actual.host, ); Ok(RoomMessageEventContent::text_markdown(msg)) } @@ -631,12 +644,46 @@ pub(super) async fn resolve_true_destination( pub(super) fn memory_stats() -> RoomMessageEventContent { let html_body = conduit::alloc::memory_stats(); - if html_body.is_empty() { + if html_body.is_none() { return RoomMessageEventContent::text_plain("malloc stats are not supported on your compiled malloc."); } RoomMessageEventContent::text_html( "This command's output can only be viewed by clients that render HTML.".to_owned(), - html_body, + html_body.expect("string result"), ) } + +#[cfg(tokio_unstable)] +pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result { + let out = services().server.metrics.runtime_metrics().map_or_else( + || "Runtime metrics are not available.".to_owned(), + |metrics| format!("```rs\n{metrics:#?}\n```"), + ); + + Ok(RoomMessageEventContent::text_markdown(out)) +} + +#[cfg(not(tokio_unstable))] +pub(super) async fn runtime_metrics(_body: Vec<&str>) -> Result { + Ok(RoomMessageEventContent::text_markdown( + "Runtime metrics require building with `tokio_unstable`.", + )) +} + +#[cfg(tokio_unstable)] +pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result { + let out = services().server.metrics.runtime_interval().map_or_else( + || "Runtime metrics are not available.".to_owned(), + |metrics| format!("```rs\n{metrics:#?}\n```"), + ); + + Ok(RoomMessageEventContent::text_markdown(out)) +} + +#[cfg(not(tokio_unstable))] +pub(super) async fn runtime_interval(_body: Vec<&str>) -> Result { + Ok(RoomMessageEventContent::text_markdown( + "Runtime metrics require building with `tokio_unstable`.", + )) +} diff --git a/src/admin/debug/mod.rs b/src/admin/debug/mod.rs index eed3b633..7d6cafa7 100644 --- a/src/admin/debug/mod.rs +++ b/src/admin/debug/mod.rs @@ -160,6 +160,13 @@ pub(super) enum DebugCommand { /// - Print extended memory usage MemoryStats, + /// - Print general tokio runtime metric totals. + RuntimeMetrics, + + /// - Print detailed tokio runtime metrics accumulated since last command + /// invocation. + RuntimeInterval, + /// - Developer test stubs #[command(subcommand)] Tester(TesterCommand), @@ -213,6 +220,8 @@ pub(super) async fn process(command: DebugCommand, body: Vec<&str>) -> Result resolve_true_destination(body, server_name, no_cache).await?, DebugCommand::MemoryStats => memory_stats(), + DebugCommand::RuntimeMetrics => runtime_metrics(body).await?, + DebugCommand::RuntimeInterval => runtime_interval(body).await?, DebugCommand::Tester(command) => tester::process(command, body).await?, }) } diff --git a/src/admin/federation/commands.rs b/src/admin/federation/commands.rs index 293efdc5..a97e7582 100644 --- a/src/admin/federation/commands.rs +++ b/src/admin/federation/commands.rs @@ -15,14 +15,19 @@ pub(super) async fn enable_room(_body: Vec<&str>, room_id: Box) -> Resul } pub(super) async fn incoming_federation(_body: Vec<&str>) -> Result { - let map = services().globals.roomid_federationhandletime.read().await; + let map = services() + .rooms + .event_handler + .federation_handletime + .read() + .expect("locked"); let mut msg = format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { let elapsed = i.elapsed(); - writeln!(msg, "{} {}: {}m{}s", r, e, elapsed.as_secs() / 60, elapsed.as_secs() % 60,) - .expect("should be able to write to string buffer"); + writeln!(msg, "{} {}: {}m{}s", r, e, elapsed.as_secs() / 60, elapsed.as_secs() % 60)?; } + Ok(RoomMessageEventContent::text_plain(&msg)) } diff --git a/src/admin/handler.rs b/src/admin/handler.rs index 0bd8af9f..95c7ed41 100644 --- a/src/admin/handler.rs +++ b/src/admin/handler.rs @@ -1,15 +1,19 @@ -use std::time::Instant; +use std::{panic::AssertUnwindSafe, time::Instant}; -use clap::Parser; -use conduit::trace; -use ruma::events::{ - relation::InReplyTo, - room::message::{Relation::Reply, RoomMessageEventContent}, +use clap::{CommandFactory, Parser}; +use conduit::{error, trace, Error}; +use futures_util::future::FutureExt; +use ruma::{ + events::{ + relation::InReplyTo, + room::message::{Relation::Reply, RoomMessageEventContent}, + }, + OwnedEventId, }; extern crate conduit_service as service; -use conduit::Result; +use conduit::{utils::string::common_prefix, Result}; pub(crate) use service::admin::{Command, Service}; use service::admin::{CommandOutput, CommandResult, HandlerResult}; @@ -20,7 +24,6 @@ use crate::{ }; pub(crate) const PAGE_SIZE: usize = 100; -#[cfg_attr(test, derive(Debug))] #[derive(Parser)] #[command(name = "admin", version = env!("CARGO_PKG_VERSION"))] pub(crate) enum AdminCommand { @@ -62,25 +65,46 @@ pub(crate) enum AdminCommand { } #[must_use] -pub fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } +pub(crate) fn handle(command: Command) -> HandlerResult { Box::pin(handle_command(command)) } + +#[must_use] +pub(crate) fn complete(line: &str) -> String { complete_admin_command(AdminCommand::command(), line) } #[tracing::instrument(skip_all, name = "admin")] async fn handle_command(command: Command) -> CommandResult { - let Some(mut content) = process_admin_message(command.command).await else { - return Ok(None); - }; + AssertUnwindSafe(process_command(&command)) + .catch_unwind() + .await + .map_err(Error::from_panic) + .or_else(|error| handle_panic(&error, command)) +} - content.relates_to = command.reply_id.map(|event_id| Reply { +async fn process_command(command: &Command) -> CommandOutput { + process_admin_message(&command.command) + .await + .and_then(|content| reply(content, command.reply_id.clone())) +} + +fn handle_panic(error: &Error, command: Command) -> CommandResult { + let link = "Please submit a [bug report](https://github.com/girlbossceo/conduwuit/issues/new). 🥺"; + let msg = format!("Panic occurred while processing command:\n```\n{error:#?}\n```\n{link}"); + let content = RoomMessageEventContent::notice_markdown(msg); + error!("Panic while processing command: {error:?}"); + Ok(reply(content, command.reply_id)) +} + +fn reply(mut content: RoomMessageEventContent, reply_id: Option) -> Option { + content.relates_to = reply_id.map(|event_id| Reply { in_reply_to: InReplyTo { event_id, }, }); - Ok(Some(content)) + Some(content) } // Parse and process a message from the admin room -async fn process_admin_message(msg: String) -> CommandOutput { +async fn process_admin_message(msg: &str) -> CommandOutput { let mut lines = msg.lines().filter(|l| !l.trim().is_empty()); let command = lines.next().expect("each string has at least one line"); let body = lines.collect::>(); @@ -100,59 +124,11 @@ async fn process_admin_message(msg: String) -> CommandOutput { match result { Ok(reply) => Some(reply), Err(error) => Some(RoomMessageEventContent::notice_markdown(format!( - "Encountered an error while handling the command:\n```\n{error}\n```" + "Encountered an error while handling the command:\n```\n{error:#?}\n```" ))), } } -// Parse chat messages from the admin room into an AdminCommand object -fn parse_admin_command(command_line: &str) -> Result { - let mut argv = command_line.split_whitespace().collect::>(); - - // Remove any escapes that came with a server-side escape command - if !argv.is_empty() && argv[0].ends_with("admin") { - argv[0] = argv[0].trim_start_matches('\\'); - } - - // First indice has to be "admin" but for console convenience we add it here - let server_user = services().globals.server_user.as_str(); - if !argv.is_empty() && !argv[0].ends_with("admin") && !argv[0].starts_with(server_user) { - argv.insert(0, "admin"); - } - - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help"); - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes_argv1; - if argv.len() > 1 && argv[1].contains('_') { - command_with_dashes_argv1 = argv[1].replace('_', "-"); - argv[1] = &command_with_dashes_argv1; - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes_argv2; - if argv.len() > 2 && argv[2].contains('_') { - command_with_dashes_argv2 = argv[2].replace('_', "-"); - argv[2] = &command_with_dashes_argv2; - } - - // if the user is using the `query` command (argv[1]), replace the database - // function/table calls with underscores to match the codebase - let command_with_dashes_argv3; - if argv.len() > 3 && argv[1].eq("query") { - command_with_dashes_argv3 = argv[3].replace('_', "-"); - argv[3] = &command_with_dashes_argv3; - } - - trace!(?command_line, ?argv, "parse"); - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) -} - #[tracing::instrument(skip_all, name = "command")] async fn process_admin_command(command: AdminCommand, body: Vec<&str>) -> Result { let reply_message_content = match command { @@ -169,3 +145,97 @@ async fn process_admin_command(command: AdminCommand, body: Vec<&str>) -> Result Ok(reply_message_content) } + +// Parse chat messages from the admin room into an AdminCommand object +fn parse_admin_command(command_line: &str) -> Result { + let argv = parse_command_line(command_line); + AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) +} + +fn complete_admin_command(mut cmd: clap::Command, line: &str) -> String { + let argv = parse_command_line(line); + let mut ret = Vec::::with_capacity(argv.len().saturating_add(1)); + + 'token: for token in argv.into_iter().skip(1) { + let cmd_ = cmd.clone(); + let mut choice = Vec::new(); + + for sub in cmd_.get_subcommands() { + let name = sub.get_name(); + if *name == token { + // token already complete; recurse to subcommand + ret.push(token); + cmd.clone_from(sub); + continue 'token; + } else if name.starts_with(&token) { + // partial match; add to choices + choice.push(name); + } + } + + if choice.len() == 1 { + // One choice. Add extra space because it's complete + let choice = *choice.first().expect("only choice"); + ret.push(choice.to_owned()); + ret.push(String::new()); + } else if choice.is_empty() { + // Nothing found, return original string + ret.push(token); + } else { + // Find the common prefix + ret.push(common_prefix(&choice).into()); + } + + // Return from completion + return ret.join(" "); + } + + // Return from no completion. Needs a space though. + ret.push(String::new()); + ret.join(" ") +} + +// Parse chat messages from the admin room into an AdminCommand object +fn parse_command_line(command_line: &str) -> Vec { + let mut argv = command_line + .split_whitespace() + .map(str::to_owned) + .collect::>(); + + // Remove any escapes that came with a server-side escape command + if !argv.is_empty() && argv[0].ends_with("admin") { + argv[0] = argv[0].trim_start_matches('\\').into(); + } + + // First indice has to be "admin" but for console convenience we add it here + let server_user = services().globals.server_user.as_str(); + if !argv.is_empty() && !argv[0].ends_with("admin") && !argv[0].starts_with(server_user) { + argv.insert(0, "admin".to_owned()); + } + + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help".to_owned()); + } + + // Backwards compatibility with `register_appservice`-style commands + if argv.len() > 1 && argv[1].contains('_') { + argv[1] = argv[1].replace('_', "-"); + } + + // Backwards compatibility with `register_appservice`-style commands + if argv.len() > 2 && argv[2].contains('_') { + argv[2] = argv[2].replace('_', "-"); + } + + // if the user is using the `query` command (argv[1]), replace the database + // function/table calls with underscores to match the codebase + if argv.len() > 3 && argv[1].eq("query") { + argv[3] = argv[3].replace('_', "-"); + } + + trace!(?command_line, ?argv, "parse"); + argv +} diff --git a/src/admin/mod.rs b/src/admin/mod.rs index f2034724..14856811 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod media; pub(crate) mod query; pub(crate) mod room; pub(crate) mod server; +mod tests; pub(crate) mod user; pub(crate) mod utils; @@ -17,7 +18,6 @@ extern crate conduit_core as conduit; extern crate conduit_service as service; pub(crate) use conduit::{mod_ctor, mod_dtor, Result}; -pub use handler::handle; pub(crate) use service::{services, user_is_local}; pub(crate) use crate::{ @@ -28,29 +28,29 @@ pub(crate) use crate::{ mod_ctor! {} mod_dtor! {} -#[cfg(test)] -mod test { - use clap::Parser; - - use crate::handler::AdminCommand; - - #[test] - fn get_help_short() { get_help_inner("-h"); } - - #[test] - fn get_help_long() { get_help_inner("--help"); } - - #[test] - fn get_help_subcommand() { get_help_inner("help"); } - - fn get_help_inner(input: &str) { - let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) - .unwrap_err() - .to_string(); - - // Search for a handful of keywords that suggest the help printed properly - assert!(error.contains("Usage:")); - assert!(error.contains("Commands:")); - assert!(error.contains("Options:")); - } +/// Install the admin command handler +pub async fn init() { + _ = services() + .admin + .complete + .write() + .expect("locked for writing") + .insert(handler::complete); + _ = services() + .admin + .handle + .write() + .await + .insert(handler::handle); +} + +/// Uninstall the admin command handler +pub async fn fini() { + _ = services().admin.handle.write().await.take(); + _ = services() + .admin + .complete + .write() + .expect("locked for writing") + .take(); } diff --git a/src/admin/query/globals.rs b/src/admin/query/globals.rs index 2e22d688..9bdd38fc 100644 --- a/src/admin/query/globals.rs +++ b/src/admin/query/globals.rs @@ -26,7 +26,7 @@ pub(super) async fn globals(subcommand: Globals) -> Result { let timer = tokio::time::Instant::now(); - let results = services().globals.db.last_check_for_updates_id(); + let results = services().updates.last_check_for_updates_id(); let query_time = timer.elapsed(); Ok(RoomMessageEventContent::notice_markdown(format!( diff --git a/src/admin/query/mod.rs b/src/admin/query/mod.rs index 946e6ec8..ea7036d0 100644 --- a/src/admin/query/mod.rs +++ b/src/admin/query/mod.rs @@ -2,6 +2,7 @@ mod account_data; mod appservice; mod globals; mod presence; +mod resolver; mod room_alias; mod room_state_cache; mod sending; @@ -12,12 +13,12 @@ use conduit::Result; use room_state_cache::room_state_cache; use ruma::{ events::{room::message::RoomMessageEventContent, RoomAccountDataEventType}, - RoomAliasId, RoomId, ServerName, UserId, + OwnedServerName, RoomAliasId, RoomId, ServerName, UserId, }; use self::{ - account_data::account_data, appservice::appservice, globals::globals, presence::presence, room_alias::room_alias, - sending::sending, users::users, + account_data::account_data, appservice::appservice, globals::globals, presence::presence, resolver::resolver, + room_alias::room_alias, sending::sending, users::users, }; #[cfg_attr(test, derive(Debug))] @@ -55,6 +56,10 @@ pub(super) enum QueryCommand { /// - users.rs iterators and getters #[command(subcommand)] Users(Users), + + /// - resolver service + #[command(subcommand)] + Resolver(Resolver), } #[cfg_attr(test, derive(Debug))] @@ -287,6 +292,21 @@ pub(super) enum Users { Iter, } +#[cfg_attr(test, derive(Debug))] +#[derive(Subcommand)] +/// Resolver service and caches +pub(super) enum Resolver { + /// Query the destinations cache + DestinationsCache { + server_name: Option, + }, + + /// Query the overrides cache + OverridesCache { + name: Option, + }, +} + /// Processes admin query commands pub(super) async fn process(command: QueryCommand, _body: Vec<&str>) -> Result { Ok(match command { @@ -298,5 +318,6 @@ pub(super) async fn process(command: QueryCommand, _body: Vec<&str>) -> Result globals(command).await?, QueryCommand::Sending(command) => sending(command).await?, QueryCommand::Users(command) => users(command).await?, + QueryCommand::Resolver(command) => resolver(command).await?, }) } diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs new file mode 100644 index 00000000..2a2554b5 --- /dev/null +++ b/src/admin/query/resolver.rs @@ -0,0 +1,87 @@ +use std::fmt::Write; + +use conduit::{utils::time, Result}; +use ruma::{events::room::message::RoomMessageEventContent, OwnedServerName}; + +use super::Resolver; +use crate::services; + +/// All the getters and iterators in key_value/users.rs +pub(super) async fn resolver(subcommand: Resolver) -> Result { + match subcommand { + Resolver::DestinationsCache { + server_name, + } => destinations_cache(server_name).await, + Resolver::OverridesCache { + name, + } => overrides_cache(name).await, + } +} + +async fn destinations_cache(server_name: Option) -> Result { + use service::sending::CachedDest; + + let mut out = String::new(); + writeln!(out, "| Server Name | Destination | Hostname | Expires |")?; + writeln!(out, "| ----------- | ----------- | -------- | ------- |")?; + let row = |( + name, + &CachedDest { + ref dest, + ref host, + expire, + }, + )| { + let expire = time::format(expire, "%+"); + writeln!(out, "| {name} | {dest} | {host} | {expire} |").expect("wrote line"); + }; + + let map = services() + .globals + .resolver + .destinations + .read() + .expect("locked"); + + if let Some(server_name) = server_name.as_ref() { + map.get_key_value(server_name).map(row); + } else { + map.iter().for_each(row); + } + + Ok(RoomMessageEventContent::notice_markdown(out)) +} + +async fn overrides_cache(server_name: Option) -> Result { + use service::sending::CachedOverride; + + let mut out = String::new(); + writeln!(out, "| Server Name | IP | Port | Expires |")?; + writeln!(out, "| ----------- | --- | ----:| ------- |")?; + let row = |( + name, + &CachedOverride { + ref ips, + port, + expire, + }, + )| { + let expire = time::format(expire, "%+"); + writeln!(out, "| {name} | {ips:?} | {port} | {expire} |").expect("wrote line"); + }; + + let map = services() + .globals + .resolver + .overrides + .read() + .expect("locked"); + + if let Some(server_name) = server_name.as_ref() { + map.get_key_value(server_name).map(row); + } else { + map.iter().for_each(row); + } + + Ok(RoomMessageEventContent::notice_markdown(out)) +} diff --git a/src/admin/room/mod.rs b/src/admin/room/mod.rs index 72e72793..b4fa15bd 100644 --- a/src/admin/room/mod.rs +++ b/src/admin/room/mod.rs @@ -16,6 +16,14 @@ pub(super) enum RoomCommand { /// - List all rooms the server knows about List { page: Option, + + /// Excludes rooms that we have federation disabled with + #[arg(long)] + exclude_disabled: bool, + + /// Excludes rooms that we have banned + #[arg(long)] + exclude_banned: bool, }, #[command(subcommand)] @@ -179,6 +187,8 @@ pub(super) async fn process(command: RoomCommand, body: Vec<&str>) -> Result list(body, page).await?, + exclude_disabled, + exclude_banned, + } => list(body, page, exclude_disabled, exclude_banned).await?, }) } diff --git a/src/admin/room/room_commands.rs b/src/admin/room/room_commands.rs index f64ccf30..d47edce2 100644 --- a/src/admin/room/room_commands.rs +++ b/src/admin/room/room_commands.rs @@ -1,18 +1,46 @@ use std::fmt::Write; -use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId}; +use ruma::events::room::message::RoomMessageEventContent; use crate::{escape_html, get_room_info, handler::PAGE_SIZE, services, Result}; -pub(super) async fn list(_body: Vec<&str>, page: Option) -> Result { +pub(super) async fn list( + _body: Vec<&str>, page: Option, exclude_disabled: bool, exclude_banned: bool, +) -> Result { // TODO: i know there's a way to do this with clap, but i can't seem to find it let page = page.unwrap_or(1); let mut rooms = services() .rooms .metadata .iter_ids() - .filter_map(Result::ok) - .map(|id: OwnedRoomId| get_room_info(&id)) + .filter_map(|room_id| { + room_id + .ok() + .filter(|room_id| { + if exclude_disabled + && services() + .rooms + .metadata + .is_disabled(room_id) + .unwrap_or(false) + { + return false; + } + + if exclude_banned + && services() + .rooms + .metadata + .is_banned(room_id) + .unwrap_or(false) + { + return false; + } + + true + }) + .map(|room_id| get_room_info(&room_id)) + }) .collect::>(); rooms.sort_by_key(|r| r.1); rooms.reverse(); diff --git a/src/admin/room/room_moderation_commands.rs b/src/admin/room/room_moderation_commands.rs index 39eb7c47..30c30c6e 100644 --- a/src/admin/room/room_moderation_commands.rs +++ b/src/admin/room/room_moderation_commands.rs @@ -1,7 +1,5 @@ use api::client::leave_room; -use ruma::{ - events::room::message::RoomMessageEventContent, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, -}; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomId, RoomAliasId, RoomId, RoomOrAliasId}; use tracing::{debug, error, info, warn}; use super::{super::Service, RoomModerationCommand}; @@ -124,9 +122,7 @@ async fn ban_room( .is_admin(local_user) .unwrap_or(true)) }) - }) - .collect::>() - { + }) { debug!( "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", &local_user, &room_id @@ -153,9 +149,7 @@ async fn ban_room( .is_admin(local_user) .unwrap_or(false)) }) - }) - .collect::>() - { + }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); if let Err(e) = leave_room(&local_user, &room_id, None).await { error!( @@ -191,7 +185,10 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo )); } - let rooms_s = body.clone().drain(1..body.len() - 1).collect::>(); + let rooms_s = body + .clone() + .drain(1..body.len().saturating_sub(1)) + .collect::>(); let admin_room_alias = &services().globals.admin_alias; @@ -332,9 +329,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo .is_admin(local_user) .unwrap_or(true)) }) - }) - .collect::>() - { + }) { debug!( "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", &local_user, room_id @@ -361,9 +356,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo .is_admin(local_user) .unwrap_or(false)) }) - }) - .collect::>() - { + }) { debug!("Attempting leave for user {} in room {}", &local_user, &room_id); if let Err(e) = leave_room(&local_user, &room_id, None).await { error!( diff --git a/src/admin/server/commands.rs b/src/admin/server/commands.rs index 77b20377..e4503736 100644 --- a/src/admin/server/commands.rs +++ b/src/admin/server/commands.rs @@ -1,24 +1,17 @@ -use conduit::{warn, Error, Result}; +use conduit::{utils::time, warn, Err, Result}; use ruma::events::room::message::RoomMessageEventContent; use crate::services; pub(super) async fn uptime(_body: Vec<&str>) -> Result { - let seconds = services() + let elapsed = services() .server .started .elapsed() - .expect("standard duration") - .as_secs(); - let result = format!( - "up {} days, {} hours, {} minutes, {} seconds.", - seconds / 86400, - (seconds % 86400) / 60 / 60, - (seconds % 3600) / 60, - seconds % 60, - ); + .expect("standard duration"); - Ok(RoomMessageEventContent::notice_plain(result)) + let result = time::pretty(elapsed); + Ok(RoomMessageEventContent::notice_plain(format!("{result}."))) } pub(super) async fn show_config(_body: Vec<&str>) -> Result { @@ -27,28 +20,17 @@ pub(super) async fn show_config(_body: Vec<&str>) -> Result) -> Result { - let response0 = services().memory_usage().await; - let response1 = services().globals.db.memory_usage(); - let response2 = conduit::alloc::memory_usage(); + let services_usage = services().memory_usage().await?; + let database_usage = services().db.db.memory_usage()?; + let allocator_usage = conduit::alloc::memory_usage().map_or(String::new(), |s| format!("\nAllocator:\n{s}")); Ok(RoomMessageEventContent::text_plain(format!( - "Services:\n{response0}\n\nDatabase:\n{response1}\n{}", - if !response2.is_empty() { - format!("Allocator:\n {response2}") - } else { - String::new() - } + "Services:\n{services_usage}\nDatabase:\n{database_usage}{allocator_usage}", ))) } -pub(super) async fn clear_database_caches(_body: Vec<&str>, amount: u32) -> Result { - services().globals.db.clear_caches(amount); - - Ok(RoomMessageEventContent::text_plain("Done.")) -} - -pub(super) async fn clear_service_caches(_body: Vec<&str>, amount: u32) -> Result { - services().clear_caches(amount).await; +pub(super) async fn clear_caches(_body: Vec<&str>) -> Result { + services().clear_cache().await; Ok(RoomMessageEventContent::text_plain("Done.")) } @@ -106,11 +88,10 @@ pub(super) async fn restart(_body: Vec<&str>, force: bool) -> Result) -> Result uptime(body).await?, ServerCommand::ShowConfig => show_config(body).await?, ServerCommand::MemoryUsage => memory_usage(body).await?, - ServerCommand::ClearDatabaseCaches { - amount, - } => clear_database_caches(body, amount).await?, - ServerCommand::ClearServiceCaches { - amount, - } => clear_service_caches(body, amount).await?, + ServerCommand::ClearCaches => clear_caches(body).await?, ServerCommand::ListBackups => list_backups(body).await?, ServerCommand::BackupDatabase => backup_database(body).await?, ServerCommand::ListDatabaseFiles => list_database_files(body).await?, diff --git a/src/admin/tests.rs b/src/admin/tests.rs new file mode 100644 index 00000000..69ccd896 --- /dev/null +++ b/src/admin/tests.rs @@ -0,0 +1,26 @@ +#![cfg(test)] + +#[test] +fn get_help_short() { get_help_inner("-h"); } + +#[test] +fn get_help_long() { get_help_inner("--help"); } + +#[test] +fn get_help_subcommand() { get_help_inner("help"); } + +fn get_help_inner(input: &str) { + use clap::Parser; + + use crate::handler::AdminCommand; + + let Err(error) = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) else { + panic!("no error!"); + }; + + let error = error.to_string(); + // Search for a handful of keywords that suggest the help printed properly + assert!(error.contains("Usage:")); + assert!(error.contains("Commands:")); + assert!(error.contains("Options:")); +} diff --git a/src/admin/user/commands.rs b/src/admin/user/commands.rs index 6dc60713..9e4b348b 100644 --- a/src/admin/user/commands.rs +++ b/src/admin/user/commands.rs @@ -8,7 +8,7 @@ use ruma::{ tag::{TagEvent, TagEventContent, TagInfo}, RoomAccountDataEventType, }, - OwnedRoomId, OwnedUserId, RoomId, + OwnedRoomId, OwnedRoomOrAliasId, OwnedUserId, RoomId, }; use tracing::{error, info, warn}; @@ -23,7 +23,7 @@ pub(super) async fn list(_body: Vec<&str>) -> Result { match services().users.list_local_users() { Ok(users) => { let mut plain_msg = format!("Found {} local user account(s):\n```\n", users.len()); - plain_msg += &users.join("\n"); + plain_msg += users.join("\n").as_str(); plain_msg += "\n```"; Ok(RoomMessageEventContent::notice_markdown(plain_msg)) @@ -95,7 +95,7 @@ pub(super) async fn create( if let Some(room_id_server_name) = room.server_name() { match join_room_by_id_helper( - Some(&user_id), + &user_id, room, Some("Automatically joining this room upon registration".to_owned()), &[room_id_server_name.to_owned(), services().globals.server_name().to_owned()], @@ -195,7 +195,10 @@ pub(super) async fn deactivate_all( )); } - let usernames = body.clone().drain(1..body.len() - 1).collect::>(); + let usernames = body + .clone() + .drain(1..body.len().saturating_sub(1)) + .collect::>(); let mut user_ids: Vec = Vec::with_capacity(usernames.len()); let mut admins = Vec::new(); @@ -331,6 +334,35 @@ pub(super) async fn list_joined_rooms(_body: Vec<&str>, user_id: String) -> Resu Ok(RoomMessageEventContent::text_html(output_plain, output_html)) } +pub(super) async fn force_join_room( + _body: Vec<&str>, user_id: String, room_id: OwnedRoomOrAliasId, +) -> Result { + let user_id = parse_local_user_id(&user_id)?; + let room_id = services().rooms.alias.resolve(&room_id).await?; + + assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); + join_room_by_id_helper(&user_id, &room_id, None, &[], None).await?; + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{user_id} has been joined to {room_id}.", + ))) +} + +pub(super) async fn make_user_admin(_body: Vec<&str>, user_id: String) -> Result { + let user_id = parse_local_user_id(&user_id)?; + let displayname = services() + .users + .displayname(&user_id)? + .unwrap_or_else(|| user_id.to_string()); + + assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); + service::admin::make_user_admin(&user_id, displayname).await?; + + Ok(RoomMessageEventContent::notice_markdown(format!( + "{user_id} has been granted admin privileges.", + ))) +} + pub(super) async fn put_room_tag( _body: Vec<&str>, user_id: String, room_id: Box, tag: String, ) -> Result { diff --git a/src/admin/user/mod.rs b/src/admin/user/mod.rs index 31bf57d6..cdb5fa5e 100644 --- a/src/admin/user/mod.rs +++ b/src/admin/user/mod.rs @@ -2,7 +2,7 @@ mod commands; use clap::Subcommand; use conduit::Result; -use ruma::{events::room::message::RoomMessageEventContent, RoomId}; +use ruma::{events::room::message::RoomMessageEventContent, OwnedRoomOrAliasId, RoomId}; use self::commands::*; @@ -49,7 +49,7 @@ pub(super) enum UserCommand { /// Markdown code block below the command. DeactivateAll { #[arg(short, long)] - /// Remove users from their joined rooms + /// Does not leave any rooms the user is in on deactivation no_leave_rooms: bool, #[arg(short, long)] /// Also deactivate admin accounts and will assume leave all rooms too @@ -65,6 +65,17 @@ pub(super) enum UserCommand { user_id: String, }, + /// - Manually join a local user to a room. + ForceJoinRoom { + user_id: String, + room_id: OwnedRoomOrAliasId, + }, + + /// - Grant server-admin privileges to a user. + MakeUserAdmin { + user_id: String, + }, + /// - Puts a room tag for the specified user and room ID. /// /// This is primarily useful if you'd like to set your admin room @@ -113,6 +124,13 @@ pub(super) async fn process(command: UserCommand, body: Vec<&str>) -> Result list_joined_rooms(body, user_id).await?, + UserCommand::ForceJoinRoom { + user_id, + room_id, + } => force_join_room(body, user_id, room_id).await?, + UserCommand::MakeUserAdmin { + user_id, + } => make_user_admin(body, user_id).await?, UserCommand::PutRoomTag { user_id, room_id, diff --git a/src/admin/utils.rs b/src/admin/utils.rs index 91982fe3..fda42e6e 100644 --- a/src/admin/utils.rs +++ b/src/admin/utils.rs @@ -1,4 +1,4 @@ -use conduit_core::Error; +use conduit_core::{err, Err}; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use service::user_is_local; @@ -33,7 +33,7 @@ pub(crate) fn get_room_info(id: &RoomId) -> (OwnedRoomId, u64, String) { /// Parses user ID pub(crate) fn parse_user_id(user_id: &str) -> Result { UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) - .map_err(|e| Error::Err(format!("The supplied username is not a valid username: {e}"))) + .map_err(|e| err!("The supplied username is not a valid username: {e}")) } /// Parses user ID as our local user @@ -41,7 +41,7 @@ pub(crate) fn parse_local_user_id(user_id: &str) -> Result { let user_id = parse_user_id(user_id)?; if !user_is_local(&user_id) { - return Err(Error::Err(String::from("User does not belong to our server."))); + return Err!("User {user_id:?} does not belong to our server."); } Ok(user_id) @@ -52,11 +52,11 @@ pub(crate) fn parse_active_local_user_id(user_id: &str) -> Result { let user_id = parse_local_user_id(user_id)?; if !services().users.exists(&user_id)? { - return Err(Error::Err(String::from("User does not exist on this server."))); + return Err!("User {user_id:?} does not exist on this server."); } if services().users.is_deactivated(&user_id)? { - return Err(Error::Err(String::from("User is deactivated."))); + return Err!("User {user_id:?} is deactivated."); } Ok(user_id) diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index a27924c6..356adc1f 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -41,9 +41,11 @@ bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true conduit-service.workspace = true +const-str.workspace = true futures-util.workspace = true hmac.workspace = true http.workspace = true +http-body-util.workspace = true hyper.workspace = true image.workspace = true ipaddress.workspace = true @@ -56,7 +58,6 @@ serde_html_form.workspace = true serde_json.workspace = true serde.workspace = true sha-1.workspace = true -thiserror.workspace = true tokio.workspace = true tracing.workspace = true webpage.workspace = true diff --git a/src/api/client/account.rs b/src/api/client/account.rs index 71324c1b..d34211bf 100644 --- a/src/api/client/account.rs +++ b/src/api/client/account.rs @@ -309,7 +309,7 @@ pub(crate) async fn register_route( // log in conduit admin channel if a guest registered if body.appservice_info.is_none() && is_guest && services().globals.log_guest_registrations() { - info!("New guest user \"{user_id}\" registered on this server from IP."); + info!("New guest user \"{user_id}\" registered on this server."); if let Some(device_display_name) = &body.initial_device_display_name { if body @@ -376,7 +376,7 @@ pub(crate) async fn register_route( if let Some(room_id_server_name) = room.server_name() { if let Err(e) = join_room_by_id_helper( - Some(&user_id), + &user_id, room, Some("Automatically joining this room upon registration".to_owned()), &[room_id_server_name.to_owned(), services().globals.server_name().to_owned()], @@ -423,7 +423,12 @@ pub(crate) async fn register_route( pub(crate) async fn change_password_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + // Authentication for this endpoint was made optional, but we need + // authentication currently + let sender_user = body + .sender_user + .as_ref() + .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let mut uiaainfo = UiaaInfo { @@ -512,7 +517,12 @@ pub(crate) async fn whoami_route(body: Ruma) -> Result, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + // Authentication for this endpoint was made optional, but we need + // authentication currently + let sender_user = body + .sender_user + .as_ref() + .ok_or_else(|| Error::BadRequest(ErrorKind::MissingToken, "Missing access token."))?; let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let mut uiaainfo = UiaaInfo { diff --git a/src/api/client/directory.rs b/src/api/client/directory.rs index 512dface..8e12c034 100644 --- a/src/api/client/directory.rs +++ b/src/api/client/directory.rs @@ -1,4 +1,5 @@ use axum_client_ip::InsecureClientIp; +use conduit::{err, info, warn, Error, Result}; use ruma::{ api::{ client::{ @@ -10,14 +11,16 @@ use ruma::{ }, directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork}, events::{ - room::join_rules::{JoinRule, RoomJoinRulesEventContent}, + room::{ + join_rules::{JoinRule, RoomJoinRulesEventContent}, + power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, + }, StateEventType, }, - uint, ServerName, UInt, + uint, RoomId, ServerName, UInt, UserId, }; -use tracing::{error, info, warn}; -use crate::{service::server_is_ours, services, Error, Result, Ruma}; +use crate::{service::server_is_ours, services, Ruma}; /// # `POST /_matrix/client/v3/publicRooms` /// @@ -103,8 +106,6 @@ pub(crate) async fn get_public_rooms_route( /// # `PUT /_matrix/client/r0/directory/list/room/{roomId}` /// /// Sets the visibility of a given room in the room directory. -/// -/// - TODO: Access control checks #[tracing::instrument(skip_all, fields(%client), name = "room_directory")] pub(crate) async fn set_room_visibility_route( InsecureClientIp(client): InsecureClientIp, body: Ruma, @@ -116,6 +117,13 @@ pub(crate) async fn set_room_visibility_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); } + if !user_can_publish_room(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "User is not allowed to publish this room", + )); + } + match &body.visibility { room::Visibility::Public => { if services().globals.config.lockdown_public_room_directory && !services().users.is_admin(sender_user)? { @@ -268,8 +276,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( _ => None, }) .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") + err!(Database(error!("Invalid room join rule event in database: {e}"))) }) }) .transpose()? @@ -351,3 +358,32 @@ pub(crate) async fn get_public_rooms_filtered_helper( total_room_count_estimate: Some(total_room_count_estimate), }) } + +/// Check whether the user can publish to the room directory via power levels of +/// room history visibility event or room creator +fn user_can_publish_room(user_id: &UserId, room_id: &RoomId) -> Result { + if let Some(event) = + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + { + serde_json::from_str(event.content.get()) + .map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels")) + .map(|content: RoomPowerLevelsEventContent| { + RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) + }) + } else if let Some(event) = + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + { + Ok(event.sender == user_id) + } else { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "User is not allowed to publish this room", + )); + } +} diff --git a/src/api/client/keys.rs b/src/api/client/keys.rs index 0cfb0138..7bb02a60 100644 --- a/src/api/client/keys.rs +++ b/src/api/client/keys.rs @@ -1,9 +1,9 @@ use std::{ - cmp, collections::{hash_map, BTreeMap, HashMap, HashSet}, - time::{Duration, Instant}, + time::Instant, }; +use conduit::{utils, utils::math::continue_exponential_backoff_secs, Error, Result}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ @@ -18,15 +18,11 @@ use ruma::{ DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; +use service::user_is_local; use tracing::debug; use super::SESSION_ID_LENGTH; -use crate::{ - service::user_is_local, - services, - utils::{self}, - Error, Result, Ruma, -}; +use crate::{services, Ruma}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -334,7 +330,7 @@ pub(crate) async fn get_keys_helper bool + Send>( .globals .bad_query_ratelimiter .write() - .await + .expect("locked") .entry(id) { hash_map::Entry::Vacant(e) => { @@ -353,15 +349,14 @@ pub(crate) async fn get_keys_helper bool + Send>( .globals .bad_query_ratelimiter .read() - .await + .expect("locked") .get(server) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off query from {:?}", server); + const MIN: u64 = 5 * 60; + const MAX: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN, MAX, time.elapsed(), *tries) { + debug!("Backing off query from {server:?}"); return (server, Err(Error::BadServerResponse("bad query, still backing off"))); } } diff --git a/src/api/client/media.rs b/src/api/client/media.rs index cfb10c3d..39640b23 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -2,6 +2,8 @@ use std::{io::Cursor, sync::Arc, time::Duration}; +use axum_client_ip::InsecureClientIp; +use conduit::{debug, error, utils::math::ruma_from_usize, warn}; use image::io::Reader as ImgReader; use ipaddress::IPAddress; use reqwest::Url; @@ -12,7 +14,6 @@ use ruma::api::client::{ get_media_preview, }, }; -use tracing::{debug, error, warn}; use webpage::HTML; use crate::{ @@ -44,7 +45,7 @@ pub(crate) async fn get_media_config_route( _body: Ruma, ) -> Result { Ok(get_media_config::v3::Response { - upload_size: services().globals.max_request_size().into(), + upload_size: ruma_from_usize(services().globals.config.max_request_size), }) } @@ -64,18 +65,22 @@ pub(crate) async fn get_media_config_v1_route( /// # `GET /_matrix/media/v3/preview_url` /// /// Returns URL preview. +#[tracing::instrument(skip_all, fields(%client), name = "url_preview")] pub(crate) async fn get_media_preview_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let url = &body.url; if !url_preview_allowed(url) { + warn!(%sender_user, "URL is not allowed to be previewed: {url}"); return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed")); } match get_url_preview(url).await { Ok(preview) => { let res = serde_json::value::to_raw_value(&preview).map_err(|e| { - error!("Failed to convert UrlPreviewData into a serde json value: {}", e); + error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}"); Error::BadRequest( ErrorKind::LimitExceeded { retry_after: Some(RetryAfter::Delay(Duration::from_secs(5))), @@ -87,7 +92,7 @@ pub(crate) async fn get_media_preview_route( Ok(get_media_preview::v3::Response::from_raw_value(res)) }, Err(e) => { - warn!("Failed to generate a URL preview: {e}"); + warn!(%sender_user, "Failed to generate a URL preview: {e}"); // there doesn't seem to be an agreed-upon error code in the spec. // the only response codes in the preview_url spec page are 200 and 429. @@ -108,10 +113,13 @@ pub(crate) async fn get_media_preview_route( /// See /// /// Returns URL preview. +#[tracing::instrument(skip_all, fields(%client), name = "url_preview")] pub(crate) async fn get_media_preview_v1_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result> { - get_media_preview_route(body).await.map(RumaResponse) + get_media_preview_route(InsecureClientIp(client), body) + .await + .map(RumaResponse) } /// # `POST /_matrix/media/v3/upload` @@ -120,8 +128,9 @@ pub(crate) async fn get_media_preview_v1_route( /// /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory +#[tracing::instrument(skip_all, fields(%client), name = "media_upload")] pub(crate) async fn create_content_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -167,10 +176,13 @@ pub(crate) async fn create_content_route( /// /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory +#[tracing::instrument(skip_all, fields(%client), name = "media_upload")] pub(crate) async fn create_content_v1_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result> { - create_content_route(body).await.map(RumaResponse) + create_content_route(InsecureClientIp(client), body) + .await + .map(RumaResponse) } /// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}` @@ -181,16 +193,20 @@ pub(crate) async fn create_content_v1_route( /// - Only redirects if `allow_redirect` is true /// - Uses client-provided `timeout_ms` if available, else defaults to 20 /// seconds -pub(crate) async fn get_content_route(body: Ruma) -> Result { +#[tracing::instrument(skip_all, fields(%client), name = "media_get")] +pub(crate) async fn get_content_route( + InsecureClientIp(client): InsecureClientIp, body: Ruma, +) -> Result { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { + content, content_type, - file, content_disposition, }) = services().media.get(&mxc).await? { let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None)); + let file = content.expect("content"); Ok(get_content::v3::Response { file, @@ -243,10 +259,13 @@ pub(crate) async fn get_content_route(body: Ruma) -> R /// - Only redirects if `allow_redirect` is true /// - Uses client-provided `timeout_ms` if available, else defaults to 20 /// seconds +#[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_v1_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result> { - get_content_route(body).await.map(RumaResponse) + get_content_route(InsecureClientIp(client), body) + .await + .map(RumaResponse) } /// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}/{fileName}` @@ -257,14 +276,15 @@ pub(crate) async fn get_content_v1_route( /// - Only redirects if `allow_redirect` is true /// - Uses client-provided `timeout_ms` if available, else defaults to 20 /// seconds +#[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_as_filename_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { + content, content_type, - file, content_disposition, }) = services().media.get(&mxc).await? { @@ -274,6 +294,7 @@ pub(crate) async fn get_content_as_filename_route( Some(body.filename.clone()), )); + let file = content.expect("content"); Ok(get_content_as_filename::v3::Response { file, content_type, @@ -328,10 +349,13 @@ pub(crate) async fn get_content_as_filename_route( /// - Only redirects if `allow_redirect` is true /// - Uses client-provided `timeout_ms` if available, else defaults to 20 /// seconds +#[tracing::instrument(skip_all, fields(%client), name = "media_get")] pub(crate) async fn get_content_as_filename_v1_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result> { - get_content_as_filename_route(body).await.map(RumaResponse) + get_content_as_filename_route(InsecureClientIp(client), body) + .await + .map(RumaResponse) } /// # `GET /_matrix/media/v3/thumbnail/{serverName}/{mediaId}` @@ -342,14 +366,15 @@ pub(crate) async fn get_content_as_filename_v1_route( /// - Only redirects if `allow_redirect` is true /// - Uses client-provided `timeout_ms` if available, else defaults to 20 /// seconds +#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")] pub(crate) async fn get_content_thumbnail_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { + content, content_type, - file, content_disposition, }) = services() .media @@ -365,6 +390,7 @@ pub(crate) async fn get_content_thumbnail_route( .await? { let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None)); + let file = content.expect("content"); Ok(get_content_thumbnail::v3::Response { file, @@ -453,10 +479,13 @@ pub(crate) async fn get_content_thumbnail_route( /// - Only redirects if `allow_redirect` is true /// - Uses client-provided `timeout_ms` if available, else defaults to 20 /// seconds +#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")] pub(crate) async fn get_content_thumbnail_v1_route( - body: Ruma, + InsecureClientIp(client): InsecureClientIp, body: Ruma, ) -> Result> { - get_content_thumbnail_route(body).await.map(RumaResponse) + get_content_thumbnail_route(InsecureClientIp(client), body) + .await + .map(RumaResponse) } async fn get_remote_content( diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 8529a9df..3adee631 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1,13 +1,15 @@ use std::{ - cmp, collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, net::IpAddr, sync::Arc, - time::{Duration, Instant}, + time::Instant, }; use axum_client_ip::InsecureClientIp; -use conduit::utils::mutex_map; +use conduit::{ + debug, debug_warn, error, info, trace, utils, utils::math::continue_exponential_backoff_secs, warn, Error, + PduEvent, Result, +}; use ruma::{ api::{ client::{ @@ -34,15 +36,16 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; -use tracing::{debug, error, info, trace, warn}; use crate::{ client::{update_avatar_url, update_displayname}, service::{ pdu::{gen_event_id_canonical_json, PduBuilder}, + rooms::state::RoomMutexGuard, + sending::convert_to_outgoing_federation_event, server_is_ours, user_is_local, }, - services, utils, Error, PduEvent, Result, Ruma, + services, Ruma, }; /// Checks if the room is banned in any way possible and the sender user is not @@ -199,7 +202,7 @@ pub(crate) async fn join_room_by_id_route( } join_room_by_id_helper( - body.sender_user.as_deref(), + sender_user, &body.room_id, body.reason.clone(), &servers, @@ -298,7 +301,7 @@ pub(crate) async fn join_room_by_id_or_alias_route( }; let join_room_response = join_room_by_id_helper( - Some(sender_user), + sender_user, &room_id, body.reason.clone(), &servers, @@ -363,6 +366,8 @@ pub(crate) async fn invite_user_route( pub(crate) async fn kick_user_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let mut event: RoomMemberEventContent = serde_json::from_str( services() .rooms @@ -380,12 +385,6 @@ pub(crate) async fn kick_user_route(body: Ruma) -> Resul event.membership = MembershipState::Leave; event.reason.clone_from(&body.reason); - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; - services() .rooms .timeline @@ -414,6 +413,8 @@ pub(crate) async fn kick_user_route(body: Ruma) -> Resul pub(crate) async fn ban_user_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let event = services() .rooms .state_accessor @@ -444,12 +445,6 @@ pub(crate) async fn ban_user_route(body: Ruma) -> Result< }, )?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; - services() .rooms .timeline @@ -478,6 +473,8 @@ pub(crate) async fn ban_user_route(body: Ruma) -> Result< pub(crate) async fn unban_user_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; + let mut event: RoomMemberEventContent = serde_json::from_str( services() .rooms @@ -493,12 +490,6 @@ pub(crate) async fn unban_user_route(body: Ruma) -> Res event.reason.clone_from(&body.reason); event.join_authorized_via_users_server = None; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; - services() .rooms .timeline @@ -650,35 +641,36 @@ pub(crate) async fn joined_members_route( } pub async fn join_room_by_id_helper( - sender_user: Option<&UserId>, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], + sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], third_party_signed: Option<&ThirdPartySigned>, ) -> Result { - let sender_user = sender_user.expect("user is authenticated"); + let state_lock = services().rooms.state.mutex.lock(room_id).await; if matches!(services().rooms.state_cache.is_joined(sender_user, room_id), Ok(true)) { - info!("{sender_user} is already joined in {room_id}"); + debug_warn!("{sender_user} is already joined in {room_id}"); return Ok(join_room_by_id::v3::Response { room_id: room_id.into(), }); } - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; - - // Ask a remote server if we are not participating in this room - if !services() + if services() .rooms .state_cache .server_in_room(services().globals.server_name(), room_id)? + || servers.is_empty() + || (servers.len() == 1 && server_is_ours(&servers[0])) { - join_room_by_id_helper_remote(sender_user, room_id, reason, servers, third_party_signed, state_lock).await - } else { join_room_by_id_helper_local(sender_user, room_id, reason, servers, third_party_signed, state_lock).await + } else { + // Ask a remote server if we are not participating in this room + join_room_by_id_helper_remote(sender_user, room_id, reason, servers, third_party_signed, state_lock).await } } +#[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_remote")] async fn join_room_by_id_helper_remote( sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, state_lock: mutex_map::Guard<()>, + _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result { info!("Joining {room_id} over federation."); @@ -779,7 +771,7 @@ async fn join_room_by_id_helper_remote( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + pdu: convert_to_outgoing_federation_event(join_event.clone()), omit_members: false, }, ) @@ -788,14 +780,9 @@ async fn join_room_by_id_helper_remote( info!("send_join finished"); if join_authorized_via_users_server.is_some() { + use RoomVersionId::*; match &room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 => { warn!( "Found `join_authorised_via_users_server` but room {} is version {}. Ignoring.", room_id, &room_version_id @@ -803,7 +790,7 @@ async fn join_room_by_id_helper_remote( }, // only room versions 8 and above using `join_authorized_via_users_server` (restricted joins) need to // validate and send signatures - RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 => { + V8 | V9 | V10 | V11 => { if let Some(signed_raw) = &send_join_response.room_state.event { info!( "There is a signed event. This room is probably using restricted joins. Adding signature to \ @@ -1011,11 +998,12 @@ async fn join_room_by_id_helper_remote( Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } +#[tracing::instrument(skip_all, fields(%sender_user, %room_id), name = "join_local")] async fn join_room_by_id_helper_local( sender_user: &UserId, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, state_lock: mutex_map::Guard<()>, + _third_party_signed: Option<&ThirdPartySigned>, state_lock: RoomMutexGuard, ) -> Result { - info!("We can join locally"); + debug!("We can join locally"); let join_rules_event = services() @@ -1115,7 +1103,7 @@ async fn join_room_by_id_helper_local( .iter() .any(|server_name| !server_is_ours(server_name)) { - info!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); + warn!("We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements"); let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; let room_version_id = match make_join_response.room_version { @@ -1207,7 +1195,7 @@ async fn join_room_by_id_helper_local( federation::membership::create_join_event::v2::Request { room_id: room_id.to_owned(), event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + pdu: convert_to_outgoing_federation_event(join_event.clone()), omit_members: false, }, ) @@ -1280,16 +1268,12 @@ async fn make_join_request( make_join_counter = make_join_counter.saturating_add(1); if let Err(ref e) = make_join_response { - trace!("make_join ErrorKind string: {:?}", e.error_code().to_string()); + trace!("make_join ErrorKind string: {:?}", e.kind().to_string()); // converting to a string is necessary (i think) because ruma is forcing us to // fill in the struct for M_INCOMPATIBLE_ROOM_VERSION - if e.error_code() - .to_string() - .contains("M_INCOMPATIBLE_ROOM_VERSION") - || e.error_code() - .to_string() - .contains("M_UNSUPPORTED_ROOM_VERSION") + if e.kind().to_string().contains("M_INCOMPATIBLE_ROOM_VERSION") + || e.kind().to_string().contains("M_UNSUPPORTED_ROOM_VERSION") { incompatible_room_version_count = incompatible_room_version_count.saturating_add(1); } @@ -1342,7 +1326,7 @@ pub async fn validate_and_add_event_id( .globals .bad_event_ratelimiter .write() - .await + .expect("locked") .entry(id) { Entry::Vacant(e) => { @@ -1358,15 +1342,14 @@ pub async fn validate_and_add_event_id( .globals .bad_event_ratelimiter .read() - .await + .expect("locked") .get(&event_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); + const MIN: u64 = 60 * 5; + const MAX: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN, MAX, time.elapsed(), *tries) { + debug!("Backing off from {event_id}"); return Err(Error::BadServerResponse("bad event, still backing off")); } } @@ -1395,7 +1378,7 @@ pub(crate) async fn invite_helper( if !user_is_local(user_id) { let (pdu, pdu_json, invite_room_state) = { - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: services().users.avatar_url(user_id)?, displayname: None, @@ -1438,7 +1421,7 @@ pub(crate) async fn invite_helper( room_id: room_id.to_owned(), event_id: (*pdu.event_id).to_owned(), room_version: room_version_id.clone(), - event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), + event: convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state, via: services().rooms.state_cache.servers_route_via(room_id).ok(), }, @@ -1507,7 +1490,7 @@ pub(crate) async fn invite_helper( )); } - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; services() .rooms @@ -1601,7 +1584,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option Result<()> { .filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::>(), + .map(|user| user.server_name().to_owned()), ); debug!("servers in remote_leave_room: {servers:?}"); @@ -1775,7 +1757,7 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { federation::membership::create_leave_event::v2::Request { room_id: room_id.to_owned(), event_id, - pdu: PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), + pdu: convert_to_outgoing_federation_event(leave_event.clone()), }, ) .await?; diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 9548f084..c376ee52 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -29,11 +29,7 @@ pub(crate) async fn send_message_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; // Forbid m.room.encrypted if encryption is disabled if MessageLikeEventType::RoomEncrypted == body.event_type && !services().globals.allow_encryption() { diff --git a/src/api/client/profile.rs b/src/api/client/profile.rs index 2b23344e..b0305939 100644 --- a/src/api/client/profile.rs +++ b/src/api/client/profile.rs @@ -353,7 +353,7 @@ pub async fn update_avatar_url( pub async fn update_all_rooms(all_joined_rooms: Vec<(PduBuilder, &OwnedRoomId)>, user_id: OwnedUserId) { for (pdu_builder, room_id) in all_joined_rooms { - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; if let Err(e) = services() .rooms .timeline diff --git a/src/api/client/redact.rs b/src/api/client/redact.rs index 4cb24c33..308d12e5 100644 --- a/src/api/client/redact.rs +++ b/src/api/client/redact.rs @@ -15,11 +15,7 @@ pub(crate) async fn redact_event_route(body: Ruma) -> let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let event_id = services() .rooms diff --git a/src/api/client/room.rs b/src/api/client/room.rs index 7090fdc8..adf58b04 100644 --- a/src/api/client/room.rs +++ b/src/api/client/room.rs @@ -90,7 +90,7 @@ pub(crate) async fn create_room_route(body: Ruma) -> R } let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; let alias: Option = if let Some(alias) = &body.room_alias_name { Some(room_alias_check(alias, &body.appservice_info).await?) @@ -118,6 +118,8 @@ pub(crate) async fn create_room_route(body: Ruma) -> R let content = match &body.creation_content { Some(content) => { + use RoomVersionId::*; + let mut content = content .deserialize_as::() .map_err(|e| { @@ -125,16 +127,7 @@ pub(crate) async fn create_room_route(body: Ruma) -> R Error::bad_database("Failed to deserialise content as canonical JSON.") })?; match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { content.insert( "creator".into(), json!(&sender_user).try_into().map_err(|e| { @@ -143,7 +136,7 @@ pub(crate) async fn create_room_route(body: Ruma) -> R })?, ); }, - RoomVersionId::V11 => {}, // V11 removed the "creator" key + V11 => {}, // V11 removed the "creator" key _ => { warn!("Unexpected or unsupported room version {room_version}"); return Err(Error::BadRequest( @@ -152,7 +145,6 @@ pub(crate) async fn create_room_route(body: Ruma) -> R )); }, } - content.insert( "room_version".into(), json!(room_version.as_str()) @@ -162,18 +154,11 @@ pub(crate) async fn create_room_route(body: Ruma) -> R content }, None => { + use RoomVersionId::*; + let content = match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()), - RoomVersionId::V11 => RoomCreateEventContent::new_v11(), + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(sender_user.clone()), + V11 => RoomCreateEventContent::new_v11(), _ => { warn!("Unexpected or unsupported room version {room_version}"); return Err(Error::BadRequest( @@ -573,11 +558,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma) -> .short .get_or_create_shortroomid(&replacement_room)?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; // Send a m.room.tombstone event to the old room to indicate that it is not // intended to be used any further Fail if the sender does not have the required @@ -605,11 +586,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma) -> // Change lock to replacement room drop(state_lock); - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&replacement_room) - .await; + let state_lock = services().rooms.state.mutex.lock(&replacement_room).await; // Get the old room creation event let mut create_event_content = serde_json::from_str::( @@ -631,36 +608,30 @@ pub(crate) async fn upgrade_room_route(body: Ruma) -> // Send a m.room.create event containing a predecessor field and the applicable // room_version - match body.new_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - create_event_content.insert( - "creator".into(), - json!(&sender_user).try_into().map_err(|e| { - info!("Error forming creation event: {e}"); - Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") - })?, - ); - }, - RoomVersionId::V11 => { - // "creator" key no longer exists in V11 rooms - create_event_content.remove("creator"); - }, - _ => { - warn!("Unexpected or unsupported room version {}", body.new_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, + { + use RoomVersionId::*; + match body.new_version { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + create_event_content.insert( + "creator".into(), + json!(&sender_user).try_into().map_err(|e| { + info!("Error forming creation event: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") + })?, + ); + }, + V11 => { + // "creator" key no longer exists in V11 rooms + create_event_content.remove("creator"); + }, + _ => { + warn!("Unexpected or unsupported room version {}", body.new_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } } create_event_content.insert( diff --git a/src/api/client/space.rs b/src/api/client/space.rs index e00171c3..0cf1b107 100644 --- a/src/api/client/space.rs +++ b/src/api/client/space.rs @@ -47,7 +47,7 @@ pub(crate) async fn get_hierarchy_route(body: Ruma) &body.room_id, limit.try_into().unwrap_or(10), key.map_or(vec![], |token| token.short_room_ids), - max_depth.try_into().unwrap_or(3), + max_depth.into(), body.suggested_only, ) .await diff --git a/src/api/client/state.rs b/src/api/client/state.rs index 9247b123..25b77fe3 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use conduit::{error, warn}; +use conduit::{debug_info, error}; use ruma::{ api::client::{ error::ErrorKind, @@ -36,18 +36,16 @@ pub(crate) async fn send_state_event_for_key_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event_id = send_state_event_for_key_helper( - sender_user, - &body.room_id, - &body.event_type, - &body.body.body, - body.state_key.clone(), - ) - .await?; - - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { - event_id, + event_id: send_state_event_for_key_helper( + sender_user, + &body.room_id, + &body.event_type, + &body.body.body, + body.state_key.clone(), + ) + .await? + .into(), }) } @@ -128,7 +126,7 @@ pub(crate) async fn get_state_events_for_key_route( .state_accessor .room_state_get(&body.room_id, &body.event_type, &body.state_key)? .ok_or_else(|| { - warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); + debug_info!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); Error::BadRequest(ErrorKind::NotFound, "State event not found.") })?; if body @@ -172,7 +170,7 @@ async fn send_state_event_for_key_helper( sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw, state_key: String, ) -> Result> { allowed_to_send_state_event(room_id, event_type, json).await?; - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let event_id = services() .rooms .timeline diff --git a/src/api/client/sync.rs b/src/api/client/sync.rs index 2ea766a4..e425616b 100644 --- a/src/api/client/sync.rs +++ b/src/api/client/sync.rs @@ -1,10 +1,15 @@ use std::{ + cmp, cmp::Ordering, collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, time::Duration, }; -use conduit::PduCount; +use conduit::{ + error, + utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, + Err, PduCount, +}; use ruma::{ api::client::{ filter::{FilterDefinition, LazyLoadOptions}, @@ -27,7 +32,7 @@ use ruma::{ serde::Raw, uint, DeviceId, EventId, OwnedUserId, RoomId, UInt, UserId, }; -use tracing::{error, Instrument as _, Span}; +use tracing::{Instrument as _, Span}; use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, Ruma, RumaResponse}; @@ -194,7 +199,7 @@ pub(crate) async fn sync_events_route( let (room_id, invite_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services().globals.roomid_mutex_insert.lock(&room_id).await; + let insert_lock = services().rooms.timeline.mutex_insert.lock(&room_id).await; drop(insert_lock); let invite_count = services() @@ -298,15 +303,9 @@ pub(crate) async fn sync_events_route( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or_default(); - if duration.as_secs() > 30 { - duration = Duration::from_secs(30); - } - - #[allow(clippy::let_underscore_must_use)] - { - _ = tokio::time::timeout(duration, watcher).await; - } + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; } Ok(response) @@ -318,7 +317,7 @@ async fn handle_left_room( next_batch_string: &str, full_state: bool, lazy_load_enabled: bool, ) -> Result<()> { // Get and drop the lock to wait for remaining operations to finish - let insert_lock = services().globals.roomid_mutex_insert.lock(room_id).await; + let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); let left_count = services() @@ -520,7 +519,7 @@ async fn load_joined_room( ) -> Result { // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch - let insert_lock = services().globals.roomid_mutex_insert.lock(room_id).await; + let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; @@ -546,8 +545,7 @@ async fn load_joined_room( // Database queries: let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { - error!("Room {} has no state", room_id); - return Err(Error::BadDatabase("Room has no state")); + return Err!(Database(error!("Room {room_id} has no state"))); }; let since_shortstatehash = services() @@ -975,8 +973,8 @@ async fn load_joined_room( }, summary: RoomSummary { heroes, - joined_member_count: joined_member_count.map(|n| (n as u32).into()), - invited_member_count: invited_member_count.map(|n| (n as u32).into()), + joined_member_count: joined_member_count.map(ruma_from_u64), + invited_member_count: invited_member_count.map(ruma_from_u64), }, unread_notifications: UnreadNotificationsCount { highlight_count, @@ -1026,7 +1024,7 @@ fn load_timeline( // Take the last events for the timeline timeline_pdus = non_timeline_pdus .by_ref() - .take(limit as usize) + .take(usize_from_u64_truncated(limit)) .collect::>() .into_iter() .rev() @@ -1300,7 +1298,7 @@ pub(crate) async fn sync_events_v4_route( r.0, UInt::try_from(all_joined_rooms.len().saturating_sub(1)).unwrap_or(UInt::MAX), ); - let room_ids = all_joined_rooms[(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)].to_vec(); + let room_ids = all_joined_rooms[usize_from_ruma(r.0)..=usize_from_ruma(r.1)].to_vec(); new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { let todo_room = todo_rooms @@ -1333,7 +1331,7 @@ pub(crate) async fn sync_events_v4_route( } }) .collect(), - count: UInt::from(all_joined_rooms.len() as u32), + count: ruma_from_usize(all_joined_rooms.len()), }, ); @@ -1529,20 +1527,22 @@ pub(crate) async fn sync_events_v4_route( prev_batch, limited, joined_count: Some( - (services() + services() .rooms .state_cache .room_joined_count(room_id)? - .unwrap_or(0) as u32) - .into(), + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), ), invited_count: Some( - (services() + services() .rooms .state_cache .room_invited_count(room_id)? - .unwrap_or(0) as u32) - .into(), + .unwrap_or(0) + .try_into() + .unwrap_or_else(|_| uint!(0)), ), num_live: None, // Count events in timeline greater than global sync counter timestamp: None, @@ -1557,14 +1557,9 @@ pub(crate) async fn sync_events_v4_route( { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or(Duration::from_secs(30)); - if duration.as_secs() > 30 { - duration = Duration::from_secs(30); - } - #[allow(clippy::let_underscore_must_use)] - { - _ = tokio::time::timeout(duration, watcher).await; - } + let default = Duration::from_secs(30); + let duration = cmp::min(body.timeout.unwrap_or(default), default); + _ = tokio::time::timeout(duration, watcher).await; } Ok(sync_events::v4::Response { diff --git a/src/api/client/thirdparty.rs b/src/api/client/thirdparty.rs index de24c0ec..f6af8729 100644 --- a/src/api/client/thirdparty.rs +++ b/src/api/client/thirdparty.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use ruma::api::client::thirdparty::get_protocols; -use crate::{Result, Ruma}; +use crate::{Result, Ruma, RumaResponse}; /// # `GET /_matrix/client/r0/thirdparty/protocols` /// @@ -15,3 +15,13 @@ pub(crate) async fn get_protocols_route( protocols: BTreeMap::new(), }) } + +/// # `GET /_matrix/client/unstable/thirdparty/protocols` +/// +/// Same as `get_protocols_route`, except for some reason Element Android legacy +/// calls this +pub(crate) async fn get_protocols_route_unstable( + body: Ruma, +) -> Result> { + get_protocols_route(body).await.map(RumaResponse) +} diff --git a/src/api/client/unstable.rs b/src/api/client/unstable.rs index 77cac0fa..e39db94e 100644 --- a/src/api/client/unstable.rs +++ b/src/api/client/unstable.rs @@ -1,12 +1,12 @@ use axum_client_ip::InsecureClientIp; -use conduit::{warn, RumaResponse}; +use conduit::warn; use ruma::{ api::client::{error::ErrorKind, membership::mutual_rooms, room::get_summary}, events::room::member::MembershipState, OwnedRoomId, }; -use crate::{services, Error, Result, Ruma}; +use crate::{services, Error, Result, Ruma, RumaResponse}; /// # `GET /_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms` /// diff --git a/src/api/mod.rs b/src/api/mod.rs index 8e30a518..79382934 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,14 +1,14 @@ pub mod client; -mod router; -pub mod routes; +pub mod router; pub mod server; extern crate conduit_core as conduit; extern crate conduit_service as service; -pub(crate) use conduit::{debug_info, debug_warn, utils, Error, Result}; -pub(crate) use service::{pdu::PduEvent, services, user_is_local}; +pub(crate) use conduit::{debug_info, debug_warn, pdu::PduEvent, utils, Error, Result}; +pub(crate) use service::{services, user_is_local}; +pub use crate::router::State; pub(crate) use crate::router::{Ruma, RumaResponse}; conduit::mod_ctor! {} diff --git a/src/api/routes.rs b/src/api/router.rs similarity index 90% rename from src/api/routes.rs rename to src/api/router.rs index 94951aec..761c173c 100644 --- a/src/api/routes.rs +++ b/src/api/router.rs @@ -1,15 +1,24 @@ +mod args; +mod auth; +mod handler; +mod request; +mod response; + use axum::{ response::IntoResponse, routing::{any, get, post}, Router, }; -use conduit::{Error, Server}; +use conduit::{err, Server}; use http::Uri; -use ruma::api::client::error::ErrorKind; -use crate::{client, router::RouterExt, server}; +use self::handler::RouterExt; +pub(super) use self::{args::Args as Ruma, response::RumaResponse}; +use crate::{client, server}; -pub fn build(router: Router, server: &Server) -> Router { +pub type State = &'static service::Services; + +pub fn build(router: Router, server: &Server) -> Router { let config = &server.config; let router = router .ruma_route(client::get_supported_versions_route) @@ -94,6 +103,8 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(client::search_users_route) .ruma_route(client::get_member_events_route) .ruma_route(client::get_protocols_route) + .route("/_matrix/client/unstable/thirdparty/protocols", + get(client::get_protocols_route_unstable)) .ruma_route(client::send_message_event_route) .ruma_route(client::send_state_event_for_key_route) .ruma_route(client::get_state_events_route) @@ -178,15 +189,15 @@ pub fn build(router: Router, server: &Server) -> Router { .ruma_route(client::get_relating_events_with_rel_type_route) .ruma_route(client::get_relating_events_route) .ruma_route(client::get_hierarchy_route) - .ruma_route(client::get_mutual_rooms_route) - .ruma_route(client::get_room_summary) - .route( - "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", - get(client::get_room_summary_legacy) - ) - .ruma_route(client::well_known_support) - .ruma_route(client::well_known_client) - .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) + .ruma_route(client::get_mutual_rooms_route) + .ruma_route(client::get_room_summary) + .route( + "/_matrix/client/unstable/im.nheko.summary/rooms/:room_id_or_alias/summary", + get(client::get_room_summary_legacy) + ) + .ruma_route(client::well_known_support) + .ruma_route(client::well_known_client) + .route("/_conduwuit/server_version", get(client::conduwuit_server_version)) .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) .route("/client/server.json", get(client::syncv3_client_server_json)); @@ -231,7 +242,7 @@ pub fn build(router: Router, server: &Server) -> Router { } async fn initial_sync(_uri: Uri) -> impl IntoResponse { - Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented") + err!(Request(GuestAccessForbidden("Guest access not implemented"))) } -async fn federation_disabled() -> impl IntoResponse { Error::bad_config("Federation is disabled.") } +async fn federation_disabled() -> impl IntoResponse { err!(Config("allow_federation", "Federation is disabled.")) } diff --git a/src/api/router/mod.rs b/src/api/router/args.rs similarity index 71% rename from src/api/router/mod.rs rename to src/api/router/args.rs index 2c439d65..776ce4f4 100644 --- a/src/api/router/mod.rs +++ b/src/api/router/args.rs @@ -1,24 +1,15 @@ -mod auth; -mod handler; -mod request; - use std::{mem, ops::Deref}; use axum::{async_trait, body::Body, extract::FromRequest}; use bytes::{BufMut, BytesMut}; -pub(super) use conduit::error::RumaResponse; -use conduit::{debug, debug_warn, trace, warn}; -use ruma::{ - api::{client::error::ErrorKind, IncomingRequest}, - CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, -}; +use conduit::{debug, err, trace, Error, Result}; +use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId}; -pub(super) use self::handler::RouterExt; -use self::{auth::Auth, request::Request}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use super::{auth, auth::Auth, request, request::Request}; +use crate::{service::appservice::RegistrationInfo, services}; /// Extractor for Ruma request structs -pub(crate) struct Ruma { +pub(crate) struct Args { /// Request struct body pub(crate) body: T, @@ -44,7 +35,7 @@ pub(crate) struct Ruma { } #[async_trait] -impl FromRequest for Ruma +impl FromRequest for Args where T: IncomingRequest, { @@ -65,7 +56,7 @@ where } } -impl Deref for Ruma { +impl Deref for Args { type Target = T; fn deref(&self) -> &Self::Target { &self.body } @@ -109,21 +100,14 @@ where let mut http_request = hyper::Request::builder() .uri(request.parts.uri.clone()) .method(request.parts.method.clone()); - *http_request.headers_mut().unwrap() = request.parts.headers.clone(); - let http_request = http_request.body(body).unwrap(); - debug!( - "{:?} {:?} {:?}", - http_request.method(), - http_request.uri(), - http_request.headers() - ); + *http_request.headers_mut().expect("mutable http headers") = request.parts.headers.clone(); + let http_request = http_request.body(body).expect("http request body"); - trace!("{:?} {:?} {:?}", http_request.method(), http_request.uri(), json_body); - let body = T::try_from_http_request(http_request, &request.path).map_err(|e| { - warn!("try_from_http_request failed: {e:?}",); - debug_warn!("JSON body: {:?}", json_body); - Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") - })?; + let headers = http_request.headers(); + let method = http_request.method(); + let uri = http_request.uri(); + debug!("{method:?} {uri:?} {headers:?}"); + trace!("{method:?} {uri:?} {json_body:?}"); - Ok(body) + T::try_from_http_request(http_request, &request.path).map_err(|e| err!(Request(BadJson(debug_warn!("{e}"))))) } diff --git a/src/api/router/auth.rs b/src/api/router/auth.rs index 08a08e08..6c2922b9 100644 --- a/src/api/router/auth.rs +++ b/src/api/router/auth.rs @@ -6,6 +6,7 @@ use axum_extra::{ typed_header::TypedHeaderRejectionReason, TypedHeader, }; +use conduit::Err; use http::uri::PathAndQuery; use ruma::{ api::{client::error::ErrorKind, AuthScheme, Metadata}, @@ -183,7 +184,7 @@ fn auth_appservice(request: &Request, info: Box) -> Result) -> Result { if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); + return Err!(Config("allow_federation", "Federation is disabled.")); } let TypedHeader(Authorization(x_matrix)) = request diff --git a/src/api/router/handler.rs b/src/api/router/handler.rs index 73361989..d112ec58 100644 --- a/src/api/router/handler.rs +++ b/src/api/router/handler.rs @@ -10,7 +10,7 @@ use conduit::Result; use http::Method; use ruma::api::IncomingRequest; -use super::{Ruma, RumaResponse}; +use super::{Ruma, RumaResponse, State}; pub(in super::super) trait RouterExt { fn ruma_route(self, handler: H) -> Self @@ -18,7 +18,7 @@ pub(in super::super) trait RouterExt { H: RumaHandler; } -impl RouterExt for Router { +impl RouterExt for Router { fn ruma_route(self, handler: H) -> Self where H: RumaHandler, @@ -28,9 +28,9 @@ impl RouterExt for Router { } pub(in super::super) trait RumaHandler { - fn add_routes(&self, router: Router) -> Router; + fn add_routes(&self, router: Router) -> Router; - fn add_route(&self, router: Router, path: &str) -> Router; + fn add_route(&self, router: Router, path: &str) -> Router; } macro_rules! ruma_handler { @@ -41,17 +41,17 @@ macro_rules! ruma_handler { Req: IncomingRequest + Send + 'static, Ret: IntoResponse, Fut: Future> + Send, - Fun: FnOnce($($tx,)* Ruma) -> Fut + Clone + Send + Sync + 'static, - $( $tx: FromRequestParts<()> + Send + 'static, )* + Fun: FnOnce($($tx,)* Ruma,) -> Fut + Clone + Send + Sync + 'static, + $( $tx: FromRequestParts + Send + 'static, )* { - fn add_routes(&self, router: Router) -> Router { + fn add_routes(&self, router: Router) -> Router { Req::METADATA .history .all_paths() .fold(router, |router, path| self.add_route(router, path)) } - fn add_route(&self, router: Router, path: &str) -> Router { + fn add_route(&self, router: Router, path: &str) -> Router { let handle = self.clone(); let method = method_to_filter(&Req::METADATA.method); let action = |$($tx,)* req| async { handle($($tx,)* req).await.map(RumaResponse) }; diff --git a/src/api/router/request.rs b/src/api/router/request.rs index 59639eaa..bed8d057 100644 --- a/src/api/router/request.rs +++ b/src/api/router/request.rs @@ -2,11 +2,11 @@ use std::str; use axum::{extract::Path, RequestExt, RequestPartsExt}; use bytes::Bytes; +use conduit::err; use http::request::Parts; -use ruma::api::client::error::ErrorKind; use serde::Deserialize; -use crate::{services, Error, Result}; +use crate::{services, Result}; #[derive(Deserialize)] pub(super) struct QueryParams { @@ -26,19 +26,15 @@ pub(super) async fn from(request: hyper::Request) -> Result> = parts.extract().await?; - let query = serde_html_form::from_str(parts.uri.query().unwrap_or_default()) - .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"))?; + let query = parts.uri.query().unwrap_or_default(); + let query = + serde_html_form::from_str(query).map_err(|e| err!(Request(Unknown("Failed to read query parameters: {e}"))))?; - let max_body_size = services() - .globals - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"); + let max_body_size = services().globals.config.max_request_size; let body = axum::body::to_bytes(body, max_body_size) .await - .map_err(|_| Error::BadRequest(ErrorKind::TooLarge, "Request body too large"))?; + .map_err(|e| err!(Request(TooLarge("Request body too large: {e}"))))?; Ok(Request { path, diff --git a/src/api/router/response.rs b/src/api/router/response.rs new file mode 100644 index 00000000..2aaa79fa --- /dev/null +++ b/src/api/router/response.rs @@ -0,0 +1,24 @@ +use axum::response::{IntoResponse, Response}; +use bytes::BytesMut; +use conduit::{error, Error}; +use http::StatusCode; +use http_body_util::Full; +use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; + +pub(crate) struct RumaResponse(pub(crate) T); + +impl From for RumaResponse { + fn from(t: Error) -> Self { Self(t.into()) } +} + +impl IntoResponse for RumaResponse { + fn into_response(self) -> Response { + self.0 + .try_into_http_response::() + .inspect_err(|e| error!("response error: {e}")) + .map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index e3ff82e4..b432ae20 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -1,9 +1,11 @@ +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::backfill::get_backfill}, uint, user_id, MilliSecondsSinceUnixEpoch, }; +use service::{sending::convert_to_outgoing_federation_event, services}; -use crate::{services, Error, PduEvent, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/federation/v1/backfill/` /// @@ -62,7 +64,7 @@ pub(crate) async fn get_backfill_route(body: Ruma) -> }) .map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id)) .filter_map(|r| r.ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) + .map(convert_to_outgoing_federation_event) .collect(); Ok(get_backfill::v1::Response { diff --git a/src/api/server/event.rs b/src/api/server/event.rs index 29f538b4..f4c9d145 100644 --- a/src/api/server/event.rs +++ b/src/api/server/event.rs @@ -1,9 +1,11 @@ +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_event}, MilliSecondsSinceUnixEpoch, RoomId, }; +use service::{sending::convert_to_outgoing_federation_event, services}; -use crate::{services, Error, PduEvent, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/federation/v1/event/{eventId}` /// @@ -48,6 +50,6 @@ pub(crate) async fn get_event_route(body: Ruma) -> Resul Ok(get_event::v1::Response { origin: services().globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: PduEvent::convert_to_outgoing_federation_event(event), + pdu: convert_to_outgoing_federation_event(event), }) } diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 1ddf2ce3..bef5116b 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,11 +1,13 @@ use std::sync::Arc; +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, RoomId, }; +use service::{sending::convert_to_outgoing_federation_event, services}; -use crate::{services, Error, PduEvent, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` /// @@ -57,7 +59,7 @@ pub(crate) async fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) - .map(PduEvent::convert_to_outgoing_federation_event) + .map(convert_to_outgoing_federation_event) .collect(), }) } diff --git a/src/api/server/get_missing_events.rs b/src/api/server/get_missing_events.rs index 1c9a6a38..5ab9abf8 100644 --- a/src/api/server/get_missing_events.rs +++ b/src/api/server/get_missing_events.rs @@ -1,9 +1,11 @@ +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_missing_events}, OwnedEventId, RoomId, }; +use service::{sending::convert_to_outgoing_federation_event, services}; -use crate::{services, Error, PduEvent, Result, Ruma}; +use crate::Ruma; /// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` /// @@ -79,7 +81,7 @@ pub(crate) async fn get_missing_events_route( ) .map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?, ); - events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); + events.push(convert_to_outgoing_federation_event(pdu)); } i = i.saturating_add(1); } diff --git a/src/api/server/invite.rs b/src/api/server/invite.rs index 333ebee6..89b90058 100644 --- a/src/api/server/invite.rs +++ b/src/api/server/invite.rs @@ -1,18 +1,14 @@ use axum_client_ip::InsecureClientIp; +use conduit::{utils, warn, Error, PduEvent, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_invite}, events::room::member::{MembershipState, RoomMemberEventContent}, serde::JsonObject, CanonicalJsonValue, EventId, OwnedUserId, }; -use tracing::warn; +use service::{sending::convert_to_outgoing_federation_event, server_is_ours, services}; -use crate::{ - service::server_is_ours, - services, - utils::{self}, - Error, PduEvent, Result, Ruma, -}; +use crate::Ruma; /// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` /// @@ -176,6 +172,6 @@ pub(crate) async fn create_invite_route( } Ok(create_invite::v2::Response { - event: PduEvent::convert_to_outgoing_federation_event(signed_event), + event: convert_to_outgoing_federation_event(signed_event), }) } diff --git a/src/api/server/make_join.rs b/src/api/server/make_join.rs index ca50dcbe..b5dadf7f 100644 --- a/src/api/server/make_join.rs +++ b/src/api/server/make_join.rs @@ -7,7 +7,7 @@ use ruma::{ }, StateEventType, TimelineEventType, }, - RoomId, RoomVersionId, UserId, + CanonicalJsonObject, RoomId, RoomVersionId, UserId, }; use serde_json::value::to_raw_value; use tracing::warn; @@ -71,11 +71,7 @@ pub(crate) async fn create_join_event_template_route( let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let join_authorized_via_users_server = if (services() .rooms @@ -148,27 +144,7 @@ pub(crate) async fn create_join_event_template_route( drop(state_lock); // room v3 and above removed the "event_id" field from remote PDU format - match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 - | RoomVersionId::V11 => { - pdu_json.remove("event_id"); - }, - _ => { - warn!("Unexpected or unsupported room version {room_version_id}"); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, - }; + maybe_strip_event_id(&mut pdu_json, &room_version_id)?; Ok(prepare_join_event::v1::Response { room_version: Some(room_version_id), @@ -183,6 +159,8 @@ pub(crate) async fn create_join_event_template_route( pub(crate) fn user_can_perform_restricted_join( user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result { + use RoomVersionId::*; + let join_rules_event = services() .rooms @@ -202,16 +180,7 @@ pub(crate) fn user_can_perform_restricted_join( return Ok(false); }; - if matches!( - room_version_id, - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - ) { + if matches!(room_version_id, V1 | V2 | V3 | V4 | V5 | V6 | V7) { return Ok(false); } @@ -243,3 +212,23 @@ pub(crate) fn user_can_perform_restricted_join( )) } } + +pub(crate) fn maybe_strip_event_id(pdu_json: &mut CanonicalJsonObject, room_version_id: &RoomVersionId) -> Result<()> { + use RoomVersionId::*; + + match room_version_id { + V1 | V2 => {}, + V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 | V11 => { + pdu_json.remove("event_id"); + }, + _ => { + warn!("Unexpected or unsupported room version {room_version_id}"); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + }; + + Ok(()) +} diff --git a/src/api/server/make_leave.rs b/src/api/server/make_leave.rs index 62c09717..63fc2b2e 100644 --- a/src/api/server/make_leave.rs +++ b/src/api/server/make_leave.rs @@ -1,14 +1,15 @@ +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::prepare_leave_event}, events::{ room::member::{MembershipState, RoomMemberEventContent}, TimelineEventType, }, - RoomVersionId, }; use serde_json::value::to_raw_value; -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; +use super::make_join::maybe_strip_event_id; +use crate::{service::pdu::PduBuilder, services, Ruma}; /// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}` /// @@ -35,11 +36,7 @@ pub(crate) async fn create_leave_event_template_route( .acl_check(origin, &body.room_id)?; let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - let state_lock = services() - .globals - .roomid_mutex_state - .lock(&body.room_id) - .await; + let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let content = to_raw_value(&RoomMemberEventContent { avatar_url: None, blurhash: None, @@ -68,26 +65,7 @@ pub(crate) async fn create_leave_event_template_route( drop(state_lock); // room v3 and above removed the "event_id" field from remote PDU format - match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 - | RoomVersionId::V11 => { - pdu_json.remove("event_id"); - }, - _ => { - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, - }; + maybe_strip_event_id(&mut pdu_json, &room_version_id)?; Ok(prepare_leave_event::v1::Response { room_version: Some(room_version_id), diff --git a/src/api/server/send.rs b/src/api/server/send.rs index 90225a14..7c699e95 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -1,7 +1,8 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant}; +use axum::extract::State; use axum_client_ip::InsecureClientIp; -use conduit::debug_warn; +use conduit::{debug, debug_warn, err, trace, warn, Err}; use ruma::{ api::{ client::error::ErrorKind, @@ -18,11 +19,10 @@ use ruma::{ OwnedEventId, ServerName, }; use tokio::sync::RwLock; -use tracing::{debug, error, trace, warn}; use crate::{ service::rooms::event_handler::parse_incoming_pdu, - services, + services::Services, utils::{self}, Error, Result, Ruma, }; @@ -34,29 +34,23 @@ type ResolvedMap = BTreeMap>; /// Push EDUs and PDUs to this server. #[tracing::instrument(skip_all, fields(%client), name = "send")] pub(crate) async fn send_transaction_message_route( - InsecureClientIp(client): InsecureClientIp, body: Ruma, + State(services): State<&Services>, InsecureClientIp(client): InsecureClientIp, + body: Ruma, ) -> Result { let origin = body.origin.as_ref().expect("server is authenticated"); if *origin != body.body.origin { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Not allowed to send transactions on behalf of other servers", - )); + return Err!(Request(Forbidden( + "Not allowed to send transactions on behalf of other servers" + ))); } if body.pdus.len() > 50_usize { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Not allowed to send more than 50 PDUs in one transaction", - )); + return Err!(Request(Forbidden("Not allowed to send more than 50 PDUs in one transaction"))); } if body.edus.len() > 100_usize { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "Not allowed to send more than 100 EDUs in one transaction", - )); + return Err!(Request(Forbidden("Not allowed to send more than 100 EDUs in one transaction"))); } let txn_start_time = Instant::now(); @@ -69,8 +63,8 @@ pub(crate) async fn send_transaction_message_route( "Starting txn", ); - let resolved_map = handle_pdus(&client, &body, origin, &txn_start_time).await?; - handle_edus(&client, &body, origin).await?; + let resolved_map = handle_pdus(services, &client, &body, origin, &txn_start_time).await?; + handle_edus(services, &client, &body, origin).await?; debug!( pdus = ?body.pdus.len(), @@ -84,13 +78,14 @@ pub(crate) async fn send_transaction_message_route( Ok(send_transaction_message::v1::Response { pdus: resolved_map .into_iter() - .map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))) + .map(|(e, r)| (e, r.map_err(|e| e.sanitized_string()))) .collect(), }) } async fn handle_pdus( - _client: &IpAddr, body: &Ruma, origin: &ServerName, txn_start_time: &Instant, + services: &Services, _client: &IpAddr, body: &Ruma, origin: &ServerName, + txn_start_time: &Instant, ) -> Result { let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); for pdu in &body.pdus { @@ -110,7 +105,7 @@ async fn handle_pdus( // corresponding signing keys let pub_key_map = RwLock::new(BTreeMap::new()); if !parsed_pdus.is_empty() { - services() + services .rooms .event_handler .fetch_required_signing_keys(parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) @@ -126,14 +121,15 @@ async fn handle_pdus( let mut resolved_map = BTreeMap::new(); for (event_id, value, room_id) in parsed_pdus { let pdu_start_time = Instant::now(); - let mutex_lock = services() - .globals - .roomid_mutex_federation + let mutex_lock = services + .rooms + .event_handler + .mutex_federation .lock(&room_id) .await; resolved_map.insert( event_id.clone(), - services() + services .rooms .event_handler .handle_incoming_pdu(origin, &room_id, &event_id, value, true, &pub_key_map) @@ -161,7 +157,7 @@ async fn handle_pdus( } async fn handle_edus( - client: &IpAddr, body: &Ruma, origin: &ServerName, + services: &Services, client: &IpAddr, body: &Ruma, origin: &ServerName, ) -> Result<()> { for edu in body .edus @@ -169,12 +165,12 @@ async fn handle_edus( .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { match edu { - Edu::Presence(presence) => handle_edu_presence(client, origin, presence).await?, - Edu::Receipt(receipt) => handle_edu_receipt(client, origin, receipt).await?, - Edu::Typing(typing) => handle_edu_typing(client, origin, typing).await?, - Edu::DeviceListUpdate(content) => handle_edu_device_list_update(client, origin, content).await?, - Edu::DirectToDevice(content) => handle_edu_direct_to_device(client, origin, content).await?, - Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(client, origin, content).await?, + Edu::Presence(presence) => handle_edu_presence(services, client, origin, presence).await?, + Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await?, + Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await?, + Edu::DeviceListUpdate(content) => handle_edu_device_list_update(services, client, origin, content).await?, + Edu::DirectToDevice(content) => handle_edu_direct_to_device(services, client, origin, content).await?, + Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(services, client, origin, content).await?, Edu::_Custom(ref _custom) => { debug_warn!(?body.edus, "received custom/unknown EDU"); }, @@ -184,8 +180,10 @@ async fn handle_edus( Ok(()) } -async fn handle_edu_presence(_client: &IpAddr, origin: &ServerName, presence: PresenceContent) -> Result<()> { - if !services().globals.allow_incoming_presence() { +async fn handle_edu_presence( + services: &Services, _client: &IpAddr, origin: &ServerName, presence: PresenceContent, +) -> Result<()> { + if !services.globals.allow_incoming_presence() { return Ok(()); } @@ -198,7 +196,7 @@ async fn handle_edu_presence(_client: &IpAddr, origin: &ServerName, presence: Pr continue; } - services().presence.set_presence( + services.presence.set_presence( &update.user_id, &update.presence, Some(update.currently_active), @@ -210,13 +208,15 @@ async fn handle_edu_presence(_client: &IpAddr, origin: &ServerName, presence: Pr Ok(()) } -async fn handle_edu_receipt(_client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) -> Result<()> { - if !services().globals.allow_incoming_read_receipts() { +async fn handle_edu_receipt( + services: &Services, _client: &IpAddr, origin: &ServerName, receipt: ReceiptContent, +) -> Result<()> { + if !services.globals.allow_incoming_read_receipts() { return Ok(()); } for (room_id, room_updates) in receipt.receipts { - if services() + if services .rooms .event_handler .acl_check(origin, &room_id) @@ -238,7 +238,7 @@ async fn handle_edu_receipt(_client: &IpAddr, origin: &ServerName, receipt: Rece continue; } - if services() + if services .rooms .state_cache .room_members(&room_id) @@ -254,7 +254,7 @@ async fn handle_edu_receipt(_client: &IpAddr, origin: &ServerName, receipt: Rece room_id: room_id.clone(), }; - services() + services .rooms .read_receipt .readreceipt_update(&user_id, &room_id, &event)?; @@ -272,8 +272,10 @@ async fn handle_edu_receipt(_client: &IpAddr, origin: &ServerName, receipt: Rece Ok(()) } -async fn handle_edu_typing(_client: &IpAddr, origin: &ServerName, typing: TypingContent) -> Result<()> { - if !services().globals.config.allow_incoming_typing { +async fn handle_edu_typing( + services: &Services, _client: &IpAddr, origin: &ServerName, typing: TypingContent, +) -> Result<()> { + if !services.globals.config.allow_incoming_typing { return Ok(()); } @@ -285,7 +287,7 @@ async fn handle_edu_typing(_client: &IpAddr, origin: &ServerName, typing: Typing return Ok(()); } - if services() + if services .rooms .event_handler .acl_check(typing.user_id.server_name(), &typing.room_id) @@ -298,26 +300,26 @@ async fn handle_edu_typing(_client: &IpAddr, origin: &ServerName, typing: Typing return Ok(()); } - if services() + if services .rooms .state_cache .is_joined(&typing.user_id, &typing.room_id)? { if typing.typing { let timeout = utils::millis_since_unix_epoch().saturating_add( - services() + services .globals .config .typing_federation_timeout_s .saturating_mul(1000), ); - services() + services .rooms .typing .typing_add(&typing.user_id, &typing.room_id, timeout) .await?; } else { - services() + services .rooms .typing .typing_remove(&typing.user_id, &typing.room_id) @@ -335,7 +337,7 @@ async fn handle_edu_typing(_client: &IpAddr, origin: &ServerName, typing: Typing } async fn handle_edu_device_list_update( - _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, + services: &Services, _client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent, ) -> Result<()> { let DeviceListUpdateContent { user_id, @@ -350,13 +352,13 @@ async fn handle_edu_device_list_update( return Ok(()); } - services().users.mark_device_key_update(&user_id)?; + services.users.mark_device_key_update(&user_id)?; Ok(()) } async fn handle_edu_direct_to_device( - _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, + services: &Services, _client: &IpAddr, origin: &ServerName, content: DirectDeviceContent, ) -> Result<()> { let DirectDeviceContent { sender, @@ -374,7 +376,7 @@ async fn handle_edu_direct_to_device( } // Check if this is a new transaction id - if services() + if services .transaction_ids .existing_txnid(&sender, None, &message_id)? .is_some() @@ -386,28 +388,27 @@ async fn handle_edu_direct_to_device( for (target_device_id_maybe, event) in map { match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services().users.add_to_device_event( + services.users.add_to_device_event( &sender, target_user_id, target_device_id, &ev_type.to_string(), - event.deserialize_as().map_err(|e| { - error!("To-Device event is invalid: {event:?} {e}"); - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") - })?, + event + .deserialize_as() + .map_err(|e| err!(Request(InvalidParam(error!("To-Device event is invalid: {e}")))))?, )?; }, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( + for target_device_id in services.users.all_device_ids(target_user_id) { + services.users.add_to_device_event( &sender, target_user_id, &target_device_id?, &ev_type.to_string(), event .deserialize_as() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, + .map_err(|e| err!(Request(InvalidParam("Event is invalid: {e}"))))?, )?; } }, @@ -416,7 +417,7 @@ async fn handle_edu_direct_to_device( } // Save transaction id with empty data - services() + services .transaction_ids .add_txnid(&sender, None, &message_id, &[])?; @@ -424,7 +425,7 @@ async fn handle_edu_direct_to_device( } async fn handle_edu_signing_key_update( - _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, + services: &Services, _client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent, ) -> Result<()> { let SigningKeyUpdateContent { user_id, @@ -441,7 +442,7 @@ async fn handle_edu_signing_key_update( } if let Some(master_key) = master_key { - services() + services .users .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 82cfc45e..577833d5 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -2,6 +2,7 @@ use std::collections::BTreeMap; +use conduit::{Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_join_event}, events::{ @@ -11,11 +12,13 @@ use ruma::{ CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use service::user_is_local; +use service::{ + pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, services, user_is_local, +}; use tokio::sync::RwLock; use tracing::warn; -use crate::{service::pdu::gen_event_id_canonical_json, services, Error, PduEvent, Result, Ruma}; +use crate::Ruma; /// helper method for /send_join v1 and v2 async fn create_join_event( @@ -153,8 +156,9 @@ async fn create_join_event( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(room_id) .await; let pdu_id: Vec = services() @@ -181,12 +185,12 @@ async fn create_join_event( Ok(create_join_event::v1::RoomState { auth_chain: auth_chain_ids .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) + .map(convert_to_outgoing_federation_event) .collect(), state: state_ids .iter() .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) + .map(convert_to_outgoing_federation_event) .collect(), // Event field is required if the room version supports restricted join rules. event: Some( diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 4fdde515..c4e17bbc 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -152,8 +152,9 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson .await?; let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(room_id) .await; let pdu_id: Vec = services() diff --git a/src/api/server/state.rs b/src/api/server/state.rs index c858f6fd..22044840 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,8 +1,10 @@ use std::sync::Arc; +use conduit::{Error, Result}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; +use service::{sending::convert_to_outgoing_federation_event, services}; -use crate::{services, Error, PduEvent, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/federation/v1/state/{roomId}` /// @@ -42,7 +44,7 @@ pub(crate) async fn get_room_state_route( .await? .into_values() .map(|id| { - PduEvent::convert_to_outgoing_federation_event( + convert_to_outgoing_federation_event( services() .rooms .timeline @@ -67,7 +69,7 @@ pub(crate) async fn get_room_state_route( .timeline .get_pdu_json(&id) .ok()? - .map(PduEvent::convert_to_outgoing_federation_event) + .map(convert_to_outgoing_federation_event) }) .collect(), pdus, diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index 2c0fc47d..453d7b13 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -53,7 +53,9 @@ sha256_media = [] argon2.workspace = true axum.workspace = true bytes.workspace = true +checked_ops.workspace = true chrono.workspace = true +const-str.workspace = true either.workspace = true figment.workspace = true http-body-util.workspace = true @@ -80,6 +82,7 @@ tikv-jemalloc-ctl.workspace = true tikv-jemalloc-sys.optional = true tikv-jemalloc-sys.workspace = true tokio.workspace = true +tokio-metrics.workspace = true tracing-core.workspace = true tracing-subscriber.workspace = true tracing.workspace = true diff --git a/src/core/alloc/default.rs b/src/core/alloc/default.rs index 4e2f8d7e..83bfca7d 100644 --- a/src/core/alloc/default.rs +++ b/src/core/alloc/default.rs @@ -1,9 +1,9 @@ //! Default allocator with no special features -/// Always returns the empty string +/// Always returns None #[must_use] -pub fn memory_stats() -> String { String::default() } +pub fn memory_stats() -> Option { None } -/// Always returns the empty string +/// Always returns None #[must_use] -pub fn memory_usage() -> String { String::default() } +pub fn memory_usage() -> Option { None } diff --git a/src/core/alloc/hardened.rs b/src/core/alloc/hardened.rs index 6727407f..335a3307 100644 --- a/src/core/alloc/hardened.rs +++ b/src/core/alloc/hardened.rs @@ -4,9 +4,10 @@ static HMALLOC: hardened_malloc_rs::HardenedMalloc = hardened_malloc_rs::HardenedMalloc; #[must_use] -pub fn memory_usage() -> String { - String::default() //TODO: get usage -} +//TODO: get usage +pub fn memory_usage() -> Option { None } #[must_use] -pub fn memory_stats() -> String { "Extended statistics are not available from hardened_malloc.".to_owned() } +pub fn memory_stats() -> Option { + Some("Extended statistics are not available from hardened_malloc.".to_owned()) +} diff --git a/src/core/alloc/je.rs b/src/core/alloc/je.rs index e0628ef3..08bfc49a 100644 --- a/src/core/alloc/je.rs +++ b/src/core/alloc/je.rs @@ -10,22 +10,31 @@ use tikv_jemallocator as jemalloc; static JEMALLOC: jemalloc::Jemalloc = jemalloc::Jemalloc; #[must_use] -pub fn memory_usage() -> String { +pub fn memory_usage() -> Option { use mallctl::stats; - let allocated = stats::allocated::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let active = stats::active::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let mapped = stats::mapped::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let metadata = stats::metadata::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let resident = stats::resident::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - let retained = stats::retained::read().unwrap_or_default() as f64 / 1024.0 / 1024.0; - format!( - " allocated: {allocated:.2} MiB\n active: {active:.2} MiB\n mapped: {mapped:.2} MiB\n metadata: {metadata:.2} \ - MiB\n resident: {resident:.2} MiB\n retained: {retained:.2} MiB\n " - ) + + let mibs = |input: Result| { + let input = input.unwrap_or_default(); + let kibs = input / 1024; + let kibs = u32::try_from(kibs).unwrap_or_default(); + let kibs = f64::from(kibs); + kibs / 1024.0 + }; + + let allocated = mibs(stats::allocated::read()); + let active = mibs(stats::active::read()); + let mapped = mibs(stats::mapped::read()); + let metadata = mibs(stats::metadata::read()); + let resident = mibs(stats::resident::read()); + let retained = mibs(stats::retained::read()); + Some(format!( + "allocated: {allocated:.2} MiB\nactive: {active:.2} MiB\nmapped: {mapped:.2} MiB\nmetadata: {metadata:.2} \ + MiB\nresident: {resident:.2} MiB\nretained: {retained:.2} MiB\n" + )) } #[must_use] -pub fn memory_stats() -> String { +pub fn memory_stats() -> Option { const MAX_LENGTH: usize = 65536 - 4096; let opts_s = "d"; @@ -42,7 +51,7 @@ pub fn memory_stats() -> String { unsafe { ffi::malloc_stats_print(Some(malloc_stats_cb), opaque, opts_p) }; str.truncate(MAX_LENGTH); - format!("
{str}
") + Some(format!("
{str}
")) } extern "C" fn malloc_stats_cb(opaque: *mut c_void, msg: *const c_char) { diff --git a/src/core/config/check.rs b/src/core/config/check.rs index 795f0c43..b36b9c5e 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -1,110 +1,126 @@ -#[cfg(unix)] -use std::path::Path; // not unix specific, just only for UNIX sockets stuff and *nix container checks +use figment::Figment; -use tracing::{debug, error, info, warn}; - -use crate::{error::Error, Config}; - -pub fn check(config: &Config) -> Result<(), Error> { - #[cfg(feature = "rocksdb")] - warn!( - "Note the rocksdb feature was deleted from conduwuit, sqlite was deleted and RocksDB is the only supported \ - backend now. Please update your build script to remove this feature." - ); - #[cfg(feature = "sha256_media")] - warn!( - "Note the sha256_media feature was deleted from conduwuit, it is now fully integrated in a \ - forwards-compatible way. Please update your build script to remove this feature." - ); - - config.warn_deprecated(); - config.warn_unknown_key(); - - if config.sentry && config.sentry_endpoint.is_none() { - return Err(Error::bad_config("Sentry cannot be enabled without an endpoint set")); - } - - if cfg!(feature = "hardened_malloc") && cfg!(feature = "jemalloc") { - warn!("hardened_malloc and jemalloc are both enabled, this causes jemalloc to be used."); - } - - if config.unix_socket_path.is_some() && !cfg!(unix) { - return Err(Error::bad_config( - "UNIX socket support is only available on *nix platforms. Please remove \"unix_socket_path\" from your \ - config.", - )); - } - - config.get_bind_addrs().iter().for_each(|addr| { - if addr.ip().is_loopback() && cfg!(unix) { - debug!("Found loopback listening address {addr}, running checks if we're in a container.",); - - #[cfg(unix)] - if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() - /* Host */ - { - error!( - "You are detected using OpenVZ with a loopback/localhost listening address of {addr}. If you are \ - using OpenVZ for containers and you use NAT-based networking to communicate with the host and \ - guest, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can \ - ignore.", - ); - } - - #[cfg(unix)] - if Path::new("/.dockerenv").exists() { - error!( - "You are detected using Docker with a loopback/localhost listening address of {addr}. If you are \ - using a reverse proxy on the host and require communication to conduwuit in the Docker container \ - via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is \ - expected, you can ignore.", - ); - } - - #[cfg(unix)] - if Path::new("/run/.containerenv").exists() { - error!( - "You are detected using Podman with a loopback/localhost listening address of {addr}. If you are \ - using a reverse proxy on the host and require communication to conduwuit in the Podman container \ - via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is \ - expected, you can ignore.", - ); - } - } - }); - - // rocksdb does not allow max_log_files to be 0 - if config.rocksdb_max_log_files == 0 { - return Err(Error::bad_config( - "When using RocksDB, rocksdb_max_log_files cannot be 0. Please set a value at least 1.", - )); - } - - // yeah, unless the user built a debug build hopefully for local testing only - if config.server_name == "your.server.name" && !cfg!(debug_assertions) { - return Err(Error::bad_config( - "You must specify a valid server name for production usage of conduwuit.", - )); - } +use super::DEPRECATED_KEYS; +use crate::{debug, debug_info, error, info, warn, Config, Err, Result}; +#[allow(clippy::cognitive_complexity)] +pub fn check(config: &Config) -> Result<()> { if cfg!(debug_assertions) { info!("Note: conduwuit was built without optimisations (i.e. debug build)"); } + // prevents catching this in `--all-features` + if cfg!(all(feature = "rocksdb", not(feature = "sha256_media"))) { + warn!( + "Note the rocksdb feature was deleted from conduwuit. SQLite support was removed and RocksDB is the only \ + supported backend now. Please update your build script to remove this feature." + ); + } + + // prevents catching this in `--all-features` + if cfg!(all(feature = "sha256_media", not(feature = "rocksdb"))) { + warn!( + "Note the sha256_media feature was deleted from conduwuit, it is now fully integrated in a \ + forwards-compatible way. Please update your build script to remove this feature." + ); + } + + warn_deprecated(config); + warn_unknown_key(config); + + if config.sentry && config.sentry_endpoint.is_none() { + return Err!(Config("sentry_endpoint", "Sentry cannot be enabled without an endpoint set")); + } + + if cfg!(all(feature = "hardened_malloc", feature = "jemalloc")) { + warn!( + "hardened_malloc and jemalloc are both enabled, this causes jemalloc to be used. If using --all-features, \ + this is harmless." + ); + } + + if cfg!(not(unix)) && config.unix_socket_path.is_some() { + return Err!(Config( + "unix_socket_path", + "UNIX socket support is only available on *nix platforms. Please remove 'unix_socket_path' from your \ + config." + )); + } + + if cfg!(unix) && config.unix_socket_path.is_none() { + config.get_bind_addrs().iter().for_each(|addr| { + use std::path::Path; + + if addr.ip().is_loopback() { + debug_info!("Found loopback listening address {addr}, running checks if we're in a container."); + + if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() + /* Host */ + { + error!( + "You are detected using OpenVZ with a loopback/localhost listening address of {addr}. If you \ + are using OpenVZ for containers and you use NAT-based networking to communicate with the \ + host and guest, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ + you can ignore.", + ); + } + + if Path::new("/.dockerenv").exists() { + error!( + "You are detected using Docker with a loopback/localhost listening address of {addr}. If you \ + are using a reverse proxy on the host and require communication to conduwuit in the Docker \ + container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". \ + If this is expected, you can ignore.", + ); + } + + if Path::new("/run/.containerenv").exists() { + error!( + "You are detected using Podman with a loopback/localhost listening address of {addr}. If you \ + are using a reverse proxy on the host and require communication to conduwuit in the Podman \ + container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". \ + If this is expected, you can ignore.", + ); + } + } + }); + } + + // rocksdb does not allow max_log_files to be 0 + if config.rocksdb_max_log_files == 0 { + return Err!(Config( + "max_log_files", + "rocksdb_max_log_files cannot be 0. Please set a value at least 1." + )); + } + + // yeah, unless the user built a debug build hopefully for local testing only + if cfg!(not(debug_assertions)) && config.server_name == "your.server.name" { + return Err!(Config( + "server_name", + "You must specify a valid server name for production usage of conduwuit." + )); + } + // check if the user specified a registration token as `""` if config.registration_token == Some(String::new()) { - return Err(Error::bad_config("Registration token was specified but is empty (\"\")")); + return Err!(Config( + "registration_token", + "Registration token was specified but is empty (\"\")" + )); } if config.max_request_size < 5_120_000 { - return Err(Error::bad_config("Max request size is less than 5MB. Please increase it.")); + return Err!(Config( + "max_request_size", + "Max request size is less than 5MB. Please increase it." + )); } // check if user specified valid IP CIDR ranges on startup for cidr in &config.ip_range_denylist { if let Err(e) = ipaddress::IPAddress::parse(cidr) { - error!("Error parsing specified IP CIDR range from string: {e}"); - return Err(Error::bad_config("Error parsing specified IP CIDR ranges from strings")); + return Err!(Config("ip_range_denylist", "Parsing specified IP CIDR range from string: {e}.")); } } @@ -112,13 +128,14 @@ pub fn check(config: &Config) -> Result<(), Error> { && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse && config.registration_token.is_none() { - return Err(Error::bad_config( + return Err!(Config( + "registration_token", "!! You have `allow_registration` enabled without a token configured in your config which means you are \ allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you \ want, please set the following config option to true: -`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`", +`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`" )); } @@ -135,8 +152,9 @@ For security and safety reasons, conduwuit will shut down. If you are extra sure } if config.allow_outgoing_presence && !config.allow_local_presence { - return Err(Error::bad_config( - "Outgoing presence requires allowing local presence. Please enable \"allow_local_presence\".", + return Err!(Config( + "allow_local_presence", + "Outgoing presence requires allowing local presence. Please enable 'allow_local_presence'." )); } @@ -173,3 +191,52 @@ For security and safety reasons, conduwuit will shut down. If you are extra sure Ok(()) } + +/// Iterates over all the keys in the config file and warns if there is a +/// deprecated key specified +fn warn_deprecated(config: &Config) { + debug!("Checking for deprecated config keys"); + let mut was_deprecated = false; + for key in config + .catchall + .keys() + .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) + { + warn!("Config parameter \"{}\" is deprecated, ignoring.", key); + was_deprecated = true; + } + + if was_deprecated { + warn!( + "Read conduwuit config documentation at https://conduwuit.puppyirl.gay/configuration.html and check your \ + configuration if any new configuration parameters should be adjusted" + ); + } +} + +/// iterates over all the catchall keys (unknown config options) and warns +/// if there are any. +fn warn_unknown_key(config: &Config) { + debug!("Checking for unknown config keys"); + for key in config + .catchall + .keys() + .filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) + { + warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); + } +} + +/// Checks the presence of the `address` and `unix_socket_path` keys in the +/// raw_config, exiting the process if both keys were detected. +pub(super) fn is_dual_listening(raw_config: &Figment) -> Result<()> { + let contains_address = raw_config.contains("address"); + let contains_unix_socket = raw_config.contains("unix_socket_path"); + if contains_address && contains_unix_socket { + return Err!( + "TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option." + ); + } + + Ok(()) +} diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index b808f196..b40ebb65 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1,6 +1,6 @@ use std::{ collections::BTreeMap, - fmt::{self, Write as _}, + fmt, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::PathBuf, }; @@ -19,30 +19,15 @@ use ruma::{ api::client::discovery::discover_support::ContactRole, OwnedRoomId, OwnedServerName, OwnedUserId, RoomVersionId, }; use serde::{de::IgnoredAny, Deserialize}; -use tracing::{debug, error, warn}; use url::Url; pub use self::check::check; use self::proxy::ProxyConfig; -use crate::error::Error; +use crate::{error::Error, Err, Result}; pub mod check; pub mod proxy; -#[derive(Deserialize, Clone, Debug)] -#[serde(transparent)] -struct ListeningPort { - #[serde(with = "either::serde_untagged")] - ports: Either>, -} - -#[derive(Deserialize, Clone, Debug)] -#[serde(transparent)] -struct ListeningAddr { - #[serde(with = "either::serde_untagged")] - addrs: Either>, -} - /// all the config options for conduwuit #[derive(Clone, Debug, Deserialize)] #[allow(clippy::struct_excessive_bools)] @@ -73,8 +58,8 @@ pub struct Config { #[serde(default = "default_pdu_cache_capacity")] pub pdu_cache_capacity: u32, - #[serde(default = "default_conduit_cache_capacity_modifier")] - pub conduit_cache_capacity_modifier: f64, + #[serde(default = "default_cache_capacity_modifier", alias = "conduit_cache_capacity_modifier")] + pub cache_capacity_modifier: f64, #[serde(default = "default_auth_chain_cache_capacity")] pub auth_chain_cache_capacity: u32, #[serde(default = "default_shorteventid_cache_capacity")] @@ -114,7 +99,7 @@ pub struct Config { pub ip_lookup_strategy: u8, #[serde(default = "default_max_request_size")] - pub max_request_size: u32, + pub max_request_size: usize, #[serde(default = "default_max_fetch_prev_events")] pub max_fetch_prev_events: u16, @@ -181,16 +166,14 @@ pub struct Config { #[serde(default)] pub well_known: WellKnownConfig, #[serde(default)] - #[cfg(feature = "perf_measurements")] pub allow_jaeger: bool, + #[serde(default = "default_jaeger_filter")] + pub jaeger_filter: String, #[serde(default)] - #[cfg(feature = "perf_measurements")] pub tracing_flame: bool, #[serde(default = "default_tracing_flame_filter")] - #[cfg(feature = "perf_measurements")] pub tracing_flame_filter: String, #[serde(default = "default_tracing_flame_output_path")] - #[cfg(feature = "perf_measurements")] pub tracing_flame_output_path: String, #[serde(default)] pub proxy: ProxyConfig, @@ -356,6 +339,14 @@ pub struct Config { pub sentry_send_server_name: bool, #[serde(default = "default_sentry_traces_sample_rate")] pub sentry_traces_sample_rate: f32, + #[serde(default)] + pub sentry_attach_stacktrace: bool, + #[serde(default = "true_fn")] + pub sentry_send_panic: bool, + #[serde(default = "true_fn")] + pub sentry_send_error: bool, + #[serde(default = "default_sentry_filter")] + pub sentry_filter: String, #[serde(default)] pub tokio_console: bool, @@ -386,8 +377,23 @@ pub struct WellKnownConfig { pub support_mxid: Option, } -const DEPRECATED_KEYS: &[&str] = &[ +#[derive(Deserialize, Clone, Debug)] +#[serde(transparent)] +struct ListeningPort { + #[serde(with = "either::serde_untagged")] + ports: Either>, +} + +#[derive(Deserialize, Clone, Debug)] +#[serde(transparent)] +struct ListeningAddr { + #[serde(with = "either::serde_untagged")] + addrs: Either>, +} + +const DEPRECATED_KEYS: &[&str; 9] = &[ "cache_capacity", + "conduit_cache_capacity_modifier", "max_concurrent_requests", "well_known_client", "well_known_server", @@ -399,7 +405,7 @@ const DEPRECATED_KEYS: &[&str] = &[ impl Config { /// Initialize config - pub fn new(path: Option) -> Result { + pub fn new(path: Option) -> Result { let raw_config = if let Some(config_file_env) = Env::var("CONDUIT_CONFIG") { Figment::new() .merge(Toml::file(config_file_env).nested()) @@ -422,69 +428,16 @@ impl Config { }; let config = match raw_config.extract::() { - Err(e) => return Err(Error::BadConfig(format!("{e}"))), + Err(e) => return Err!("There was a problem with your configuration file: {e}"), Ok(config) => config, }; // don't start if we're listening on both UNIX sockets and TCP at same time - if Self::is_dual_listening(&raw_config) { - return Err(Error::bad_config("dual listening on UNIX and TCP sockets not allowed.")); - }; + check::is_dual_listening(&raw_config)?; Ok(config) } - /// Iterates over all the keys in the config file and warns if there is a - /// deprecated key specified - pub(crate) fn warn_deprecated(&self) { - debug!("Checking for deprecated config keys"); - let mut was_deprecated = false; - for key in self - .catchall - .keys() - .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) - { - warn!("Config parameter \"{}\" is deprecated, ignoring.", key); - was_deprecated = true; - } - - if was_deprecated { - warn!( - "Read conduwuit config documentation at https://conduwuit.puppyirl.gay/configuration.html and check \ - your configuration if any new configuration parameters should be adjusted" - ); - } - } - - /// iterates over all the catchall keys (unknown config options) and warns - /// if there are any. - pub(crate) fn warn_unknown_key(&self) { - debug!("Checking for unknown config keys"); - for key in self - .catchall - .keys() - .filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) - { - warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); - } - } - - /// Checks the presence of the `address` and `unix_socket_path` keys in the - /// raw_config, exiting the process if both keys were detected. - fn is_dual_listening(raw_config: &Figment) -> bool { - let check_address = raw_config.find_value("address"); - let check_unix_socket = raw_config.find_value("unix_socket_path"); - - // are the check_address and check_unix_socket keys both Ok (specified) at the - // same time? - if check_address.is_ok() && check_unix_socket.is_ok() { - error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option."); - return true; - } - - false - } - #[must_use] pub fn get_bind_addrs(&self) -> Vec { let mut addrs = Vec::new(); @@ -516,361 +469,358 @@ impl Config { impl fmt::Display for Config { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Prepare a list of config values to show - let lines = [ - ("Server name", self.server_name.host()), - ("Database backend", &self.database_backend), - ("Database path", &self.database_path.to_string_lossy()), - ( - "Database backup path", - self.database_backup_path - .as_ref() - .map_or("", |path| path.to_str().unwrap_or("")), - ), - ("Database backups to keep", &self.database_backups_to_keep.to_string()), - ("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()), - ("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()), - ("PDU cache capacity", &self.pdu_cache_capacity.to_string()), - ("Auth chain cache capacity", &self.auth_chain_cache_capacity.to_string()), - ("Short eventid cache capacity", &self.shorteventid_cache_capacity.to_string()), - ("Eventid short cache capacity", &self.eventidshort_cache_capacity.to_string()), - ("Short statekey cache capacity", &self.shortstatekey_cache_capacity.to_string()), - ("Statekey short cache capacity", &self.statekeyshort_cache_capacity.to_string()), - ( - "Server visibility cache capacity", - &self.server_visibility_cache_capacity.to_string(), - ), - ( - "User visibility cache capacity", - &self.user_visibility_cache_capacity.to_string(), - ), - ("Stateinfo cache capacity", &self.stateinfo_cache_capacity.to_string()), - ( - "Roomid space hierarchy cache capacity", - &self.roomid_spacehierarchy_cache_capacity.to_string(), - ), - ("DNS cache entry limit", &self.dns_cache_entries.to_string()), - ("DNS minimum TTL", &self.dns_min_ttl.to_string()), - ("DNS minimum NXDOMAIN TTL", &self.dns_min_ttl_nxdomain.to_string()), - ("DNS attempts", &self.dns_attempts.to_string()), - ("DNS timeout", &self.dns_timeout.to_string()), - ("DNS fallback to TCP", &self.dns_tcp_fallback.to_string()), - ("DNS query over TCP only", &self.query_over_tcp_only.to_string()), - ("Query all nameservers", &self.query_all_nameservers.to_string()), - ("Maximum request size (bytes)", &self.max_request_size.to_string()), - ("Sender retry backoff limit", &self.sender_retry_backoff_limit.to_string()), - ("Request connect timeout", &self.request_conn_timeout.to_string()), - ("Request timeout", &self.request_timeout.to_string()), - ("Request total timeout", &self.request_total_timeout.to_string()), - ("Idle connections per host", &self.request_idle_per_host.to_string()), - ("Request pool idle timeout", &self.request_idle_timeout.to_string()), - ("Well_known connect timeout", &self.well_known_conn_timeout.to_string()), - ("Well_known timeout", &self.well_known_timeout.to_string()), - ("Federation timeout", &self.federation_timeout.to_string()), - ("Federation pool idle per host", &self.federation_idle_per_host.to_string()), - ("Federation pool idle timeout", &self.federation_idle_timeout.to_string()), - ("Sender timeout", &self.sender_timeout.to_string()), - ("Sender pool idle timeout", &self.sender_idle_timeout.to_string()), - ("Appservice timeout", &self.appservice_timeout.to_string()), - ("Appservice pool idle timeout", &self.appservice_idle_timeout.to_string()), - ("Pusher pool idle timeout", &self.pusher_idle_timeout.to_string()), - ("Allow registration", &self.allow_registration.to_string()), - ( - "Registration token", - if self.registration_token.is_some() { - "set" - } else { - "not set (open registration!)" - }, - ), - ( - "Allow guest registration (inherently false if allow registration is false)", - &self.allow_guest_registration.to_string(), - ), - ( - "Log guest registrations in admin room", - &self.log_guest_registrations.to_string(), - ), - ( - "Allow guests to auto join rooms", - &self.allow_guests_auto_join_rooms.to_string(), - ), - ("New user display name suffix", &self.new_user_displayname_suffix), - ("Allow encryption", &self.allow_encryption.to_string()), - ("Allow federation", &self.allow_federation.to_string()), - ( - "Allow incoming federated presence requests (updates)", - &self.allow_incoming_presence.to_string(), - ), - ( - "Allow outgoing federated presence requests (updates)", - &self.allow_outgoing_presence.to_string(), - ), - ( - "Allow local presence requests (updates)", - &self.allow_local_presence.to_string(), - ), - ( - "Allow incoming remote read receipts", - &self.allow_incoming_read_receipts.to_string(), - ), - ( - "Allow outgoing remote read receipts", - &self.allow_outgoing_read_receipts.to_string(), - ), - ( - "Block non-admin room invites (local and remote, admins can still send and receive invites)", - &self.block_non_admin_invites.to_string(), - ), - ("Enable admin escape commands", &self.admin_escape_commands.to_string()), - ("Allow outgoing federated typing", &self.allow_outgoing_typing.to_string()), - ("Allow incoming federated typing", &self.allow_incoming_typing.to_string()), - ( - "Incoming federated typing timeout", - &self.typing_federation_timeout_s.to_string(), - ), - ("Client typing timeout minimum", &self.typing_client_timeout_min_s.to_string()), - ("Client typing timeout maxmimum", &self.typing_client_timeout_max_s.to_string()), - ("Allow device name federation", &self.allow_device_name_federation.to_string()), - ( - "Allow incoming profile lookup federation requests", - &self.allow_profile_lookup_federation_requests.to_string(), - ), - ( - "Auto deactivate banned room join attempts", - &self.auto_deactivate_banned_room_attempts.to_string(), - ), - ("Notification push path", &self.notification_push_path), - ("Allow room creation", &self.allow_room_creation.to_string()), - ( - "Allow public room directory over federation", - &self.allow_public_room_directory_over_federation.to_string(), - ), - ( - "Allow public room directory without authentication", - &self.allow_public_room_directory_without_auth.to_string(), - ), - ( - "Lockdown public room directory (only allow admins to publish)", - &self.lockdown_public_room_directory.to_string(), - ), - ( - "JWT secret", - match self.jwt_secret { - Some(_) => "set", - None => "not set", - }, - ), - ( - "Trusted key servers", - &self - .trusted_servers - .iter() - .map(|server| server.host()) - .join(", "), - ), - ( - "Query Trusted Key Servers First", - &self.query_trusted_key_servers_first.to_string(), - ), - ("OpenID Token TTL", &self.openid_token_ttl.to_string()), - ( - "TURN username", - if self.turn_username.is_empty() { - "not set" - } else { - &self.turn_username - }, - ), - ("TURN password", { - if self.turn_password.is_empty() { - "not set" - } else { - "set" - } - }), - ("TURN secret", { - if self.turn_secret.is_empty() { - "not set" - } else { - "set" - } - }), - ("Turn TTL", &self.turn_ttl.to_string()), - ("Turn URIs", { - let mut lst = vec![]; - for item in self.turn_uris.iter().cloned().enumerate() { - let (_, uri): (usize, String) = item; - lst.push(uri); - } - &lst.join(", ") - }), - ("Auto Join Rooms", { - let mut lst = vec![]; - for room in &self.auto_join_rooms { - lst.push(room); - } - &lst.into_iter().join(", ") - }), - #[cfg(feature = "zstd_compression")] - ("Zstd HTTP Compression", &self.zstd_compression.to_string()), - #[cfg(feature = "gzip_compression")] - ("Gzip HTTP Compression", &self.gzip_compression.to_string()), - #[cfg(feature = "brotli_compression")] - ("Brotli HTTP Compression", &self.brotli_compression.to_string()), - ("RocksDB database LOG level", &self.rocksdb_log_level), - ("RocksDB database LOG to stderr", &self.rocksdb_log_stderr.to_string()), - ("RocksDB database LOG time-to-roll", &self.rocksdb_log_time_to_roll.to_string()), - ("RocksDB Max LOG Files", &self.rocksdb_max_log_files.to_string()), - ( - "RocksDB database max LOG file size", - &self.rocksdb_max_log_file_size.to_string(), - ), - ( - "RocksDB database optimize for spinning disks", - &self.rocksdb_optimize_for_spinning_disks.to_string(), - ), - ("RocksDB Direct-IO", &self.rocksdb_direct_io.to_string()), - ("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()), - ("RocksDB Compression Algorithm", &self.rocksdb_compression_algo), - ("RocksDB Compression Level", &self.rocksdb_compression_level.to_string()), - ( - "RocksDB Bottommost Compression Level", - &self.rocksdb_bottommost_compression_level.to_string(), - ), - ( - "RocksDB Bottommost Level Compression", - &self.rocksdb_bottommost_compression.to_string(), - ), - ("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()), - ("RocksDB Repair Mode", &self.rocksdb_repair.to_string()), - ("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()), - ( - "RocksDB Compaction Idle Priority", - &self.rocksdb_compaction_prio_idle.to_string(), - ), - ( - "RocksDB Compaction Idle IOPriority", - &self.rocksdb_compaction_ioprio_idle.to_string(), - ), - ("Media integrity checks on startup", &self.media_startup_check.to_string()), - ("Media compatibility filesystem links", &self.media_compat_file_link.to_string()), - ("Prevent Media Downloads From", { - let mut lst = vec![]; - for domain in &self.prevent_media_downloads_from { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Forbidden Remote Server Names (\"Global\" ACLs)", { - let mut lst = vec![]; - for domain in &self.forbidden_remote_server_names { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Forbidden Remote Room Directory Server Names", { - let mut lst = vec![]; - for domain in &self.forbidden_remote_room_directory_server_names { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Outbound Request IP Range Denylist", { - let mut lst = vec![]; - for item in self.ip_range_denylist.iter().cloned().enumerate() { - let (_, ip): (usize, String) = item; - lst.push(ip); - } - &lst.join(", ") - }), - ("Forbidden usernames", { - &self.forbidden_usernames.patterns().iter().join(", ") - }), - ("Forbidden room aliases", { - &self.forbidden_alias_names.patterns().iter().join(", ") - }), - ( - "URL preview domain contains allowlist", - &self.url_preview_domain_contains_allowlist.join(", "), - ), - ( - "URL preview domain explicit allowlist", - &self.url_preview_domain_explicit_allowlist.join(", "), - ), - ( - "URL preview domain explicit denylist", - &self.url_preview_domain_explicit_denylist.join(", "), - ), - ( - "URL preview URL contains allowlist", - &self.url_preview_url_contains_allowlist.join(", "), - ), - ("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()), - ("URL preview check root domain", &self.url_preview_check_root_domain.to_string()), - ( - "Allow check for updates / announcements check", - &self.allow_check_for_updates.to_string(), - ), - ("Enable netburst on startup", &self.startup_netburst.to_string()), - #[cfg(feature = "sentry_telemetry")] - ("Sentry.io reporting and tracing", &self.sentry.to_string()), - #[cfg(feature = "sentry_telemetry")] - ("Sentry.io send server_name in logs", &self.sentry_send_server_name.to_string()), - #[cfg(feature = "sentry_telemetry")] - ("Sentry.io tracing sample rate", &self.sentry_traces_sample_rate.to_string()), - ( - "Well-known server name", - self.well_known - .server - .as_ref() - .map_or("", |server| server.as_str()), - ), - ( - "Well-known client URL", - self.well_known - .client - .as_ref() - .map_or("", |url| url.as_str()), - ), - ( - "Well-known support email", - self.well_known - .support_email - .as_ref() - .map_or("", |str| str.as_ref()), - ), - ( - "Well-known support Matrix ID", - self.well_known - .support_mxid - .as_ref() - .map_or("", |mxid| mxid.as_str()), - ), - ( - "Well-known support role", - self.well_known - .support_role - .as_ref() - .map_or("", |role| role.as_str()), - ), - ( - "Well-known support page/URL", - self.well_known - .support_page - .as_ref() - .map_or("", |url| url.as_str()), - ), - ("Enable the tokio-console", &self.tokio_console.to_string()), - ]; + writeln!(f, "Active config values:\n\n").expect("wrote line to formatter stream"); + let mut line = |key: &str, val: &str| { + writeln!(f, "{key}: {val}").expect("wrote line to formatter stream"); + }; - let mut msg: String = "Active config values:\n\n".to_owned(); + line("Server name", self.server_name.host()); + line("Database backend", &self.database_backend); + line("Database path", &self.database_path.to_string_lossy()); + line( + "Database backup path", + self.database_backup_path + .as_ref() + .map_or("", |path| path.to_str().unwrap_or("")), + ); + line("Database backups to keep", &self.database_backups_to_keep.to_string()); + line("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()); + line("Cache capacity modifier", &self.cache_capacity_modifier.to_string()); + line("PDU cache capacity", &self.pdu_cache_capacity.to_string()); + line("Auth chain cache capacity", &self.auth_chain_cache_capacity.to_string()); + line("Short eventid cache capacity", &self.shorteventid_cache_capacity.to_string()); + line("Eventid short cache capacity", &self.eventidshort_cache_capacity.to_string()); + line("Short statekey cache capacity", &self.shortstatekey_cache_capacity.to_string()); + line("Statekey short cache capacity", &self.statekeyshort_cache_capacity.to_string()); + line( + "Server visibility cache capacity", + &self.server_visibility_cache_capacity.to_string(), + ); + line( + "User visibility cache capacity", + &self.user_visibility_cache_capacity.to_string(), + ); + line("Stateinfo cache capacity", &self.stateinfo_cache_capacity.to_string()); + line( + "Roomid space hierarchy cache capacity", + &self.roomid_spacehierarchy_cache_capacity.to_string(), + ); + line("DNS cache entry limit", &self.dns_cache_entries.to_string()); + line("DNS minimum TTL", &self.dns_min_ttl.to_string()); + line("DNS minimum NXDOMAIN TTL", &self.dns_min_ttl_nxdomain.to_string()); + line("DNS attempts", &self.dns_attempts.to_string()); + line("DNS timeout", &self.dns_timeout.to_string()); + line("DNS fallback to TCP", &self.dns_tcp_fallback.to_string()); + line("DNS query over TCP only", &self.query_over_tcp_only.to_string()); + line("Query all nameservers", &self.query_all_nameservers.to_string()); + line("Maximum request size (bytes)", &self.max_request_size.to_string()); + line("Sender retry backoff limit", &self.sender_retry_backoff_limit.to_string()); + line("Request connect timeout", &self.request_conn_timeout.to_string()); + line("Request timeout", &self.request_timeout.to_string()); + line("Request total timeout", &self.request_total_timeout.to_string()); + line("Idle connections per host", &self.request_idle_per_host.to_string()); + line("Request pool idle timeout", &self.request_idle_timeout.to_string()); + line("Well_known connect timeout", &self.well_known_conn_timeout.to_string()); + line("Well_known timeout", &self.well_known_timeout.to_string()); + line("Federation timeout", &self.federation_timeout.to_string()); + line("Federation pool idle per host", &self.federation_idle_per_host.to_string()); + line("Federation pool idle timeout", &self.federation_idle_timeout.to_string()); + line("Sender timeout", &self.sender_timeout.to_string()); + line("Sender pool idle timeout", &self.sender_idle_timeout.to_string()); + line("Appservice timeout", &self.appservice_timeout.to_string()); + line("Appservice pool idle timeout", &self.appservice_idle_timeout.to_string()); + line("Pusher pool idle timeout", &self.pusher_idle_timeout.to_string()); + line("Allow registration", &self.allow_registration.to_string()); + line( + "Registration token", + if self.registration_token.is_some() { + "set" + } else { + "not set (open registration!)" + }, + ); + line( + "Allow guest registration (inherently false if allow registration is false)", + &self.allow_guest_registration.to_string(), + ); + line( + "Log guest registrations in admin room", + &self.log_guest_registrations.to_string(), + ); + line( + "Allow guests to auto join rooms", + &self.allow_guests_auto_join_rooms.to_string(), + ); + line("New user display name suffix", &self.new_user_displayname_suffix); + line("Allow encryption", &self.allow_encryption.to_string()); + line("Allow federation", &self.allow_federation.to_string()); + line( + "Allow incoming federated presence requests (updates)", + &self.allow_incoming_presence.to_string(), + ); + line( + "Allow outgoing federated presence requests (updates)", + &self.allow_outgoing_presence.to_string(), + ); + line( + "Allow local presence requests (updates)", + &self.allow_local_presence.to_string(), + ); + line( + "Allow incoming remote read receipts", + &self.allow_incoming_read_receipts.to_string(), + ); + line( + "Allow outgoing remote read receipts", + &self.allow_outgoing_read_receipts.to_string(), + ); + line( + "Block non-admin room invites (local and remote, admins can still send and receive invites)", + &self.block_non_admin_invites.to_string(), + ); + line("Enable admin escape commands", &self.admin_escape_commands.to_string()); + line("Allow outgoing federated typing", &self.allow_outgoing_typing.to_string()); + line("Allow incoming federated typing", &self.allow_incoming_typing.to_string()); + line( + "Incoming federated typing timeout", + &self.typing_federation_timeout_s.to_string(), + ); + line("Client typing timeout minimum", &self.typing_client_timeout_min_s.to_string()); + line("Client typing timeout maxmimum", &self.typing_client_timeout_max_s.to_string()); + line("Allow device name federation", &self.allow_device_name_federation.to_string()); + line( + "Allow incoming profile lookup federation requests", + &self.allow_profile_lookup_federation_requests.to_string(), + ); + line( + "Auto deactivate banned room join attempts", + &self.auto_deactivate_banned_room_attempts.to_string(), + ); + line("Notification push path", &self.notification_push_path); + line("Allow room creation", &self.allow_room_creation.to_string()); + line( + "Allow public room directory over federation", + &self.allow_public_room_directory_over_federation.to_string(), + ); + line( + "Allow public room directory without authentication", + &self.allow_public_room_directory_without_auth.to_string(), + ); + line( + "Lockdown public room directory (only allow admins to publish)", + &self.lockdown_public_room_directory.to_string(), + ); + line( + "JWT secret", + match self.jwt_secret { + Some(_) => "set", + None => "not set", + }, + ); + line( + "Trusted key servers", + &self + .trusted_servers + .iter() + .map(|server| server.host()) + .join(", "), + ); + line( + "Query Trusted Key Servers First", + &self.query_trusted_key_servers_first.to_string(), + ); + line("OpenID Token TTL", &self.openid_token_ttl.to_string()); + line( + "TURN username", + if self.turn_username.is_empty() { + "not set" + } else { + &self.turn_username + }, + ); + line("TURN password", { + if self.turn_password.is_empty() { + "not set" + } else { + "set" + } + }); + line("TURN secret", { + if self.turn_secret.is_empty() { + "not set" + } else { + "set" + } + }); + line("Turn TTL", &self.turn_ttl.to_string()); + line("Turn URIs", { + let mut lst = vec![]; + for item in self.turn_uris.iter().cloned().enumerate() { + let (_, uri): (usize, String) = item; + lst.push(uri); + } + &lst.join(", ") + }); + line("Auto Join Rooms", { + let mut lst = vec![]; + for room in &self.auto_join_rooms { + lst.push(room); + } + &lst.into_iter().join(", ") + }); + line("Zstd HTTP Compression", &self.zstd_compression.to_string()); + line("Gzip HTTP Compression", &self.gzip_compression.to_string()); + line("Brotli HTTP Compression", &self.brotli_compression.to_string()); + line("RocksDB database LOG level", &self.rocksdb_log_level); + line("RocksDB database LOG to stderr", &self.rocksdb_log_stderr.to_string()); + line("RocksDB database LOG time-to-roll", &self.rocksdb_log_time_to_roll.to_string()); + line("RocksDB Max LOG Files", &self.rocksdb_max_log_files.to_string()); + line( + "RocksDB database max LOG file size", + &self.rocksdb_max_log_file_size.to_string(), + ); + line( + "RocksDB database optimize for spinning disks", + &self.rocksdb_optimize_for_spinning_disks.to_string(), + ); + line("RocksDB Direct-IO", &self.rocksdb_direct_io.to_string()); + line("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()); + line("RocksDB Compression Algorithm", &self.rocksdb_compression_algo); + line("RocksDB Compression Level", &self.rocksdb_compression_level.to_string()); + line( + "RocksDB Bottommost Compression Level", + &self.rocksdb_bottommost_compression_level.to_string(), + ); + line( + "RocksDB Bottommost Level Compression", + &self.rocksdb_bottommost_compression.to_string(), + ); + line("RocksDB Recovery Mode", &self.rocksdb_recovery_mode.to_string()); + line("RocksDB Repair Mode", &self.rocksdb_repair.to_string()); + line("RocksDB Read-only Mode", &self.rocksdb_read_only.to_string()); + line( + "RocksDB Compaction Idle Priority", + &self.rocksdb_compaction_prio_idle.to_string(), + ); + line( + "RocksDB Compaction Idle IOPriority", + &self.rocksdb_compaction_ioprio_idle.to_string(), + ); + line("Media integrity checks on startup", &self.media_startup_check.to_string()); + line("Media compatibility filesystem links", &self.media_compat_file_link.to_string()); + line("Prevent Media Downloads From", { + let mut lst = vec![]; + for domain in &self.prevent_media_downloads_from { + lst.push(domain.host()); + } + &lst.join(", ") + }); + line("Forbidden Remote Server Names (\"Global\" ACLs)", { + let mut lst = vec![]; + for domain in &self.forbidden_remote_server_names { + lst.push(domain.host()); + } + &lst.join(", ") + }); + line("Forbidden Remote Room Directory Server Names", { + let mut lst = vec![]; + for domain in &self.forbidden_remote_room_directory_server_names { + lst.push(domain.host()); + } + &lst.join(", ") + }); + line("Outbound Request IP Range Denylist", { + let mut lst = vec![]; + for item in self.ip_range_denylist.iter().cloned().enumerate() { + let (_, ip): (usize, String) = item; + lst.push(ip); + } + &lst.join(", ") + }); + line("Forbidden usernames", { + &self.forbidden_usernames.patterns().iter().join(", ") + }); + line("Forbidden room aliases", { + &self.forbidden_alias_names.patterns().iter().join(", ") + }); + line( + "URL preview domain contains allowlist", + &self.url_preview_domain_contains_allowlist.join(", "), + ); + line( + "URL preview domain explicit allowlist", + &self.url_preview_domain_explicit_allowlist.join(", "), + ); + line( + "URL preview domain explicit denylist", + &self.url_preview_domain_explicit_denylist.join(", "), + ); + line( + "URL preview URL contains allowlist", + &self.url_preview_url_contains_allowlist.join(", "), + ); + line("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()); + line("URL preview check root domain", &self.url_preview_check_root_domain.to_string()); + line( + "Allow check for updates / announcements check", + &self.allow_check_for_updates.to_string(), + ); + line("Enable netburst on startup", &self.startup_netburst.to_string()); + #[cfg(feature = "sentry_telemetry")] + line("Sentry.io reporting and tracing", &self.sentry.to_string()); + #[cfg(feature = "sentry_telemetry")] + line("Sentry.io send server_name in logs", &self.sentry_send_server_name.to_string()); + #[cfg(feature = "sentry_telemetry")] + line("Sentry.io tracing sample rate", &self.sentry_traces_sample_rate.to_string()); + line("Sentry.io attach stacktrace", &self.sentry_attach_stacktrace.to_string()); + line("Sentry.io send panics", &self.sentry_send_panic.to_string()); + line("Sentry.io send errors", &self.sentry_send_error.to_string()); + line("Sentry.io tracing filter", &self.sentry_filter); + line( + "Well-known server name", + self.well_known + .server + .as_ref() + .map_or("", |server| server.as_str()), + ); + line( + "Well-known client URL", + self.well_known + .client + .as_ref() + .map_or("", |url| url.as_str()), + ); + line( + "Well-known support email", + self.well_known + .support_email + .as_ref() + .map_or("", |str| str.as_ref()), + ); + line( + "Well-known support Matrix ID", + self.well_known + .support_mxid + .as_ref() + .map_or("", |mxid| mxid.as_str()), + ); + line( + "Well-known support role", + self.well_known + .support_role + .as_ref() + .map_or("", |role| role.as_str()), + ); + line( + "Well-known support page/URL", + self.well_known + .support_page + .as_ref() + .map_or("", |url| url.as_str()), + ); + line("Enable the tokio-console", &self.tokio_console.to_string()); - for line in lines.into_iter().enumerate() { - writeln!(msg, "{}: {}", line.1 .0, line.1 .1).expect("should be able to write to string buffer"); - } - - write!(f, "{msg}") + Ok(()) } } @@ -898,7 +848,7 @@ fn default_db_cache_capacity_mb() -> f64 { 256.0 } fn default_pdu_cache_capacity() -> u32 { 150_000 } -fn default_conduit_cache_capacity_modifier() -> f64 { 1.0 } +fn default_cache_capacity_modifier() -> f64 { 1.0 } fn default_auth_chain_cache_capacity() -> u32 { 100_000 } @@ -930,7 +880,7 @@ fn default_dns_timeout() -> u64 { 10 } fn default_ip_lookup_strategy() -> u8 { 5 } -fn default_max_request_size() -> u32 { +fn default_max_request_size() -> usize { 20 * 1024 * 1024 // Default to 20 MB } @@ -968,10 +918,20 @@ fn default_pusher_idle_timeout() -> u64 { 15 } fn default_max_fetch_prev_events() -> u16 { 100_u16 } -#[cfg(feature = "perf_measurements")] -fn default_tracing_flame_filter() -> String { "trace,h2=off".to_owned() } +fn default_tracing_flame_filter() -> String { + cfg!(debug_assertions) + .then_some("trace,h2=off") + .unwrap_or("info") + .to_owned() +} + +fn default_jaeger_filter() -> String { + cfg!(debug_assertions) + .then_some("trace,h2=off") + .unwrap_or("info") + .to_owned() +} -#[cfg(feature = "perf_measurements")] fn default_tracing_flame_output_path() -> String { "./tracing.folded".to_owned() } fn default_trusted_servers() -> Vec { vec![OwnedServerName::try_from("matrix.org").unwrap()] } @@ -1070,4 +1030,6 @@ fn default_sentry_endpoint() -> Option { fn default_sentry_traces_sample_rate() -> f32 { 0.15 } +fn default_sentry_filter() -> String { "info".to_owned() } + fn default_startup_netburst_keep() -> i64 { 50 } diff --git a/src/core/config/proxy.rs b/src/core/config/proxy.rs index d823e5e4..48f883c6 100644 --- a/src/core/config/proxy.rs +++ b/src/core/config/proxy.rs @@ -127,6 +127,7 @@ impl WildCardedDomain { impl std::str::FromStr for WildCardedDomain { type Err = std::convert::Infallible; + #[allow(clippy::string_slice)] fn from_str(s: &str) -> Result { // maybe do some domain validation? Ok(if s.starts_with("*.") { diff --git a/src/core/debug.rs b/src/core/debug.rs index 7522aa77..14d0be87 100644 --- a/src/core/debug.rs +++ b/src/core/debug.rs @@ -1,6 +1,4 @@ -#![allow(dead_code)] // this is a developer's toolbox - -use std::panic; +use std::{any::Any, panic}; /// Export all of the ancillary tools from here as well. pub use crate::utils::debug::*; @@ -14,9 +12,9 @@ pub use crate::utils::debug::*; macro_rules! debug_event { ( $level:expr, $($x:tt)+ ) => { if cfg!(debug_assertions) && cfg!(not(feature = "dev_release_log_level")) { - ::tracing::event!( $level, $($x)+ ); + ::tracing::event!( $level, $($x)+ ) } else { - ::tracing::debug!( $($x)+ ); + ::tracing::debug!( $($x)+ ) } } } @@ -27,7 +25,7 @@ macro_rules! debug_event { #[macro_export] macro_rules! debug_error { ( $($x:tt)+ ) => { - $crate::debug_event!(::tracing::Level::ERROR, $($x)+ ); + $crate::debug_event!(::tracing::Level::ERROR, $($x)+ ) } } @@ -37,7 +35,7 @@ macro_rules! debug_error { #[macro_export] macro_rules! debug_warn { ( $($x:tt)+ ) => { - $crate::debug_event!(::tracing::Level::WARN, $($x)+ ); + $crate::debug_event!(::tracing::Level::WARN, $($x)+ ) } } @@ -47,7 +45,7 @@ macro_rules! debug_warn { #[macro_export] macro_rules! debug_info { ( $($x:tt)+ ) => { - $crate::debug_event!(::tracing::Level::INFO, $($x)+ ); + $crate::debug_event!(::tracing::Level::INFO, $($x)+ ) } } @@ -79,3 +77,6 @@ pub fn trap() { std::arch::asm!("int3"); } } + +#[must_use] +pub fn panic_str(p: &Box) -> &'static str { p.downcast_ref::<&str>().copied().unwrap_or_default() } diff --git a/src/core/error.rs b/src/core/error.rs deleted file mode 100644 index ac2e176d..00000000 --- a/src/core/error.rs +++ /dev/null @@ -1,209 +0,0 @@ -use std::{convert::Infallible, fmt}; - -use axum::response::{IntoResponse, Response}; -use bytes::BytesMut; -use http::StatusCode; -use http_body_util::Full; -use ruma::{ - api::{ - client::{ - error::ErrorKind::{ - Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, ThreepidAuthFailed, - ThreepidDenied, TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, - WrongRoomKeysVersion, - }, - uiaa::{UiaaInfo, UiaaResponse}, - }, - OutgoingResponse, - }, - OwnedServerName, -}; -use thiserror::Error; -use tracing::error; - -#[derive(Error)] -pub enum Error { - // std - #[error("{0}")] - Fmt(#[from] fmt::Error), - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - #[error("{0}")] - Utf8Error(#[from] std::str::Utf8Error), - #[error("{0}")] - FromUtf8Error(#[from] std::string::FromUtf8Error), - #[error("{0}")] - TryFromSliceError(#[from] std::array::TryFromSliceError), - - // third-party - #[error("Regex error: {0}")] - Regex(#[from] regex::Error), - #[error("Tracing filter error: {0}")] - TracingFilter(#[from] tracing_subscriber::filter::ParseError), - #[error("Image error: {0}")] - Image(#[from] image::error::ImageError), - #[error("Request error: {0}")] - Reqwest(#[from] reqwest::Error), - #[error("{0}")] - Extension(#[from] axum::extract::rejection::ExtensionRejection), - #[error("{0}")] - Path(#[from] axum::extract::rejection::PathRejection), - - // ruma - #[error("{0}")] - Mxid(#[from] ruma::IdParseError), - #[error("{0}: {1}")] - BadRequest(ruma::api::client::error::ErrorKind, &'static str), - #[error("from {0}: {1}")] - Redaction(OwnedServerName, ruma::canonical_json::RedactionError), - #[error("Remote server {0} responded with: {1}")] - Federation(OwnedServerName, ruma::api::client::error::Error), - #[error("{0} in {1}")] - InconsistentRoomState(&'static str, ruma::OwnedRoomId), - - // conduwuit - #[error("There was a problem with your configuration: {0}")] - BadConfig(String), - #[error("{0}")] - BadDatabase(&'static str), - #[error("{0}")] - Database(String), - #[error("{0}")] - BadServerResponse(&'static str), - #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists - #[error("uiaa")] - Uiaa(UiaaInfo), - - // unique / untyped - #[error("{0}")] - Err(String), -} - -impl Error { - pub fn bad_database(message: &'static str) -> Self { - error!("BadDatabase: {}", message); - Self::BadDatabase(message) - } - - pub fn bad_config(message: &str) -> Self { - error!("BadConfig: {}", message); - Self::BadConfig(message.to_owned()) - } - - /// Returns the Matrix error code / error kind - pub fn error_code(&self) -> ruma::api::client::error::ErrorKind { - if let Self::Federation(_, error) = self { - return error.error_kind().unwrap_or_else(|| &Unknown).clone(); - } - - match self { - Self::BadRequest(kind, _) => kind.clone(), - _ => Unknown, - } - } - - /// Sanitizes public-facing errors that can leak sensitive information. - pub fn sanitized_error(&self) -> String { - match self { - Self::Database { - .. - } => String::from("Database error occurred."), - Self::Io { - .. - } => String::from("I/O error occurred."), - _ => self.to_string(), - } - } -} - -impl From for Error { - fn from(i: Infallible) -> Self { match i {} } -} - -impl fmt::Debug for Error { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } -} - -#[derive(Clone)] -pub struct RumaResponse(pub T); - -impl From for RumaResponse { - fn from(t: T) -> Self { Self(t) } -} - -impl From for RumaResponse { - fn from(t: Error) -> Self { t.to_response() } -} - -impl Error { - pub fn to_response(&self) -> RumaResponse { - use ruma::api::client::error::{Error as RumaError, ErrorBody}; - - if let Self::Uiaa(uiaainfo) = self { - return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); - } - - if let Self::Federation(origin, error) = self { - let mut error = error.clone(); - error.body = ErrorBody::Standard { - kind: error.error_kind().unwrap_or_else(|| &Unknown).clone(), - message: format!("Answer from {origin}: {error}"), - }; - return RumaResponse(UiaaResponse::MatrixError(error)); - } - - let message = format!("{self}"); - let (kind, status_code) = match self { - Self::BadRequest(kind, _) => ( - kind.clone(), - match kind { - WrongRoomKeysVersion { - .. - } - | Forbidden { - .. - } - | GuestAccessForbidden - | ThreepidAuthFailed - | UserDeactivated - | ThreepidDenied => StatusCode::FORBIDDEN, - Unauthorized - | UnknownToken { - .. - } - | MissingToken => StatusCode::UNAUTHORIZED, - NotFound | Unrecognized => StatusCode::NOT_FOUND, - LimitExceeded { - .. - } => StatusCode::TOO_MANY_REQUESTS, - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - _ => StatusCode::BAD_REQUEST, - }, - ), - Self::Conflict(_) => (Unknown, StatusCode::CONFLICT), - _ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR), - }; - - RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { - kind, - message, - }, - status_code, - })) - } -} - -impl ::axum::response::IntoResponse for Error { - fn into_response(self) -> ::axum::response::Response { self.to_response().into_response() } -} - -impl IntoResponse for RumaResponse { - fn into_response(self) -> Response { - match self.0.try_into_http_response::() { - Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), - Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), - } - } -} diff --git a/src/core/error/err.rs b/src/core/error/err.rs new file mode 100644 index 00000000..55e38f18 --- /dev/null +++ b/src/core/error/err.rs @@ -0,0 +1,100 @@ +//! Error construction macros +//! +//! These are specialized macros specific to this project's patterns for +//! throwing Errors; they make Error construction succinct and reduce clutter. +//! They are developed from folding existing patterns into the macro while +//! fixing several anti-patterns in the codebase. +//! +//! - The primary macros `Err!` and `err!` are provided. `Err!` simply wraps +//! `err!` in the Result variant to reduce `Err(err!(...))` boilerplate, thus +//! `err!` can be used in any case. +//! +//! 1. The macro makes the general Error construction easy: `return +//! Err!("something went wrong")` replaces the prior `return +//! Err(Error::Err("something went wrong".to_owned()))`. +//! +//! 2. The macro integrates format strings automatically: `return +//! Err!("something bad: {msg}")` replaces the prior `return +//! Err(Error::Err(format!("something bad: {msg}")))`. +//! +//! 3. The macro scopes variants of Error: `return Err!(Database("problem with +//! bad database."))` replaces the prior `return Err(Error::Database("problem +//! with bad database."))`. +//! +//! 4. The macro matches and scopes some special-case sub-variants, for example +//! with ruma ErrorKind: `return Err!(Request(MissingToken("you must provide +//! an access token")))`. +//! +//! 5. The macro fixes the anti-pattern of repeating messages in an error! log +//! and then again in an Error construction, often slightly different due to +//! the Error variant not supporting a format string. Instead `return +//! Err(Database(error!("problem with db: {msg}")))` logs the error at the +//! callsite and then returns the error with the same string. Caller has the +//! option of replacing `error!` with `debug_error!`. + +#[macro_export] +macro_rules! Err { + ($($args:tt)*) => { + Err($crate::err!($($args)*)) + }; +} + +#[macro_export] +macro_rules! err { + (Config($item:literal, $($args:expr),*)) => {{ + $crate::error!(config = %$item, $($args),*); + $crate::error::Error::Config($item, $crate::format_maybe!($($args),*)) + }}; + + (Request(Forbidden($level:ident!($($args:expr),*)))) => {{ + $crate::$level!($($args),*); + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST + ) + }}; + + (Request(Forbidden($($args:expr),*))) => { + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::forbidden(), + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST + ) + }; + + (Request($variant:ident($level:ident!($($args:expr),*)))) => {{ + $crate::$level!($($args),*); + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::$variant, + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST + ) + }}; + + (Request($variant:ident($($args:expr),*))) => { + $crate::error::Error::Request( + ::ruma::api::client::error::ErrorKind::$variant, + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST + ) + }; + + ($variant:ident($level:ident!($($args:expr),*))) => {{ + $crate::$level!($($args),*); + $crate::error::Error::$variant($crate::format_maybe!($($args),*)) + }}; + + ($variant:ident($($args:expr),*)) => { + $crate::error::Error::$variant($crate::format_maybe!($($args),*)) + }; + + ($level:ident!($($args:expr),*)) => {{ + $crate::$level!($($args),*); + $crate::error::Error::Err($crate::format_maybe!($($args),*)) + }}; + + ($($args:expr),*) => { + $crate::error::Error::Err($crate::format_maybe!($($args),*)) + }; +} diff --git a/src/core/error/log.rs b/src/core/error/log.rs new file mode 100644 index 00000000..c272bf73 --- /dev/null +++ b/src/core/error/log.rs @@ -0,0 +1,74 @@ +use std::{convert::Infallible, fmt}; + +use super::Error; +use crate::{debug_error, error}; + +#[inline] +pub fn else_log(error: E) -> Result +where + T: Default, + Error: From, +{ + Ok(default_log(error)) +} + +#[inline] +pub fn else_debug_log(error: E) -> Result +where + T: Default, + Error: From, +{ + Ok(default_debug_log(error)) +} + +#[inline] +pub fn default_log(error: E) -> T +where + T: Default, + Error: From, +{ + let error = Error::from(error); + inspect_log(&error); + T::default() +} + +#[inline] +pub fn default_debug_log(error: E) -> T +where + T: Default, + Error: From, +{ + let error = Error::from(error); + inspect_debug_log(&error); + T::default() +} + +#[inline] +pub fn map_log(error: E) -> Error +where + Error: From, +{ + let error = Error::from(error); + inspect_log(&error); + error +} + +#[inline] +pub fn map_debug_log(error: E) -> Error +where + Error: From, +{ + let error = Error::from(error); + inspect_debug_log(&error); + error +} + +#[inline] +pub fn inspect_log(error: &E) { + error!("{error}"); +} + +#[inline] +pub fn inspect_debug_log(error: &E) { + debug_error!("{error:?}"); +} diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs new file mode 100644 index 00000000..9439261e --- /dev/null +++ b/src/core/error/mod.rs @@ -0,0 +1,156 @@ +mod err; +mod log; +mod panic; +mod response; + +use std::{any::Any, borrow::Cow, convert::Infallible, fmt}; + +pub use log::*; + +use crate::error; + +#[derive(thiserror::Error)] +pub enum Error { + #[error("PANIC!")] + PanicAny(Box), + #[error("PANIC! {0}")] + Panic(&'static str, Box), + + // std + #[error("{0}")] + Fmt(#[from] fmt::Error), + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + #[error("{0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("{0}")] + FromUtf8Error(#[from] std::string::FromUtf8Error), + #[error("{0}")] + TryFromSliceError(#[from] std::array::TryFromSliceError), + #[error("{0}")] + TryFromIntError(#[from] std::num::TryFromIntError), + #[error("{0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("{0}")] + ParseFloatError(#[from] std::num::ParseFloatError), + + // third-party + #[error("Join error: {0}")] + JoinError(#[from] tokio::task::JoinError), + #[error("Regex error: {0}")] + Regex(#[from] regex::Error), + #[error("Tracing filter error: {0}")] + TracingFilter(#[from] tracing_subscriber::filter::ParseError), + #[error("Tracing reload error: {0}")] + TracingReload(#[from] tracing_subscriber::reload::Error), + #[error("Image error: {0}")] + Image(#[from] image::error::ImageError), + #[error("Request error: {0}")] + Reqwest(#[from] reqwest::Error), + #[error("{0}")] + Extension(#[from] axum::extract::rejection::ExtensionRejection), + #[error("{0}")] + Path(#[from] axum::extract::rejection::PathRejection), + #[error("{0}")] + Http(#[from] http::Error), + #[error("{0}")] + HttpHeader(#[from] http::header::InvalidHeaderValue), + + // ruma + #[error("{0}")] + IntoHttpError(#[from] ruma::api::error::IntoHttpError), + #[error("{0}")] + RumaError(#[from] ruma::api::client::error::Error), + #[error("uiaa")] + Uiaa(ruma::api::client::uiaa::UiaaInfo), + #[error("{0}")] + Mxid(#[from] ruma::IdParseError), + #[error("{0}: {1}")] + BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove + #[error("{0}: {1}")] + Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), + #[error("from {0}: {1}")] + Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError), + #[error("Remote server {0} responded with: {1}")] + Federation(ruma::OwnedServerName, ruma::api::client::error::Error), + #[error("{0} in {1}")] + InconsistentRoomState(&'static str, ruma::OwnedRoomId), + + // conduwuit + #[error("Arithmetic operation failed: {0}")] + Arithmetic(Cow<'static, str>), + #[error("There was a problem with the '{0}' directive in your configuration: {1}")] + Config(&'static str, Cow<'static, str>), + #[error("{0}")] + Database(Cow<'static, str>), + #[error("{0}")] + BadServerResponse(&'static str), + #[error("{0}")] + Conflict(&'static str), // This is only needed for when a room alias already exists + + // unique / untyped + #[error("{0}")] + Err(Cow<'static, str>), +} + +impl Error { + pub fn bad_database(message: &'static str) -> Self { crate::err!(Database(error!("{message}"))) } + + /// Sanitizes public-facing errors that can leak sensitive information. + pub fn sanitized_string(&self) -> String { + match self { + Self::Database(..) => String::from("Database error occurred."), + Self::Io(..) => String::from("I/O error occurred."), + _ => self.to_string(), + } + } + + pub fn message(&self) -> String { + match self { + Self::Federation(ref origin, ref error) => format!("Answer from {origin}: {error}"), + Self::RumaError(ref error) => response::ruma_error_message(error), + _ => format!("{self}"), + } + } + + /// Returns the Matrix error code / error kind + #[inline] + pub fn kind(&self) -> ruma::api::client::error::ErrorKind { + use ruma::api::client::error::ErrorKind::Unknown; + + match self { + Self::Federation(_, error) => response::ruma_error_kind(error).clone(), + Self::BadRequest(kind, ..) | Self::Request(kind, ..) => kind.clone(), + _ => Unknown, + } + } + + pub fn status_code(&self) -> http::StatusCode { + match self { + Self::Federation(_, ref error) | Self::RumaError(ref error) => error.status_code, + Self::Request(ref kind, _, code) => response::status_code(kind, *code), + Self::BadRequest(ref kind, ..) => response::bad_request_code(kind), + Self::Conflict(_) => http::StatusCode::CONFLICT, + _ => http::StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +impl fmt::Debug for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{self}") } +} + +#[allow(clippy::fallible_impl_from)] +impl From for Error { + #[cold] + #[inline(never)] + fn from(_e: Infallible) -> Self { + panic!("infallible error should never exist"); + } +} + +#[cold] +#[inline(never)] +pub fn infallible(_e: &Infallible) { + panic!("infallible error should never exist"); +} diff --git a/src/core/error/panic.rs b/src/core/error/panic.rs new file mode 100644 index 00000000..c070f786 --- /dev/null +++ b/src/core/error/panic.rs @@ -0,0 +1,41 @@ +use std::{ + any::Any, + panic::{panic_any, RefUnwindSafe, UnwindSafe}, +}; + +use super::Error; +use crate::debug; + +impl UnwindSafe for Error {} +impl RefUnwindSafe for Error {} + +impl Error { + pub fn panic(self) -> ! { panic_any(self.into_panic()) } + + #[must_use] + pub fn from_panic(e: Box) -> Self { Self::Panic(debug::panic_str(&e), e) } + + pub fn into_panic(self) -> Box { + match self { + Self::Panic(_, e) | Self::PanicAny(e) => e, + Self::JoinError(e) => e.into_panic(), + _ => Box::new(self), + } + } + + /// Get the panic message string. + pub fn panic_str(self) -> Option<&'static str> { + self.is_panic() + .then_some(debug::panic_str(&self.into_panic())) + } + + /// Check if the Error is trafficking a panic object. + #[inline] + pub fn is_panic(&self) -> bool { + match &self { + Self::Panic(..) | Self::PanicAny(..) => true, + Self::JoinError(e) => e.is_panic(), + _ => false, + } + } +} diff --git a/src/core/error/response.rs b/src/core/error/response.rs new file mode 100644 index 00000000..7568a1c0 --- /dev/null +++ b/src/core/error/response.rs @@ -0,0 +1,110 @@ +use bytes::BytesMut; +use http::StatusCode; +use http_body_util::Full; +use ruma::api::{ + client::{ + error::{ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + OutgoingResponse, +}; + +use super::Error; +use crate::error; + +impl axum::response::IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let response: UiaaResponse = self.into(); + response + .try_into_http_response::() + .inspect_err(|e| error!("error response error: {e}")) + .map_or_else( + |_| StatusCode::INTERNAL_SERVER_ERROR.into_response(), + |r| r.map(BytesMut::freeze).map(Full::new).into_response(), + ) + } +} + +impl From for UiaaResponse { + fn from(error: Error) -> Self { + if let Error::Uiaa(uiaainfo) = error { + return Self::AuthResponse(uiaainfo); + } + + let body = ErrorBody::Standard { + kind: error.kind(), + message: error.message(), + }; + + Self::MatrixError(ruma::api::client::error::Error { + status_code: error.status_code(), + body, + }) + } +} + +pub(super) fn status_code(kind: &ErrorKind, hint: StatusCode) -> StatusCode { + if hint == StatusCode::BAD_REQUEST { + bad_request_code(kind) + } else { + hint + } +} + +pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode { + use ErrorKind::*; + + match kind { + // 429 + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + + // 413 + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + + // 405 + Unrecognized => StatusCode::METHOD_NOT_ALLOWED, + + // 404 + NotFound => StatusCode::NOT_FOUND, + + // 403 + GuestAccessForbidden + | ThreepidAuthFailed + | UserDeactivated + | ThreepidDenied + | WrongRoomKeysVersion { + .. + } + | Forbidden { + .. + } => StatusCode::FORBIDDEN, + + // 401 + UnknownToken { + .. + } + | MissingToken + | Unauthorized => StatusCode::UNAUTHORIZED, + + // 400 + _ => StatusCode::BAD_REQUEST, + } +} + +pub(super) fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { + if let ErrorBody::Standard { + message, + .. + } = &error.body + { + return message.to_string(); + } + + format!("{error}") +} + +pub(super) fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ErrorKind { + e.error_kind().unwrap_or(&ErrorKind::Unknown) +} diff --git a/src/core/log/capture/guard.rs b/src/core/log/capture/guard.rs index a126fedc..b5a6d8b3 100644 --- a/src/core/log/capture/guard.rs +++ b/src/core/log/capture/guard.rs @@ -8,5 +8,6 @@ pub struct Guard { } impl Drop for Guard { + #[inline] fn drop(&mut self) { self.capture.stop(); } } diff --git a/src/core/log/capture/layer.rs b/src/core/log/capture/layer.rs index 19d76771..57f22dc7 100644 --- a/src/core/log/capture/layer.rs +++ b/src/core/log/capture/layer.rs @@ -17,6 +17,7 @@ struct Visitor { } impl Layer { + #[inline] pub fn new(state: &Arc) -> Self { Self { state: state.clone(), @@ -25,6 +26,7 @@ impl Layer { } impl fmt::Debug for Layer { + #[inline] fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.debug_struct("capture::Layer").finish() } diff --git a/src/core/log/mod.rs b/src/core/log/mod.rs index daa6b8e8..04d250a6 100644 --- a/src/core/log/mod.rs +++ b/src/core/log/mod.rs @@ -29,25 +29,25 @@ pub struct Log { #[macro_export] macro_rules! error { - ( $($x:tt)+ ) => { ::tracing::error!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::error!( $($x)+ ) } } #[macro_export] macro_rules! warn { - ( $($x:tt)+ ) => { ::tracing::warn!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::warn!( $($x)+ ) } } #[macro_export] macro_rules! info { - ( $($x:tt)+ ) => { ::tracing::info!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::info!( $($x)+ ) } } #[macro_export] macro_rules! debug { - ( $($x:tt)+ ) => { ::tracing::debug!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::debug!( $($x)+ ) } } #[macro_export] macro_rules! trace { - ( $($x:tt)+ ) => { ::tracing::trace!( $($x)+ ); } + ( $($x:tt)+ ) => { ::tracing::trace!( $($x)+ ) } } diff --git a/src/core/log/reload.rs b/src/core/log/reload.rs index 7646254e..6d651065 100644 --- a/src/core/log/reload.rs +++ b/src/core/log/reload.rs @@ -1,7 +1,12 @@ -use std::sync::Arc; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; use tracing_subscriber::{reload, EnvFilter}; +use crate::{error, Result}; + /// We need to store a reload::Handle value, but can't name it's type explicitly /// because the S type parameter depends on the subscriber's previous layers. In /// our case, this includes unnameable 'impl Trait' types. @@ -17,39 +22,60 @@ use tracing_subscriber::{reload, EnvFilter}; /// /// [1]: pub trait ReloadHandle { + fn current(&self) -> Option; + fn reload(&self, new_value: L) -> Result<(), reload::Error>; } -impl ReloadHandle for reload::Handle { +impl ReloadHandle for reload::Handle { + fn current(&self) -> Option { Self::clone_current(self) } + fn reload(&self, new_value: L) -> Result<(), reload::Error> { Self::reload(self, new_value) } } -struct LogLevelReloadHandlesInner { - handles: Vec + Send + Sync>>, -} - -/// Wrapper to allow reloading the filter on several several -/// [`tracing_subscriber::reload::Handle`]s at once, with the same value. #[derive(Clone)] pub struct LogLevelReloadHandles { - inner: Arc, + handles: Arc>, } +type HandleMap = HashMap; +type Handle = Box + Send + Sync>; + impl LogLevelReloadHandles { - #[must_use] - pub fn new(handles: Vec + Send + Sync>>) -> Self { - Self { - inner: Arc::new(LogLevelReloadHandlesInner { - handles, - }), - } + pub fn add(&self, name: &str, handle: Handle) { + self.handles + .lock() + .expect("locked") + .insert(name.into(), handle); } - pub fn reload(&self, new_value: &EnvFilter) -> Result<(), reload::Error> { - for handle in &self.inner.handles { - handle.reload(new_value.clone())?; - } + pub fn reload(&self, new_value: &EnvFilter, names: Option<&[&str]>) -> Result<()> { + self.handles + .lock() + .expect("locked") + .iter() + .filter(|(name, _)| names.map_or(false, |names| names.contains(&name.as_str()))) + .for_each(|(_, handle)| { + _ = handle.reload(new_value.clone()).or_else(error::else_log); + }); Ok(()) } + + #[must_use] + pub fn current(&self, name: &str) -> Option { + self.handles + .lock() + .expect("locked") + .get(name) + .map(|handle| handle.current())? + } +} + +impl Default for LogLevelReloadHandles { + fn default() -> Self { + Self { + handles: Arc::new(HandleMap::new().into()), + } + } } diff --git a/src/core/log/suppress.rs b/src/core/log/suppress.rs index 6e883086..b13ee99e 100644 --- a/src/core/log/suppress.rs +++ b/src/core/log/suppress.rs @@ -10,16 +10,21 @@ pub struct Suppress { impl Suppress { pub fn new(server: &Arc) -> Self { + let handle = "console"; let config = &server.config.log; - Self::from_filters(server, EnvFilter::try_new(config).unwrap_or_default(), &EnvFilter::default()) - } + let suppress = EnvFilter::default(); + let restore = server + .log + .reload + .current(handle) + .unwrap_or_else(|| EnvFilter::try_new(config).unwrap_or_default()); - fn from_filters(server: &Arc, restore: EnvFilter, suppress: &EnvFilter) -> Self { server .log .reload - .reload(suppress) + .reload(&suppress, Some(&[handle])) .expect("log filter reloaded"); + Self { server: server.clone(), restore, @@ -32,7 +37,7 @@ impl Drop for Suppress { self.server .log .reload - .reload(&self.restore) + .reload(&self.restore, Some(&["console"])) .expect("log filter reloaded"); } } diff --git a/src/core/metrics/mod.rs b/src/core/metrics/mod.rs new file mode 100644 index 00000000..3ae139a8 --- /dev/null +++ b/src/core/metrics/mod.rs @@ -0,0 +1,72 @@ +use std::sync::atomic::AtomicU32; + +use tokio::runtime; +use tokio_metrics::TaskMonitor; +#[cfg(tokio_unstable)] +use tokio_metrics::{RuntimeIntervals, RuntimeMonitor}; + +pub struct Metrics { + _runtime: Option, + + runtime_metrics: Option, + + task_monitor: Option, + + #[cfg(tokio_unstable)] + _runtime_monitor: Option, + + #[cfg(tokio_unstable)] + runtime_intervals: std::sync::Mutex>, + + // TODO: move stats + pub requests_spawn_active: AtomicU32, + pub requests_spawn_finished: AtomicU32, + pub requests_handle_active: AtomicU32, + pub requests_handle_finished: AtomicU32, + pub requests_panic: AtomicU32, +} + +impl Metrics { + #[must_use] + pub fn new(runtime: Option) -> Self { + #[cfg(tokio_unstable)] + let runtime_monitor = runtime.as_ref().map(RuntimeMonitor::new); + + #[cfg(tokio_unstable)] + let runtime_intervals = runtime_monitor.as_ref().map(RuntimeMonitor::intervals); + + Self { + _runtime: runtime.clone(), + + runtime_metrics: runtime.as_ref().map(runtime::Handle::metrics), + + task_monitor: runtime.map(|_| TaskMonitor::new()), + + #[cfg(tokio_unstable)] + _runtime_monitor: runtime_monitor, + + #[cfg(tokio_unstable)] + runtime_intervals: std::sync::Mutex::new(runtime_intervals), + + requests_spawn_active: AtomicU32::new(0), + requests_spawn_finished: AtomicU32::new(0), + requests_handle_active: AtomicU32::new(0), + requests_handle_finished: AtomicU32::new(0), + requests_panic: AtomicU32::new(0), + } + } + + #[cfg(tokio_unstable)] + pub fn runtime_interval(&self) -> Option { + self.runtime_intervals + .lock() + .expect("locked") + .as_mut() + .map(Iterator::next) + .expect("next interval") + } + + pub fn task_root(&self) -> Option<&TaskMonitor> { self.task_monitor.as_ref() } + + pub fn runtime_metrics(&self) -> Option<&runtime::RuntimeMetrics> { self.runtime_metrics.as_ref() } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 5ffe4cb9..9716b46e 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -3,15 +3,16 @@ pub mod config; pub mod debug; pub mod error; pub mod log; +pub mod metrics; pub mod mods; -pub mod pducount; +pub mod pdu; pub mod server; pub mod utils; pub mod version; pub use config::Config; -pub use error::{Error, RumaResponse}; -pub use pducount::PduCount; +pub use error::Error; +pub use pdu::{PduBuilder, PduCount, PduEvent}; pub use server::Server; pub use version::version; diff --git a/src/core/pdu/builder.rs b/src/core/pdu/builder.rs new file mode 100644 index 00000000..a8bad677 --- /dev/null +++ b/src/core/pdu/builder.rs @@ -0,0 +1,16 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use ruma::{events::TimelineEventType, EventId}; +use serde::Deserialize; +use serde_json::value::RawValue as RawJsonValue; + +/// Build the start of a PDU in order to add it to the Database. +#[derive(Debug, Deserialize)] +pub struct PduBuilder { + #[serde(rename = "type")] + pub event_type: TimelineEventType, + pub content: Box, + pub unsigned: Option>, + pub state_key: Option, + pub redacts: Option>, +} diff --git a/src/core/pducount.rs b/src/core/pdu/count.rs similarity index 100% rename from src/core/pducount.rs rename to src/core/pdu/count.rs diff --git a/src/service/pdu.rs b/src/core/pdu/mod.rs similarity index 85% rename from src/service/pdu.rs rename to src/core/pdu/mod.rs index b5650c0a..a4254b08 100644 --- a/src/service/pdu.rs +++ b/src/core/pdu/mod.rs @@ -1,6 +1,10 @@ +mod builder; +mod count; + use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; -use conduit::{warn, Error}; +pub use builder::PduBuilder; +pub use count::PduCount; use ruma::{ canonical_json::redact_content_in_place, events::{ @@ -19,7 +23,7 @@ use serde_json::{ value::{to_raw_value, RawValue as RawJsonValue}, }; -use crate::services; +use crate::{warn, Error}; #[derive(Deserialize)] struct ExtractRedactedBecause { @@ -60,7 +64,7 @@ pub struct PduEvent { } impl PduEvent { - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &Self) -> crate::Result<()> { self.unsigned = None; @@ -112,7 +116,11 @@ impl PduEvent { .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap()); + let now: u64 = MilliSecondsSinceUnixEpoch::now().get().into(); + let then: u64 = self.origin_server_ts.into(); + let this_age: u64 = now - then; + + unsigned.insert("age".to_owned(), to_raw_value(&this_age).unwrap()); self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); Ok(()) @@ -154,7 +162,7 @@ impl PduEvent { (self.redacts.clone(), self.content.clone()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -179,7 +187,7 @@ impl PduEvent { } /// This only works for events that are also AnyRoomEvents. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_any_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -204,7 +212,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_room_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -229,7 +237,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_message_like_event(&self) -> Raw { let (redacts, content) = self.copy_redacts(); let mut json = json!({ @@ -254,7 +262,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_state_event(&self) -> Raw { let mut json = json!({ "content": self.content, @@ -273,7 +281,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_sync_state_event(&self) -> Raw { let mut json = json!({ "content": self.content, @@ -291,7 +299,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_stripped_state_event(&self) -> Raw { let json = json!({ "content": self.content, @@ -303,7 +311,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_stripped_spacechild_state_event(&self) -> Raw { let json = json!({ "content": self.content, @@ -316,7 +324,7 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn to_member_event(&self) -> Raw> { let mut json = json!({ "content": self.content, @@ -336,42 +344,6 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } - /// This does not return a full `Pdu` it is only to satisfy ruma's types. - #[tracing::instrument] - pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { - if let Some(unsigned) = pdu_json - .get_mut("unsigned") - .and_then(|val| val.as_object_mut()) - { - unsigned.remove("transaction_id"); - } - - // room v3 and above removed the "event_id" field from remote PDU format - if let Some(room_id) = pdu_json - .get("room_id") - .and_then(|val| RoomId::parse(val.as_str()?).ok()) - { - match services().rooms.state.get_room_version(&room_id) { - Ok(room_version_id) => match room_version_id { - RoomVersionId::V1 | RoomVersionId::V2 => {}, - _ => _ = pdu_json.remove("event_id"), - }, - Err(_) => _ = pdu_json.remove("event_id"), - } - } else { - pdu_json.remove("event_id"); - } - - // TODO: another option would be to convert it to a canonical string to validate - // size and return a Result> - // serde_json::from_str::>( - // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is - // valid serde_json::Value"), ) - // .expect("Raw::from_value always works") - - to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") - } - pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); @@ -438,14 +410,3 @@ pub fn gen_event_id_canonical_json( Ok((event_id, value)) } - -/// Build the start of a PDU in order to add it to the Database. -#[derive(Debug, Deserialize)] -pub struct PduBuilder { - #[serde(rename = "type")] - pub event_type: TimelineEventType, - pub content: Box, - pub unsigned: Option>, - pub state_key: Option, - pub redacts: Option>, -} diff --git a/src/core/server.rs b/src/core/server.rs index e76e4d57..bf0ab99d 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -1,11 +1,11 @@ use std::{ - sync::atomic::{AtomicBool, AtomicU32, Ordering}, + sync::atomic::{AtomicBool, Ordering}, time::SystemTime, }; use tokio::{runtime, sync::broadcast}; -use crate::{config::Config, log, Error, Result}; +use crate::{config::Config, log::Log, metrics::Metrics, Err, Result}; /// Server runtime state; public portion pub struct Server { @@ -33,71 +33,68 @@ pub struct Server { pub signal: broadcast::Sender<&'static str>, /// Logging subsystem state - pub log: log::Log, + pub log: Log, - /// TODO: move stats - pub requests_spawn_active: AtomicU32, - pub requests_spawn_finished: AtomicU32, - pub requests_handle_active: AtomicU32, - pub requests_handle_finished: AtomicU32, - pub requests_panic: AtomicU32, + /// Metrics subsystem state + pub metrics: Metrics, } impl Server { #[must_use] - pub fn new(config: Config, runtime: Option, log: log::Log) -> Self { + pub fn new(config: Config, runtime: Option, log: Log) -> Self { Self { config, started: SystemTime::now(), stopping: AtomicBool::new(false), reloading: AtomicBool::new(false), restarting: AtomicBool::new(false), - runtime, + runtime: runtime.clone(), signal: broadcast::channel::<&'static str>(1).0, log, - requests_spawn_active: AtomicU32::new(0), - requests_spawn_finished: AtomicU32::new(0), - requests_handle_active: AtomicU32::new(0), - requests_handle_finished: AtomicU32::new(0), - requests_panic: AtomicU32::new(0), + metrics: Metrics::new(runtime), } } pub fn reload(&self) -> Result<()> { if cfg!(not(conduit_mods)) { - return Err(Error::Err("Reloading not enabled".into())); + return Err!("Reloading not enabled"); } if self.reloading.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Reloading already in progress".into())); + return Err!("Reloading already in progress"); } if self.stopping.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Shutdown already in progress".into())); + return Err!("Shutdown already in progress"); } - self.signal("SIGINT") + self.signal("SIGINT").inspect_err(|_| { + self.stopping.store(false, Ordering::Release); + self.reloading.store(false, Ordering::Release); + }) } pub fn restart(&self) -> Result<()> { if self.restarting.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Restart already in progress".into())); + return Err!("Restart already in progress"); } self.shutdown() + .inspect_err(|_| self.restarting.store(false, Ordering::Release)) } pub fn shutdown(&self) -> Result<()> { if self.stopping.swap(true, Ordering::AcqRel) { - return Err(Error::Err("Shutdown already in progress".into())); + return Err!("Shutdown already in progress"); } self.signal("SIGTERM") + .inspect_err(|_| self.stopping.store(false, Ordering::Release)) } pub fn signal(&self, sig: &'static str) -> Result<()> { if let Err(e) = self.signal.send(sig) { - return Err(Error::Err(format!("Failed to send signal: {e}"))); + return Err!("Failed to send signal: {e}"); } Ok(()) diff --git a/src/core/utils/bytes.rs b/src/core/utils/bytes.rs index 6ebf4c4f..e8975a49 100644 --- a/src/core/utils/bytes.rs +++ b/src/core/utils/bytes.rs @@ -3,27 +3,28 @@ use crate::Result; #[inline] #[must_use] pub fn increment(old: Option<&[u8]>) -> [u8; 8] { - old.map(TryInto::try_into) - .map_or(0_u64, |val| val.map_or(0_u64, u64::from_be_bytes)) + old.map_or(0_u64, u64_from_bytes_or_zero) .wrapping_add(1) .to_be_bytes() } -/// Parses the big-endian bytes into an u64. -#[inline] -pub fn u64_from_bytes(bytes: &[u8]) -> Result { - let array: [u8; 8] = bytes.try_into()?; - Ok(u64_from_u8x8(array)) -} - -/// Parses the 8 big-endian bytes into an u64. +/// Parses 8 big-endian bytes into an u64; panic on invalid argument #[inline] #[must_use] -pub fn u64_from_u8(bytes: &[u8]) -> u64 { - let bytes: &[u8; 8] = bytes.try_into().expect("must slice at least 8 bytes"); - u64_from_u8x8(*bytes) -} +pub fn u64_from_u8(bytes: &[u8]) -> u64 { u64_from_bytes(bytes).expect("must slice at least 8 bytes") } + +/// Parses the big-endian bytes into an u64. +#[inline] +#[must_use] +pub fn u64_from_bytes_or_zero(bytes: &[u8]) -> u64 { u64_from_bytes(bytes).unwrap_or(0) } + +/// Parses the big-endian bytes into an u64. +#[inline] +pub fn u64_from_bytes(bytes: &[u8]) -> Result { Ok(u64_from_u8x8(*u8x8_from_bytes(bytes)?)) } #[inline] #[must_use] pub fn u64_from_u8x8(bytes: [u8; 8]) -> u64 { u64::from_be_bytes(bytes) } + +#[inline] +pub fn u8x8_from_bytes(bytes: &[u8]) -> Result<&[u8; 8]> { Ok(bytes.try_into()?) } diff --git a/src/core/utils/content_disposition.rs b/src/core/utils/content_disposition.rs index 1c2b066d..be17a731 100644 --- a/src/core/utils/content_disposition.rs +++ b/src/core/utils/content_disposition.rs @@ -66,7 +66,7 @@ pub fn content_disposition_type(content_type: &Option) -> &'static str { /// sanitises the file name for the Content-Disposition using /// `sanitize_filename` crate -#[tracing::instrument] +#[tracing::instrument(level = "debug")] pub fn sanitise_filename(filename: String) -> String { let options = sanitize_filename::Options { truncate: false, diff --git a/src/core/utils/defer.rs b/src/core/utils/defer.rs index 2762d4fa..9d42e679 100644 --- a/src/core/utils/defer.rs +++ b/src/core/utils/defer.rs @@ -1,17 +1,11 @@ #[macro_export] macro_rules! defer { ($body:block) => { - struct _Defer_ - where - F: FnMut(), - { + struct _Defer_ { closure: F, } - impl Drop for _Defer_ - where - F: FnMut(), - { + impl Drop for _Defer_ { fn drop(&mut self) { (self.closure)(); } } diff --git a/src/core/utils/hash.rs b/src/core/utils/hash.rs index 9469321f..5a3664cb 100644 --- a/src/core/utils/hash.rs +++ b/src/core/utils/hash.rs @@ -3,14 +3,11 @@ mod sha256; use crate::Result; -#[inline] pub fn password(password: &str) -> Result { argon::password(password) } -#[inline] pub fn verify_password(password: &str, password_hash: &str) -> Result<()> { argon::verify_password(password, password_hash) } -#[inline] #[must_use] pub fn calculate_hash(keys: &[&[u8]]) -> Vec { sha256::hash(keys) } diff --git a/src/core/utils/hash/argon.rs b/src/core/utils/hash/argon.rs index 98cef00e..0a1e1e14 100644 --- a/src/core/utils/hash/argon.rs +++ b/src/core/utils/hash/argon.rs @@ -5,7 +5,7 @@ use argon2::{ PasswordVerifier, Version, }; -use crate::{Error, Result}; +use crate::{err, Error, Result}; const M_COST: u32 = Params::DEFAULT_M_COST; // memory size in 1 KiB blocks const T_COST: u32 = Params::DEFAULT_T_COST; // nr of iterations @@ -44,7 +44,7 @@ pub(super) fn verify_password(password: &str, password_hash: &str) -> Result<()> .map_err(map_err) } -fn map_err(e: password_hash::Error) -> Error { Error::Err(e.to_string()) } +fn map_err(e: password_hash::Error) -> Error { err!("{e}") } #[cfg(test)] mod tests { diff --git a/src/core/utils/hash/sha256.rs b/src/core/utils/hash/sha256.rs index 6a1f1879..b2e5a94c 100644 --- a/src/core/utils/hash/sha256.rs +++ b/src/core/utils/hash/sha256.rs @@ -1,6 +1,6 @@ use ring::{digest, digest::SHA256}; -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all, level = "debug")] pub(super) fn hash(keys: &[&[u8]]) -> Vec { // We only hash the pdu's event ids, not the whole pdu let bytes = keys.join(&0xFF); diff --git a/src/core/utils/html.rs b/src/core/utils/html.rs index 3b44a31b..fe07b2dd 100644 --- a/src/core/utils/html.rs +++ b/src/core/utils/html.rs @@ -6,6 +6,7 @@ pub struct Escape<'a>(pub &'a str); /// Copied from librustdoc: /// * +#[allow(clippy::string_slice)] impl fmt::Display for Escape<'_> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { // Because the internet is always right, turns out there's not that many @@ -26,7 +27,7 @@ impl fmt::Display for Escape<'_> { fmt.write_str(s)?; // NOTE: we only expect single byte characters here - which is fine as long as // we only match single byte characters - last = i + 1; + last = i.saturating_add(1); } if last < s.len() { diff --git a/src/core/utils/math.rs b/src/core/utils/math.rs new file mode 100644 index 00000000..96ac6dc2 --- /dev/null +++ b/src/core/utils/math.rs @@ -0,0 +1,88 @@ +use std::{cmp, time::Duration}; + +pub use checked_ops::checked_ops; + +use crate::{Err, Error, Result}; + +/// Checked arithmetic expression. Returns a Result +#[macro_export] +macro_rules! checked { + ($($input:tt)*) => { + $crate::utils::math::checked_ops!($($input)*) + .ok_or_else(|| $crate::err!(Arithmetic("operation overflowed or result invalid"))) + } +} + +/// in release-mode. Use for performance when the expression is obviously safe. +/// The check remains in debug-mode for regression analysis. +#[cfg(not(debug_assertions))] +#[macro_export] +macro_rules! validated { + ($($input:tt)*) => { + //#[allow(clippy::arithmetic_side_effects)] { + //Some($($input)*) + // .ok_or_else(|| $crate::err!(Arithmetic("this error should never been seen"))) + //} + + //NOTE: remove me when stmt_expr_attributes is stable + $crate::checked!($($input)*) + } +} + +#[cfg(debug_assertions)] +#[macro_export] +macro_rules! validated { + ($($input:tt)*) => { $crate::checked!($($input)*) } +} + +/// Returns false if the exponential backoff has expired based on the inputs +#[inline] +#[must_use] +pub fn continue_exponential_backoff_secs(min: u64, max: u64, elapsed: Duration, tries: u32) -> bool { + let min = Duration::from_secs(min); + let max = Duration::from_secs(max); + continue_exponential_backoff(min, max, elapsed, tries) +} + +/// Returns false if the exponential backoff has expired based on the inputs +#[inline] +#[must_use] +pub fn continue_exponential_backoff(min: Duration, max: Duration, elapsed: Duration, tries: u32) -> bool { + let min = min.saturating_mul(tries).saturating_mul(tries); + let min = cmp::min(min, max); + elapsed < min +} + +#[inline] +#[allow(clippy::as_conversions)] +pub fn usize_from_f64(val: f64) -> Result { + if val < 0.0 { + return Err!(Arithmetic("Converting negative float to unsigned integer")); + } + + //SAFETY: + Ok(unsafe { val.to_int_unchecked::() }) +} + +#[inline] +#[must_use] +pub fn usize_from_ruma(val: ruma::UInt) -> usize { + usize::try_from(val).expect("failed conversion from ruma::UInt to usize") +} + +#[inline] +#[must_use] +pub fn ruma_from_u64(val: u64) -> ruma::UInt { + ruma::UInt::try_from(val).expect("failed conversion from u64 to ruma::UInt") +} + +#[inline] +#[must_use] +pub fn ruma_from_usize(val: usize) -> ruma::UInt { + ruma::UInt::try_from(val).expect("failed conversion from usize to ruma::UInt") +} + +#[inline] +#[must_use] +#[allow(clippy::as_conversions, clippy::cast_possible_truncation)] +pub fn usize_from_u64_truncated(val: u64) -> usize { val as usize } diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 401775f7..bbd52829 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -5,7 +5,9 @@ pub mod defer; pub mod hash; pub mod html; pub mod json; +pub mod math; pub mod mutex_map; +pub mod rand; pub mod string; pub mod sys; mod tests; @@ -18,30 +20,17 @@ pub use debug::slice_truncated as debug_slice_truncated; pub use hash::calculate_hash; pub use html::Escape as HtmlEscape; pub use json::{deserialize_from_str, to_canonical_object}; -pub use mutex_map::MutexMap; -pub use string::{random_string, str_from_bytes, string_from_bytes}; +pub use mutex_map::{Guard as MutexMapGuard, MutexMap}; +pub use rand::string as random_string; +pub use string::{str_from_bytes, string_from_bytes}; pub use sys::available_parallelism; -pub use time::millis_since_unix_epoch; - -use crate::Result; +pub use time::now_millis as millis_since_unix_epoch; pub fn clamp(val: T, min: T, max: T) -> T { cmp::min(cmp::max(val, min), max) } -/// Boilerplate for wraps which are typed to never error. -/// -/// * -#[must_use] -#[inline(always)] -pub fn unwrap_infallible(result: Result) -> T { - match result { - Ok(val) => val, - Err(err) => match err {}, - } -} - #[must_use] pub fn generate_keypair() -> Vec { - let mut value = random_string(8).as_bytes().to_vec(); + let mut value = rand::string(8).as_bytes().to_vec(); value.push(0xFF); value.extend_from_slice( &ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"), diff --git a/src/core/utils/mutex_map.rs b/src/core/utils/mutex_map.rs index f102487c..c3c51798 100644 --- a/src/core/utils/mutex_map.rs +++ b/src/core/utils/mutex_map.rs @@ -1,20 +1,22 @@ -use std::{hash::Hash, sync::Arc}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; -type Value = tokio::sync::Mutex; -type ArcMutex = Arc>; -type HashMap = std::collections::HashMap>; -type MapMutex = std::sync::Mutex>; -type Map = MapMutex; +use tokio::sync::OwnedMutexGuard as Omg; /// Map of Mutexes pub struct MutexMap { map: Map, } -pub struct Guard { - _guard: tokio::sync::OwnedMutexGuard, +pub struct Guard { + map: Map, + val: Omg, } +type Map = Arc>; +type MapMutex = std::sync::Mutex>; +type HashMap = std::collections::HashMap>; +type Value = Arc>; + impl MutexMap where Key: Send + Hash + Eq + Clone, @@ -23,28 +25,38 @@ where #[must_use] pub fn new() -> Self { Self { - map: Map::::new(HashMap::::new()), + map: Map::new(MapMutex::new(HashMap::new())), } } - pub async fn lock(&self, k: &K) -> Guard + #[tracing::instrument(skip(self), level = "debug")] + pub async fn lock(&self, k: &K) -> Guard where - K: ?Sized + Send + Sync, + K: ?Sized + Send + Sync + Debug, Key: for<'a> From<&'a K>, { let val = self .map .lock() - .expect("map mutex locked") + .expect("locked") .entry(k.into()) .or_default() .clone(); - let guard = val.lock_owned().await; - Guard:: { - _guard: guard, + Guard:: { + map: Arc::clone(&self.map), + val: val.lock_owned().await, } } + + #[must_use] + pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } + + #[must_use] + pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } + + #[must_use] + pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl Default for MutexMap @@ -54,3 +66,14 @@ where { fn default() -> Self { Self::new() } } + +impl Drop for Guard { + fn drop(&mut self) { + if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { + self.map + .lock() + .expect("locked") + .retain(|_, val| !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2); + } + } +} diff --git a/src/core/utils/rand.rs b/src/core/utils/rand.rs new file mode 100644 index 00000000..b80671eb --- /dev/null +++ b/src/core/utils/rand.rs @@ -0,0 +1,28 @@ +use std::{ + ops::Range, + time::{Duration, SystemTime}, +}; + +use rand::{thread_rng, Rng}; + +pub fn string(length: usize) -> String { + thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(length) + .map(char::from) + .collect() +} + +#[inline] +#[must_use] +pub fn timepoint_secs(range: Range) -> SystemTime { + SystemTime::now() + .checked_add(secs(range)) + .expect("range does not overflow SystemTime") +} + +#[must_use] +pub fn secs(range: Range) -> Duration { + let mut rng = thread_rng(); + Duration::from_secs(rng.gen_range(range)) +} diff --git a/src/core/utils/string.rs b/src/core/utils/string.rs index 7f6f6531..106d0cb7 100644 --- a/src/core/utils/string.rs +++ b/src/core/utils/string.rs @@ -1,17 +1,58 @@ -use rand::prelude::*; - use crate::Result; -pub fn random_string(length: usize) -> String { - thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(length) - .map(char::from) - .collect() +pub const EMPTY: &str = ""; + +/// Constant expression to bypass format! if the argument is a string literal +/// but not a format string. If the literal is a format string then String is +/// returned otherwise the input (i.e. &'static str) is returned. If multiple +/// arguments are provided the first is assumed to be a format string. +#[macro_export] +macro_rules! format_maybe { + ($s:literal) => { + if $crate::is_format!($s) { std::format!($s).into() } else { $s.into() } + }; + + ($($args:expr),*) => { + std::format!($($args),*).into() + }; +} + +/// Constant expression to decide if a literal is a format string. Note: could +/// use some improvement. +#[macro_export] +macro_rules! is_format { + ($s:literal) => { + ::const_str::contains!($s, "{") && ::const_str::contains!($s, "}") + }; +} + +/// Find the common prefix from a collection of strings and return a slice +/// ``` +/// use conduit_core::utils::string::common_prefix; +/// let input = ["conduwuit", "conduit", "construct"]; +/// common_prefix(&input) == "con"; +/// ``` +#[must_use] +#[allow(clippy::string_slice)] +pub fn common_prefix<'a>(choice: &'a [&str]) -> &'a str { + choice.first().map_or(EMPTY, move |best| { + choice.iter().skip(1).fold(*best, |best, choice| { + &best[0..choice + .char_indices() + .zip(best.char_indices()) + .take_while(|&(a, b)| a == b) + .count()] + }) + }) +} + +#[inline] +#[must_use] +pub fn split_once_infallible<'a>(input: &'a str, delim: &'_ str) -> (&'a str, &'a str) { + input.split_once(delim).unwrap_or((input, EMPTY)) } /// Parses the bytes into a string. -#[inline] pub fn string_from_bytes(bytes: &[u8]) -> Result { let str: &str = str_from_bytes(bytes)?; Ok(str.to_owned()) diff --git a/src/core/utils/tests.rs b/src/core/utils/tests.rs index b226bd41..43968947 100644 --- a/src/core/utils/tests.rs +++ b/src/core/utils/tests.rs @@ -35,3 +35,102 @@ fn increment_wrap() { let res = u64::from_be_bytes(bytes); assert_eq!(res, 0); } + +#[test] +fn common_prefix() { + use utils::string; + + let input = ["conduwuit", "conduit", "construct"]; + let output = string::common_prefix(&input); + assert_eq!(output, "con"); +} + +#[test] +fn common_prefix_empty() { + use utils::string; + + let input = ["abcdefg", "hijklmn", "opqrstu"]; + let output = string::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn common_prefix_none() { + use utils::string; + + let input = []; + let output = string::common_prefix(&input); + assert_eq!(output, ""); +} + +#[test] +fn checked_add() { + use crate::checked; + + let a = 1234; + let res = checked!(a + 1).unwrap(); + assert_eq!(res, 1235); +} + +#[test] +#[should_panic(expected = "overflow")] +fn checked_add_overflow() { + use crate::checked; + + let a = u64::MAX; + let res = checked!(a + 1).expect("overflow"); + assert_eq!(res, 0); +} + +#[tokio::test] +async fn mutex_map_cleanup() { + use crate::utils::MutexMap; + + let map = MutexMap::::new(); + + let lock = map.lock("foo").await; + assert!(!map.is_empty(), "map must not be empty"); + + drop(lock); + assert!(map.is_empty(), "map must be empty"); +} + +#[tokio::test] +async fn mutex_map_contend() { + use std::sync::Arc; + + use tokio::sync::Barrier; + + use crate::utils::MutexMap; + + let map = Arc::new(MutexMap::::new()); + let seq = Arc::new([Barrier::new(2), Barrier::new(2)]); + let str = "foo".to_owned(); + + let seq_ = seq.clone(); + let map_ = map.clone(); + let str_ = str.clone(); + let join_a = tokio::spawn(async move { + let _lock = map_.lock(&str_).await; + assert!(!map_.is_empty(), "A0 must not be empty"); + seq_[0].wait().await; + assert!(map_.contains(&str_), "A1 must contain key"); + }); + + let seq_ = seq.clone(); + let map_ = map.clone(); + let str_ = str.clone(); + let join_b = tokio::spawn(async move { + let _lock = map_.lock(&str_).await; + assert!(!map_.is_empty(), "B0 must not be empty"); + seq_[1].wait().await; + assert!(map_.contains(&str_), "B1 must contain key"); + }); + + seq[0].wait().await; + assert!(map.contains(&str), "Must contain key"); + seq[1].wait().await; + + tokio::try_join!(join_b, join_a).expect("joined"); + assert!(map.is_empty(), "Must be empty"); +} diff --git a/src/core/utils/time.rs b/src/core/utils/time.rs index d7cab1fb..9a31632e 100644 --- a/src/core/utils/time.rs +++ b/src/core/utils/time.rs @@ -1,12 +1,12 @@ -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; #[inline] #[must_use] -#[allow(clippy::as_conversions)] -pub fn millis_since_unix_epoch() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time is valid") +#[allow(clippy::as_conversions, clippy::cast_possible_truncation)] +pub fn now_millis() -> u64 { + UNIX_EPOCH + .elapsed() + .expect("positive duration after epoch") .as_millis() as u64 } @@ -18,3 +18,88 @@ pub fn rfc2822_from_seconds(epoch: i64) -> String { .unwrap_or_default() .to_rfc2822() } + +#[must_use] +pub fn format(ts: SystemTime, str: &str) -> String { + use chrono::{DateTime, Utc}; + + let dt: DateTime = ts.into(); + dt.format(str).to_string() +} + +#[must_use] +#[allow(clippy::as_conversions, clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub fn pretty(d: Duration) -> String { + use Unit::*; + + let fmt = |w, f, u| format!("{w}.{f} {u}"); + let gen64 = |w, f, u| fmt(w, (f * 100.0) as u32, u); + let gen128 = |w, f, u| gen64(u64::try_from(w).expect("u128 to u64"), f, u); + match whole_and_frac(d) { + (Days(whole), frac) => gen64(whole, frac, "days"), + (Hours(whole), frac) => gen64(whole, frac, "hours"), + (Mins(whole), frac) => gen64(whole, frac, "minutes"), + (Secs(whole), frac) => gen64(whole, frac, "seconds"), + (Millis(whole), frac) => gen128(whole, frac, "milliseconds"), + (Micros(whole), frac) => gen128(whole, frac, "microseconds"), + (Nanos(whole), frac) => gen128(whole, frac, "nanoseconds"), + } +} + +/// Return a pair of (whole part, frac part) from a duration where. The whole +/// part is the largest Unit containing a non-zero value, the frac part is a +/// rational remainder left over. +#[must_use] +#[allow(clippy::as_conversions, clippy::cast_precision_loss)] +pub fn whole_and_frac(d: Duration) -> (Unit, f64) { + use Unit::*; + + let whole = whole_unit(d); + ( + whole, + match whole { + Days(_) => (d.as_secs() % 86_400) as f64 / 86_400.0, + Hours(_) => (d.as_secs() % 3_600) as f64 / 3_600.0, + Mins(_) => (d.as_secs() % 60) as f64 / 60.0, + Secs(_) => f64::from(d.subsec_millis()) / 1000.0, + Millis(_) => f64::from(d.subsec_micros()) / 1000.0, + Micros(_) => f64::from(d.subsec_nanos()) / 1000.0, + Nanos(_) => 0.0, + }, + ) +} + +/// Return the largest Unit which represents the duration. The value is +/// rounded-down, but never zero. +#[must_use] +pub fn whole_unit(d: Duration) -> Unit { + use Unit::*; + + match d.as_secs() { + 86_400.. => Days(d.as_secs() / 86_400), + 3_600..=86_399 => Hours(d.as_secs() / 3_600), + 60..=3_599 => Mins(d.as_secs() / 60), + + _ => match d.as_micros() { + 1_000_000.. => Secs(d.as_secs()), + 1_000..=999_999 => Millis(d.subsec_millis().into()), + + _ => match d.as_nanos() { + 1_000.. => Micros(d.subsec_micros().into()), + + _ => Nanos(d.subsec_nanos().into()), + }, + }, + } +} + +#[derive(Eq, PartialEq, Clone, Copy, Debug)] +pub enum Unit { + Days(u64), + Hours(u64), + Mins(u64), + Secs(u64), + Millis(u128), + Micros(u128), + Nanos(u128), +} diff --git a/src/core/version.rs b/src/core/version.rs index bd89bee2..2876cea8 100644 --- a/src/core/version.rs +++ b/src/core/version.rs @@ -27,5 +27,11 @@ fn init_user_agent() -> String { format!("{}/{}", name(), version()) } fn init_version() -> String { option_env!("CONDUWUIT_VERSION_EXTRA") .or(option_env!("CONDUIT_VERSION_EXTRA")) - .map_or(SEMANTIC.to_owned(), |extra| format!("{SEMANTIC} ({extra})")) + .map_or(SEMANTIC.to_owned(), |extra| { + if extra.is_empty() { + SEMANTIC.to_owned() + } else { + format!("{SEMANTIC} ({extra})") + } + }) } diff --git a/src/database/Cargo.toml b/src/database/Cargo.toml index 6e95236b..34d98416 100644 --- a/src/database/Cargo.toml +++ b/src/database/Cargo.toml @@ -36,8 +36,8 @@ zstd_compression = [ [dependencies] conduit-core.workspace = true +const-str.workspace = true log.workspace = true -ruma.workspace = true rust-rocksdb.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/src/database/cork.rs b/src/database/cork.rs index b7e1f60f..26c520a2 100644 --- a/src/database/cork.rs +++ b/src/database/cork.rs @@ -9,6 +9,7 @@ pub struct Cork { } impl Cork { + #[inline] pub(super) fn new(db: &Arc, flush: bool, sync: bool) -> Self { db.cork(); Self { diff --git a/src/database/database.rs b/src/database/database.rs index 2c6c6808..44bb655c 100644 --- a/src/database/database.rs +++ b/src/database/database.rs @@ -6,7 +6,7 @@ use crate::{cork::Cork, maps, maps::Maps, Engine, Map}; pub struct Database { pub db: Arc, - pub map: Maps, + map: Maps, } impl Database { @@ -19,12 +19,15 @@ impl Database { }) } + #[inline] #[must_use] pub fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } + #[inline] #[must_use] pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } + #[inline] #[must_use] pub fn cork_and_sync(&self) -> Cork { Cork::new(&self.db, true, true) } } diff --git a/src/database/engine.rs b/src/database/engine.rs index fe16d29f..d02ecf58 100644 --- a/src/database/engine.rs +++ b/src/database/engine.rs @@ -30,6 +30,7 @@ pub struct Engine { pub(crate) type Db = DBWithThreadMode; impl Engine { + #[tracing::instrument(skip_all)] pub(crate) fn open(server: &Arc) -> Result> { let config = &server.config; let cache_capacity_bytes = config.db_cache_capacity_mb * 1024.0 * 1024.0; @@ -51,7 +52,7 @@ impl Engine { if config.rocksdb_repair { warn!("Starting database repair. This may take a long time..."); if let Err(e) = Db::repair(&db_opts, &config.database_path) { - error!("Repair failed: {:?}", e); + error!("Repair failed: {e:?}"); } } @@ -76,9 +77,9 @@ impl Engine { let db = res.or_else(or_else)?; info!( - "Opened database at sequence number {} in {:?}", - db.latest_sequence_number(), - load_time.elapsed() + sequence = %db.latest_sequence_number(), + time = ?load_time.elapsed(), + "Opened database." ); Ok(Arc::new(Self { @@ -93,15 +94,16 @@ impl Engine { })) } + #[tracing::instrument(skip(self))] pub(crate) fn open_cf(&self, name: &str) -> Result>> { let mut cfs = self.cfs.lock().expect("locked"); if !cfs.contains(name) { - debug!("Creating new column family in database: {}", name); + debug!("Creating new column family in database: {name}"); let mut col_cache = self.col_cache.write().expect("locked"); let opts = cf_options(&self.server.config, name, self.opts.clone(), &mut col_cache); if let Err(e) = self.db.create_cf(name, &opts) { - error!("Failed to create new column family: {e}"); + error!(?name, "Failed to create new column family: {e}"); return or_else(e); } @@ -121,6 +123,7 @@ impl Engine { pub fn sync(&self) -> Result<()> { result(DBCommon::flush_wal(&self.db, true)) } + #[inline] pub fn corked(&self) -> bool { self.corks.load(std::sync::atomic::Ordering::Relaxed) > 0 } pub(crate) fn cork(&self) { @@ -133,34 +136,34 @@ impl Engine { .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); } - #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub fn memory_usage(&self) -> Result { let mut res = String::new(); let stats = get_memory_usage_stats(Some(&[&self.db]), Some(&[&self.row_cache])).or_else(or_else)?; + let mibs = |input| f64::from(u32::try_from(input / 1024).unwrap_or(0)) / 1024.0; writeln!( res, "Memory buffers: {:.2} MiB\nPending write: {:.2} MiB\nTable readers: {:.2} MiB\nRow cache: {:.2} MiB", - stats.mem_table_total as f64 / 1024.0 / 1024.0, - stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, - stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, - self.row_cache.get_usage() as f64 / 1024.0 / 1024.0, - ) - .expect("should be able to write to string buffer"); + mibs(stats.mem_table_total), + mibs(stats.mem_table_unflushed), + mibs(stats.mem_table_readers_total), + mibs(u64::try_from(self.row_cache.get_usage())?), + )?; for (name, cache) in &*self.col_cache.read().expect("locked") { - writeln!(res, "{} cache: {:.2} MiB", name, cache.get_usage() as f64 / 1024.0 / 1024.0,) - .expect("should be able to write to string buffer"); + writeln!(res, "{name} cache: {:.2} MiB", mibs(u64::try_from(cache.get_usage())?))?; } Ok(res) } + #[tracing::instrument(skip(self), level = "debug")] pub fn cleanup(&self) -> Result<()> { debug!("Running flush_opt"); let flushoptions = rocksdb::FlushOptions::default(); result(DBCommon::flush_opt(&self.db, &flushoptions)) } + #[tracing::instrument(skip(self))] pub fn backup(&self) -> Result<(), Box> { let config = &self.server.config; let path = config.database_backup_path.as_ref(); @@ -213,8 +216,7 @@ impl Engine { rfc2822_from_seconds(info.timestamp), info.size, info.num_files, - ) - .expect("should be able to write to string buffer"); + )?; } Ok(res) @@ -225,16 +227,16 @@ impl Engine { Err(e) => Ok(String::from(e)), Ok(files) => { let mut res = String::new(); - writeln!(res, "| lev | sst | keys | dels | size | column |").expect("written to string buffer"); - writeln!(res, "| ---: | :--- | ---: | ---: | ---: | :--- |").expect("written to string buffer"); + writeln!(res, "| lev | sst | keys | dels | size | column |")?; + writeln!(res, "| ---: | :--- | ---: | ---: | ---: | :--- |")?; for file in files { writeln!( res, "| {} | {:<13} | {:7}+ | {:4}- | {:9} | {} |", file.level, file.name, file.num_entries, file.num_deletions, file.size, file.column_family_name, - ) - .expect("should be able to writeln to string buffer"); + )?; } + Ok(res) }, } @@ -242,6 +244,7 @@ impl Engine { } impl Drop for Engine { + #[cold] fn drop(&mut self) { const BLOCKING: bool = true; diff --git a/src/database/handle.rs b/src/database/handle.rs index 787a5cea..0b45a75f 100644 --- a/src/database/handle.rs +++ b/src/database/handle.rs @@ -17,9 +17,11 @@ impl<'a> From> for Handle<'a> { impl Deref for Handle<'_> { type Target = [u8]; + #[inline] fn deref(&self) -> &Self::Target { &self.val } } impl AsRef<[u8]> for Handle<'_> { + #[inline] fn as_ref(&self) -> &[u8] { &self.val } } diff --git a/src/database/map.rs b/src/database/map.rs index 73cf5107..1b35a72a 100644 --- a/src/database/map.rs +++ b/src/database/map.rs @@ -199,6 +199,7 @@ impl<'a> IntoIterator for &'a Map { type IntoIter = Box + Send + 'a>; type Item = OwnedKeyValPair; + #[inline] fn into_iter(self) -> Self::IntoIter { self.iter() } } @@ -232,7 +233,7 @@ fn open(db: &Arc, name: &str) -> Result> { // closing the database (dropping `Engine`). Since `Arc` is a sibling // member along with this handle in `Map`, that is prevented. Ok(unsafe { - Arc::decrement_strong_count(cf_ptr); + Arc::increment_strong_count(cf_ptr); Arc::from_raw(cf_ptr) }) } diff --git a/src/database/maps.rs b/src/database/maps.rs index 1e09041c..de78eaed 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -8,6 +8,7 @@ pub type Maps = BTreeMap>; pub(crate) fn open(db: &Arc) -> Result { open_list(db, MAPS) } +#[tracing::instrument(skip_all, level = "debug")] pub(crate) fn open_list(db: &Arc, maps: &[&str]) -> Result { Ok(maps .iter() diff --git a/src/database/opts.rs b/src/database/opts.rs index 1b9c163f..d2236454 100644 --- a/src/database/opts.rs +++ b/src/database/opts.rs @@ -310,7 +310,7 @@ fn set_table_with_shared_cache( } fn cache_size(config: &Config, base_size: u32, entity_size: usize) -> usize { - let ents = f64::from(base_size) * config.conduit_cache_capacity_modifier; + let ents = f64::from(base_size) * config.cache_capacity_modifier; #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] (ents as usize) diff --git a/src/database/slice.rs b/src/database/slice.rs index 770f1957..448d969d 100644 --- a/src/database/slice.rs +++ b/src/database/slice.rs @@ -11,21 +11,19 @@ pub type Key = [Byte]; pub(crate) type Byte = u8; impl OwnedKeyVal { - #[inline] #[must_use] pub fn as_slice(&self) -> KeyVal<'_> { KeyVal(&self.0, &self.1) } - #[inline] #[must_use] pub fn to_tuple(self) -> OwnedKeyValPair { (self.0, self.1) } } impl From for OwnedKeyVal { - #[inline] fn from((key, val): OwnedKeyValPair) -> Self { Self(key, val) } } impl From<&KeyVal<'_>> for OwnedKeyVal { + #[inline] fn from(slice: &KeyVal<'_>) -> Self { slice.to_owned() } } @@ -34,7 +32,6 @@ impl From> for OwnedKeyVal { } impl From for OwnedKeyValPair { - #[inline] fn from(val: OwnedKeyVal) -> Self { val.to_tuple() } } @@ -43,22 +40,18 @@ impl KeyVal<'_> { #[must_use] pub fn to_owned(&self) -> OwnedKeyVal { OwnedKeyVal::from(self) } - #[inline] #[must_use] pub fn as_tuple(&self) -> KeyValPair<'_> { (self.0, self.1) } } impl<'a> From<&'a OwnedKeyVal> for KeyVal<'a> { - #[inline] fn from(owned: &'a OwnedKeyVal) -> Self { owned.as_slice() } } impl<'a> From<&'a OwnedKeyValPair> for KeyVal<'a> { - #[inline] fn from((key, val): &'a OwnedKeyValPair) -> Self { KeyVal(key.as_slice(), val.as_slice()) } } impl<'a> From> for KeyVal<'a> { - #[inline] fn from((key, val): KeyValPair<'a>) -> Self { KeyVal(key, val) } } diff --git a/src/database/util.rs b/src/database/util.rs index 513cedc8..f0ccbcbe 100644 --- a/src/database/util.rs +++ b/src/database/util.rs @@ -1,4 +1,4 @@ -use conduit::Result; +use conduit::{err, Result}; #[inline] pub(crate) fn result(r: std::result::Result) -> Result { @@ -10,4 +10,7 @@ pub(crate) fn and_then(t: T) -> Result { Ok(t) } pub(crate) fn or_else(e: rocksdb::Error) -> Result { Err(map_err(e)) } -pub(crate) fn map_err(e: rocksdb::Error) -> conduit::Error { conduit::Error::Database(e.into_string()) } +pub(crate) fn map_err(e: rocksdb::Error) -> conduit::Error { + let string = e.into_string(); + err!(Database(error!("{string}"))) +} diff --git a/src/main/Cargo.toml b/src/main/Cargo.toml index fa0e5874..8dc2a34d 100644 --- a/src/main/Cargo.toml +++ b/src/main/Cargo.toml @@ -147,6 +147,7 @@ log.workspace = true tracing.workspace = true tracing-subscriber.workspace = true clap.workspace = true +const-str.workspace = true opentelemetry.workspace = true opentelemetry.optional = true diff --git a/src/main/main.rs b/src/main/main.rs index 23f53b4e..959e8610 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -20,7 +20,7 @@ use tokio::runtime; const WORKER_NAME: &str = "conduwuit:worker"; const WORKER_MIN: usize = 2; -const WORKER_KEEPALIVE_MS: u64 = 2500; +const WORKER_KEEPALIVE: u64 = 36; fn main() -> Result<(), Error> { let args = clap::parse(); @@ -29,7 +29,7 @@ fn main() -> Result<(), Error> { .enable_time() .thread_name(WORKER_NAME) .worker_threads(cmp::max(WORKER_MIN, available_parallelism())) - .thread_keep_alive(Duration::from_millis(WORKER_KEEPALIVE_MS)) + .thread_keep_alive(Duration::from_secs(WORKER_KEEPALIVE)) .build() .expect("built runtime"); diff --git a/src/main/restart.rs b/src/main/restart.rs index 00de9f21..009e5228 100644 --- a/src/main/restart.rs +++ b/src/main/restart.rs @@ -4,6 +4,7 @@ use std::{env, os::unix::process::CommandExt, process::Command}; use conduit::{debug, info, utils}; +#[cold] pub(super) fn restart() -> ! { // SAFETY: We have allowed an override for the case where the current_exe() has // been replaced or removed. By default the server will fail to restart if the diff --git a/src/main/sentry.rs b/src/main/sentry.rs index 6ed4bb8a..04ad8654 100644 --- a/src/main/sentry.rs +++ b/src/main/sentry.rs @@ -1,18 +1,34 @@ #![cfg(feature = "sentry_telemetry")] -use std::{str::FromStr, sync::Arc}; - -use conduit::{config::Config, trace}; -use sentry::{ - types::{protocol::v7::Event, Dsn}, - Breadcrumb, ClientOptions, +use std::{ + str::FromStr, + sync::{Arc, OnceLock}, }; +use conduit::{config::Config, debug, trace}; +use sentry::{ + types::{ + protocol::v7::{Context, Event}, + Dsn, + }, + Breadcrumb, ClientOptions, Level, +}; + +static SEND_PANIC: OnceLock = OnceLock::new(); +static SEND_ERROR: OnceLock = OnceLock::new(); + pub(crate) fn init(config: &Config) -> Option { config.sentry.then(|| sentry::init(options(config))) } fn options(config: &Config) -> ClientOptions { + SEND_PANIC + .set(config.sentry_send_panic) + .expect("SEND_PANIC was not previously set"); + SEND_ERROR + .set(config.sentry_send_error) + .expect("SEND_ERROR was not previously set"); + let dsn = config .sentry_endpoint .as_ref() @@ -28,6 +44,7 @@ fn options(config: &Config) -> ClientOptions { debug: cfg!(debug_assertions), release: sentry::release_name!(), user_agent: conduit::version::user_agent().into(), + attach_stacktrace: config.sentry_attach_stacktrace, before_send: Some(Arc::new(before_send)), before_breadcrumb: Some(Arc::new(before_breadcrumb)), ..Default::default() @@ -35,11 +52,40 @@ fn options(config: &Config) -> ClientOptions { } fn before_send(event: Event<'static>) -> Option> { - trace!("Sending sentry event: {event:?}"); + if event.exception.iter().any(|e| e.ty == "panic") && !SEND_PANIC.get().unwrap_or(&true) { + return None; + } + + if event.level == Level::Error { + if !SEND_ERROR.get().unwrap_or(&true) { + return None; + } + + if cfg!(debug_assertions) { + return None; + } + + //NOTE: we can enable this to specify error!(sentry = true, ...) + if let Some(Context::Other(context)) = event.contexts.get("Rust Tracing Fields") { + if !context.contains_key("sentry") { + //return None; + } + } + } + + if event.level == Level::Fatal { + trace!("{event:#?}"); + } + + debug!("Sending sentry event: {event:?}"); Some(event) } fn before_breadcrumb(crumb: Breadcrumb) -> Option { - trace!("Adding sentry breadcrumb: {crumb:?}"); + if crumb.ty == "log" && crumb.level == Level::Debug { + return None; + } + + trace!("Sentry breadcrumb: {crumb:?}"); Some(crumb) } diff --git a/src/main/server.rs b/src/main/server.rs index f72b3ef3..73c06f0c 100644 --- a/src/main/server.rs +++ b/src/main/server.rs @@ -27,7 +27,7 @@ impl Server { #[cfg(feature = "sentry_telemetry")] let sentry_guard = crate::sentry::init(&config); - let (tracing_reload_handle, tracing_flame_guard, capture) = crate::tracing::init(&config); + let (tracing_reload_handle, tracing_flame_guard, capture) = crate::tracing::init(&config)?; config.check()?; diff --git a/src/main/tracing.rs b/src/main/tracing.rs index bbfe4dc4..0217f38a 100644 --- a/src/main/tracing.rs +++ b/src/main/tracing.rs @@ -1,10 +1,10 @@ use std::sync::Arc; use conduit::{ - config, config::Config, - debug_warn, - log::{capture, LogLevelReloadHandles, ReloadHandle}, + debug_warn, err, + log::{capture, LogLevelReloadHandles}, + Result, }; use tracing_subscriber::{layer::SubscriberExt, reload, EnvFilter, Layer, Registry}; @@ -14,47 +14,38 @@ pub(crate) type TracingFlameGuard = Option (LogLevelReloadHandles, TracingFlameGuard, Arc) { - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); - EnvFilter::try_new(config::default_log()).expect("failed to set default EnvFilter") - }, - }; +pub(crate) fn init(config: &Config) -> Result<(LogLevelReloadHandles, TracingFlameGuard, Arc)> { + let reload_handles = LogLevelReloadHandles::default(); - let mut reload_handles = Vec:: + Send + Sync>>::new(); - let (fmt_reload_filter, fmt_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(fmt_reload_handle)); - - let subscriber = Registry::default().with(fmt_layer.with_filter(fmt_reload_filter)); + let console_filter = EnvFilter::try_new(&config.log).map_err(|e| err!(Config("log", "{e}.")))?; + let console_layer = tracing_subscriber::fmt::Layer::new(); + let (console_reload_filter, console_reload_handle) = reload::Layer::new(console_filter.clone()); + reload_handles.add("console", Box::new(console_reload_handle)); let cap_state = Arc::new(capture::State::new()); let cap_layer = capture::Layer::new(&cap_state); - let subscriber = subscriber.with(cap_layer); + + let subscriber = Registry::default() + .with(console_layer.with_filter(console_reload_filter)) + .with(cap_layer); #[cfg(feature = "sentry_telemetry")] let subscriber = { + let sentry_filter = + EnvFilter::try_new(&config.sentry_filter).map_err(|e| err!(Config("sentry_filter", "{e}.")))?; let sentry_layer = sentry_tracing::layer(); - let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(sentry_reload_handle)); + let (sentry_reload_filter, sentry_reload_handle) = reload::Layer::new(sentry_filter); + reload_handles.add("sentry", Box::new(sentry_reload_handle)); subscriber.with(sentry_layer.with_filter(sentry_reload_filter)) }; #[cfg(feature = "perf_measurements")] let (subscriber, flame_guard) = { let (flame_layer, flame_guard) = if config.tracing_flame { - let flame_filter = match EnvFilter::try_new(&config.tracing_flame_filter) { - Ok(flame_filter) => flame_filter, - Err(e) => panic!("tracing_flame_filter config value is invalid: {e}"), - }; - - let (flame_layer, flame_guard) = - match tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) { - Ok(ok) => ok, - Err(e) => panic!("failed to initialize tracing-flame: {e}"), - }; + let flame_filter = EnvFilter::try_new(&config.tracing_flame_filter) + .map_err(|e| err!(Config("tracing_flame_filter", "{e}.")))?; + let (flame_layer, flame_guard) = tracing_flame::FlameLayer::with_file(&config.tracing_flame_output_path) + .map_err(|e| err!(Config("tracing_flame_output_path", "{e}.")))?; let flame_layer = flame_layer .with_empty_samples(false) .with_filter(flame_filter); @@ -63,21 +54,20 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard (None, None) }; - let jaeger_layer = if config.allow_jaeger { + let jaeger_filter = + EnvFilter::try_new(&config.jaeger_filter).map_err(|e| err!(Config("jaeger_filter", "{e}.")))?; + let jaeger_layer = config.allow_jaeger.then(|| { opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); let tracer = opentelemetry_jaeger::new_agent_pipeline() .with_auto_split_batch(true) .with_service_name("conduwuit") .install_batch(opentelemetry_sdk::runtime::Tokio) - .unwrap(); + .expect("jaeger agent pipeline"); let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - - let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(filter_layer.clone()); - reload_handles.push(Box::new(jaeger_reload_handle)); + let (jaeger_reload_filter, jaeger_reload_handle) = reload::Layer::new(jaeger_filter.clone()); + reload_handles.add("jaeger", Box::new(jaeger_reload_handle)); Some(telemetry.with_filter(jaeger_reload_filter)) - } else { - None - }; + }); let subscriber = subscriber.with(flame_layer).with(jaeger_layer); (subscriber, flame_guard) @@ -87,7 +77,7 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard #[cfg_attr(not(feature = "perf_measurements"), allow(clippy::let_unit_value))] let flame_guard = (); - let ret = (LogLevelReloadHandles::new(reload_handles), flame_guard, cap_state); + let ret = (reload_handles, flame_guard, cap_state); // Enable the tokio console. This is slightly kludgy because we're judggling // compile-time and runtime conditions to elide it, each of those changing the @@ -100,7 +90,7 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard .spawn(); set_global_default(subscriber.with(console_layer)); - return ret; + return Ok(ret); } set_global_default(subscriber); @@ -111,7 +101,7 @@ pub(crate) fn init(config: &Config) -> (LogLevelReloadHandles, TracingFlameGuard debug_warn!("{console_disabled_reason}"); } - ret + Ok(ret) } fn tokio_console_enabled(config: &Config) -> (bool, &'static str) { diff --git a/src/router/Cargo.toml b/src/router/Cargo.toml index 2fac3858..38e6adc7 100644 --- a/src/router/Cargo.toml +++ b/src/router/Cargo.toml @@ -54,19 +54,17 @@ axum.workspace = true conduit-admin.workspace = true conduit-api.workspace = true conduit-core.workspace = true -conduit-database.workspace = true conduit-service.workspace = true +const-str.workspace = true log.workspace = true tokio.workspace = true tower.workspace = true tracing.workspace = true bytes.workspace = true -clap.workspace = true http-body-util.workspace = true http.workspace = true hyper.workspace = true hyper-util.workspace = true -regex.workspace = true ruma.workspace = true sentry.optional = true sentry-tower.optional = true diff --git a/src/router/layers.rs b/src/router/layers.rs index 4fe35164..67342eb3 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -1,11 +1,11 @@ -use std::{any::Any, io, sync::Arc, time::Duration}; +use std::{any::Any, sync::Arc, time::Duration}; use axum::{ extract::{DefaultBodyLimit, MatchedPath}, Router, }; use axum_client_ip::SecureClientIpSource; -use conduit::Server; +use conduit::{error, Result, Server}; use http::{ header::{self, HeaderName}, HeaderValue, Method, StatusCode, @@ -22,11 +22,19 @@ use tracing::Level; use crate::{request, router}; -const CONDUWUIT_CSP: &str = "sandbox; default-src 'none'; font-src 'none'; script-src 'none'; frame-ancestors 'none'; \ - form-action 'none'; base-uri 'none';"; -const CONDUWUIT_PERMISSIONS_POLICY: &str = "interest-cohort=(),browsing-topics=()"; +const CONDUWUIT_CSP: &[&str] = &[ + "sandbox", + "default-src 'none'", + "font-src 'none'", + "script-src 'none'", + "frame-ancestors 'none'", + "form-action 'none'", + "base-uri 'none'", +]; -pub(crate) fn build(server: &Arc) -> io::Result { +const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; + +pub(crate) fn build(server: &Arc) -> Result { let layers = ServiceBuilder::new(); #[cfg(feature = "sentry_telemetry")] @@ -65,11 +73,11 @@ pub(crate) fn build(server: &Arc) -> io::Result { )) .layer(SetResponseHeaderLayer::if_not_present( HeaderName::from_static("permissions-policy"), - HeaderValue::from_static(CONDUWUIT_PERMISSIONS_POLICY), + HeaderValue::from_str(&CONDUWUIT_PERMISSIONS_POLICY.join(","))?, )) .layer(SetResponseHeaderLayer::if_not_present( header::CONTENT_SECURITY_POLICY, - HeaderValue::from_static(CONDUWUIT_CSP), + HeaderValue::from_str(&CONDUWUIT_CSP.join("; "))?, )) .layer(cors_layer(server)) .layer(body_limit_layer(server)) @@ -138,21 +146,14 @@ fn cors_layer(_server: &Server) -> CorsLayer { .max_age(Duration::from_secs(86400)) } -fn body_limit_layer(server: &Server) -> DefaultBodyLimit { - DefaultBodyLimit::max( - server - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - ) -} +fn body_limit_layer(server: &Server) -> DefaultBodyLimit { DefaultBodyLimit::max(server.config.max_request_size) } #[allow(clippy::needless_pass_by_value)] -#[tracing::instrument(skip_all)] +#[tracing::instrument(skip_all, name = "panic")] fn catch_panic(err: Box) -> http::Response> { conduit_service::services() .server + .metrics .requests_panic .fetch_add(1, std::sync::atomic::Ordering::Release); @@ -164,17 +165,17 @@ fn catch_panic(err: Box) -> http::Response>, req: http::Request, next: axum::middleware::Next, -) -> Result { +) -> Result { if !server.running() { debug_warn!("unavailable pending shutdown"); return Err(StatusCode::SERVICE_UNAVAILABLE); } - let active = server.requests_spawn_active.fetch_add(1, Ordering::Relaxed); + let active = server + .metrics + .requests_spawn_active + .fetch_add(1, Ordering::Relaxed); trace!(active, "enter"); defer! {{ - let active = server.requests_spawn_active.fetch_sub(1, Ordering::Relaxed); - let finished = server.requests_spawn_finished.fetch_add(1, Ordering::Relaxed); + let active = server.metrics.requests_spawn_active.fetch_sub(1, Ordering::Relaxed); + let finished = server.metrics.requests_spawn_finished.fetch_add(1, Ordering::Relaxed); trace!(active, finished, "leave"); }}; @@ -34,7 +35,7 @@ pub(crate) async fn spawn( #[tracing::instrument(skip_all, name = "handle")] pub(crate) async fn handle( State(server): State>, req: http::Request, next: axum::middleware::Next, -) -> Result { +) -> Result { if !server.running() { debug_warn!( method = %req.method(), @@ -46,47 +47,23 @@ pub(crate) async fn handle( } let active = server + .metrics .requests_handle_active .fetch_add(1, Ordering::Relaxed); trace!(active, "enter"); defer! {{ - let active = server.requests_handle_active.fetch_sub(1, Ordering::Relaxed); - let finished = server.requests_handle_finished.fetch_add(1, Ordering::Relaxed); + let active = server.metrics.requests_handle_active.fetch_sub(1, Ordering::Relaxed); + let finished = server.metrics.requests_handle_finished.fetch_add(1, Ordering::Relaxed); trace!(active, finished, "leave"); }}; - let method = req.method().clone(); let uri = req.uri().clone(); + let method = req.method().clone(); let result = next.run(req).await; handle_result(&method, &uri, result) } -fn handle_result( - method: &Method, uri: &Uri, result: axum::response::Response, -) -> Result { - handle_result_log(method, uri, &result); - match result.status() { - StatusCode::METHOD_NOT_ALLOWED => handle_result_403(method, uri, &result), - _ => Ok(result), - } -} - -#[allow(clippy::unnecessary_wraps)] -fn handle_result_403( - _method: &Method, _uri: &Uri, result: &axum::response::Response, -) -> Result { - let error = UiaaResponse::MatrixError(RumaError { - status_code: result.status(), - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), - }, - }); - - Ok(RumaResponse(error).into_response()) -} - -fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { +fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result { let status = result.status(); let reason = status.canonical_reason().unwrap_or("Unknown Reason"); let code = status.as_u16(); @@ -99,4 +76,10 @@ fn handle_result_log(method: &Method, uri: &Uri, result: &axum::response::Respon } else { trace!(method = ?method, uri = ?uri, "{code} {reason}"); } + + if status == StatusCode::METHOD_NOT_ALLOWED { + return Ok(err!(Request(Unrecognized("Method Not Allowed"))).into_response()); + } + + Ok(result) } diff --git a/src/router/router.rs b/src/router/router.rs index da31ffea..7c374b47 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -2,24 +2,23 @@ use std::sync::Arc; use axum::{response::IntoResponse, routing::get, Router}; use conduit::{Error, Server}; -use conduit_service as service; -use http::Uri; +use http::{StatusCode, Uri}; use ruma::api::client::error::ErrorKind; extern crate conduit_api as api; +extern crate conduit_service as service; pub(crate) fn build(server: &Arc) -> Router { - let state = service::services(); - let router = Router::new() + let router = Router::::new(); + + api::router::build(router, server) .route("/", get(it_works)) .fallback(not_found) - .with_state(state); - - api::routes::build(router, server) + .with_state(service::services()) } async fn not_found(_uri: Uri) -> impl IntoResponse { - Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") + Error::Request(ErrorKind::Unrecognized, "Not Found".into(), StatusCode::NOT_FOUND) } async fn it_works() -> &'static str { "hewwo from conduwuit woof!" } diff --git a/src/router/run.rs b/src/router/run.rs index fb59c797..91507772 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -1,8 +1,10 @@ use std::{sync::Arc, time::Duration}; use axum_server::Handle as ServerHandle; -use tokio::sync::broadcast::{self, Sender}; -use tracing::{debug, error, info}; +use tokio::{ + sync::broadcast::{self, Sender}, + task::JoinHandle, +}; extern crate conduit_admin as admin; extern crate conduit_core as conduit; @@ -10,19 +12,17 @@ extern crate conduit_service as service; use std::sync::atomic::Ordering; -use conduit::{debug_info, trace, Error, Result, Server}; -use service::services; +use conduit::{debug, debug_info, error, info, trace, Error, Result, Server}; -use crate::{layers, serve}; +use crate::serve; /// Main loop base #[tracing::instrument(skip_all)] -#[allow(clippy::let_underscore_must_use)] // various of these are intended -pub(crate) async fn run(server: Arc) -> Result<(), Error> { - let app = layers::build(&server)?; +pub(crate) async fn run(server: Arc) -> Result<()> { + debug!("Start"); // Install the admin room callback here for now - _ = services().admin.handle.lock().await.insert(admin::handle); + admin::init().await; // Setup shutdown/signal handling let handle = ServerHandle::new(); @@ -31,27 +31,34 @@ pub(crate) async fn run(server: Arc) -> Result<(), Error> { .runtime() .spawn(signal(server.clone(), tx.clone(), handle.clone())); - // Serve clients - let res = serve::serve(&server, app, handle, tx.subscribe()).await; + let mut listener = server + .runtime() + .spawn(serve::serve(server.clone(), handle.clone(), tx.subscribe())); + + // Focal point + debug!("Running"); + let res = tokio::select! { + res = &mut listener => res.map_err(Error::from).unwrap_or_else(Err), + res = service::services().poll() => handle_services_poll(&server, res, listener).await, + }; // Join the signal handler before we leave. sigs.abort(); _ = sigs.await; // Remove the admin room callback - _ = services().admin.handle.lock().await.take(); + admin::fini().await; - debug_info!("Finished"); + debug_info!("Finish"); res } /// Async initializations #[tracing::instrument(skip_all)] -pub(crate) async fn start(server: Arc) -> Result<(), Error> { +pub(crate) async fn start(server: Arc) -> Result<()> { debug!("Starting..."); - service::init(&server).await?; - services().start().await?; + service::start(&server).await?; #[cfg(feature = "systemd")] sd_notify::notify(true, &[sd_notify::NotifyState::Ready]).expect("failed to notify systemd of ready state"); @@ -62,14 +69,12 @@ pub(crate) async fn start(server: Arc) -> Result<(), Error> { /// Async destructions #[tracing::instrument(skip_all)] -pub(crate) async fn stop(_server: Arc) -> Result<(), Error> { +pub(crate) async fn stop(_server: Arc) -> Result<()> { debug!("Shutting down..."); // Wait for all completions before dropping or we'll lose them to the module // unload and explode. - services().stop().await; - // Deactivate services(). Any further use will panic the caller. - service::fini(); + service::stop().await; debug!("Cleaning up..."); @@ -103,7 +108,7 @@ async fn handle_shutdown(server: &Arc, tx: &Sender<()>, handle: &axum_se error!("failed sending shutdown transaction to channel: {e}"); } - let pending = server.requests_spawn_active.load(Ordering::Relaxed); + let pending = server.metrics.requests_spawn_active.load(Ordering::Relaxed); if pending > 0 { let timeout = Duration::from_secs(36); trace!(pending, ?timeout, "Notifying for graceful shutdown"); @@ -113,3 +118,21 @@ async fn handle_shutdown(server: &Arc, tx: &Sender<()>, handle: &axum_se handle.shutdown(); } } + +async fn handle_services_poll( + server: &Arc, result: Result<()>, listener: JoinHandle>, +) -> Result<()> { + debug!("Service manager finished: {result:?}"); + + if server.running() { + if let Err(e) = server.shutdown() { + error!("Failed to send shutdown signal: {e}"); + } + } + + if let Err(e) = listener.await { + error!("Client listener task finished with error: {e}"); + } + + result +} diff --git a/src/router/serve/mod.rs b/src/router/serve/mod.rs index 47f2fd43..4e923444 100644 --- a/src/router/serve/mod.rs +++ b/src/router/serve/mod.rs @@ -4,23 +4,23 @@ mod unix; use std::sync::Arc; -use axum::Router; use axum_server::Handle as ServerHandle; -use conduit::{Error, Result, Server}; +use conduit::{Result, Server}; use tokio::sync::broadcast; +use crate::layers; + /// Serve clients -pub(super) async fn serve( - server: &Arc, app: Router, handle: ServerHandle, shutdown: broadcast::Receiver<()>, -) -> Result<(), Error> { +pub(super) async fn serve(server: Arc, handle: ServerHandle, shutdown: broadcast::Receiver<()>) -> Result<()> { let config = &server.config; let addrs = config.get_bind_addrs(); + let app = layers::build(&server)?; if cfg!(unix) && config.unix_socket_path.is_some() { - unix::serve(server, app, shutdown).await + unix::serve(&server, app, shutdown).await } else if config.tls.is_some() { - tls::serve(server, app, handle, addrs).await + tls::serve(&server, app, handle, addrs).await } else { - plain::serve(server, app, handle, addrs).await + plain::serve(&server, app, handle, addrs).await } } diff --git a/src/router/serve/plain.rs b/src/router/serve/plain.rs index b79d342d..08263353 100644 --- a/src/router/serve/plain.rs +++ b/src/router/serve/plain.rs @@ -21,12 +21,21 @@ pub(super) async fn serve( info!("Listening on {addrs:?}"); while join_set.join_next().await.is_some() {} - let spawn_active = server.requests_spawn_active.load(Ordering::Relaxed); - let handle_active = server.requests_handle_active.load(Ordering::Relaxed); + let spawn_active = server.metrics.requests_spawn_active.load(Ordering::Relaxed); + let handle_active = server + .metrics + .requests_handle_active + .load(Ordering::Relaxed); debug_info!( - spawn_finished = server.requests_spawn_finished.load(Ordering::Relaxed), - handle_finished = server.requests_handle_finished.load(Ordering::Relaxed), - panics = server.requests_panic.load(Ordering::Relaxed), + spawn_finished = server + .metrics + .requests_spawn_finished + .load(Ordering::Relaxed), + handle_finished = server + .metrics + .requests_handle_finished + .load(Ordering::Relaxed), + panics = server.metrics.requests_panic.load(Ordering::Relaxed), spawn_active, handle_active, "Stopped listening on {addrs:?}", diff --git a/src/router/serve/unix.rs b/src/router/serve/unix.rs index 6c406d28..fb011f18 100644 --- a/src/router/serve/unix.rs +++ b/src/router/serve/unix.rs @@ -3,14 +3,14 @@ use std::{ net::{self, IpAddr, Ipv4Addr}, path::Path, - sync::Arc, + sync::{atomic::Ordering, Arc}, }; use axum::{ extract::{connect_info::IntoMakeServiceWithConnectInfo, Request}, Router, }; -use conduit::{debug_error, trace, utils, Error, Result, Server}; +use conduit::{debug, debug_error, error::infallible, info, trace, warn, Err, Result, Server}; use hyper::{body::Incoming, service::service_fn}; use hyper_util::{ rt::{TokioExecutor, TokioIo}, @@ -21,14 +21,14 @@ use tokio::{ net::{unix::SocketAddr, UnixListener, UnixStream}, sync::broadcast::{self}, task::JoinSet, + time::{sleep, Duration}, }; use tower::{Service, ServiceExt}; -use tracing::{debug, info, warn}; -use utils::unwrap_infallible; type MakeService = IntoMakeServiceWithConnectInfo; -static NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const NULL_ADDR: net::SocketAddr = net::SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0); +const FINI_POLL_INTERVAL: Duration = Duration::from_millis(750); #[tracing::instrument(skip_all)] pub(super) async fn serve(server: &Arc, app: Router, mut shutdown: broadcast::Receiver<()>) -> Result<()> { @@ -49,12 +49,11 @@ pub(super) async fn serve(server: &Arc, app: Router, mut shutdown: broad } } - fini(listener, tasks).await; + fini(server, listener, tasks).await; Ok(()) } -#[allow(clippy::let_underscore_must_use)] async fn accept( server: &Arc, listener: &UnixListener, tasks: &mut JoinSet<()>, mut app: MakeService, builder: server::conn::auto::Builder, conn: (UnixStream, SocketAddr), @@ -63,9 +62,14 @@ async fn accept( let socket = TokioIo::new(socket); trace!(?listener, ?socket, ?remote, "accepted"); - let called = unwrap_infallible(app.call(NULL_ADDR).await); - let handler = service_fn(move |req: Request| called.clone().oneshot(req)); + let called = app + .call(NULL_ADDR) + .await + .inspect_err(infallible) + .expect("infallible"); + let service = move |req: Request| called.clone().oneshot(req); + let handler = service_fn(service); let task = async move { // bug on darwin causes all results to be errors. do not unwrap this _ = builder.serve_connection(socket, handler).await; @@ -94,30 +98,41 @@ async fn init(server: &Arc) -> Result { let dir = path.parent().unwrap_or_else(|| Path::new("/")); if let Err(e) = fs::create_dir_all(dir).await { - return Err(Error::Err(format!("Failed to create {dir:?} for socket {path:?}: {e}"))); + return Err!("Failed to create {dir:?} for socket {path:?}: {e}"); } let listener = UnixListener::bind(path); if let Err(e) = listener { - return Err(Error::Err(format!("Failed to bind listener {path:?}: {e}"))); + return Err!("Failed to bind listener {path:?}: {e}"); } let socket_perms = config.unix_socket_perms.to_string(); let octal_perms = u32::from_str_radix(&socket_perms, 8).expect("failed to convert octal permissions"); let perms = std::fs::Permissions::from_mode(octal_perms); if let Err(e) = fs::set_permissions(&path, perms).await { - return Err(Error::Err(format!("Failed to set socket {path:?} permissions: {e}"))); + return Err!("Failed to set socket {path:?} permissions: {e}"); } - info!("Listening at {:?}", path); + info!("Listening at {path:?}"); Ok(listener.unwrap()) } -async fn fini(listener: UnixListener, mut tasks: JoinSet<()>) { +async fn fini(server: &Arc, listener: UnixListener, mut tasks: JoinSet<()>) { let local = listener.local_addr(); + debug!("Closing listener at {local:?} ..."); drop(listener); + + debug!("Waiting for requests to finish..."); + while server.metrics.requests_spawn_active.load(Ordering::Relaxed) > 0 { + tokio::select! { + task = tasks.join_next() => if task.is_none() { break; }, + () = sleep(FINI_POLL_INTERVAL) => {}, + } + } + + debug!("Shutting down..."); tasks.shutdown().await; if let Ok(local) = local { diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index f59a6036..d2c9785f 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -42,6 +42,7 @@ base64.workspace = true bytes.workspace = true conduit-core.workspace = true conduit-database.workspace = true +const-str.workspace = true cyborgtime.workspace = true futures-util.workspace = true hickory-resolver.workspace = true diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 62560500..69d2f799 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -2,9 +2,8 @@ mod data; use std::{collections::HashMap, sync::Arc}; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{ events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, @@ -15,13 +14,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Places one event in the account data of the user and removes the /// previous entry. #[allow(clippy::needless_pass_by_value)] @@ -41,7 +44,7 @@ impl Service { } /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip_all, name = "since")] + #[tracing::instrument(skip_all, name = "since", level = "debug")] pub fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, ) -> Result>> { diff --git a/src/service/admin/console.rs b/src/service/admin/console.rs index c0562c52..c590b928 100644 --- a/src/service/admin/console.rs +++ b/src/service/admin/console.rs @@ -45,7 +45,6 @@ impl Console { } } - #[allow(clippy::let_underscore_must_use)] pub async fn start(self: &Arc) { let mut worker_join = self.worker_join.lock().expect("locked"); if worker_join.is_none() { @@ -54,7 +53,6 @@ impl Console { } } - #[allow(clippy::let_underscore_must_use)] pub async fn close(self: &Arc) { self.interrupt(); let Some(worker_join) = self.worker_join.lock().expect("locked").take() else { @@ -97,6 +95,10 @@ impl Console { ReadlineEvent::Line(string) => self.clone().handle(string).await, ReadlineEvent::Interrupted => continue, ReadlineEvent::Eof => break, + ReadlineEvent::Quit => services() + .server + .shutdown() + .unwrap_or_else(error::default_log), }, Err(error) => match error { ReadlineError::Closed => break, @@ -112,11 +114,11 @@ impl Console { self.worker_join.lock().expect("locked").take(); } - #[allow(clippy::let_underscore_must_use)] async fn readline(self: &Arc) -> Result { let _suppression = log::Suppress::new(&services().server); let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?; + readline.set_tab_completer(Self::tab_complete); self.set_history(&mut readline); let future = readline.readline(); @@ -136,7 +138,6 @@ impl Console { result } - #[allow(clippy::let_underscore_must_use)] async fn handle(self: Arc, line: String) { if line.trim().is_empty() { return; @@ -184,6 +185,13 @@ impl Console { history.push_front(line); history.truncate(HISTORY_LIMIT); } + + fn tab_complete(line: &str) -> String { + services() + .admin + .complete_command(line) + .unwrap_or_else(|| line.to_owned()) + } } fn configure_output(mut output: MadSkin) -> MadSkin { diff --git a/src/service/admin/create.rs b/src/service/admin/create.rs index ad70fe0c..fbb6a078 100644 --- a/src/service/admin/create.rs +++ b/src/service/admin/create.rs @@ -34,32 +34,27 @@ pub async fn create_admin_room() -> Result<()> { let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; // Create a user for the server let server_user = &services().globals.server_user; services().users.create(server_user, None)?; let room_version = services().globals.default_room_version(); - let mut content = match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(server_user.clone()), - RoomVersionId::V11 => RoomCreateEventContent::new_v11(), - _ => { - warn!("Unexpected or unsupported room version {}", room_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - }, + + let mut content = { + use RoomVersionId::*; + match room_version { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => RoomCreateEventContent::new_v1(server_user.clone()), + V11 => RoomCreateEventContent::new_v11(), + _ => { + warn!("Unexpected or unsupported room version {}", room_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } }; content.federate = true; diff --git a/src/service/admin/grant.rs b/src/service/admin/grant.rs index ca48ce0d..9a4ef242 100644 --- a/src/service/admin/grant.rs +++ b/src/service/admin/grant.rs @@ -22,7 +22,7 @@ use crate::{pdu::PduBuilder, services}; /// In conduit, this is equivalent to granting admin privileges. pub async fn make_user_admin(user_id: &UserId, displayname: String) -> Result<()> { if let Some(room_id) = Service::get_admin_room()? { - let state_lock = services().globals.roomid_mutex_state.lock(&room_id).await; + let state_lock = services().rooms.state.mutex.lock(&room_id).await; // Use the server user to grant the new admin's power level let server_user = &services().globals.server_user; diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index b5cb8af4..f63ebf09 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -2,11 +2,15 @@ pub mod console; mod create; mod grant; -use std::{future::Future, pin::Pin, sync::Arc}; +use std::{ + future::Future, + pin::Pin, + sync::{Arc, RwLock as StdRwLock}, +}; -use conduit::{error, utils::mutex_map, Error, Result, Server}; +use async_trait::async_trait; +use conduit::{debug, error, error::default_log, Error, Result}; pub use create::create_admin_room; -use database::Database; pub use grant::make_user_admin; use loole::{Receiver, Sender}; use ruma::{ @@ -17,22 +21,15 @@ use ruma::{ OwnedEventId, OwnedRoomId, RoomId, UserId, }; use serde_json::value::to_raw_value; -use tokio::{sync::Mutex, task::JoinHandle}; +use tokio::sync::{Mutex, RwLock}; -use crate::{pdu::PduBuilder, services, user_is_local, PduEvent}; - -const COMMAND_QUEUE_LIMIT: usize = 512; - -pub type CommandOutput = Option; -pub type CommandResult = Result; -pub type HandlerResult = Pin + Send>>; -pub type Handler = fn(Command) -> HandlerResult; +use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, user_is_local, PduEvent}; pub struct Service { sender: Sender, receiver: Mutex>, - handler_join: Mutex>>, - pub handle: Mutex>, + pub handle: RwLock>, + pub complete: StdRwLock>, #[cfg(feature = "console")] pub console: Arc, } @@ -43,32 +40,52 @@ pub struct Command { pub reply_id: Option, } -impl Service { - pub fn build(_server: &Arc, _db: &Arc) -> Result> { +pub type Completer = fn(&str) -> String; +pub type Handler = fn(Command) -> HandlerResult; +pub type HandlerResult = Pin + Send>>; +pub type CommandResult = Result; +pub type CommandOutput = Option; + +const COMMAND_QUEUE_LIMIT: usize = 512; + +#[async_trait] +impl crate::Service for Service { + fn build(_args: crate::Args<'_>) -> Result> { let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); Ok(Arc::new(Self { sender, receiver: Mutex::new(receiver), - handler_join: Mutex::new(None), - handle: Mutex::new(None), + handle: RwLock::new(None), + complete: StdRwLock::new(None), #[cfg(feature = "console")] console: console::Console::new(), })) } - pub async fn start_handler(self: &Arc) { - let self_ = Arc::clone(self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to initialize admin room handler"); - }); + async fn worker(self: Arc) -> Result<()> { + let receiver = self.receiver.lock().await; + let mut signals = services().server.signal.subscribe(); + loop { + tokio::select! { + command = receiver.recv_async() => match command { + Ok(command) => self.handle_command(command).await, + Err(_) => break, + }, + sig = signals.recv() => match sig { + Ok(sig) => self.handle_signal(sig).await, + Err(_) => continue, + }, + } + } - _ = self.handler_join.lock().await.insert(handle); + //TODO: not unwind safe + #[cfg(feature = "console")] + self.console.close().await; + + Ok(()) } - pub fn interrupt(&self) { + fn interrupt(&self) { #[cfg(feature = "console")] self.console.interrupt(); @@ -77,19 +94,10 @@ impl Service { } } - pub async fn close(&self) { - self.interrupt(); - - #[cfg(feature = "console")] - self.console.close().await; - - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); - } - } - } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} +impl Service { pub async fn send_text(&self, body: &str) { self.send_message(RoomMessageEventContent::text_markdown(body)) .await; @@ -120,41 +128,33 @@ impl Service { .await } + pub fn complete_command(&self, command: &str) -> Option { + self.complete + .read() + .expect("locked for reading") + .map(|complete| complete(command)) + } + async fn send(&self, message: Command) { debug_assert!(!self.sender.is_closed(), "channel closed"); self.sender.send_async(message).await.expect("message sent"); } - async fn handler(self: &Arc) -> Result<()> { - let receiver = self.receiver.lock().await; - let mut signals = services().server.signal.subscribe(); - loop { - tokio::select! { - command = receiver.recv_async() => match command { - Ok(command) => self.handle_command(command).await, - Err(_) => return Ok(()), - }, - sig = signals.recv() => match sig { - Ok(sig) => self.handle_signal(sig).await, - Err(_) => continue, - }, - } - } - } - async fn handle_signal(&self, #[allow(unused_variables)] sig: &'static str) { #[cfg(feature = "console")] self.console.handle_signal(sig).await; } async fn handle_command(&self, command: Command) { - if let Ok(Some(output)) = self.process_command(command).await { - handle_response(output).await; + match self.process_command(command).await { + Ok(Some(output)) => handle_response(output).await, + Ok(None) => debug!("Command successful with no response"), + Err(e) => error!("Command processing error: {e}"), } } async fn process_command(&self, command: Command) -> CommandResult { - if let Some(handle) = self.handle.lock().await.as_ref() { + if let Some(handle) = self.handle.read().await.as_ref() { handle(command).await } else { Err(Error::Err("Admin module is not loaded.".into())) @@ -224,7 +224,7 @@ async fn respond_to_room(content: RoomMessageEventContent, room_id: &RoomId, use "sender is not admin" ); - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; let response_pdu = PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -239,14 +239,14 @@ async fn respond_to_room(content: RoomMessageEventContent, room_id: &RoomId, use .build_and_append_pdu(response_pdu, user_id, room_id, &state_lock) .await { - if let Err(e) = handle_response_error(&e, room_id, user_id, &state_lock).await { - error!("{e}"); - } + handle_response_error(e, room_id, user_id, &state_lock) + .await + .unwrap_or_else(default_log); } } async fn handle_response_error( - e: &Error, room_id: &RoomId, user_id: &UserId, state_lock: &mutex_map::Guard<()>, + e: Error, room_id: &RoomId, user_id: &UserId, state_lock: &RoomMutexGuard, ) -> Result<()> { error!("Failed to build and append admin room response PDU: \"{e}\""); let error_room_message = RoomMessageEventContent::text_plain(format!( diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 916e7df3..24c9b8b0 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -2,9 +2,8 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Result, Server}; +use conduit::{err, Result}; use data::Data; -use database::Database; use futures_util::Future; use regex::RegexSet; use ruma::{ @@ -24,6 +23,7 @@ pub struct NamespaceRegex { impl NamespaceRegex { /// Checks if this namespace has rights to a namespace + #[inline] #[must_use] pub fn is_match(&self, heystack: &str) -> bool { if self.is_exclusive_match(heystack) { @@ -39,6 +39,7 @@ impl NamespaceRegex { } /// Checks if this namespace has exlusive rights to a namespace + #[inline] #[must_use] pub fn is_exclusive_match(&self, heystack: &str) -> bool { if let Some(exclusive) = &self.exclusive { @@ -56,6 +57,7 @@ impl RegistrationInfo { self.users.is_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() } + #[inline] #[must_use] pub fn is_exclusive_user_match(&self, user_id: &UserId) -> bool { self.users.is_exclusive_match(user_id.as_str()) || self.registration.sender_localpart == user_id.localpart() @@ -119,10 +121,10 @@ pub struct Service { registration_info: RwLock>, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { let mut registration_info = BTreeMap::new(); - let db = Data::new(db); + let db = Data::new(args.db); // Inserting registrations into cache for appservice in iter_ids(&db)? { registration_info.insert( @@ -134,12 +136,17 @@ impl Service { ); } - Ok(Self { + Ok(Arc::new(Self { db, registration_info: RwLock::new(registration_info), - }) + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[inline] pub fn all(&self) -> Result> { iter_ids(&self.db) } /// Registers an appservice and returns the ID to the caller @@ -164,7 +171,7 @@ impl Service { .write() .await .remove(service_name) - .ok_or_else(|| crate::Error::Err("Appservice not found".to_owned()))?; + .ok_or(err!("Appservice not found"))?; // remove the appservice from the database self.db.unregister_appservice(service_name)?; diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 022762f4..281c2a94 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -14,9 +14,6 @@ use ruma::{ use crate::services; -const COUNTER: &[u8] = b"c"; -const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; - pub struct Data { global: Arc, todeviceid_events: Arc, @@ -35,6 +32,8 @@ pub struct Data { counter: RwLock, } +const COUNTER: &[u8] = b"c"; + impl Data { pub(super) fn new(db: &Arc) -> Self { Self { @@ -57,6 +56,7 @@ impl Data { } pub fn next_count(&self) -> Result { + let _cork = self.db.cork(); let mut lock = self.counter.write().expect("locked"); let counter: &mut u64 = &mut lock; debug_assert!( @@ -64,7 +64,10 @@ impl Data { "counter mismatch" ); - *counter = counter.wrapping_add(1); + *counter = counter + .checked_add(1) + .expect("counter must not overflow u64"); + self.global.insert(COUNTER, &counter.to_be_bytes())?; Ok(*counter) @@ -89,23 +92,7 @@ impl Data { .map_or(Ok(0_u64), utils::u64_from_bytes) } - pub fn last_check_for_updates_id(&self) -> Result { - self.global - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) - } - - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.global - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - - Ok(()) - } - - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); @@ -206,43 +193,6 @@ impl Data { Ok(()) } - pub fn cleanup(&self) -> Result<()> { self.db.db.cleanup() } - - pub fn memory_usage(&self) -> String { - let (auth_chain_cache, max_auth_chain_cache) = services().rooms.auth_chain.get_cache_usage(); - let (appservice_in_room_cache, max_appservice_in_room_cache) = services() - .rooms - .state_cache - .get_appservice_in_room_cache_usage(); - let (lasttimelinecount_cache, max_lasttimelinecount_cache) = services() - .rooms - .timeline - .get_lasttimelinecount_cache_usage(); - - format!( - "auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache}\nappservice_in_room_cache: \ - {appservice_in_room_cache} / {max_appservice_in_room_cache}\nlasttimelinecount_cache: \ - {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n{}", - self.db.db.memory_usage().unwrap_or_default() - ) - } - - #[allow(clippy::unused_self)] - pub fn clear_caches(&self, amount: u32) { - if amount > 1 { - services().rooms.auth_chain.clear_cache(); - } - if amount > 2 { - services() - .rooms - .state_cache - .clear_appservice_in_room_cache(); - } - if amount > 3 { - services().rooms.timeline.clear_lasttimelinecount_cache(); - } - } - pub fn load_keypair(&self) -> Result { let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { @@ -275,8 +225,16 @@ impl Data { }) } + #[inline] pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored + /// in room version > 4 + /// + /// Remove the outdated keys and insert the new ones. + /// + /// This doesn't actually check that the keys provided are newer than the + /// old set. pub fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> Result> { @@ -340,14 +298,18 @@ impl Data { }) } + #[inline] pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.global.insert(b"version", &new_version.to_be_bytes())?; Ok(()) } + #[inline] pub fn backup(&self) -> Result<(), Box> { self.db.db.backup() } + #[inline] pub fn backup_list(&self) -> Result { self.db.db.backup_list() } + #[inline] pub fn file_list(&self) -> Result { self.db.db.file_list() } } diff --git a/src/service/globals/migrations.rs b/src/service/globals/migrations.rs index 5cc25a1d..3948d1f5 100644 --- a/src/service/globals/migrations.rs +++ b/src/service/globals/migrations.rs @@ -50,7 +50,10 @@ pub(crate) async fn migrations(db: &Arc, config: &Config) -> Result<() } async fn fresh(db: &Arc, config: &Config) -> Result<()> { - services().globals.bump_database_version(DATABASE_VERSION)?; + services() + .globals + .db + .bump_database_version(DATABASE_VERSION)?; db["global"].insert(b"fix_bad_double_separator_in_state_cache", &[])?; db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", &[])?; @@ -68,57 +71,57 @@ async fn fresh(db: &Arc, config: &Config) -> Result<()> { /// Apply any migrations async fn migrate(db: &Arc, config: &Config) -> Result<()> { - if services().globals.database_version()? < 1 { + if services().globals.db.database_version()? < 1 { db_lt_1(db, config).await?; } - if services().globals.database_version()? < 2 { + if services().globals.db.database_version()? < 2 { db_lt_2(db, config).await?; } - if services().globals.database_version()? < 3 { + if services().globals.db.database_version()? < 3 { db_lt_3(db, config).await?; } - if services().globals.database_version()? < 4 { + if services().globals.db.database_version()? < 4 { db_lt_4(db, config).await?; } - if services().globals.database_version()? < 5 { + if services().globals.db.database_version()? < 5 { db_lt_5(db, config).await?; } - if services().globals.database_version()? < 6 { + if services().globals.db.database_version()? < 6 { db_lt_6(db, config).await?; } - if services().globals.database_version()? < 7 { + if services().globals.db.database_version()? < 7 { db_lt_7(db, config).await?; } - if services().globals.database_version()? < 8 { + if services().globals.db.database_version()? < 8 { db_lt_8(db, config).await?; } - if services().globals.database_version()? < 9 { + if services().globals.db.database_version()? < 9 { db_lt_9(db, config).await?; } - if services().globals.database_version()? < 10 { + if services().globals.db.database_version()? < 10 { db_lt_10(db, config).await?; } - if services().globals.database_version()? < 11 { + if services().globals.db.database_version()? < 11 { db_lt_11(db, config).await?; } - if services().globals.database_version()? < 12 { + if services().globals.db.database_version()? < 12 { db_lt_12(db, config).await?; } // This migration can be reused as-is anytime the server-default rules are // updated. - if services().globals.database_version()? < 13 { + if services().globals.db.database_version()? < 13 { db_lt_13(db, config).await?; } @@ -143,10 +146,10 @@ async fn migrate(db: &Arc, config: &Config) -> Result<()> { } assert_eq!( - services().globals.database_version().unwrap(), + services().globals.db.database_version().unwrap(), DATABASE_VERSION, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", - services().globals.database_version().unwrap(), + services().globals.db.database_version().unwrap(), DATABASE_VERSION, ); @@ -225,7 +228,7 @@ async fn db_lt_1(db: &Arc, _config: &Config) -> Result<()> { serverroomids.insert(&serverroomid, &[])?; } - services().globals.bump_database_version(1)?; + services().globals.db.bump_database_version(1)?; info!("Migration: 0 -> 1 finished"); Ok(()) } @@ -242,7 +245,7 @@ async fn db_lt_2(db: &Arc, _config: &Config) -> Result<()> { } } - services().globals.bump_database_version(2)?; + services().globals.db.bump_database_version(2)?; info!("Migration: 1 -> 2 finished"); Ok(()) } @@ -262,7 +265,7 @@ async fn db_lt_3(db: &Arc, _config: &Config) -> Result<()> { mediaid_file.insert(&key, &[])?; } - services().globals.bump_database_version(3)?; + services().globals.db.bump_database_version(3)?; info!("Migration: 2 -> 3 finished"); Ok(()) } @@ -285,7 +288,7 @@ async fn db_lt_4(_db: &Arc, config: &Config) -> Result<()> { } } - services().globals.bump_database_version(4)?; + services().globals.db.bump_database_version(4)?; info!("Migration: 3 -> 4 finished"); Ok(()) } @@ -309,7 +312,7 @@ async fn db_lt_5(db: &Arc, _config: &Config) -> Result<()> { roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; } - services().globals.bump_database_version(5)?; + services().globals.db.bump_database_version(5)?; info!("Migration: 4 -> 5 finished"); Ok(()) } @@ -323,7 +326,7 @@ async fn db_lt_6(db: &Arc, _config: &Config) -> Result<()> { services().rooms.state_cache.update_joined_count(room_id)?; } - services().globals.bump_database_version(6)?; + services().globals.db.bump_database_version(6)?; info!("Migration: 5 -> 6 finished"); Ok(()) } @@ -448,7 +451,7 @@ async fn db_lt_7(db: &Arc, _config: &Config) -> Result<()> { )?; } - services().globals.bump_database_version(7)?; + services().globals.db.bump_database_version(7)?; info!("Migration: 6 -> 7 finished"); Ok(()) } @@ -514,7 +517,7 @@ async fn db_lt_8(db: &Arc, _config: &Config) -> Result<()> { eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?; - services().globals.bump_database_version(8)?; + services().globals.db.bump_database_version(8)?; info!("Migration: 7 -> 8 finished"); Ok(()) } @@ -571,7 +574,7 @@ async fn db_lt_9(db: &Arc, _config: &Config) -> Result<()> { tokenids.remove(&key)?; } - services().globals.bump_database_version(9)?; + services().globals.db.bump_database_version(9)?; info!("Migration: 8 -> 9 finished"); Ok(()) } @@ -590,7 +593,7 @@ async fn db_lt_10(db: &Arc, _config: &Config) -> Result<()> { services().users.mark_device_key_update(&user_id)?; } - services().globals.bump_database_version(10)?; + services().globals.db.bump_database_version(10)?; info!("Migration: 9 -> 10 finished"); Ok(()) } @@ -601,7 +604,7 @@ async fn db_lt_11(_db: &Arc, _config: &Config) -> Result<()> { //let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"]; //userdevicesessionid_uiaarequest.clear()?; - services().globals.bump_database_version(11)?; + services().globals.db.bump_database_version(11)?; info!("Migration: 10 -> 11 finished"); Ok(()) } @@ -669,7 +672,7 @@ async fn db_lt_12(_db: &Arc, config: &Config) -> Result<()> { )?; } - services().globals.bump_database_version(12)?; + services().globals.db.bump_database_version(12)?; info!("Migration: 11 -> 12 finished"); Ok(()) } @@ -706,7 +709,7 @@ async fn db_lt_13(_db: &Arc, config: &Config) -> Result<()> { )?; } - services().globals.bump_database_version(13)?; + services().globals.db.bump_database_version(13)?; info!("Migration: 12 -> 13 finished"); Ok(()) } @@ -736,8 +739,8 @@ async fn migrate_sha256_media(db: &Arc, _config: &Config) -> Result<() // Apply fix from when sha256_media was backward-incompat and bumped the schema // version from 13 to 14. For users satisfying these conditions we can go back. - if services().globals.database_version()? == 14 && DATABASE_VERSION == 13 { - services().globals.bump_database_version(13)?; + if services().globals.db.database_version()? == 14 && DATABASE_VERSION == 13 { + services().globals.db.bump_database_version(13)?; } db["global"].insert(b"feat_sha256_media", &[])?; @@ -833,11 +836,14 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc, _config: &C for (mut key, value) in roomuserid_joined.iter() { iter_count = iter_count.saturating_add(1); debug_info!(%iter_count); - let first_sep_index = key.iter().position(|&i| i == 0xFF).unwrap(); + let first_sep_index = key + .iter() + .position(|&i| i == 0xFF) + .expect("found 0xFF delim"); if key .iter() - .get(first_sep_index..=first_sep_index + 1) + .get(first_sep_index..=first_sep_index.saturating_add(1)) .copied() .collect_vec() == vec![0xFF, 0xFF] diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 7e1ccb5e..0a0d0d8e 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,41 +1,32 @@ mod client; mod data; -pub(super) mod emerg_access; +mod emerg_access; pub(super) mod migrations; -mod resolver; -pub(super) mod updates; +pub(crate) mod resolver; use std::{ collections::{BTreeMap, HashMap}, - sync::Arc, + fmt::Write, + sync::{Arc, RwLock}, time::Instant, }; -use conduit::{error, trace, utils::MutexMap, Config, Result, Server}; +use async_trait::async_trait; +use conduit::{error, trace, Config, Result}; use data::Data; -use database::Database; -use hickory_resolver::TokioAsyncResolver; use ipaddress::IPAddress; use regex::RegexSet; use ruma::{ - api::{ - client::discovery::discover_support::ContactRole, - federation::discovery::{ServerSigningKeys, VerifyKey}, - }, + api::{client::discovery::discover_support::ContactRole, federation::discovery::VerifyKey}, serde::Base64, - DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, - RoomAliasId, RoomVersionId, ServerName, UserId, -}; -use tokio::{ - sync::{Mutex, RwLock}, - task::JoinHandle, + DeviceId, OwnedEventId, OwnedRoomAliasId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomAliasId, + RoomVersionId, ServerName, UserId, }; +use tokio::sync::Mutex; use url::Url; use crate::services; -type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries - pub struct Service { pub db: Data, @@ -50,20 +41,18 @@ pub struct Service { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub bad_query_ratelimiter: Arc>>, - pub roomid_mutex_insert: MutexMap, - pub roomid_mutex_state: MutexMap, - pub roomid_mutex_federation: MutexMap, - pub roomid_federationhandletime: RwLock>, - pub updates_handle: Mutex>>, pub stateres_mutex: Arc>, pub server_user: OwnedUserId, pub admin_alias: OwnedRoomAliasId, } -impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result { - let config = &server.config; - let db = Data::new(db); +type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let db = Data::new(args.db); let keypair = db.load_keypair(); let keypair = match keypair { @@ -114,11 +103,6 @@ impl Service { bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - roomid_mutex_state: MutexMap::::new(), - roomid_mutex_insert: MutexMap::::new(), - roomid_mutex_federation: MutexMap::::new(), - roomid_federationhandletime: RwLock::new(HashMap::new()), - updates_handle: Mutex::new(None), stateres_mutex: Arc::new(Mutex::new(())), admin_alias: RoomAliasId::parse(format!("#admins:{}", &config.server_name)) .expect("#admins:server_name is valid alias name"), @@ -134,9 +118,65 @@ impl Service { s.config.default_room_version = crate::config::default_default_room_version(); }; - Ok(s) + Ok(Arc::new(s)) } + async fn worker(self: Arc) -> Result<()> { + emerg_access::init_emergency_access(); + + Ok(()) + } + + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + self.resolver.memory_usage(out)?; + + let bad_event_ratelimiter = self + .bad_event_ratelimiter + .read() + .expect("locked for reading") + .len(); + writeln!(out, "bad_event_ratelimiter: {bad_event_ratelimiter}")?; + + let bad_query_ratelimiter = self + .bad_query_ratelimiter + .read() + .expect("locked for reading") + .len(); + writeln!(out, "bad_query_ratelimiter: {bad_query_ratelimiter}")?; + + let bad_signature_ratelimiter = self + .bad_signature_ratelimiter + .read() + .expect("locked for reading") + .len(); + writeln!(out, "bad_signature_ratelimiter: {bad_signature_ratelimiter}")?; + + Ok(()) + } + + fn clear_cache(&self) { + self.resolver.clear_cache(); + + self.bad_event_ratelimiter + .write() + .expect("locked for writing") + .clear(); + + self.bad_query_ratelimiter + .write() + .expect("locked for writing") + .clear(); + + self.bad_signature_ratelimiter + .write() + .expect("locked for writing") + .clear(); + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Returns this server's keypair. pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } @@ -146,20 +186,12 @@ impl Service { #[inline] pub fn current_count(&self) -> Result { Ok(self.db.current_count()) } - #[tracing::instrument(skip(self))] - pub fn last_check_for_updates_id(&self) -> Result { self.db.last_check_for_updates_id() } - - #[tracing::instrument(skip(self))] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { self.db.watch(user_id, device_id).await } pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } - pub fn max_request_size(&self) -> u32 { self.config.max_request_size } - pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } pub fn allow_registration(&self) -> bool { self.config.allow_registration } @@ -184,6 +216,7 @@ impl Service { pub fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions } + #[inline] pub fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() } pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix } @@ -194,8 +227,6 @@ impl Service { pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } - pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver } - pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } pub fn turn_password(&self) -> &String { &self.config.turn_password } @@ -273,19 +304,6 @@ impl Service { room_versions } - /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored - /// in room version > 4 - /// - /// Remove the outdated keys and insert the new ones. - /// - /// This doesn't actually check that the keys provided are newer than the - /// old set. - pub fn add_signing_key( - &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result> { - self.db.add_signing_key(origin, new_keys) - } - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { @@ -304,14 +322,11 @@ impl Service { Ok(keys) } - pub fn database_version(&self) -> Result { self.db.database_version() } - - pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.db.bump_database_version(new_version) } - pub fn well_known_client(&self) -> &Option { &self.config.well_known.client } pub fn well_known_server(&self) -> &Option { &self.config.well_known.server } + #[inline] pub fn valid_cidr_range(&self, ip: &IPAddress) -> bool { for cidr in &self.cidr_range_denylist { if cidr.includes(ip) { diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index 86fa6700..3002decf 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -1,41 +1,39 @@ use std::{ collections::HashMap, + fmt::Write, future, iter, net::{IpAddr, SocketAddr}, sync::{Arc, RwLock}, time::Duration, }; -use conduit::{error, Config, Error}; +use conduit::{error, Config, Result}; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use ruma::OwnedServerName; -use crate::sending::FedDest; +use crate::sending::{CachedDest, CachedOverride}; -type WellKnownMap = HashMap; -type TlsNameMap = HashMap, u16)>; +type WellKnownMap = HashMap; +type TlsNameMap = HashMap; pub struct Resolver { pub destinations: Arc>, // actual_destination, host pub overrides: Arc>, - pub resolver: Arc, - pub hooked: Arc, + pub(crate) resolver: Arc, + pub(crate) hooked: Arc, } -pub struct Hooked { - pub overrides: Arc>, - pub resolver: Arc, +pub(crate) struct Hooked { + overrides: Arc>, + resolver: Arc, } impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] - pub fn new(config: &Config) -> Self { + pub(super) fn new(config: &Config) -> Self { let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() - .map_err(|e| { - error!("Failed to set up hickory dns resolver with system config: {e}"); - Error::bad_config("Failed to set up hickory dns resolver with system config.") - }) + .inspect_err(|e| error!("Failed to set up hickory dns resolver with system config: {e}")) .expect("DNS system config must be valid"); let mut conf = hickory_resolver::config::ResolverConfig::new(); @@ -92,6 +90,22 @@ impl Resolver { }), } } + + pub(super) fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); + writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; + + let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); + writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; + + Ok(()) + } + + pub(super) fn clear_cache(&self) { + self.overrides.write().expect("write locked").clear(); + self.destinations.write().expect("write locked").clear(); + self.resolver.clear_cache(); + } } impl Resolve for Resolver { @@ -100,15 +114,16 @@ impl Resolve for Resolver { impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let addr_port = self + let cached = self .overrides .read() .expect("locked for reading") .get(name.as_str()) + .filter(|cached| cached.valid()) .cloned(); - if let Some((addr, port)) = addr_port { - cached_to_reqwest(&addr, port) + if let Some(cached) = cached { + cached_to_reqwest(&cached.ips, cached.port) } else { resolve_to_reqwest(self.resolver.clone(), name) } diff --git a/src/service/globals/updates.rs b/src/service/globals/updates.rs deleted file mode 100644 index c6ac9fff..00000000 --- a/src/service/globals/updates.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::time::Duration; - -use ruma::events::room::message::RoomMessageEventContent; -use serde::Deserialize; -use tokio::{task::JoinHandle, time::interval}; -use tracing::{error, warn}; - -use crate::{ - conduit::{Error, Result}, - services, -}; - -const CHECK_FOR_UPDATES_URL: &str = "https://pupbrain.dev/check-for-updates/stable"; -const CHECK_FOR_UPDATES_INTERVAL: u64 = 7200; // 2 hours - -#[derive(Deserialize)] -struct CheckForUpdatesResponseEntry { - id: u64, - date: String, - message: String, -} -#[derive(Deserialize)] -struct CheckForUpdatesResponse { - updates: Vec, -} - -#[tracing::instrument] -pub fn start_check_for_updates_task() -> JoinHandle<()> { - let timer_interval = Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL); - - services().server.runtime().spawn(async move { - let mut i = interval(timer_interval); - - loop { - i.tick().await; - - if let Err(e) = try_handle_updates().await { - warn!(%e, "Failed to check for updates"); - } - } - }) -} - -#[tracing::instrument(skip_all)] -async fn try_handle_updates() -> Result<()> { - let response = services() - .globals - .client - .default - .get(CHECK_FOR_UPDATES_URL) - .send() - .await?; - - let response = serde_json::from_str::(&response.text().await?) - .map_err(|e| Error::Err(format!("Bad check for updates response: {e}")))?; - - let mut last_update_id = services().globals.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > services().globals.last_check_for_updates_id()? { - error!("{}", update.message); - services() - .admin - .send_message(RoomMessageEventContent::text_plain(format!( - "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", - update.date, update.message - ))) - .await; - } - } - services() - .globals - .update_check_for_updates_id(last_update_id)?; - - Ok(()) -} diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 650aa6b6..d83d4497 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -2,9 +2,8 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, @@ -15,13 +14,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { self.db.create_backup(user_id, backup_metadata) } diff --git a/src/service/manager.rs b/src/service/manager.rs new file mode 100644 index 00000000..af59b4a4 --- /dev/null +++ b/src/service/manager.rs @@ -0,0 +1,156 @@ +use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration}; + +use conduit::{debug, debug_warn, error, trace, utils::time, warn, Err, Error, Result, Server}; +use futures_util::FutureExt; +use tokio::{ + sync::{Mutex, MutexGuard}, + task::{JoinHandle, JoinSet}, + time::sleep, +}; + +use crate::{service::Service, Services}; + +pub(crate) struct Manager { + manager: Mutex>>>, + workers: Mutex, + server: Arc, + services: &'static Services, +} + +type Workers = JoinSet; +type WorkerResult = (Arc, Result<()>); +type WorkersLocked<'a> = MutexGuard<'a, Workers>; + +const RESTART_DELAY_MS: u64 = 2500; + +impl Manager { + pub(super) fn new(services: &Services) -> Arc { + Arc::new(Self { + manager: Mutex::new(None), + workers: Mutex::new(JoinSet::new()), + server: services.server.clone(), + services: crate::services(), + }) + } + + pub(super) async fn poll(&self) -> Result<()> { + if let Some(manager) = &mut *self.manager.lock().await { + trace!("Polling service manager..."); + return manager.await?; + } + + Ok(()) + } + + pub(super) async fn start(self: Arc) -> Result<()> { + let mut workers = self.workers.lock().await; + + debug!("Starting service manager..."); + let self_ = self.clone(); + _ = self.manager.lock().await.insert( + self.server + .runtime() + .spawn(async move { self_.worker().await }), + ); + + debug!("Starting service workers..."); + for service in self.services.service.values() { + self.start_worker(&mut workers, service).await?; + } + + Ok(()) + } + + pub(super) async fn stop(&self) { + if let Some(manager) = self.manager.lock().await.take() { + debug!("Waiting for service manager..."); + if let Err(e) = manager.await { + error!("Manager shutdown error: {e:?}"); + } + } + } + + async fn worker(&self) -> Result<()> { + loop { + let mut workers = self.workers.lock().await; + tokio::select! { + result = workers.join_next() => match result { + Some(Ok(result)) => self.handle_result(&mut workers, result).await?, + Some(Err(error)) => self.handle_abort(&mut workers, Error::from(error)).await?, + None => break, + } + } + } + + debug!("Worker manager finished"); + Ok(()) + } + + async fn handle_abort(&self, _workers: &mut WorkersLocked<'_>, error: Error) -> Result<()> { + // not supported until service can be associated with abort + unimplemented!("unexpected worker task abort {error:?}"); + } + + async fn handle_result(&self, workers: &mut WorkersLocked<'_>, result: WorkerResult) -> Result<()> { + let (service, result) = result; + match result { + Ok(()) => self.handle_finished(workers, &service).await, + Err(error) => self.handle_error(workers, &service, error).await, + } + } + + async fn handle_finished(&self, _workers: &mut WorkersLocked<'_>, service: &Arc) -> Result<()> { + debug!("service {:?} worker finished", service.name()); + Ok(()) + } + + async fn handle_error( + &self, workers: &mut WorkersLocked<'_>, service: &Arc, error: Error, + ) -> Result<()> { + let name = service.name(); + error!("service {name:?} aborted: {error}"); + + if !self.server.running() { + debug_warn!("service {name:?} error ignored on shutdown."); + return Ok(()); + } + + if !error.is_panic() { + return Err(error); + } + + let delay = Duration::from_millis(RESTART_DELAY_MS); + warn!("service {name:?} worker restarting after {} delay", time::pretty(delay)); + sleep(delay).await; + + self.start_worker(workers, service).await + } + + /// Start the worker in a task for the service. + async fn start_worker(&self, workers: &mut WorkersLocked<'_>, service: &Arc) -> Result<()> { + if !self.server.running() { + return Err!("Service {:?} worker not starting during server shutdown.", service.name()); + } + + debug!("Service {:?} worker starting...", service.name()); + workers.spawn_on(worker(service.clone()), self.server.runtime()); + + Ok(()) + } +} + +/// Base frame for service worker. This runs in a tokio::task. All errors and +/// panics from the worker are caught and returned cleanly. The JoinHandle +/// should never error with a panic, and if so it should propagate, but it may +/// error with an Abort which the manager should handle along with results to +/// determine if the worker should be restarted. +async fn worker(service: Arc) -> WorkerResult { + let service_ = Arc::clone(&service); + let result = AssertUnwindSafe(service_.worker()) + .catch_unwind() + .await + .map_err(Error::from_panic); + + // flattens JoinError for panic into worker's Error + (service, result.unwrap_or_else(Err)) +} diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 4cabf167..e5856bbf 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -12,6 +12,13 @@ pub(crate) struct Data { url_previews: Arc, } +#[derive(Debug)] +pub(super) struct Metadata { + pub(super) content_disposition: Option, + pub(super) content_type: Option, + pub(super) key: Vec, +} + impl Data { pub(super) fn new(db: &Arc) -> Self { Self { @@ -104,9 +111,7 @@ impl Data { Ok(keys) } - pub(super) fn search_file_metadata( - &self, mxc: &str, width: u32, height: u32, - ) -> Result<(Option, Option, Vec)> { + pub(super) fn search_file_metadata(&self, mxc: &str, width: u32, height: u32) -> Result { let mut prefix = mxc.as_bytes().to_vec(); prefix.push(0xFF); prefix.extend_from_slice(&width.to_be_bytes()); @@ -141,13 +146,19 @@ impl Data { .map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?, ) }; - Ok((content_disposition, content_type, key)) + + Ok(Metadata { + content_disposition, + content_type, + key, + }) } /// Gets all the media keys in our database (this includes all the metadata /// associated with it such as width, height, content-type, etc) pub(crate) fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + #[inline] pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } pub(super) fn set_url_preview( diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 0323bfc5..1638235b 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,13 +1,13 @@ mod data; mod tests; +mod thumbnail; -use std::{collections::HashMap, io::Cursor, path::PathBuf, sync::Arc, time::SystemTime}; +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::SystemTime}; +use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; -use conduit::{debug, debug_error, error, utils, Error, Result, Server}; -use data::Data; -use database::Database; -use image::imageops::FilterType; +use conduit::{debug, debug_error, err, error, utils, Err, Result, Server}; +use data::{Data, Metadata}; use ruma::{OwnedMxcUri, OwnedUserId}; use serde::Serialize; use tokio::{ @@ -20,10 +20,9 @@ use crate::services; #[derive(Debug)] pub struct FileMeta { - #[allow(dead_code)] - pub content_disposition: Option, + pub content: Option>, pub content_type: Option, - pub file: Vec, + pub content_disposition: Option, } #[derive(Serialize, Default)] @@ -48,15 +47,26 @@ pub struct Service { pub url_preview_mutex: RwLock>>>, } -impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result { - Ok(Self { - server: server.clone(), - db: Data::new(db), +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + server: args.server.clone(), + db: Data::new(args.db), url_preview_mutex: RwLock::new(HashMap::new()), - }) + })) } + async fn worker(self: Arc) -> Result<()> { + self.create_media_dir().await?; + + Ok(()) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Uploads a file. pub async fn create( &self, sender_user: Option, mxc: &str, content_disposition: Option<&str>, @@ -90,47 +100,30 @@ impl Service { Ok(()) } else { - error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)"); - Err(Error::bad_database( - "Failed to find any media keys for the provided MXC in our database (MXC does not exist)", - )) + Err!(Database(error!( + "Failed to find any media keys for MXC {mxc:?} in our database." + ))) } } - /// Uploads or replaces a file thumbnail. - #[allow(clippy::too_many_arguments)] - pub async fn upload_thumbnail( - &self, sender_user: Option, mxc: &str, content_disposition: Option<&str>, - content_type: Option<&str>, width: u32, height: u32, file: &[u8], - ) -> Result<()> { - let key = if let Some(user) = sender_user { - self.db - .create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)? - } else { - self.db - .create_file_metadata(None, mxc, width, height, content_disposition, content_type)? - }; - - //TODO: Dangling metadata in database if creation fails - let mut f = self.create_media_file(&key).await?; - f.write_all(file).await?; - - Ok(()) - } - /// Downloads a file. pub async fn get(&self, mxc: &str) -> Result> { - if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { - let mut file = Vec::new(); + if let Ok(Metadata { + content_disposition, + content_type, + key, + }) = self.db.search_file_metadata(mxc, 0, 0) + { + let mut content = Vec::new(); let path = self.get_media_file(&key); BufReader::new(fs::File::open(path).await?) - .read_to_end(&mut file) + .read_to_end(&mut content) .await?; Ok(Some(FileMeta { - content_disposition, + content: Some(content), content_type, - file, + content_disposition, })) } else { Ok(None) @@ -143,23 +136,16 @@ impl Service { let all_keys = self.db.get_all_media_keys(); let user_duration: SystemTime = match cyborgtime::parse_duration(&time) { - Ok(duration) => { - debug!("Parsed duration: {:?}", duration); - debug!("System time now: {:?}", SystemTime::now()); - SystemTime::now().checked_sub(duration).ok_or_else(|| { - Error::bad_database("Duration specified is not valid against the current system time") - })? - }, - Err(e) => { - error!("Failed to parse user-specified time duration: {}", e); - return Err(Error::bad_database("Failed to parse user-specified time duration.")); - }, + Err(e) => return Err!(Database(error!("Failed to parse specified time duration: {e}"))), + Ok(duration) => SystemTime::now() + .checked_sub(duration) + .ok_or(err!(Arithmetic("Duration {duration:?} is too large")))?, }; let mut remote_mxcs: Vec = vec![]; for key in all_keys { - debug!("Full MXC key from database: {:?}", key); + debug!("Full MXC key from database: {key:?}"); // we need to get the MXC URL from the first part of the key (the first 0xff / // 255 push). this is all necessary because of conduit using magic keys for @@ -168,24 +154,19 @@ impl Service { let mxc = parts .next() .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|e| { - error!("Failed to parse MXC unicode bytes from our database: {}", e); - Error::bad_database("Failed to parse MXC unicode bytes from our database") - }) + utils::string_from_bytes(bytes) + .map_err(|e| err!(Database(error!("Failed to parse MXC unicode bytes from our database: {e}")))) }) .transpose()?; let Some(mxc_s) = mxc else { - return Err(Error::bad_database( - "Parsed MXC URL unicode bytes from database but still is None", - )); + return Err!(Database("Parsed MXC URL unicode bytes from database but still is None")); }; - debug!("Parsed MXC key to URL: {}", mxc_s); - + debug!("Parsed MXC key to URL: {mxc_s}"); let mxc = OwnedMxcUri::from(mxc_s); if mxc.server_name() == Ok(services().globals.server_name()) { - debug!("Ignoring local media MXC: {}", mxc); + debug!("Ignoring local media MXC: {mxc}"); // ignore our own MXC URLs as this would be local media. continue; } @@ -204,14 +185,14 @@ impl Service { }, Err(err) => { if force { - error!("Could not delete MXC path {:?}: {:?}. Skipping...", path, err); + error!("Could not delete MXC path {path:?}: {err:?}. Skipping..."); continue; } return Err(err.into()); }, }; - debug!("File created at: {:?}", file_created_at); + debug!("File created at: {file_created_at:?}"); if file_created_at <= user_duration { debug!("File is within user duration, pushing to list of file paths and keys to delete."); remote_mxcs.push(mxc.to_string()); @@ -221,15 +202,12 @@ impl Service { debug!( "Finished going through all our media in database for eligible keys to delete, checking if these are empty" ); - if remote_mxcs.is_empty() { - return Err(Error::bad_database("Did not found any eligible MXCs to delete.")); + return Err!(Database("Did not found any eligible MXCs to delete.")); } - debug!("Deleting media now in the past \"{:?}\".", user_duration); - + debug!("Deleting media now in the past {user_duration:?}."); let mut deletion_count: usize = 0; - for mxc in remote_mxcs { debug!("Deleting MXC {mxc} from database and filesystem"); self.delete(&mxc).await?; @@ -239,139 +217,6 @@ impl Service { Ok(deletion_count) } - /// Returns width, height of the thumbnail and whether it should be cropped. - /// Returns None when the server should send the original file. - pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { - match (width, height) { - (0..=32, 0..=32) => Some((32, 32, true)), - (0..=96, 0..=96) => Some((96, 96, true)), - (0..=320, 0..=240) => Some((320, 240, false)), - (0..=640, 0..=480) => Some((640, 480, false)), - (0..=800, 0..=600) => Some((800, 600, false)), - _ => None, - } - } - - /// Downloads a file's thumbnail. - /// - /// Here's an example on how it works: - /// - /// - Client requests an image with width=567, height=567 - /// - Server rounds that up to (800, 600), so it doesn't have to save too - /// many thumbnails - /// - Server rounds that up again to (958, 600) to fix the aspect ratio - /// (only for width,height>96) - /// - Server creates the thumbnail and sends it to the user - /// - /// For width,height <= 96 the server uses another thumbnailing algorithm - /// which crops the image afterwards. - pub async fn get_thumbnail(&self, mxc: &str, width: u32, height: u32) -> Result> { - let (width, height, crop) = self - .thumbnail_properties(width, height) - .unwrap_or((0, 0, false)); // 0, 0 because that's the original file - - if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, width, height) { - // Using saved thumbnail - let mut file = Vec::new(); - let path = self.get_media_file(&key); - fs::File::open(path).await?.read_to_end(&mut file).await?; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.clone(), - })) - } else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { - // Generate a thumbnail - let mut file = Vec::new(); - let path = self.get_media_file(&key); - fs::File::open(path).await?.read_to_end(&mut file).await?; - - if let Ok(image) = image::load_from_memory(&file) { - let original_width = image.width(); - let original_height = image.height(); - if width > original_width || height > original_height { - return Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.clone(), - })); - } - - let thumbnail = if crop { - image.resize_to_fill(width, height, FilterType::CatmullRom) - } else { - let (exact_width, exact_height) = { - // Copied from image::dynimage::resize_dimensions - // - // https://github.com/image-rs/image/blob/6edf8ae492c4bb1dacb41da88681ea74dab1bab3/src/math/utils.rs#L5-L11 - // Calculates the width and height an image should be - // resized to. This preserves aspect ratio, and based - // on the `fill` parameter will either fill the - // dimensions to fit inside the smaller constraint - // (will overflow the specified bounds on one axis to - // preserve aspect ratio), or will shrink so that both - // dimensions are completely contained within the given - // `width` and `height`, with empty space on one axis. - let ratio = u64::from(original_width) * u64::from(height); - let nratio = u64::from(width) * u64::from(original_height); - - let use_width = nratio <= ratio; - let intermediate = if use_width { - u64::from(original_height) * u64::from(width) / u64::from(original_width) - } else { - u64::from(original_width) * u64::from(height) / u64::from(original_height) - }; - if use_width { - if u32::try_from(intermediate).is_ok() { - (width, intermediate as u32) - } else { - ((u64::from(width) * u64::from(u32::MAX) / intermediate) as u32, u32::MAX) - } - } else if u32::try_from(intermediate).is_ok() { - (intermediate as u32, height) - } else { - (u32::MAX, (u64::from(height) * u64::from(u32::MAX) / intermediate) as u32) - } - }; - - image.thumbnail_exact(exact_width, exact_height) - }; - - let mut thumbnail_bytes = Vec::new(); - thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageFormat::Png)?; - - // Save thumbnail in database so we don't have to generate it again next time - let thumbnail_key = self.db.create_file_metadata( - None, - mxc, - width, - height, - content_disposition.as_deref(), - content_type.as_deref(), - )?; - - let mut f = self.create_media_file(&thumbnail_key).await?; - f.write_all(&thumbnail_bytes).await?; - - Ok(Some(FileMeta { - content_disposition, - content_type, - file: thumbnail_bytes.clone(), - })) - } else { - // Couldn't parse file to generate thumbnail, send original - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.clone(), - })) - } - } else { - Ok(None) - } - } - pub async fn get_url_preview(&self, url: &str) -> Option { self.db.get_url_preview(url) } /// TODO: use this? diff --git a/src/service/media/thumbnail.rs b/src/service/media/thumbnail.rs new file mode 100644 index 00000000..01bf73f6 --- /dev/null +++ b/src/service/media/thumbnail.rs @@ -0,0 +1,174 @@ +use std::{cmp, io::Cursor, num::Saturating as Sat}; + +use conduit::{checked, Result}; +use image::{imageops::FilterType, DynamicImage}; +use ruma::OwnedUserId; +use tokio::{ + fs, + io::{AsyncReadExt, AsyncWriteExt}, +}; + +use super::{data::Metadata, FileMeta}; + +impl super::Service { + /// Uploads or replaces a file thumbnail. + #[allow(clippy::too_many_arguments)] + pub async fn upload_thumbnail( + &self, sender_user: Option, mxc: &str, content_disposition: Option<&str>, + content_type: Option<&str>, width: u32, height: u32, file: &[u8], + ) -> Result<()> { + let key = if let Some(user) = sender_user { + self.db + .create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)? + } else { + self.db + .create_file_metadata(None, mxc, width, height, content_disposition, content_type)? + }; + + //TODO: Dangling metadata in database if creation fails + let mut f = self.create_media_file(&key).await?; + f.write_all(file).await?; + + Ok(()) + } + + /// Downloads a file's thumbnail. + /// + /// Here's an example on how it works: + /// + /// - Client requests an image with width=567, height=567 + /// - Server rounds that up to (800, 600), so it doesn't have to save too + /// many thumbnails + /// - Server rounds that up again to (958, 600) to fix the aspect ratio + /// (only for width,height>96) + /// - Server creates the thumbnail and sends it to the user + /// + /// For width,height <= 96 the server uses another thumbnailing algorithm + /// which crops the image afterwards. + #[tracing::instrument(skip(self), name = "thumbnail", level = "debug")] + pub async fn get_thumbnail(&self, mxc: &str, width: u32, height: u32) -> Result> { + // 0, 0 because that's the original file + let (width, height, crop) = thumbnail_properties(width, height).unwrap_or((0, 0, false)); + + if let Ok(metadata) = self.db.search_file_metadata(mxc, width, height) { + self.get_thumbnail_saved(metadata).await + } else if let Ok(metadata) = self.db.search_file_metadata(mxc, 0, 0) { + self.get_thumbnail_generate(mxc, width, height, crop, metadata) + .await + } else { + Ok(None) + } + } + + /// Using saved thumbnail + #[tracing::instrument(skip(self), name = "saved", level = "debug")] + async fn get_thumbnail_saved(&self, data: Metadata) -> Result> { + let mut content = Vec::new(); + let path = self.get_media_file(&data.key); + fs::File::open(path) + .await? + .read_to_end(&mut content) + .await?; + + Ok(Some(into_filemeta(data, content))) + } + + /// Generate a thumbnail + #[tracing::instrument(skip(self), name = "generate", level = "debug")] + async fn get_thumbnail_generate( + &self, mxc: &str, width: u32, height: u32, crop: bool, data: Metadata, + ) -> Result> { + let mut content = Vec::new(); + let path = self.get_media_file(&data.key); + fs::File::open(path) + .await? + .read_to_end(&mut content) + .await?; + + let Ok(image) = image::load_from_memory(&content) else { + // Couldn't parse file to generate thumbnail, send original + return Ok(Some(into_filemeta(data, content))); + }; + + if width > image.width() || height > image.height() { + return Ok(Some(into_filemeta(data, content))); + } + + let mut thumbnail_bytes = Vec::new(); + let thumbnail = thumbnail_generate(&image, width, height, crop)?; + thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageFormat::Png)?; + + // Save thumbnail in database so we don't have to generate it again next time + let thumbnail_key = self.db.create_file_metadata( + None, + mxc, + width, + height, + data.content_disposition.as_deref(), + data.content_type.as_deref(), + )?; + + let mut f = self.create_media_file(&thumbnail_key).await?; + f.write_all(&thumbnail_bytes).await?; + + Ok(Some(into_filemeta(data, thumbnail_bytes))) + } +} + +fn thumbnail_generate(image: &DynamicImage, width: u32, height: u32, crop: bool) -> Result { + let thumbnail = if crop { + image.resize_to_fill(width, height, FilterType::CatmullRom) + } else { + let (exact_width, exact_height) = thumbnail_dimension(image, width, height)?; + image.thumbnail_exact(exact_width, exact_height) + }; + + Ok(thumbnail) +} + +fn thumbnail_dimension(image: &DynamicImage, width: u32, height: u32) -> Result<(u32, u32)> { + let image_width = image.width(); + let image_height = image.height(); + + let width = cmp::min(width, image_width); + let height = cmp::min(height, image_height); + + let use_width = Sat(width) * Sat(image_height) < Sat(height) * Sat(image_width); + + let x = if use_width { + let dividend = (Sat(height) * Sat(image_width)).0; + checked!(dividend / image_height)? + } else { + width + }; + + let y = if !use_width { + let dividend = (Sat(width) * Sat(image_height)).0; + checked!(dividend / image_width)? + } else { + height + }; + + Ok((x, y)) +} + +/// Returns width, height of the thumbnail and whether it should be cropped. +/// Returns None when the server should send the original file. +fn thumbnail_properties(width: u32, height: u32) -> Option<(u32, u32, bool)> { + match (width, height) { + (0..=32, 0..=32) => Some((32, 32, true)), + (0..=96, 0..=96) => Some((96, 96, true)), + (0..=320, 0..=240) => Some((320, 240, false)), + (0..=640, 0..=480) => Some((640, 480, false)), + (0..=800, 0..=600) => Some((800, 600, false)), + _ => None, + } +} + +fn into_filemeta(data: Metadata, content: Vec) -> FileMeta { + FileMeta { + content: Some(content), + content_type: data.content_type, + content_disposition: data.content_disposition, + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index c80c9862..81e0be3b 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,4 +1,7 @@ -pub mod pdu; +#![allow(refining_impl_trait)] + +mod manager; +mod service; pub mod services; pub mod account_data; @@ -13,6 +16,7 @@ pub mod rooms; pub mod sending; pub mod transaction_ids; pub mod uiaa; +pub mod updates; pub mod users; extern crate conduit_core as conduit; @@ -20,12 +24,13 @@ extern crate conduit_database as database; use std::sync::{Arc, RwLock}; -pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, PduCount, Result, Server}; +pub(crate) use conduit::{config, debug_error, debug_warn, utils, Config, Error, Result, Server}; +pub use conduit::{pdu, PduBuilder, PduCount, PduEvent}; use database::Database; +pub(crate) use service::{Args, Service}; pub use crate::{ globals::{server_is_ours, user_is_local}, - pdu::PduEvent, services::Services, }; @@ -34,16 +39,17 @@ conduit::mod_dtor! {} static SERVICES: RwLock> = RwLock::new(None); -#[allow(clippy::let_underscore_must_use)] -pub async fn init(server: &Arc) -> Result<()> { +pub async fn start(server: &Arc) -> Result<()> { let d = Arc::new(Database::open(server).await?); - let s = Box::new(Services::build(server.clone(), d.clone()).await?); + let s = Box::new(Services::build(server.clone(), d)?); _ = SERVICES.write().expect("write locked").insert(Box::leak(s)); - Ok(()) + services().start().await } -pub fn fini() { +pub async fn stop() { + services().stop().await; + // Deactivate services(). Any further use will panic the caller. let s = SERVICES .write() diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 5065efdd..254304ba 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -2,9 +2,9 @@ mod data; use std::{sync::Arc, time::Duration}; -use conduit::{debug, error, utils, Error, Result, Server}; +use async_trait::async_trait; +use conduit::{checked, debug, error, utils, Error, Result}; use data::Data; -use database::Database; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ events::presence::{PresenceEvent, PresenceEventContent}, @@ -12,7 +12,7 @@ use ruma::{ OwnedUserId, UInt, UserId, }; use serde::{Deserialize, Serialize}; -use tokio::{sync::Mutex, task::JoinHandle, time::sleep}; +use tokio::{sync::Mutex, time::sleep}; use crate::{services, user_is_local}; @@ -77,51 +77,58 @@ pub struct Service { pub db: Data, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex>, - handler_join: Mutex>>, timeout_remote_users: bool, + idle_timeout: u64, + offline_timeout: u64, } -impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result> { - let config = &server.config; +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let idle_timeout_s = config.presence_idle_timeout_s; + let offline_timeout_s = config.presence_offline_timeout_s; let (timer_sender, timer_receiver) = loole::unbounded(); Ok(Arc::new(Self { - db: Data::new(db), + db: Data::new(args.db), timer_sender, timer_receiver: Mutex::new(timer_receiver), - handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, + idle_timeout: checked!(idle_timeout_s * 1_000)?, + offline_timeout: checked!(offline_timeout_s * 1_000)?, })) } - pub async fn start_handler(self: &Arc) { - let self_ = Arc::clone(self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to start presence handler"); - }); - - _ = self.handler_join.lock().await.insert(handle); - } - - pub async fn close(&self) { - self.interrupt(); - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); + async fn worker(self: Arc) -> Result<()> { + let mut presence_timers = FuturesUnordered::new(); + let receiver = self.timer_receiver.lock().await; + loop { + debug_assert!(!receiver.is_closed(), "channel error"); + tokio::select! { + Some(user_id) = presence_timers.next() => self.process_presence_timer(&user_id)?, + event = receiver.recv_async() => match event { + Err(_e) => return Ok(()), + Ok((user_id, timeout)) => { + debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); + presence_timers.push(presence_timer(user_id, timeout)); + }, + }, } } } - pub fn interrupt(&self) { + fn interrupt(&self) { if !self.timer_sender.is_closed() { self.timer_sender.close(); } } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Returns the latest presence event for the given user. + #[inline] pub fn get_presence(&self, user_id: &UserId) -> Result> { if let Some((_, presence)) = self.db.get_presence(user_id)? { Ok(Some(presence)) @@ -198,26 +205,39 @@ impl Service { /// Returns the most recent presence updates that happened after the event /// with id `since`. + #[inline] pub fn presence_since(&self, since: u64) -> Box)> + '_> { self.db.presence_since(since) } - async fn handler(&self) -> Result<()> { - let mut presence_timers = FuturesUnordered::new(); - let receiver = self.timer_receiver.lock().await; - loop { - debug_assert!(!receiver.is_closed(), "channel error"); - tokio::select! { - Some(user_id) = presence_timers.next() => process_presence_timer(&user_id)?, - event = receiver.recv_async() => match event { - Err(_e) => return Ok(()), - Ok((user_id, timeout)) => { - debug!("Adding timer {}: {user_id} timeout:{timeout:?}", presence_timers.len()); - presence_timers.push(presence_timer(user_id, timeout)); - }, - }, - } + fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> { + let mut presence_state = PresenceState::Offline; + let mut last_active_ago = None; + let mut status_msg = None; + + let presence_event = self.get_presence(user_id)?; + + if let Some(presence_event) = presence_event { + presence_state = presence_event.content.presence; + last_active_ago = presence_event.content.last_active_ago; + status_msg = presence_event.content.status_msg; } + + let new_state = match (&presence_state, last_active_ago.map(u64::from)) { + (PresenceState::Online, Some(ago)) if ago >= self.idle_timeout => Some(PresenceState::Unavailable), + (PresenceState::Unavailable, Some(ago)) if ago >= self.offline_timeout => Some(PresenceState::Offline), + _ => None, + }; + + debug!( + "Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}" + ); + + if let Some(new_state) = new_state { + self.set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; + } + + Ok(()) } } @@ -226,36 +246,3 @@ async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId user_id } - -fn process_presence_timer(user_id: &OwnedUserId) -> Result<()> { - let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000; - let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000; - - let mut presence_state = PresenceState::Offline; - let mut last_active_ago = None; - let mut status_msg = None; - - let presence_event = services().presence.get_presence(user_id)?; - - if let Some(presence_event) = presence_event { - presence_state = presence_event.content.presence; - last_active_ago = presence_event.content.last_active_ago; - status_msg = presence_event.content.status_msg; - } - - let new_state = match (&presence_state, last_active_ago.map(u64::from)) { - (PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable), - (PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline), - _ => None, - }; - - debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"); - - if let Some(new_state) = new_state { - services() - .presence - .set_presence(user_id, &new_state, Some(false), last_active_ago, status_msg)?; - } - - Ok(()) -} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 28280273..ea48ea7c 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -3,9 +3,8 @@ mod data; use std::{fmt::Debug, mem, sync::Arc}; use bytes::BytesMut; -use conduit::{debug_info, info, trace, warn, Error, Result, Server}; +use conduit::{debug_info, info, trace, warn, Error, Result}; use data::Data; -use database::Database; use ipaddress::IPAddress; use ruma::{ api::{ @@ -30,13 +29,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub fn set_pusher(&self, sender: &UserId, pusher: &set_pusher::v3::PusherAction) -> Result<()> { self.db.set_pusher(sender, pusher) } @@ -183,7 +186,7 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, user, ruleset, pdu))] + #[tracing::instrument(skip(self, user, ruleset, pdu), level = "debug")] pub fn get_actions<'a>( &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw, room_id: &RoomId, diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index eb99e543..302c21ae 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -84,7 +84,7 @@ impl Data { pub(super) fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a + Send> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 9f66cf88..792f5c98 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -3,9 +3,8 @@ mod remote; use std::sync::Arc; -use conduit::{Error, Result, Server}; +use conduit::{err, Error, Result}; use data::Data; -use database::Database; use ruma::{ api::{appservice, client::error::ErrorKind}, events::{ @@ -21,13 +20,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { if alias == services().globals.admin_alias && user_id != services().globals.server_user { @@ -86,19 +89,19 @@ impl Service { ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { self.db.resolve_local_alias(alias) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, - ) -> Box> + 'a> { + ) -> Box> + 'a + Send> { self.db.local_aliases_for_room(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { self.db.all_local_aliases() } @@ -168,7 +171,7 @@ impl Service { .rooms .alias .resolve_local_alias(room_alias)? - .ok_or_else(|| Error::bad_config("Room does not exist."))?, + .ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?, )); } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index a5771f4a..4e468234 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{utils, Result, Server}; +use conduit::{utils, utils::math::usize_from_f64, Result, Server}; use database::{Database, Map}; use lru_cache::LruCache; @@ -16,7 +16,7 @@ impl Data { pub(super) fn new(server: &Arc, db: &Arc) -> Self { let config = &server.config; let cache_size = f64::from(config.auth_chain_cache_capacity); - let cache_size = (cache_size * config.conduit_cache_capacity_modifier) as usize; + let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size"); Self { shorteventid_authchain: db["shorteventid_authchain"].clone(), auth_chain_cache: Mutex::new(LruCache::new(cache_size)), diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 6b14fda0..4e8c7bb2 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,10 +5,9 @@ use std::{ sync::Arc, }; -use conduit::{debug, error, trace, warn, Error, Result, Server}; +use conduit::{debug, error, trace, validated, warn, Err, Result}; use data::Data; -use database::Database; -use ruma::{api::client::error::ErrorKind, EventId, RoomId}; +use ruma::{EventId, RoomId}; use crate::services; @@ -16,13 +15,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(server, db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.server, args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub async fn event_ids_iter<'a>( &self, room_id: &RoomId, starting_events_: Vec>, ) -> Result> + 'a> { @@ -45,15 +48,16 @@ impl Service { let started = std::time::Instant::now(); let mut buckets = [BUCKET; NUM_BUCKETS]; - for (i, short) in services() + for (i, &short) in services() .rooms .short .multi_get_or_create_shorteventid(starting_events)? .iter() .enumerate() { - let bucket = short % NUM_BUCKETS as u64; - buckets[bucket as usize].insert((*short, starting_events[i])); + let bucket: usize = short.try_into()?; + let bucket: usize = validated!(bucket % NUM_BUCKETS)?; + buckets[bucket].insert((short, starting_events[i])); } debug!( @@ -139,8 +143,11 @@ impl Service { match services().rooms.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { - error!(?event_id, ?pdu, "auth event for incorrect room_id"); - return Err(Error::BadRequest(ErrorKind::forbidden(), "Evil event in db")); + return Err!(Request(Forbidden( + "auth event {event_id:?} for incorrect room {} which is not {}", + pdu.room_id, + room_id + ))); } for auth_event in &pdu.auth_events { let sauthevent = services() @@ -170,13 +177,13 @@ impl Service { self.db.get_cached_eventid_authchain(key) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { self.db .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) -> Result<()> { self.db .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 87c7cf92..23ec6b6b 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -2,9 +2,7 @@ mod data; use std::sync::Arc; -use conduit::Server; use data::Data; -use database::Database; use ruma::{OwnedRoomId, RoomId}; use crate::Result; @@ -13,22 +11,26 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } - #[tracing::instrument(skip(self))] + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip(self), level = "debug")] pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index ed09e2e6..6cb23b9f 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -2,15 +2,18 @@ mod parse_incoming_pdu; mod signing_keys; use std::{ - cmp, collections::{hash_map, BTreeMap, HashMap, HashSet}, + fmt::Write, pin::Pin, - sync::Arc, - time::{Duration, Instant}, + sync::{Arc, RwLock as StdRwLock}, + time::Instant, }; -use conduit::{debug_error, debug_info, Error, Result, Server}; -use database::Database; +use conduit::{ + debug, debug_error, debug_info, err, error, info, trace, + utils::{math::continue_exponential_backoff_secs, MutexMap}, + warn, Error, Result, +}; use futures_util::Future; pub use parse_incoming_pdu::parse_incoming_pdu; use ruma::{ @@ -27,15 +30,21 @@ use ruma::{ int, serde::Base64, state_res::{self, RoomVersion, StateMap}, - uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedUserId, RoomId, RoomVersionId, ServerName, + uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, + RoomVersionId, ServerName, }; use tokio::sync::RwLock; -use tracing::{debug, error, info, trace, warn}; use super::state_compressor::CompressedStateEvent; use crate::{pdu, services, PduEvent}; -pub struct Service; +pub struct Service { + pub federation_handletime: StdRwLock, + pub mutex_federation: RoomMutexMap, +} + +type RoomMutexMap = MutexMap; +type HandleTimeMap = HashMap; // We use some AsyncRecursiveType hacks here so we can call async funtion // recursively. @@ -45,9 +54,32 @@ type AsyncRecursiveCanonicalJsonVec<'a> = type AsyncRecursiveCanonicalJsonResult<'a> = AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; -impl Service { - pub fn build(_server: &Arc, _db: &Arc) -> Result { Ok(Self {}) } +impl crate::Service for Service { + fn build(_args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + federation_handletime: HandleTimeMap::new().into(), + mutex_federation: RoomMutexMap::new(), + })) + } + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let mutex_federation = self.mutex_federation.len(); + writeln!(out, "federation_mutex: {mutex_federation}")?; + + let federation_handletime = self + .federation_handletime + .read() + .expect("locked for reading") + .len(); + writeln!(out, "federation_handletime: {federation_handletime}")?; + + Ok(()) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// When receiving an event one needs to: /// 0. Check the server is in the room /// 1. Skip the PDU if we already know about it @@ -180,14 +212,14 @@ impl Service { .globals .bad_event_ratelimiter .write() - .await + .expect("locked") .entry((*prev_id).to_owned()) { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1); + *e.get_mut() = (Instant::now(), e.get().1.saturating_add(1)); }, }; }, @@ -196,22 +228,18 @@ impl Service { // Done with prev events, now handling the incoming event let start_time = Instant::now(); - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() - .await + .expect("locked") .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); let r = self .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) .await; - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() - .await + .expect("locked") .remove(&room_id.to_owned()); r @@ -245,18 +273,16 @@ impl Service { .globals .bad_event_ratelimiter .read() - .await + .expect("locked") .get(prev_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_duration = cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - let duration = time.elapsed(); - - if duration < min_duration { + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { debug!( - duration = ?duration, - min_duration = ?min_duration, + ?tries, + duration = ?time.elapsed(), "Backing off from prev_event" ); return Ok(()); @@ -270,21 +296,17 @@ impl Service { } let start_time = Instant::now(); - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() - .await + .expect("locked") .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id, pub_key_map) .await?; - services() - .globals - .roomid_federationhandletime + self.federation_handletime .write() - .await + .expect("locked") .remove(&room_id.to_owned()); debug!( @@ -529,55 +551,50 @@ impl Service { // Soft fail check before doing state res debug!("Performing soft-fail check"); - let soft_fail = !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::, |k, s| { - auth_events.get(&(k.clone(), s.to_owned())) - }) - .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? - || incoming_pdu.kind == TimelineEventType::RoomRedaction - && match room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - if let Some(redact_id) = &incoming_pdu.redacts { - !services().rooms.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } - }, - _ => { - let content = serde_json::from_str::(incoming_pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + let soft_fail = { + use RoomVersionId::*; - if let Some(redact_id) = &content.redacts { - !services().rooms.state_accessor.user_can_redact( - redact_id, - &incoming_pdu.sender, - &incoming_pdu.room_id, - true, - )? - } else { - false - } - }, - }; + !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::, |k, s| { + auth_events.get(&(k.clone(), s.to_owned())) + }) + .map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))? + || incoming_pdu.kind == TimelineEventType::RoomRedaction + && match room_version_id { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { + if let Some(redact_id) = &incoming_pdu.redacts { + !services().rooms.state_accessor.user_can_redact( + redact_id, + &incoming_pdu.sender, + &incoming_pdu.room_id, + true, + )? + } else { + false + } + }, + _ => { + let content = serde_json::from_str::(incoming_pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + + if let Some(redact_id) = &content.redacts { + !services().rooms.state_accessor.user_can_redact( + redact_id, + &incoming_pdu.sender, + &incoming_pdu.room_id, + true, + )? + } else { + false + } + }, + } + }; // 13. Use state resolution to find new room state // We start looking at current room state now, so lets lock the room trace!("Locking the room"); - let state_lock = services().globals.roomid_mutex_state.lock(room_id).await; + let state_lock = services().rooms.state.mutex.lock(room_id).await; // Now we calculate the set of extremities this room has after the incoming // event has been applied. We start with the previous extremities (aka leaves) @@ -951,7 +968,7 @@ impl Service { /// Call /state_ids to find out what the state at this pdu is. We trust the /// server's response to some extend (sic), but we still do a lot of checks /// on the events - #[tracing::instrument(skip_all)] + #[tracing::instrument(skip(self, pub_key_map, create_event, room_version_id))] async fn fetch_state( &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, pub_key_map: &RwLock>>, event_id: &EventId, @@ -1043,7 +1060,7 @@ impl Service { .globals .bad_event_ratelimiter .write() - .await + .expect("locked") .entry(id) { hash_map::Entry::Vacant(e) => { @@ -1070,22 +1087,20 @@ impl Service { let mut todo_auth_events = vec![Arc::clone(id)]; let mut events_in_reverse_order = Vec::with_capacity(todo_auth_events.len()); let mut events_all = HashSet::with_capacity(todo_auth_events.len()); - let mut i = 0; + let mut i: u64 = 0; while let Some(next_id) = todo_auth_events.pop() { if let Some((time, tries)) = services() .globals .bad_event_ratelimiter .read() - .await + .expect("locked") .get(&*next_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = - cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", next_id); + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + info!("Backing off from {next_id}"); continue; } } @@ -1094,7 +1109,7 @@ impl Service { continue; } - i += 1; + i = i.saturating_add(1); if i % 100 == 0 { tokio::task::yield_now().await; } @@ -1184,16 +1199,14 @@ impl Service { .globals .bad_event_ratelimiter .read() - .await + .expect("locked") .get(&**next_id) { // Exponential backoff - const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24); - let min_elapsed_duration = - cmp::min(MAX_DURATION, Duration::from_secs(5 * 60) * (*tries) * (*tries)); - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", next_id); + const MIN_DURATION: u64 = 5 * 60; + const MAX_DURATION: u64 = 60 * 60 * 24; + if continue_exponential_backoff_secs(MIN_DURATION, MAX_DURATION, time.elapsed(), *tries) { + debug!("Backing off from {next_id}"); continue; } } @@ -1369,15 +1382,13 @@ impl Service { } fn get_room_version_id(create_event: &PduEvent) -> Result { - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - error!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; + let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()) + .map_err(|e| err!(Database("Invalid create event: {e}")))?; Ok(create_event_content.room_version) } + #[inline] fn to_room_version(room_version_id: &RoomVersionId) -> RoomVersion { RoomVersion::new(room_version_id).expect("room version is supported") } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 4c907e51..8fcd8549 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -1,4 +1,4 @@ -use conduit::{Error, Result}; +use conduit::{Err, Error, Result}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId}; use serde_json::value::RawValue as RawJsonValue; use tracing::warn; @@ -17,15 +17,12 @@ pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, Canonical .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else { - return Err(Error::Err(format!("Server is not in room {room_id}"))); + return Err!("Server is not in room {room_id}"); }; let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); + return Err!(Request(InvalidParam("Could not convert event to canonical json."))); }; Ok((event_id, value, room_id)) diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs index d22af9bd..2fa5b0df 100644 --- a/src/service/rooms/event_handler/signing_keys.rs +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -3,6 +3,7 @@ use std::{ time::{Duration, SystemTime}, }; +use conduit::{debug, error, info, trace, warn}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::{ @@ -19,7 +20,6 @@ use ruma::{ }; use serde_json::value::RawValue as RawJsonValue; use tokio::sync::{RwLock, RwLockWriteGuard}; -use tracing::{debug, error, info, trace, warn}; use crate::{services, Error, Result}; @@ -201,6 +201,7 @@ impl super::Service { let result = services() .globals + .db .add_signing_key(&k.server_name, k.clone())? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) @@ -249,6 +250,7 @@ impl super::Service { if let Ok(key) = get_keys_response.server_key.deserialize() { let result: BTreeMap<_, _> = services() .globals + .db .add_signing_key(&origin, key)? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) @@ -392,7 +394,7 @@ impl super::Service { }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { - services().globals.add_signing_key(origin, k.clone())?; + services().globals.db.add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() @@ -421,6 +423,7 @@ impl super::Service { { services() .globals + .db .add_signing_key(origin, server_key.clone())?; result.extend( @@ -453,6 +456,7 @@ impl super::Service { { services() .globals + .db .add_signing_key(origin, server_key.clone())?; result.extend( @@ -499,7 +503,7 @@ impl super::Service { }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { - services().globals.add_signing_key(origin, k.clone())?; + services().globals.db.add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 1f2ae6dd..96f623f2 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -2,14 +2,12 @@ mod data; use std::{ collections::{HashMap, HashSet}, - sync::Arc, + fmt::Write, + sync::{Arc, Mutex}, }; -use conduit::Server; use data::Data; -use database::Database; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; -use tokio::sync::Mutex; use crate::{PduCount, Result}; @@ -20,15 +18,28 @@ pub struct Service { pub lazy_load_waiting: Mutex>>, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), lazy_load_waiting: Mutex::new(HashMap::new()), - }) + })) } - #[tracing::instrument(skip(self))] + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let lazy_load_waiting = self.lazy_load_waiting.lock().expect("locked").len(); + writeln!(out, "lazy_load_waiting: {lazy_load_waiting}")?; + + Ok(()) + } + + fn clear_cache(&self) { self.lazy_load_waiting.lock().expect("locked").clear(); } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip(self), level = "debug")] pub fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result { @@ -36,22 +47,22 @@ impl Service { .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn lazy_load_mark_sent( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, count: PduCount, ) { self.lazy_load_waiting .lock() - .await + .expect("locked") .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn lazy_load_confirm_delivery( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&( + if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&( user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), @@ -66,7 +77,7 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { self.db.lazy_load_reset(user_id, device_id, room_id) } diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index df9b6b10..763dd0e8 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -48,10 +48,12 @@ impl Data { })) } + #[inline] pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result { Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) } + #[inline] pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { if disabled { self.disabledroomids.insert(room_id.as_bytes(), &[])?; @@ -62,10 +64,12 @@ impl Data { Ok(()) } + #[inline] pub(super) fn is_banned(&self, room_id: &RoomId) -> Result { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) } + #[inline] pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { if banned { self.bannedroomids.insert(room_id.as_bytes(), &[])?; diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index b91fc67e..ec34a82c 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -2,39 +2,47 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{OwnedRoomId, RoomId}; pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Checks if a room exists. - #[tracing::instrument(skip(self))] + #[inline] pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } #[must_use] pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } + #[inline] pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } + #[inline] pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { self.db.disable_room(room_id, disabled) } + #[inline] pub fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } + #[inline] pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } + #[inline] #[must_use] pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { self.db.list_banned_rooms() diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index bef56a25..ef50b094 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -19,25 +19,27 @@ pub mod timeline; pub mod typing; pub mod user; +use std::sync::Arc; + pub struct Service { - pub alias: alias::Service, - pub auth_chain: auth_chain::Service, - pub directory: directory::Service, - pub event_handler: event_handler::Service, - pub lazy_loading: lazy_loading::Service, - pub metadata: metadata::Service, - pub outlier: outlier::Service, - pub pdu_metadata: pdu_metadata::Service, - pub read_receipt: read_receipt::Service, - pub search: search::Service, - pub short: short::Service, - pub state: state::Service, - pub state_accessor: state_accessor::Service, - pub state_cache: state_cache::Service, - pub state_compressor: state_compressor::Service, - pub timeline: timeline::Service, - pub threads: threads::Service, - pub typing: typing::Service, - pub spaces: spaces::Service, - pub user: user::Service, + pub alias: Arc, + pub auth_chain: Arc, + pub directory: Arc, + pub event_handler: Arc, + pub lazy_loading: Arc, + pub metadata: Arc, + pub outlier: Arc, + pub pdu_metadata: Arc, + pub read_receipt: Arc, + pub search: Arc, + pub short: Arc, + pub state: Arc, + pub state_accessor: Arc, + pub state_cache: Arc, + pub state_compressor: Arc, + pub timeline: Arc, + pub threads: Arc, + pub typing: Arc, + pub spaces: Arc, + pub user: Arc, } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 0da12a14..22bd2092 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -2,9 +2,8 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{CanonicalJsonObject, EventId}; use crate::PduEvent; @@ -13,13 +12,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Returns the pdu from the outlier tree. pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu_json(event_id) @@ -32,7 +35,7 @@ impl Service { pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu(event_id) } /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu))] + #[tracing::instrument(skip(self, pdu), level = "debug")] pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { self.db.add_pdu_outlier(event_id, pdu) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 978a0ec0..05067aa8 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -2,9 +2,8 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, @@ -28,14 +27,18 @@ struct ExtractRelatesToEventId { relates_to: ExtractRelType, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } - #[tracing::instrument(skip(self, from, to))] + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip(self, from, to), level = "debug")] pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { match (from, to) { (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), @@ -202,7 +205,7 @@ impl Service { if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { for relation in relations.flatten() { if stack_pdu.1 < max_depth { - stack.push((relation.clone(), stack_pdu.1 + 1)); + stack.push((relation.clone(), stack_pdu.1.saturating_add(1))); } pdus.push(relation); @@ -215,19 +218,19 @@ impl Service { }) } - #[tracing::instrument(skip(self, room_id, event_ids))] + #[tracing::instrument(skip_all, level = "debug")] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { self.db.mark_as_referenced(room_id, event_ids) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { self.db.is_event_referenced(room_id, event_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.db.is_event_soft_failed(event_id) } } diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 17acb0b3..06eaf655 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -76,10 +76,12 @@ impl Data { .iter_from(&first_possible_edu, false) .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(move |(k, v)| { - let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + size_of::()]) + let count_offset = prefix.len().saturating_add(size_of::()); + let count = utils::u64_from_bytes(&k[prefix.len()..count_offset]) .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id_offset = count_offset.saturating_add(1); let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + size_of::() + 1..]) + utils::string_from_bytes(&k[user_id_offset..]) .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, ) .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index ccc17d3a..9375276e 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -2,9 +2,8 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; use crate::services; @@ -13,13 +12,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Replaces the previous read receipt. pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> { self.db.readreceipt_update(user_id, room_id, event)?; @@ -30,7 +33,7 @@ impl Service { /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn readreceipts_since<'a>( &'a self, room_id: &RoomId, since: u64, ) -> impl Iterator)>> + 'a { @@ -38,13 +41,13 @@ impl Service { } /// Sets a private read marker at `count`. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { self.db.private_read_set(room_id, user_id, count) } /// Returns the private read marker. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.private_read_get(room_id, user_id) } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index e2ef0f80..082dd432 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -2,33 +2,36 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::RoomId; pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } - #[tracing::instrument(skip(self))] + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip(self), level = "debug")] pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { self.db.index_pdu(shortroomid, pdu_id, message_body) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { self.db.deindex_pdu(shortroomid, pdu_id, message_body) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, ) -> Result> + 'a, Vec)>> { diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 638b6c1c..0979fb4f 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -2,22 +2,25 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{events::StateEventType, EventId, RoomId}; pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { self.db.get_or_create_shorteventid(event_id) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 6a87ac98..02db7fab 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -7,8 +7,7 @@ use std::{ sync::Arc, }; -use conduit::{debug_info, Server}; -use database::Database; +use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result}; use lru_cache::LruCache; use ruma::{ api::{ @@ -28,9 +27,8 @@ use ruma::{ OwnedRoomId, OwnedServerName, RoomId, ServerName, UInt, UserId, }; use tokio::sync::Mutex; -use tracing::{debug, error, warn}; -use crate::{services, Error, Result}; +use crate::services; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -159,17 +157,20 @@ impl From for SpaceHierarchyRoomsChunk { } } -impl Service { - pub fn build(server: &Arc, _db: &Arc) -> Result { - let config = &server.config; - Ok(Self { - roomid_spacehierarchy_cache: Mutex::new(LruCache::new( - (f64::from(config.roomid_spacehierarchy_cache_capacity) * config.conduit_cache_capacity_modifier) - as usize, - )), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity); + let cache_size = cache_size * config.cache_capacity_modifier; + Ok(Arc::new(Self { + roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Gets the response for the space hierarchy over federation request /// /// Errors if the room does not exist, so a check if the room exists should @@ -378,10 +379,7 @@ impl Service { .map(|s| { serde_json::from_str(s.content.get()) .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) + .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) }) .transpose()? .unwrap_or(JoinRule::Invite); @@ -444,7 +442,7 @@ impl Service { } pub async fn get_client_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: usize, + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: u64, suggested_only: bool, ) -> Result { let mut parents = VecDeque::new(); @@ -505,12 +503,14 @@ impl Service { } // We have reached the room after where we last left off - if parents.len() + 1 == short_room_ids.len() { + let parents_len = parents.len(); + if checked!(parents_len + 1)? == short_room_ids.len() { populate_results = true; } } - if !children.is_empty() && parents.len() < max_depth { + let parents_len: u64 = parents.len().try_into()?; + if !children.is_empty() && parents_len < max_depth { parents.push_back(current_room.clone()); stack.push(children); } @@ -545,9 +545,8 @@ impl Service { Some( PaginationToken { short_room_ids, - limit: UInt::new(max_depth as u64).expect("When sent in request it must have been valid UInt"), - max_depth: UInt::new(max_depth as u64) - .expect("When sent in request it must have been valid UInt"), + limit: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"), + max_depth: UInt::new(max_depth).expect("When sent in request it must have been valid UInt"), suggested_only, } .to_string(), diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index dad613d2..3c110afc 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -3,7 +3,8 @@ use std::{collections::HashSet, sync::Arc}; use conduit::{utils, Error, Result}; use database::{Database, Map}; use ruma::{EventId, OwnedEventId, RoomId}; -use utils::mutex_map; + +use super::RoomMutexGuard; pub(super) struct Data { shorteventid_shortstatehash: Arc, @@ -30,11 +31,12 @@ impl Data { }) } + #[inline] pub(super) fn set_room_state( &self, room_id: &RoomId, new_shortstatehash: u64, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; @@ -67,7 +69,7 @@ impl Data { &self, room_id: &RoomId, event_ids: Vec, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 42483c97..a3a317a5 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -2,15 +2,15 @@ mod data; use std::{ collections::{HashMap, HashSet}, + fmt::Write, sync::Arc, }; use conduit::{ - utils::{calculate_hash, mutex_map}, - warn, Error, Result, Server, + utils::{calculate_hash, MutexMap, MutexMapGuard}, + warn, Error, Result, }; use data::Data; -use database::Database; use ruma::{ api::client::error::ErrorKind, events::{ @@ -19,7 +19,7 @@ use ruma::{ }, serde::Raw, state_res::{self, StateMap}, - EventId, OwnedEventId, RoomId, RoomVersionId, UserId, + EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, }; use super::state_compressor::CompressedStateEvent; @@ -27,15 +27,31 @@ use crate::{services, PduEvent}; pub struct Service { db: Data, + pub mutex: RoomMutexMap, +} + +type RoomMutexMap = MutexMap; +pub type RoomMutexGuard = MutexMapGuard; + +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + mutex: RoomMutexMap::new(), + })) + } + + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let mutex = self.mutex.len(); + writeln!(out, "state_mutex: {mutex}")?; + + Ok(()) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) - } - /// Set the room to the given statehash and update caches. pub async fn force_state( &self, @@ -43,7 +59,7 @@ impl Service { shortstatehash: u64, statediffnew: Arc>, _statediffremoved: Arc>, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { services() @@ -113,7 +129,7 @@ impl Service { /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, state_ids_compressed))] + #[tracing::instrument(skip(self, state_ids_compressed), level = "debug")] pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { @@ -181,7 +197,7 @@ impl Service { /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, new_pdu))] + #[tracing::instrument(skip(self, new_pdu), level = "debug")] pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { let shorteventid = services() .rooms @@ -197,6 +213,7 @@ impl Service { if let Some(state_key) = &new_pdu.state_key { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), + #[inline] |p| { services() .rooms @@ -253,7 +270,7 @@ impl Service { } } - #[tracing::instrument(skip(self, invite_event))] + #[tracing::instrument(skip(self, invite_event), level = "debug")] pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { let mut state = Vec::new(); // Add recommended events @@ -309,18 +326,18 @@ impl Service { } /// Set the state hash to a new version, but does not update state_cache. - #[tracing::instrument(skip(self, mutex_lock))] + #[tracing::instrument(skip(self, mutex_lock), level = "debug")] pub fn set_room_state( &self, room_id: &RoomId, shortstatehash: u64, - mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db.set_room_state(room_id, shortstatehash, mutex_lock) } /// Returns the room's version. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_room_version(&self, room_id: &RoomId) -> Result { let create_event = services() .rooms @@ -341,6 +358,7 @@ impl Service { Ok(create_event_content.room_version) } + #[inline] pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.db.get_room_shortstatehash(room_id) } @@ -353,14 +371,14 @@ impl Service { &self, room_id: &RoomId, event_ids: Vec, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.db .set_forward_extremities(room_id, event_ids, state_lock) } /// This fetches auth events from the current state. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_auth_events( &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 03e22187..bd3eb0a1 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -2,12 +2,12 @@ mod data; use std::{ collections::HashMap, + fmt::Write, sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{error, utils::mutex_map, warn, Error, Result, Server}; +use conduit::{err, error, utils::math::usize_from_f64, warn, Error, Result}; use data::Data; -use database::Database; use lru_cache::LruCache; use ruma::{ events::{ @@ -33,7 +33,7 @@ use ruma::{ }; use serde_json::value::to_raw_value; -use crate::{pdu::PduBuilder, services, PduEvent}; +use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, PduEvent}; pub struct Service { db: Data, @@ -41,23 +41,43 @@ pub struct Service { pub user_visibility_cache: Mutex>, } -impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result { - let config = &server.config; - Ok(Self { - db: Data::new(db), - server_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - user_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let server_visibility_cache_capacity = + f64::from(config.server_visibility_cache_capacity) * config.cache_capacity_modifier; + let user_visibility_cache_capacity = + f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier; + + Ok(Arc::new(Self { + db: Data::new(args.db), + server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)), + user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)), + })) } + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let server_visibility_cache = self.server_visibility_cache.lock().expect("locked").len(); + writeln!(out, "server_visibility_cache: {server_visibility_cache}")?; + + let user_visibility_cache = self.user_visibility_cache.lock().expect("locked").len(); + writeln!(out, "user_visibility_cache: {user_visibility_cache}")?; + + Ok(()) + } + + fn clear_cache(&self) { + self.server_visibility_cache.lock().expect("locked").clear(); + self.user_visibility_cache.lock().expect("locked").clear(); + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { self.db.state_full_ids(shortstatehash).await } @@ -68,7 +88,7 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { @@ -77,6 +97,7 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). + #[inline] pub fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { @@ -94,6 +115,7 @@ impl Service { } /// The user was a joined member at this state (potentially in the past) + #[inline] fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { self.user_membership(shortstatehash, user_id) .is_ok_and(|s| s == MembershipState::Join) @@ -103,6 +125,7 @@ impl Service { /// The user was an invited or joined room member at this state (potentially /// in the past) + #[inline] fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { self.user_membership(shortstatehash, user_id) .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) @@ -259,14 +282,14 @@ impl Service { pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.db.pdu_shortstatehash(event_id) } /// Returns the full room state. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub async fn room_state_full(&self, room_id: &RoomId) -> Result>> { self.db.room_state_full(room_id).await } /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { @@ -275,7 +298,7 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { @@ -306,7 +329,7 @@ impl Service { } pub fn user_can_invite( - &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &mutex_map::Guard<()>, + &self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard, ) -> Result { let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite)) .expect("Event content always serializes"); @@ -431,10 +454,7 @@ impl Service { .map(|c: RoomJoinRulesEventContent| { (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) }) - .map_err(|e| { - error!("Invalid room join rule event in database: {e}"); - Error::BadDatabase("Invalid room join rule event in database.") - }) + .map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}")))) }) .transpose()? .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) @@ -460,10 +480,8 @@ impl Service { Ok(self .room_state_get(room_id, &StateEventType::RoomCreate, "")? .map(|s| { - serde_json::from_str::(s.content.get()).map_err(|e| { - error!("Invalid room create event in database: {e}"); - Error::BadDatabase("Invalid room create event in database.") - }) + serde_json::from_str::(s.content.get()) + .map_err(|e| err!(Database(error!("Invalid room create event in database: {e}")))) }) .transpose()? .and_then(|e| e.room_type)) @@ -476,10 +494,7 @@ impl Service { .map_or(Ok(None), |s| { serde_json::from_str::(s.content.get()) .map(|content| Some(content.algorithm)) - .map_err(|e| { - error!("Invalid room encryption event in database: {e}"); - Error::BadDatabase("Invalid room encryption event in database.") - }) + .map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}")))) }) } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index f79ee678..2b9fbe94 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -213,7 +213,7 @@ impl Data { Ok(()) } - #[tracing::instrument(skip(self, room_id, appservice))] + #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { let maybe = self .appservice_in_room_cache @@ -249,7 +249,7 @@ impl Data { } /// Makes a user forget a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -266,7 +266,7 @@ impl Data { } /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_servers<'a>( &'a self, room_id: &RoomId, ) -> Box> + 'a> { @@ -286,7 +286,7 @@ impl Data { })) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { let mut key = server.as_bytes().to_vec(); key.push(0xFF); @@ -297,7 +297,7 @@ impl Data { /// Returns an iterator of all rooms a server participates in (as far as we /// know). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn server_rooms<'a>( &'a self, server: &ServerName, ) -> Box> + 'a> { @@ -318,8 +318,10 @@ impl Data { } /// Returns an iterator of all joined members of a room. - #[tracing::instrument(skip(self))] - pub(super) fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + #[tracing::instrument(skip(self), level = "debug")] + pub(super) fn room_members<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + Send + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -348,7 +350,7 @@ impl Data { /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn active_local_users_in_room<'a>( &'a self, room_id: &RoomId, ) -> Box + 'a> { @@ -359,7 +361,7 @@ impl Data { } /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.roomid_joinedcount .get(room_id.as_bytes())? @@ -368,7 +370,7 @@ impl Data { } /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.roomid_invitedcount .get(room_id.as_bytes())? @@ -377,7 +379,7 @@ impl Data { } /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_useroncejoined<'a>( &'a self, room_id: &RoomId, ) -> Box> + 'a> { @@ -402,7 +404,7 @@ impl Data { } /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn room_members_invited<'a>( &'a self, room_id: &RoomId, ) -> Box> + 'a> { @@ -426,7 +428,7 @@ impl Data { ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); @@ -441,7 +443,7 @@ impl Data { }) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); @@ -454,7 +456,7 @@ impl Data { } /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box> + '_> { Box::new( self.userroomid_joined @@ -474,7 +476,7 @@ impl Data { } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -501,7 +503,7 @@ impl Data { ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn invite_state( &self, user_id: &UserId, room_id: &RoomId, ) -> Result>>> { @@ -520,7 +522,7 @@ impl Data { .transpose() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn left_state( &self, user_id: &UserId, room_id: &RoomId, ) -> Result>>> { @@ -540,7 +542,7 @@ impl Data { } /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -567,7 +569,7 @@ impl Data { ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -576,7 +578,7 @@ impl Data { Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -585,7 +587,7 @@ impl Data { Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -594,7 +596,7 @@ impl Data { Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); @@ -603,7 +605,7 @@ impl Data { Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn servers_invite_via<'a>( &'a self, room_id: &RoomId, ) -> Box> + 'a> { @@ -629,7 +631,7 @@ impl Data { ) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { let mut prev_servers = self .servers_invite_via(room_id) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 5038ef1c..48215817 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -2,9 +2,8 @@ mod data; use std::sync::Arc; -use conduit::{error, warn, Error, Result, Server}; +use conduit::{err, error, warn, Error, Result}; use data::Data; -use database::Database; use itertools::Itertools; use ruma::{ events::{ @@ -28,13 +27,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] @@ -125,10 +128,8 @@ impl Service { .account_data .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? .map(|event| { - serde_json::from_str(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { services() .account_data @@ -141,10 +142,8 @@ impl Service { .account_data .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str::(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) { let mut direct_event = direct_event?; let mut room_ids_updated = false; @@ -182,10 +181,8 @@ impl Service { .into(), )? .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) + serde_json::from_str::(event.get()) + .map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}")))) }) .transpose()? .map_or(false, |ignored| { @@ -216,10 +213,10 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, room_id))] + #[tracing::instrument(skip(self, room_id), level = "debug")] pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id, appservice))] + #[tracing::instrument(skip(self, room_id, appservice), level = "debug")] pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { self.db.appservice_in_room(room_id, appservice) } @@ -227,7 +224,7 @@ impl Service { /// Direct DB function to directly mark a user as left. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.mark_as_left(user_id, room_id) } @@ -235,35 +232,35 @@ impl Service { /// Direct DB function to directly mark a user as joined. It is not /// recommended to use this directly. You most likely should use /// `update_membership` instead - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.mark_as_joined(user_id, room_id) } /// Makes a user forget a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator> + '_ { self.db.room_servers(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { self.db.server_in_room(server, room_id) } /// Returns an iterator of all rooms a server participates in (as far as we /// know). - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn server_rooms(&self, server: &ServerName) -> impl Iterator> + '_ { self.db.server_rooms(server) } /// Returns true if server can see user by sharing at least one room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result { Ok(self .server_rooms(server) @@ -272,7 +269,7 @@ impl Service { } /// Returns true if user_a and user_b share at least one room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result { // Minimize number of point-queries by iterating user with least nr rooms let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() { @@ -288,23 +285,23 @@ impl Service { } /// Returns an iterator over all joined members of a room. - #[tracing::instrument(skip(self))] - pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> + '_ { + #[tracing::instrument(skip(self), level = "debug")] + pub fn room_members(&self, room_id: &RoomId) -> impl Iterator> + Send + '_ { self.db.room_members(room_id) } /// Returns the number of users which are currently in a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { self.db.local_users_in_room(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator + 'a { @@ -312,80 +309,80 @@ impl Service { } /// Returns the number of users which are currently invited to a room - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator> + '_ { self.db.room_useroncejoined(room_id) } /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator> + '_ { self.db.room_members_invited(room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.get_invite_count(room_id, user_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { self.db.get_left_count(room_id, user_id) } /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator> + '_ { self.db.rooms_joined(user_id) } /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn rooms_invited( &self, user_id: &UserId, ) -> impl Iterator>)>> + '_ { self.db.rooms_invited(user_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { self.db.invite_state(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { self.db.left_state(user_id, room_id) } /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn rooms_left( &self, user_id: &UserId, ) -> impl Iterator>)>> + '_ { self.db.rooms_left(user_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.once_joined(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_joined(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_invited(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator> + '_ { self.db.servers_invite_via(room_id) } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 61c7d6e6..33773001 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; +use conduit::{checked, utils, Error, Result}; use database::{Database, Map}; use super::CompressedStateEvent; @@ -38,11 +38,12 @@ impl Data { let mut added = HashSet::new(); let mut removed = HashSet::new(); - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { + let stride = size_of::(); + let mut i = stride; + while let Some(v) = value.get(i..checked!(i + 2 * stride)?) { if add_mode && v.starts_with(&0_u64.to_be_bytes()) { add_mode = false; - i += size_of::(); + i = checked!(i + stride)?; continue; } if add_mode { @@ -50,7 +51,7 @@ impl Data { } else { removed.insert(v.try_into().expect("we checked the size above")); } - i += 2 * size_of::(); + i = checked!(i + 2 * stride)?; } Ok(StateDiff { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 08112995..422c562b 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -2,13 +2,13 @@ mod data; use std::{ collections::HashSet, + fmt::Write, mem::size_of, sync::{Arc, Mutex as StdMutex, Mutex}, }; -use conduit::{utils, Result, Server}; +use conduit::{checked, utils, utils::math::usize_from_f64, Result}; use data::Data; -use database::Database; use lru_cache::LruCache; use ruma::{EventId, RoomId}; @@ -52,20 +52,32 @@ pub struct Service { pub stateinfo_cache: StateInfoLruCache, } -impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result { - let config = &server.config; - Ok(Self { - db: Data::new(db), - stateinfo_cache: StdMutex::new(LruCache::new( - (f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; + Ok(Arc::new(Self { + db: Data::new(args.db), + stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), + })) } + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let stateinfo_cache = self.stateinfo_cache.lock().expect("locked").len(); + writeln!(out, "stateinfo_cache: {stateinfo_cache}")?; + + Ok(()) + } + + fn clear_cache(&self) { self.stateinfo_cache.lock().expect("locked").clear(); } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { if let Some(r) = self .stateinfo_cache @@ -122,6 +134,7 @@ impl Service { } /// Returns shortstatekey, event id + #[inline] pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { Ok(( utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), @@ -149,18 +162,20 @@ impl Service { /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer - #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states))] + #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states), level = "debug")] pub fn save_state_from_diff( &self, shortstatehash: u64, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); + let statediffnew_len = statediffnew.len(); + let statediffremoved_len = statediffremoved.len(); + let diffsum = checked!(statediffnew_len + statediffremoved_len)?; if parent_states.len() > 3 { // Number of layers // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); + let parent = parent_states.pop().expect("parent must have a state"); let mut parent_new = (*parent.2).clone(); let mut parent_removed = (*parent.3).clone(); @@ -212,10 +227,12 @@ impl Service { // 1. We add the current diff on top of the parent layer. // 2. We replace a layer above - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); + let parent = parent_states.pop().expect("parent must have a state"); + let parent_2_len = parent.2.len(); + let parent_3_len = parent.3.len(); + let parent_diff = checked!(parent_2_len + parent_3_len)?; - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { // Diff too big, we replace above layer(s) let mut parent_new = (*parent.2).clone(); let mut parent_removed = (*parent.3).clone(); diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 29539847..c4a1a294 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,6 +1,6 @@ use std::{mem::size_of, sync::Arc}; -use conduit::{utils, Error, Result}; +use conduit::{checked, utils, Error, Result}; use database::{Database, Map}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; @@ -31,7 +31,7 @@ impl Data { .to_vec(); let mut current = prefix.clone(); - current.extend_from_slice(&(until - 1).to_be_bytes()); + current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); Ok(Box::new( self.threadid_userids diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index f47fc499..dd2686b0 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,9 +2,8 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{Error, Result, Server}; +use conduit::{Error, Result}; use data::Data; -use database::Database; use ruma::{ api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, events::relation::BundledThread, @@ -18,13 +17,17 @@ pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, ) -> Result> + 'a> { @@ -61,7 +64,7 @@ impl Service { .and_then(|relations| serde_json::from_value::(relations.clone().into()).ok()) { // Thread already existed - relations.count += uint!(1); + relations.count = relations.count.saturating_add(uint!(1)); relations.latest_event = pdu.to_message_like_event(); let content = serde_json::to_value(relations).expect("to_value always works"); diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 2054caf7..ec975b99 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -4,7 +4,7 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{error, utils, Error, Result}; +use conduit::{checked, error, utils, Error, Result}; use database::{Database, Map}; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; @@ -97,6 +97,7 @@ impl Data { } /// Returns the pdu's id. + #[inline] pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } @@ -280,10 +281,12 @@ impl Data { /// Returns the `count` of this pdu's id. pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { - let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + let stride = size_of::(); + let pdu_id_len = pdu_id.len(); + let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; let second_last_u64 = - utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()]); + utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]); if matches!(second_last_u64, Ok(0)) { Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index b9c0d7ba..0bc5ade1 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,15 +1,18 @@ mod data; use std::{ - collections::{BTreeMap, HashMap, HashSet}, + collections::{BTreeMap, HashSet}, + fmt::Write, sync::Arc, }; -use conduit::{debug, error, info, utils, utils::mutex_map, warn, Error, Result, Server}; +use conduit::{ + debug, error, info, utils, + utils::{MutexMap, MutexMapGuard}, + validated, warn, Error, Result, +}; use data::Data; -use database::Database; use itertools::Itertools; -use rand::prelude::SliceRandom; use ruma::{ api::{client::error::ErrorKind, federation}, canonical_json::to_canonical_value, @@ -32,7 +35,7 @@ use ruma::{ }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::RwLock; use crate::{ admin, @@ -66,19 +69,48 @@ struct ExtractBody { pub struct Service { db: Data, + pub mutex_insert: RoomMutexMap, +} - pub lasttimelinecount_cache: Mutex>, +type RoomMutexMap = MutexMap; +pub type RoomMutexGuard = MutexMapGuard; + +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + mutex_insert: RoomMutexMap::new(), + })) + } + + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let lasttimelinecount_cache = self + .db + .lasttimelinecount_cache + .lock() + .expect("locked") + .len(); + writeln!(out, "lasttimelinecount_cache: {lasttimelinecount_cache}")?; + + let mutex_insert = self.mutex_insert.len(); + writeln!(out, "insert_mutex: {mutex_insert}")?; + + Ok(()) + } + + fn clear_cache(&self) { + self.db + .lasttimelinecount_cache + .lock() + .expect("locked") + .clear(); + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }) - } - - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? .next() @@ -86,7 +118,7 @@ impl Service { .transpose() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result>> { self.all_pdus(user_id!("@placeholder:conduwuit.placeholder"), room_id)? .last() @@ -94,7 +126,7 @@ impl Service { .transpose() } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { self.db.last_timeline_count(sender_user, room_id) } @@ -125,40 +157,19 @@ impl Service { } */ - /// Returns the version of a room, if known - /// - /// TODO: use this? - #[allow(dead_code)] - pub fn get_room_version(&self, room_id: &RoomId) -> Result> { - let create_event = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - Ok(create_event_content.map(|content| content.room_version)) - } - /// Returns the json of a pdu. pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_pdu_json(event_id) } /// Returns the json of a pdu. + #[inline] pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_non_outlier_pdu_json(event_id) } /// Returns the pdu's id. + #[inline] pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.db.get_pdu_id(event_id) } @@ -166,9 +177,7 @@ impl Service { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - /// - /// TODO: use this? - #[allow(dead_code)] + #[inline] pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.db.get_non_outlier_pdu(event_id) } @@ -189,7 +198,7 @@ impl Service { } /// Removes a pdu and creates a new one with the same id. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { self.db.replace_pdu(pdu_id, pdu_json, pdu) } @@ -206,7 +215,7 @@ impl Service { pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, leaves: Vec, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { // Coalesce database writes for the remainder of this scope. let _cork = services().db.cork_and_flush(); @@ -271,11 +280,7 @@ impl Service { .state .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; - let insert_lock = services() - .globals - .roomid_mutex_insert - .lock(&pdu.room_id) - .await; + let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; let count1 = services().globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if @@ -384,18 +389,11 @@ impl Service { match pdu.kind { TimelineEventType::RoomRedaction => { + use RoomVersionId::*; + let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; match room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if services().rooms.state_accessor.user_can_redact( redact_id, @@ -407,7 +405,7 @@ impl Service { } } }, - RoomVersionId::V11 => { + V11 => { let content = serde_json::from_str::(pdu.content.get()).map_err(|e| { warn!("Invalid content in redaction pdu: {e}"); @@ -596,7 +594,7 @@ impl Service { pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result<(PduEvent, CanonicalJsonObject)> { let PduBuilder { event_type, @@ -646,7 +644,7 @@ impl Service { .filter_map(|event_id| Some(self.get_pdu(event_id).ok()??.depth)) .max() .unwrap_or_else(|| uint!(0)) - + uint!(1); + .saturating_add(uint!(1)); let mut unsigned = unsigned.unwrap_or_default(); @@ -783,7 +781,7 @@ impl Service { pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result> { let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; if let Some(admin_room) = admin::Service::get_admin_room()? { @@ -863,17 +861,9 @@ impl Service { // If redaction event is not authorized, do not append it to the timeline if pdu.kind == TimelineEventType::RoomRedaction { + use RoomVersionId::*; match services().rooms.state.get_room_version(&pdu.room_id)? { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { + V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => { if let Some(redact_id) = &pdu.redacts { if !services().rooms.state_accessor.user_can_redact( redact_id, @@ -966,7 +956,7 @@ impl Service { new_room_leaves: Vec, state_ids_compressed: Arc>, soft_fail: bool, - state_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex + state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex ) -> Result>> { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't @@ -996,6 +986,7 @@ impl Service { } /// Returns an iterator over all PDUs in a room. + #[inline] pub fn all_pdus<'a>( &'a self, user_id: &UserId, room_id: &RoomId, ) -> Result> + 'a> { @@ -1005,7 +996,7 @@ impl Service { /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, ) -> Result> + 'a> { @@ -1014,7 +1005,7 @@ impl Service { /// Returns an iterator over all events and their token in a room that /// happened after the event with id `from` in chronological order. - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, ) -> Result> + 'a> { @@ -1056,7 +1047,7 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, room_id))] + #[tracing::instrument(skip(self))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { let first_pdu = self .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? @@ -1068,41 +1059,6 @@ impl Service { return Ok(()); } - let mut servers: Vec = vec![]; - - // add server names of any trusted key servers if they're in the room - servers.extend( - services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(Result::ok) - .filter(|server_name| { - services().globals.trusted_servers().contains(server_name) && !server_is_ours(server_name) - }), - ); - - // add server names from room aliases on the room ID - let room_aliases = services() - .rooms - .alias - .local_aliases_for_room(room_id) - .collect::, _>>(); - if let Ok(aliases) = &room_aliases { - for alias in aliases { - if !server_is_ours(alias.server_name()) { - servers.push(alias.server_name().to_owned()); - } - } - } - - // add room ID server name for backfill server - if let Some(server_name) = room_id.server_name() { - if !server_is_ours(server_name) { - servers.push(server_name.to_owned()); - } - } - let power_levels: RoomPowerLevelsEventContent = services() .rooms .state_accessor @@ -1114,29 +1070,39 @@ impl Service { .transpose()? .unwrap_or_default(); - // add server names of the list of admins in the room for backfill server - servers.extend( - power_levels - .users - .iter() - .filter(|(_, level)| **level > power_levels.users_default) - .map(|(user_id, _)| user_id.server_name()) - .filter(|server_name| !server_is_ours(server_name)) - .map(ToOwned::to_owned), - ); + let room_mods = power_levels.users.iter().filter_map(|(user_id, level)| { + if level > &power_levels.users_default && !server_is_ours(user_id.server_name()) { + Some(user_id.server_name().to_owned()) + } else { + None + } + }); - // don't backfill from ourselves (might be noop if we checked it above already) - if let Some(server_index) = servers - .clone() - .into_iter() - .position(|server_name| server_is_ours(&server_name)) - { - servers.swap_remove(server_index); - } + let room_alias_servers = services() + .rooms + .alias + .local_aliases_for_room(room_id) + .filter_map(|alias| { + alias + .ok() + .filter(|alias| !server_is_ours(alias.server_name())) + .map(|alias| alias.server_name().to_owned()) + }); - servers.sort_unstable(); - servers.dedup(); - servers.shuffle(&mut rand::thread_rng()); + let servers = room_mods + .chain(room_alias_servers) + .chain(services().globals.config.trusted_servers.clone()) + .filter(|server_name| { + if server_is_ours(server_name) { + return false; + } + + services() + .rooms + .state_cache + .server_in_room(server_name, room_id) + .unwrap_or(false) + }); for backfill_server in servers { info!("Asking {backfill_server} for backfill"); @@ -1180,8 +1146,9 @@ impl Service { // Lock so we cannot backfill the same pdu twice at the same time let mutex_lock = services() - .globals - .roomid_mutex_federation + .rooms + .event_handler + .mutex_federation .lock(&room_id) .await; @@ -1213,12 +1180,13 @@ impl Service { .get_shortroomid(&room_id)? .expect("room exists"); - let insert_lock = services().globals.roomid_mutex_insert.lock(&room_id).await; + let insert_lock = self.mutex_insert.lock(&room_id).await; + let max = u64::MAX; let count = services().globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); + pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes()); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; @@ -1241,19 +1209,6 @@ impl Service { debug!("Prepended backfill pdu"); Ok(()) } - - pub fn get_lasttimelinecount_cache_usage(&self) -> (usize, usize) { - let cache = self.db.lasttimelinecount_cache.lock().expect("locked"); - (cache.len(), cache.capacity()) - } - - pub fn clear_lasttimelinecount_cache(&self) { - self.db - .lasttimelinecount_cache - .lock() - .expect("locked") - .clear(); - } } #[cfg(test)] diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index dd4e7fe2..715e3162 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -1,7 +1,6 @@ use std::{collections::BTreeMap, sync::Arc}; -use conduit::{debug_info, trace, utils, Result, Server}; -use database::Database; +use conduit::{debug_info, trace, utils, Result}; use ruma::{ api::federation::transactions::edu::{Edu, TypingContent}, events::SyncEphemeralRoomEvent, @@ -19,15 +18,19 @@ pub struct Service { pub typing_update_sender: broadcast::Sender, } -impl Service { - pub fn build(_server: &Arc, _db: &Arc) -> Result { - Ok(Self { +impl crate::Service for Service { + fn build(_args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { typing: RwLock::new(BTreeMap::new()), last_typing_update: RwLock::new(BTreeMap::new()), typing_update_sender: broadcast::channel(100).0, - }) + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 505e0662..12124a57 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -2,22 +2,25 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; pub struct Service { db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.reset_notification_counts(user_id, room_id) } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 22b69818..65725618 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -27,6 +27,7 @@ impl Data { } } + #[inline] pub fn active_requests(&self) -> OutgoingSendingIter<'_> { Box::new( self.servercurrentevent_data @@ -35,6 +36,7 @@ impl Data { ) } + #[inline] pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> { let prefix = destination.get_prefix(); Box::new( @@ -134,7 +136,7 @@ impl Data { } } -#[tracing::instrument(skip(key))] +#[tracing::instrument(skip(key), level = "debug")] fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(Destination, SendingEvent)> { // Appservices start with a plus Ok::<_, Error>(if key.starts_with(b"+") { diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 9bda6cad..88b8b189 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,31 +1,29 @@ mod appservice; mod data; -pub mod resolve; +mod resolve; mod send; mod sender; -use std::{fmt::Debug, sync::Arc}; +use std::fmt::Debug; -use conduit::{Error, Result, Server}; -use data::Data; -use database::Database; -pub use resolve::FedDest; +use conduit::{err, Result}; +pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use tokio::{sync::Mutex, task::JoinHandle}; -use tracing::{error, warn}; +pub use sender::convert_to_outgoing_federation_event; +use tokio::sync::Mutex; +use tracing::warn; use crate::{server_is_ours, services}; pub struct Service { - pub db: Data, + pub db: data::Data, /// The state for a given state hash. sender: loole::Sender, receiver: Mutex>, - handler_join: Mutex>>, startup_netburst: bool, startup_netburst_keep: i64, } @@ -53,35 +51,7 @@ pub enum SendingEvent { } impl Service { - pub fn build(server: &Arc, db: &Arc) -> Result> { - let config = &server.config; - let (sender, receiver) = loole::unbounded(); - Ok(Arc::new(Self { - db: Data::new(db.clone()), - sender, - receiver: Mutex::new(receiver), - handler_join: Mutex::new(None), - startup_netburst: config.startup_netburst, - startup_netburst_keep: config.startup_netburst_keep, - })) - } - - pub async fn close(&self) { - self.interrupt(); - if let Some(handler_join) = self.handler_join.lock().await.take() { - if let Err(e) = handler_join.await { - error!("Failed to shutdown: {e:?}"); - } - } - } - - pub fn interrupt(&self) { - if !self.sender.is_closed() { - self.sender.close(); - } - } - - #[tracing::instrument(skip(self, pdu_id, user, pushkey))] + #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { let dest = Destination::Push(user.to_owned(), pushkey); let event = SendingEvent::Pdu(pdu_id.to_owned()); @@ -94,7 +64,7 @@ impl Service { }) } - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); @@ -107,7 +77,7 @@ impl Service { }) } - #[tracing::instrument(skip(self, room_id, pdu_id))] + #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { let servers = services() .rooms @@ -119,7 +89,7 @@ impl Service { self.send_pdu_servers(servers, pdu_id) } - #[tracing::instrument(skip(self, servers, pdu_id))] + #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] pub fn send_pdu_servers>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { let requests = servers .into_iter() @@ -143,7 +113,7 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, server, serialized))] + #[tracing::instrument(skip(self, server, serialized), level = "debug")] pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { let dest = Destination::Normal(server.to_owned()); let event = SendingEvent::Edu(serialized); @@ -156,7 +126,7 @@ impl Service { }) } - #[tracing::instrument(skip(self, room_id, serialized))] + #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { let servers = services() .rooms @@ -168,7 +138,7 @@ impl Service { self.send_edu_servers(servers, serialized) } - #[tracing::instrument(skip(self, servers, serialized))] + #[tracing::instrument(skip(self, servers, serialized), level = "debug")] pub fn send_edu_servers>(&self, servers: I, serialized: Vec) -> Result<()> { let requests = servers .into_iter() @@ -193,7 +163,7 @@ impl Service { Ok(()) } - #[tracing::instrument(skip(self, room_id))] + #[tracing::instrument(skip(self, room_id), level = "debug")] pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { let servers = services() .rooms @@ -205,7 +175,7 @@ impl Service { self.flush_servers(servers) } - #[tracing::instrument(skip(self, servers))] + #[tracing::instrument(skip(self, servers), level = "debug")] pub fn flush_servers>(&self, servers: I) -> Result<()> { let requests = servers.into_iter().map(Destination::Normal); for dest in requests { @@ -243,7 +213,7 @@ impl Service { /// Cleanup event data /// Used for instance after we remove an appservice registration - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), level = "debug")] pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { self.db .delete_all_requests_for(&Destination::Appservice(appservice_id))?; @@ -254,34 +224,52 @@ impl Service { fn dispatch(&self, msg: Msg) -> Result<()> { debug_assert!(!self.sender.is_full(), "channel full"); debug_assert!(!self.sender.is_closed(), "channel closed"); - self.sender.send(msg).map_err(|e| Error::Err(e.to_string())) + self.sender.send(msg).map_err(|e| err!("{e}")) } } impl Destination { - #[tracing::instrument(skip(self))] + #[must_use] pub fn get_prefix(&self) -> Vec { - let mut prefix = match self { - Self::Appservice(server) => { - let mut p = b"+".to_vec(); + match self { + Self::Normal(server) => { + let len = server.as_bytes().len().saturating_add(1); + + let mut p = Vec::with_capacity(len); p.extend_from_slice(server.as_bytes()); + p.push(0xFF); + p + }, + Self::Appservice(server) => { + let sigil = b"+"; + let len = sigil + .len() + .saturating_add(server.as_bytes().len()) + .saturating_add(1); + + let mut p = Vec::with_capacity(len); + p.extend_from_slice(sigil); + p.extend_from_slice(server.as_bytes()); + p.push(0xFF); p }, Self::Push(user, pushkey) => { - let mut p = b"$".to_vec(); + let sigil = b"$"; + let len = sigil + .len() + .saturating_add(user.as_bytes().len()) + .saturating_add(1) + .saturating_add(pushkey.as_bytes().len()) + .saturating_add(1); + + let mut p = Vec::with_capacity(len); + p.extend_from_slice(sigil); p.extend_from_slice(user.as_bytes()); p.push(0xFF); p.extend_from_slice(pushkey.as_bytes()); + p.push(0xFF); p }, - Self::Normal(server) => { - let mut p = Vec::new(); - p.extend_from_slice(server.as_bytes()); - p - }, - }; - prefix.push(0xFF); - - prefix + } } } diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 294ac09f..e7a71af1 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -2,14 +2,15 @@ use std::{ fmt, fmt::Debug, net::{IpAddr, SocketAddr}, + time::SystemTime, }; +use conduit::{debug, debug_error, debug_info, debug_warn, trace, utils::rand, Err, Error, Result}; use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; -use ruma::ServerName; -use tracing::{debug, error, trace}; +use ruma::{OwnedServerName, ServerName}; -use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; +use crate::services; /// Wraps either an literal IP address plus port, or a hostname plus complement /// (colon-plus-port if it was specified). @@ -35,6 +36,7 @@ pub enum FedDest { Named(String, String), } +#[derive(Clone, Debug)] pub(crate) struct ActualDest { pub(crate) dest: FedDest, pub(crate) host: String, @@ -42,19 +44,33 @@ pub(crate) struct ActualDest { pub(crate) cached: bool, } +#[derive(Clone, Debug)] +pub struct CachedDest { + pub dest: FedDest, + pub host: String, + pub expire: SystemTime, +} + +#[derive(Clone, Debug)] +pub struct CachedOverride { + pub ips: Vec, + pub port: u16, + pub expire: SystemTime, +} + #[tracing::instrument(skip_all, name = "resolve")] pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result { let cached; let cached_result = services() .globals .resolver - .destinations - .read() - .expect("locked for reading") - .get(server_name) - .cloned(); + .get_cached_destination(server_name); - let (dest, host) = if let Some(result) = cached_result { + let CachedDest { + dest, + host, + .. + } = if let Some(result) = cached_result { cached = true; result } else { @@ -77,7 +93,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result Result<(FedDest, String)> { +pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result { trace!("Finding actual destination for {dest}"); let mut host = dest.as_str().to_owned(); let actual_dest = match get_ip_with_port(dest.as_str()) { @@ -109,7 +125,11 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD }; debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); - Ok((actual_dest, host.into_uri_string())) + Ok(CachedDest { + dest: actual_dest, + host: host.into_uri_string(), + expire: CachedDest::default_expire(), + }) } fn actual_dest_1(host_port: FedDest) -> Result { @@ -193,14 +213,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result { #[tracing::instrument(skip_all, name = "well-known")] async fn request_well_known(dest: &str) -> Result> { trace!("Requesting well known for {dest}"); - if !services() - .globals - .resolver - .overrides - .read() - .unwrap() - .contains_key(dest) - { + if !services().globals.resolver.has_cached_override(dest) { query_and_cache_override(dest, dest, 8448).await?; } @@ -261,22 +274,25 @@ async fn conditional_query_and_cache_override(overname: &str, hostname: &str, po async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { match services() .globals - .dns_resolver() + .resolver + .resolver .lookup_ip(hostname.to_owned()) .await { Err(e) => handle_resolve_error(&e), Ok(override_ip) => { if hostname != overname { - debug_info!("{:?} overriden by {:?}", overname, hostname); + debug_info!("{overname:?} overriden by {hostname:?}"); } - services() - .globals - .resolver - .overrides - .write() - .unwrap() - .insert(overname.to_owned(), (override_ip.iter().collect(), port)); + + services().globals.resolver.set_cached_override( + overname.to_owned(), + CachedOverride { + ips: override_ip.iter().collect(), + port, + expire: CachedOverride::default_expire(), + }, + ); Ok(()) }, @@ -299,7 +315,8 @@ async fn query_srv_record(hostname: &'_ str) -> Result> { let hostname = hostname.trim_end_matches('.'); services() .globals - .dns_resolver() + .resolver + .resolver .srv_lookup(hostname.to_owned()) .await } @@ -328,16 +345,13 @@ fn handle_resolve_error(e: &ResolveError) -> Result<()> { debug!("{e}"); Ok(()) }, - _ => { - error!("DNS {e}"); - Err(Error::Err(e.to_string())) - }, + _ => Err!(error!("DNS {e}")), } } fn validate_dest(dest: &ServerName) -> Result<()> { if dest == services().globals.server_name() { - return Err(Error::bad_config("Won't send federation request to ourselves")); + return Err!("Won't send federation request to ourselves"); } if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { @@ -390,6 +404,64 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest { FedDest::Named(host.to_owned(), port.to_owned()) } +impl crate::globals::resolver::Resolver { + pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option { + trace!(?name, ?dest, "set cached destination"); + self.destinations + .write() + .expect("locked for writing") + .insert(name, dest) + } + + pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option { + self.destinations + .read() + .expect("locked for reading") + .get(name) + .filter(|cached| cached.valid()) + .cloned() + } + + pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { + trace!(?name, ?over, "set cached override"); + self.overrides + .write() + .expect("locked for writing") + .insert(name, over) + } + + pub(crate) fn has_cached_override(&self, name: &str) -> bool { + self.overrides + .read() + .expect("locked for reading") + .get(name) + .filter(|cached| cached.valid()) + .is_some() + } +} + +impl CachedDest { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } +} + +impl CachedOverride { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } +} + impl FedDest { fn into_https_string(self) -> String { match self { @@ -412,6 +484,8 @@ impl FedDest { } } + #[inline] + #[allow(clippy::string_slice)] fn port(&self) -> Option { match &self { Self::Literal(addr) => Some(addr.port()), diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 57ff9127..18a98828 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,5 +1,6 @@ use std::{fmt::Debug, mem}; +use conduit::Err; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; use reqwest::{Client, Method, Request, Response, Url}; @@ -8,11 +9,16 @@ use ruma::{ client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, }, - OwnedServerName, ServerName, + serde::Base64, + server_util::authorization::XMatrix, + ServerName, }; use tracing::{debug, trace}; -use super::{resolve, resolve::ActualDest}; +use super::{ + resolve, + resolve::{ActualDest, CachedDest}, +}; use crate::{debug_error, debug_warn, services, Error, Result}; #[tracing::instrument(skip_all, name = "send")] @@ -21,7 +27,7 @@ where T: OutgoingRequest + Debug + Send, { if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); + return Err!(Config("allow_federation", "Federation is disabled.")); } let actual = resolve::get_actual_dest(dest).await?; @@ -103,13 +109,14 @@ where let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && !actual.cached { - services() - .globals - .resolver - .destinations - .write() - .expect("locked for writing") - .insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone())); + services().globals.resolver.set_cached_destination( + dest.to_owned(), + CachedDest { + dest: actual.dest.clone(), + host: actual.host.clone(), + expire: CachedDest::default_expire(), + }, + ); } match response { @@ -192,16 +199,20 @@ where for signature_server in signatures { for s in signature_server { + let key = + s.0.as_str() + .try_into() + .expect("valid homeserver signing key ID"); + let sig = Base64::parse(s.1).expect("valid base64"); + http_request.headers_mut().insert( AUTHORIZATION, - HeaderValue::from_str(&format!( - "X-Matrix origin=\"{}\",destination=\"{}\",key=\"{}\",sig=\"{}\"", - services().globals.config.server_name, - dest, - s.0, - s.1 - )) - .expect("formatted X-Matrix header"), + HeaderValue::from(&XMatrix::new( + services().globals.config.server_name.clone(), + dest.to_owned(), + key, + sig, + )), ); } } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index aa2865ec..2f542dfe 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -6,7 +6,9 @@ use std::{ time::{Duration, Instant}, }; +use async_trait::async_trait; use base64::{engine::general_purpose, Engine as _}; +use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn}; use federation::transactions::send_transaction_message; use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use ruma::{ @@ -18,12 +20,14 @@ use ruma::{ }, device_id, events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, - push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, RoomId, ServerName, UInt, + push, uint, CanonicalJsonObject, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, + ServerName, UInt, }; -use tracing::{debug, error, warn}; +use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::{sync::Mutex, time::sleep_until}; -use super::{appservice, send, Destination, Msg, SendingEvent, Service}; -use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, PduEvent, Result}; +use super::{appservice, data::Data, send, Destination, Msg, SendingEvent, Service}; +use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result}; #[derive(Debug)] enum TransactionStatus { @@ -40,43 +44,58 @@ type CurTransactionStatus = HashMap; const DEQUEUE_LIMIT: usize = 48; const SELECT_EDU_LIMIT: usize = 16; +const CLEANUP_TIMEOUT_MS: u64 = 3500; -impl Service { - pub async fn start_handler(self: &Arc) { - let self_ = Arc::clone(self); - let handle = services().server.runtime().spawn(async move { - self_ - .handler() - .await - .expect("Failed to start sending handler"); - }); - - _ = self.handler_join.lock().await.insert(handle); +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let (sender, receiver) = loole::unbounded(); + Ok(Arc::new(Self { + db: Data::new(args.db.clone()), + sender, + receiver: Mutex::new(receiver), + startup_netburst: config.startup_netburst, + startup_netburst_keep: config.startup_netburst_keep, + })) } #[tracing::instrument(skip_all, name = "sender")] - async fn handler(&self) -> Result<()> { + async fn worker(self: Arc) -> Result<()> { let receiver = self.receiver.lock().await; let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); - self.initial_transactions(&futures, &mut statuses); + self.initial_requests(&futures, &mut statuses); loop { debug_assert!(!receiver.is_closed(), "channel error"); tokio::select! { request = receiver.recv_async() => match request { Ok(request) => self.handle_request(request, &futures, &mut statuses), - Err(_) => return Ok(()), + Err(_) => break, }, Some(response) = futures.next() => { - self.handle_response(response, &mut futures, &mut statuses); + self.handle_response(response, &futures, &mut statuses); }, } } + self.finish_responses(&mut futures, &mut statuses).await; + + Ok(()) } + fn interrupt(&self) { + if !self.sender.is_closed() { + self.sender.close(); + } + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { fn handle_response( - &self, response: SendingResult, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, + &self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, ) { match response { Ok(dest) => self.handle_response_ok(&dest, futures, statuses), @@ -85,13 +104,13 @@ impl Service { } fn handle_response_err( - dest: Destination, _futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, + dest: Destination, _futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus, e: &Error, ) { debug!(dest = ?dest, "{e:?}"); statuses.entry(dest).and_modify(|e| { *e = match e { TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), - TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n + 1, Instant::now()), + TransactionStatus::Retrying(ref n) => TransactionStatus::Failed(n.saturating_add(1), Instant::now()), TransactionStatus::Failed(..) => panic!("Request that was not even running failed?!"), } }); @@ -136,7 +155,25 @@ impl Service { } } - fn initial_transactions(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) { + let now = Instant::now(); + let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS); + let deadline = now.checked_add(timeout).unwrap_or(now); + loop { + trace!("Waiting for {} requests to complete...", futures.len()); + tokio::select! { + () = sleep_until(deadline.into()) => break, + response = futures.next() => match response { + Some(response) => self.handle_response(response, futures, statuses), + None => return, + } + } + } + + debug_warn!("Leaving with {} unfinished requests...", futures.len()); + } + + fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) { let keep = usize::try_from(self.startup_netburst_keep).unwrap_or(usize::MAX); let mut txns = HashMap::>::new(); for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) { @@ -214,11 +251,9 @@ impl Service { .and_modify(|e| match e { TransactionStatus::Failed(tries, time) => { // Fail if a request has failed recently (exponential backoff) - let max_duration = Duration::from_secs(services().globals.config.sender_retry_backoff_limit); - let min_duration = Duration::from_secs(services().globals.config.sender_timeout); - let min_elapsed_duration = min_duration * (*tries) * (*tries); - let min_elapsed_duration = cmp::min(min_elapsed_duration, max_duration); - if time.elapsed() < min_elapsed_duration { + let min = services().globals.config.sender_timeout; + let max = services().globals.config.sender_retry_backoff_limit; + if continue_exponential_backoff_secs(min, max, time.elapsed(), *tries) { allow = false; } else { retry = true; @@ -322,8 +357,10 @@ fn select_edus_presence( } } - let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); - events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + if !presence_updates.is_empty() { + let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); + events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + } Ok(true) } @@ -548,24 +585,21 @@ async fn send_events_dest_normal( for event in &events { match event { - SendingEvent::Pdu(pdu_id) => { + SendingEvent::Pdu(pdu_id) => pdu_jsons.push(convert_to_outgoing_federation_event( // TODO: check room version and remove event_id if needed - let raw = PduEvent::convert_to_outgoing_federation_event( - services() - .rooms - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (dest.clone(), e))? - .ok_or_else(|| { - error!(?dest, ?server, ?pdu_id, "event not found"); - ( - dest.clone(), - Error::bad_database("[Normal] Event in servernameevent_data not found in db."), - ) - })?, - ); - pdu_jsons.push(raw); - }, + services() + .rooms + .timeline + .get_pdu_json_from_id(pdu_id) + .map_err(|e| (dest.clone(), e))? + .ok_or_else(|| { + error!(?dest, ?server, ?pdu_id, "event not found"); + ( + dest.clone(), + Error::bad_database("[Normal] Event in servernameevent_data not found in db."), + ) + })?, + )), SendingEvent::Edu(edu) => { if let Ok(raw) = serde_json::from_slice(edu) { edu_jsons.push(raw); @@ -611,3 +645,39 @@ async fn send_events_dest_normal( }) .map_err(|e| (dest.clone(), e)) } + +/// This does not return a full `Pdu` it is only to satisfy ruma's types. +#[tracing::instrument] +pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { + if let Some(unsigned) = pdu_json + .get_mut("unsigned") + .and_then(|val| val.as_object_mut()) + { + unsigned.remove("transaction_id"); + } + + // room v3 and above removed the "event_id" field from remote PDU format + if let Some(room_id) = pdu_json + .get("room_id") + .and_then(|val| RoomId::parse(val.as_str()?).ok()) + { + match services().rooms.state.get_room_version(&room_id) { + Ok(room_version_id) => match room_version_id { + RoomVersionId::V1 | RoomVersionId::V2 => {}, + _ => _ = pdu_json.remove("event_id"), + }, + Err(_) => _ = pdu_json.remove("event_id"), + } + } else { + pdu_json.remove("event_id"); + } + + // TODO: another option would be to convert it to a canonical string to validate + // size and return a Result> + // serde_json::from_str::>( + // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is + // valid serde_json::Value"), ) + // .expect("Raw::from_value always works") + + to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") +} diff --git a/src/service/service.rs b/src/service/service.rs new file mode 100644 index 00000000..3b8f4231 --- /dev/null +++ b/src/service/service.rs @@ -0,0 +1,46 @@ +use std::{collections::BTreeMap, fmt::Write, sync::Arc}; + +use async_trait::async_trait; +use conduit::{utils::string::split_once_infallible, Result, Server}; +use database::Database; + +#[async_trait] +pub(crate) trait Service: Send + Sync { + /// Implement the construction of the service instance. Services are + /// generally singletons so expect this to only be called once for a + /// service type. Note that it may be called again after a server reload, + /// but the prior instance will have been dropped first. Failure will + /// shutdown the server with an error. + fn build(args: Args<'_>) -> Result> + where + Self: Sized; + + /// Implement the service's worker loop. The service manager spawns a + /// task and calls this function after all services have been built. + async fn worker(self: Arc) -> Result<()> { Ok(()) } + + /// Interrupt the service. This is sent to initiate a graceful shutdown. + /// The service worker should return from its work loop. + fn interrupt(&self) {} + + /// Clear any caches or similar runtime state. + fn clear_cache(&self) {} + + /// Memory usage report in a markdown string. + fn memory_usage(&self, _out: &mut dyn Write) -> Result<()> { Ok(()) } + + /// Return the name of the service. + /// i.e. `crate::service::make_name(std::module_path!())` + fn name(&self) -> &str; +} + +pub(crate) struct Args<'a> { + pub(crate) server: &'a Arc, + pub(crate) db: &'a Arc, + pub(crate) _service: &'a Map, +} + +pub(crate) type Map = BTreeMap>; + +#[inline] +pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 } diff --git a/src/service/services.rs b/src/service/services.rs index 70e6a6d5..cc9ec290 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,104 +1,153 @@ -use std::sync::Arc; +use std::{collections::BTreeMap, fmt::Write, sync::Arc}; -use conduit::{debug_info, Result, Server}; +use conduit::{debug, debug_info, info, trace, Result, Server}; use database::Database; -use tracing::{debug, info, trace}; +use tokio::sync::Mutex; use crate::{ - account_data, admin, appservice, globals, key_backups, media, presence, pusher, rooms, sending, transaction_ids, - uiaa, users, + account_data, admin, appservice, globals, key_backups, + manager::Manager, + media, presence, pusher, rooms, sending, + service::{Args, Map, Service}, + transaction_ids, uiaa, updates, users, }; pub struct Services { pub rooms: rooms::Service, - pub appservice: appservice::Service, - pub pusher: pusher::Service, - pub transaction_ids: transaction_ids::Service, - pub uiaa: uiaa::Service, - pub users: users::Service, - pub account_data: account_data::Service, + pub appservice: Arc, + pub pusher: Arc, + pub transaction_ids: Arc, + pub uiaa: Arc, + pub users: Arc, + pub account_data: Arc, pub presence: Arc, pub admin: Arc, - pub globals: globals::Service, - pub key_backups: key_backups::Service, - pub media: media::Service, + pub key_backups: Arc, + pub media: Arc, pub sending: Arc, + pub updates: Arc, + pub globals: Arc, + + manager: Mutex>>, + pub(crate) service: Map, pub server: Arc, pub db: Arc, } impl Services { - pub async fn build(server: Arc, db: Arc) -> Result { + pub fn build(server: Arc, db: Arc) -> Result { + let mut service: Map = BTreeMap::new(); + macro_rules! build { + ($tyname:ty) => {{ + let built = <$tyname>::build(Args { + server: &server, + db: &db, + _service: &service, + })?; + service.insert(built.name().to_owned(), built.clone()); + built + }}; + } + Ok(Self { rooms: rooms::Service { - alias: rooms::alias::Service::build(&server, &db)?, - auth_chain: rooms::auth_chain::Service::build(&server, &db)?, - directory: rooms::directory::Service::build(&server, &db)?, - event_handler: rooms::event_handler::Service::build(&server, &db)?, - lazy_loading: rooms::lazy_loading::Service::build(&server, &db)?, - metadata: rooms::metadata::Service::build(&server, &db)?, - outlier: rooms::outlier::Service::build(&server, &db)?, - pdu_metadata: rooms::pdu_metadata::Service::build(&server, &db)?, - read_receipt: rooms::read_receipt::Service::build(&server, &db)?, - search: rooms::search::Service::build(&server, &db)?, - short: rooms::short::Service::build(&server, &db)?, - state: rooms::state::Service::build(&server, &db)?, - state_accessor: rooms::state_accessor::Service::build(&server, &db)?, - state_cache: rooms::state_cache::Service::build(&server, &db)?, - state_compressor: rooms::state_compressor::Service::build(&server, &db)?, - timeline: rooms::timeline::Service::build(&server, &db)?, - threads: rooms::threads::Service::build(&server, &db)?, - typing: rooms::typing::Service::build(&server, &db)?, - spaces: rooms::spaces::Service::build(&server, &db)?, - user: rooms::user::Service::build(&server, &db)?, + alias: build!(rooms::alias::Service), + auth_chain: build!(rooms::auth_chain::Service), + directory: build!(rooms::directory::Service), + event_handler: build!(rooms::event_handler::Service), + lazy_loading: build!(rooms::lazy_loading::Service), + metadata: build!(rooms::metadata::Service), + outlier: build!(rooms::outlier::Service), + pdu_metadata: build!(rooms::pdu_metadata::Service), + read_receipt: build!(rooms::read_receipt::Service), + search: build!(rooms::search::Service), + short: build!(rooms::short::Service), + state: build!(rooms::state::Service), + state_accessor: build!(rooms::state_accessor::Service), + state_cache: build!(rooms::state_cache::Service), + state_compressor: build!(rooms::state_compressor::Service), + timeline: build!(rooms::timeline::Service), + threads: build!(rooms::threads::Service), + typing: build!(rooms::typing::Service), + spaces: build!(rooms::spaces::Service), + user: build!(rooms::user::Service), }, - appservice: appservice::Service::build(&server, &db)?, - pusher: pusher::Service::build(&server, &db)?, - transaction_ids: transaction_ids::Service::build(&server, &db)?, - uiaa: uiaa::Service::build(&server, &db)?, - users: users::Service::build(&server, &db)?, - account_data: account_data::Service::build(&server, &db)?, - presence: presence::Service::build(&server, &db)?, - admin: admin::Service::build(&server, &db)?, - key_backups: key_backups::Service::build(&server, &db)?, - media: media::Service::build(&server, &db)?, - sending: sending::Service::build(&server, &db)?, - globals: globals::Service::build(&server, &db)?, + appservice: build!(appservice::Service), + pusher: build!(pusher::Service), + transaction_ids: build!(transaction_ids::Service), + uiaa: build!(uiaa::Service), + users: build!(users::Service), + account_data: build!(account_data::Service), + presence: build!(presence::Service), + admin: build!(admin::Service), + key_backups: build!(key_backups::Service), + media: build!(media::Service), + sending: build!(sending::Service), + updates: build!(updates::Service), + globals: build!(globals::Service), + manager: Mutex::new(None), + service, server, db, }) } - pub async fn memory_usage(&self) -> String { - let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); - let server_visibility_cache = self - .rooms - .state_accessor - .server_visibility_cache - .lock() - .unwrap() - .len(); - let user_visibility_cache = self - .rooms - .state_accessor - .user_visibility_cache - .lock() - .unwrap() - .len(); - let stateinfo_cache = self - .rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .len(); - let lasttimelinecount_cache = self - .rooms - .timeline - .lasttimelinecount_cache + pub(super) async fn start(&self) -> Result<()> { + debug_info!("Starting services..."); + + globals::migrations::migrations(&self.db, &self.server.config).await?; + self.manager .lock() .await - .len(); + .insert(Manager::new(self)) + .clone() + .start() + .await?; + + debug_info!("Services startup complete."); + Ok(()) + } + + pub(super) async fn stop(&self) { + info!("Shutting down services..."); + + self.interrupt(); + if let Some(manager) = self.manager.lock().await.as_ref() { + manager.stop().await; + } + + debug_info!("Services shutdown complete."); + } + + pub async fn poll(&self) -> Result<()> { + if let Some(manager) = self.manager.lock().await.as_ref() { + return manager.poll().await; + } + + Ok(()) + } + + pub async fn clear_cache(&self) { + for service in self.service.values() { + service.clear_cache(); + } + + //TODO + self.rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .clear(); + } + + pub async fn memory_usage(&self) -> Result { + let mut out = String::new(); + for service in self.service.values() { + service.memory_usage(&mut out)?; + } + + //TODO let roomid_spacehierarchy_cache = self .rooms .spaces @@ -106,176 +155,17 @@ impl Services { .lock() .await .len(); - let resolver_overrides_cache = self - .globals - .resolver - .overrides - .read() - .expect("locked for reading") - .len(); - let resolver_destinations_cache = self - .globals - .resolver - .destinations - .read() - .expect("locked for reading") - .len(); - let bad_event_ratelimiter = self.globals.bad_event_ratelimiter.read().await.len(); - let bad_query_ratelimiter = self.globals.bad_query_ratelimiter.read().await.len(); - let bad_signature_ratelimiter = self.globals.bad_signature_ratelimiter.read().await.len(); + writeln!(out, "roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}")?; - format!( - "\ -lazy_load_waiting: {lazy_load_waiting} -server_visibility_cache: {server_visibility_cache} -user_visibility_cache: {user_visibility_cache} -stateinfo_cache: {stateinfo_cache} -lasttimelinecount_cache: {lasttimelinecount_cache} -roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache} -resolver_overrides_cache: {resolver_overrides_cache} -resolver_destinations_cache: {resolver_destinations_cache} -bad_event_ratelimiter: {bad_event_ratelimiter} -bad_query_ratelimiter: {bad_query_ratelimiter} -bad_signature_ratelimiter: {bad_signature_ratelimiter} -" - ) + Ok(out) } - pub async fn clear_caches(&self, amount: u32) { - if amount > 0 { - self.rooms - .lazy_loading - .lazy_load_waiting - .lock() - .await - .clear(); + fn interrupt(&self) { + debug!("Interrupting services..."); + + for (name, service) in &self.service { + trace!("Interrupting {name}"); + service.interrupt(); } - if amount > 1 { - self.rooms - .state_accessor - .server_visibility_cache - .lock() - .unwrap() - .clear(); - } - if amount > 2 { - self.rooms - .state_accessor - .user_visibility_cache - .lock() - .unwrap() - .clear(); - } - if amount > 3 { - self.rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .clear(); - } - if amount > 4 { - self.rooms - .timeline - .lasttimelinecount_cache - .lock() - .await - .clear(); - } - if amount > 5 { - self.rooms - .spaces - .roomid_spacehierarchy_cache - .lock() - .await - .clear(); - } - if amount > 6 { - self.globals - .resolver - .overrides - .write() - .expect("locked for writing") - .clear(); - self.globals - .resolver - .destinations - .write() - .expect("locked for writing") - .clear(); - } - if amount > 7 { - self.globals.resolver.resolver.clear_cache(); - } - if amount > 8 { - self.globals.bad_event_ratelimiter.write().await.clear(); - } - if amount > 9 { - self.globals.bad_query_ratelimiter.write().await.clear(); - } - if amount > 10 { - self.globals.bad_signature_ratelimiter.write().await.clear(); - } - } - - pub async fn start(&self) -> Result<()> { - debug_info!("Starting services"); - - self.media.create_media_dir().await?; - globals::migrations::migrations(&self.db, &self.globals.config).await?; - globals::emerg_access::init_emergency_access(); - - self.admin.start_handler().await; - self.sending.start_handler().await; - if self.globals.config.allow_local_presence { - self.presence.start_handler().await; - } - - if self.globals.allow_check_for_updates() { - let handle = globals::updates::start_check_for_updates_task(); - - #[allow(clippy::let_underscore_must_use)] // needed for shutdown - { - _ = self.globals.updates_handle.lock().await.insert(handle); - } - } - - debug_info!("Services startup complete."); - Ok(()) - } - - pub async fn interrupt(&self) { - trace!("Interrupting services..."); - self.sending.interrupt(); - self.presence.interrupt(); - self.admin.interrupt(); - - trace!("Services interrupt complete."); - } - - pub async fn stop(&self) { - info!("Shutting down services"); - self.interrupt().await; - - debug!("Waiting for update worker..."); - if let Some(updates_handle) = self.globals.updates_handle.lock().await.take() { - updates_handle.abort(); - - #[allow(clippy::let_underscore_must_use)] - { - _ = updates_handle.await; - } - } - - debug!("Waiting for admin worker..."); - self.admin.close().await; - - debug!("Waiting for presence worker..."); - self.presence.close().await; - - debug!("Waiting for sender..."); - self.sending.close().await; - - debug_info!("Services shutdown complete."); } } diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index f18bbe64..78e6337f 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -2,22 +2,25 @@ mod data; use std::sync::Arc; -use conduit::{Result, Server}; +use conduit::Result; use data::Data; -use database::Database; use ruma::{DeviceId, TransactionId, UserId}; pub struct Service { pub db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { pub fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()> { diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 0d2f8bf7..4b953ffb 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -2,9 +2,8 @@ mod data; use std::sync::Arc; -use conduit::{utils, utils::hash, Error, Result, Server}; +use conduit::{utils, utils::hash, Error, Result}; use data::Data; -use database::Database; use ruma::{ api::client::{ error::ErrorKind, @@ -22,13 +21,17 @@ pub struct Service { pub db: Data, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db), - }) +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Creates a new Uiaa session. Make sure the session token is unique. pub fn create( &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, diff --git a/src/service/updates/mod.rs b/src/service/updates/mod.rs new file mode 100644 index 00000000..3fb680d6 --- /dev/null +++ b/src/service/updates/mod.rs @@ -0,0 +1,112 @@ +use std::{sync::Arc, time::Duration}; + +use async_trait::async_trait; +use conduit::{err, info, utils, warn, Error, Result}; +use database::Map; +use ruma::events::room::message::RoomMessageEventContent; +use serde::Deserialize; +use tokio::{sync::Notify, time::interval}; + +use crate::services; + +pub struct Service { + db: Arc, + interrupt: Notify, + interval: Duration, +} + +#[derive(Deserialize)] +struct CheckForUpdatesResponse { + updates: Vec, +} + +#[derive(Deserialize)] +struct CheckForUpdatesResponseEntry { + id: u64, + date: String, + message: String, +} + +const CHECK_FOR_UPDATES_URL: &str = "https://pupbrain.dev/check-for-updates/stable"; +const CHECK_FOR_UPDATES_INTERVAL: u64 = 7200; // 2 hours +const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; + +#[async_trait] +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: args.db["global"].clone(), + interrupt: Notify::new(), + interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL), + })) + } + + async fn worker(self: Arc) -> Result<()> { + let mut i = interval(self.interval); + loop { + tokio::select! { + () = self.interrupt.notified() => return Ok(()), + _ = i.tick() => (), + } + + if let Err(e) = self.handle_updates().await { + warn!(%e, "Failed to check for updates"); + } + } + } + + fn interrupt(&self) { self.interrupt.notify_waiters(); } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + #[tracing::instrument(skip_all)] + async fn handle_updates(&self) -> Result<()> { + let response = services() + .globals + .client + .default + .get(CHECK_FOR_UPDATES_URL) + .send() + .await?; + + let response = serde_json::from_str::(&response.text().await?) + .map_err(|e| err!("Bad check for updates response: {e}"))?; + + let mut last_update_id = self.last_check_for_updates_id()?; + for update in response.updates { + last_update_id = last_update_id.max(update.id); + if update.id > self.last_check_for_updates_id()? { + info!("{:#}", update.message); + services() + .admin + .send_message(RoomMessageEventContent::text_markdown(format!( + "### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}", + update.date, update.message + ))) + .await; + } + } + self.update_check_for_updates_id(last_update_id)?; + + Ok(()) + } + + #[inline] + pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + self.db + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; + + Ok(()) + } + + pub fn last_check_for_updates_id(&self) -> Result { + self.db + .get(LAST_CHECK_FOR_UPDATES_COUNT)? + .map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) + }) + } +} diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 14aa8006..5546adb1 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, mem::size_of, sync::Arc}; -use conduit::{debug_info, utils, warn, Error, Result}; +use conduit::{debug_info, err, utils, warn, Err, Error, Result}; use database::{Database, Map}; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, @@ -61,6 +61,7 @@ impl Data { } /// Check if a user has an account on this homeserver. + #[inline] pub(super) fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } @@ -75,6 +76,7 @@ impl Data { } /// Returns the number of users registered on this server. + #[inline] pub(super) fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. @@ -85,19 +87,19 @@ impl Data { let mut parts = bytes.split(|&b| b == 0xFF); let user_bytes = parts .next() - .ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; + .ok_or_else(|| err!(Database("User ID in token_userdeviceid is invalid.")))?; let device_bytes = parts .next() - .ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; + .ok_or_else(|| err!(Database("Device ID in token_userdeviceid is invalid.")))?; Ok(Some(( UserId::parse( utils::string_from_bytes(user_bytes) - .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?, + .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid unicode. {e}")))?, ) - .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?, + .map_err(|e| err!(Database("User ID in token_userdeviceid is invalid. {e}")))?, utils::string_from_bytes(device_bytes) - .map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?, + .map_err(|e| err!(Database("Device ID in token_userdeviceid is invalid. {e}")))?, ))) }) } @@ -107,9 +109,9 @@ impl Data { Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse( utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?, + .map_err(|e| err!(Database("User ID in userid_password is invalid unicode. {e}")))?, ) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) + .map_err(|e| err!(Database("User ID in userid_password is invalid. {e}"))) })) } @@ -163,7 +165,7 @@ impl Data { .map_or(Ok(None), |bytes| { Ok(Some( utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Displayname in db is invalid."))?, + .map_err(|e| err!(Database("Displayname in db is invalid. {e}")))?, )) }) } @@ -186,10 +188,8 @@ impl Data { self.userid_avatarurl .get(user_id.as_bytes())? .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| { - warn!("Avatar URL in db is invalid: {}", e); - Error::bad_database("Avatar URL in db is invalid.") - })?; + let s_bytes = utils::string_from_bytes(&bytes) + .map_err(|e| err!(Database(warn!("Avatar URL in db is invalid: {e}"))))?; let mxc_uri: OwnedMxcUri = s_bytes.into(); Ok(mxc_uri) }) @@ -213,10 +213,7 @@ impl Data { self.userid_blurhash .get(user_id.as_bytes())? .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - - Ok(s) + utils::string_from_bytes(&bytes).map_err(|e| err!(Database("Avatar URL in db is invalid. {e}"))) }) .transpose() } @@ -312,9 +309,9 @@ impl Data { bytes .rsplit(|&b| b == 0xFF) .next() - .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, + .ok_or_else(|| err!(Database("UserDevice ID in db is invalid.")))?, ) - .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? + .map_err(|e| err!(Database("Device ID in userdeviceid_metadata is invalid. {e}")))? .into()) }), ) @@ -328,13 +325,9 @@ impl Data { // should not be None, but we shouldn't assert either lol... if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!( - "Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); + return Err!(Database(error!( + "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." + ))); } // Remove old token @@ -364,14 +357,9 @@ impl Data { // Only existing devices should be able to call this, but we shouldn't assert // either... if self.userdeviceid_metadata.get(&key)?.is_none() { - warn!( - "Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \ - database", - user_id, device_id - ); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); + return Err!(Database(error!( + "User {user_id:?} does not exist or device ID {device_id:?} has no metadata." + ))); } key.push(0xFF); @@ -399,7 +387,7 @@ impl Data { .get(user_id.as_bytes())? .map_or(Ok(0), |bytes| { utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) + .map_err(|e| err!(Database("Count in roomid_lastroomactiveupdate is invalid. {e}"))) }) } @@ -427,11 +415,10 @@ impl Data { serde_json::from_slice( key.rsplit(|&b| b == 0xFF) .next() - .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, + .ok_or_else(|| err!(Database("OneTimeKeyId in db is invalid.")))?, ) - .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, + .map_err(|e| err!(Database("OneTimeKeyId in db is invalid. {e}")))?, + serde_json::from_slice(&value).map_err(|e| err!(Database("OneTimeKeys in db are invalid. {e}")))?, )) }) .transpose() @@ -455,13 +442,14 @@ impl Data { bytes .rsplit(|&b| b == 0xFF) .next() - .ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?, + .ok_or_else(|| err!(Database("OneTimeKey ID in db is invalid.")))?, ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .map_err(|e| err!(Database("DeviceKeyId in db is invalid. {e}")))? .algorithm(), ) }) { - *counts.entry(algorithm?).or_default() += uint!(1); + let count: &mut UInt = counts.entry(algorithm?).or_default(); + *count = count.saturating_add(uint!(1)); } Ok(counts) @@ -578,19 +566,19 @@ impl Data { .get(&key)? .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, ) - .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; + .map_err(|e| err!(Database("key in keyid_key is invalid. {e}")))?; let signatures = cross_signing_key .get_mut("signatures") - .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? + .ok_or_else(|| err!(Database("key in keyid_key has no signatures field.")))? .as_object_mut() - .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? + .ok_or_else(|| err!(Database("key in keyid_key has invalid signatures field.")))? .entry(sender_id.to_string()) .or_insert_with(|| serde_json::Map::new().into()); signatures .as_object_mut() - .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? + .ok_or_else(|| err!(Database("signatures in keyid_key for a user is invalid.")))? .insert(signature.0, signature.1.into()); self.keyid_key.insert( @@ -637,7 +625,7 @@ impl Data { Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") })?, ) - .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) + .map_err(|e| err!(Database("User ID in devicekeychangeid_userid is invalid. {e}"))) }), ) } @@ -682,7 +670,7 @@ impl Data { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { Ok(Some( - serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?, + serde_json::from_slice(&bytes).map_err(|e| err!(Database("DeviceKeys in db are invalid. {e}")))?, )) }) } @@ -716,7 +704,7 @@ impl Data { ) -> Result>> { self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?; clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; Ok(Some(Raw::from_json( @@ -748,7 +736,7 @@ impl Data { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { Ok(Some( serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?, + .map_err(|e| err!(Database("CrossSigningKey in db is invalid. {e}")))?, )) }) }) @@ -789,7 +777,7 @@ impl Data { for (_, value) in self.todeviceid_events.scan_prefix(prefix) { events.push( serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, + .map_err(|e| err!(Database("Event in todeviceid_events is invalid. {e}")))?, ); } @@ -812,8 +800,8 @@ impl Data { .map(|(key, _)| { Ok::<_, Error>(( key.clone(), - utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) - .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, + utils::u64_from_bytes(&key[key.len().saturating_sub(size_of::())..key.len()]) + .map_err(|e| err!(Database("ToDeviceId has invalid count bytes. {e}")))?, )) }) .filter_map(Result::ok) @@ -874,7 +862,7 @@ impl Data { .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) + .map_err(|e| err!(Database("Invalid devicelistversion in db. {e}"))) .map(Some) }) } @@ -890,7 +878,7 @@ impl Data { .scan_prefix(key) .map(|(_, bytes)| { serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) + .map_err(|e| err!(Database("Device in userdeviceid_metadata is invalid. {e}"))) }), ) } @@ -917,7 +905,7 @@ impl Data { let raw = self.userfilterid_filter.get(&key)?; if let Some(raw) = raw { - serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db.")) + serde_json::from_slice(&raw).map_err(|e| err!(Database("Invalid filter event in db. {e}"))) } else { Ok(None) } @@ -926,10 +914,12 @@ impl Data { /// Creates an OpenID token, which can be used to prove that a user has /// access to an account (primarily for integrations) pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result { - let expires_in = services().globals.config.openid_token_ttl; - let expires_at = utils::millis_since_unix_epoch().saturating_add(expires_in * 1000); + use std::num::Saturating as Sat; - let mut value = expires_at.to_be_bytes().to_vec(); + let expires_in = services().globals.config.openid_token_ttl; + let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000); + + let mut value = expires_at.0.to_be_bytes().to_vec(); value.extend_from_slice(user_id.as_bytes()); self.openidtoken_expiresatuserid @@ -949,7 +939,7 @@ impl Data { let expires_at = u64::from_be_bytes( expires_at_bytes .try_into() - .map_err(|_| Error::bad_database("expires_at in openid_userid is invalid u64."))?, + .map_err(|e| err!(Database("expires_at in openid_userid is invalid u64. {e}")))?, ); if expires_at < utils::millis_since_unix_epoch() { @@ -961,9 +951,9 @@ impl Data { UserId::parse( utils::string_from_bytes(user_bytes) - .map_err(|_| Error::bad_database("User ID in openid_userid is invalid unicode."))?, + .map_err(|e| err!(Database("User ID in openid_userid is invalid unicode. {e}")))?, ) - .map_err(|_| Error::bad_database("User ID in openid_userid is invalid.")) + .map_err(|e| err!(Database("User ID in openid_userid is invalid. {e}"))) } } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 8444d538..0196e1aa 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -6,9 +6,8 @@ use std::{ sync::{Arc, Mutex, Mutex as StdMutex}, }; -use conduit::{Error, Result, Server}; +use conduit::{Error, Result}; use data::Data; -use database::Database; use ruma::{ api::client::{ device::Device, @@ -41,15 +40,20 @@ pub struct Service { pub connections: DbConnections, } -impl Service { - pub fn build(_server: &Arc, db: &Arc) -> Result { - Ok(Self { - db: Data::new(db.clone()), +impl crate::Service for Service { + fn build(args: crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db.clone()), connections: StdMutex::new(BTreeMap::new()), - }) + })) } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { /// Check if a user has an account on this homeserver. + #[inline] pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { @@ -254,12 +258,14 @@ impl Service { } /// Create a new user account on this homeserver. + #[inline] pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.db.set_password(user_id, password)?; Ok(()) } /// Returns the number of users registered on this server. + #[inline] pub fn count(&self) -> Result { self.db.count() } /// Find out which user an access token belongs to. @@ -280,6 +286,7 @@ impl Service { pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } /// Hash and set the user's password to the Argon2 hash + #[inline] pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { self.db.set_password(user_id, password) } @@ -328,6 +335,7 @@ impl Service { } /// Replaces the access token of one device. + #[inline] pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { self.db.set_token(user_id, device_id, token) } @@ -382,18 +390,21 @@ impl Service { self.db.keys_changed(user_or_room_id, from, to) } + #[inline] pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { self.db.get_device_keys(user_id, device_id) } + #[inline] pub fn parse_master_key( &self, user_id: &UserId, master_key: &Raw, ) -> Result<(Vec, CrossSigningKey)> { Data::parse_master_key(user_id, master_key) } + #[inline] pub fn get_key( &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { diff --git a/tests/test_results/complement/test_results.jsonl b/tests/test_results/complement/test_results.jsonl index 2f3db95d..873397ba 100644 --- a/tests/test_results/complement/test_results.jsonl +++ b/tests/test_results/complement/test_results.jsonl @@ -138,7 +138,8 @@ {"Action":"fail","Test":"TestKnockingInMSC3787Room/Knocking_on_a_room_with_join_rule_'knock'_should_succeed#01"} {"Action":"fail","Test":"TestKnockingInMSC3787Room/Users_in_the_room_see_a_user's_membership_update_when_they_knock"} {"Action":"fail","Test":"TestKnockingInMSC3787Room/Users_in_the_room_see_a_user's_membership_update_when_they_knock#01"} -{"Action":"pass","Test":"TestLocalPngThumbnail"} +{"Action":"fail","Test":"TestLocalPngThumbnail"} +{"Action":"fail","Test":"TestLocalPngThumbnail/test_/_matrix/client/v1/media_endpoint"} {"Action":"fail","Test":"TestMediaFilenames"} {"Action":"fail","Test":"TestMediaFilenames/Parallel"} {"Action":"fail","Test":"TestMediaFilenames/Parallel/ASCII"} @@ -181,7 +182,8 @@ {"Action":"pass","Test":"TestOutboundFederationProfile/Outbound_federation_can_query_profile_data"} {"Action":"pass","Test":"TestOutboundFederationSend"} {"Action":"pass","Test":"TestRemoteAliasRequestsUnderstandUnicode"} -{"Action":"pass","Test":"TestRemotePngThumbnail"} +{"Action":"fail","Test":"TestRemotePngThumbnail"} +{"Action":"fail","Test":"TestRemotePngThumbnail/test_/_matrix/client/v1/media_endpoint"} {"Action":"fail","Test":"TestRemotePresence"} {"Action":"fail","Test":"TestRemotePresence/Presence_changes_are_also_reported_to_remote_room_members"} {"Action":"fail","Test":"TestRemotePresence/Presence_changes_to_UNAVAILABLE_are_reported_to_remote_room_members"}